diff --git a/clientapi/routing/admin_whois.go b/clientapi/routing/admin_whois.go index b448791c3..d7f505fd3 100644 --- a/clientapi/routing/admin_whois.go +++ b/clientapi/routing/admin_whois.go @@ -47,8 +47,7 @@ func GetAdminWhois( req *http.Request, userAPI api.UserInternalAPI, device *api.Device, userID string, ) util.JSONResponse { - if userID != device.UserID { - // TODO: Still allow if user is admin + if device.AccountType != api.AccountTypeAdmin && userID != device.UserID { return util.JSONResponse{ Code: http.StatusForbidden, JSON: jsonerror.Forbidden("userID does not match the current user"), diff --git a/userapi/api/api.go b/userapi/api/api.go index ddf4ddf64..2be662e55 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -354,6 +354,7 @@ type Device struct { // If the device is for an appservice user, // this is the appservice ID. AppserviceID string + AccountType AccountType } // Account represents a Matrix account on this home server. @@ -362,7 +363,7 @@ type Account struct { Localpart string ServerName gomatrixserverlib.ServerName AppServiceID string - // TODO: Other flags like IsAdmin, IsGuest + AccountType AccountType // TODO: Associations (e.g. with application services) } diff --git a/userapi/internal/api.go b/userapi/internal/api.go index c83491135..46e5ff200 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -373,6 +373,15 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc } return err } + localPart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + if err != nil { + return err + } + acc, err := a.AccountDB.GetAccountByLocalpart(ctx, localPart) + if err != nil { + return err + } + device.AccountType = acc.AccountType res.Device = device return nil } @@ -399,6 +408,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe // AS dummy device has AS's token. AccessToken: token, AppserviceID: appService.ID, + AccountType: api.AccountTypeAppService, } localpart, err := userutil.ParseUsernameParam(appServiceUserID, &a.ServerName) diff --git a/userapi/storage/accounts/postgres/accounts_table.go b/userapi/storage/accounts/postgres/accounts_table.go index 30a77bd13..636451b5f 100644 --- a/userapi/storage/accounts/postgres/accounts_table.go +++ b/userapi/storage/accounts/postgres/accounts_table.go @@ -60,7 +60,7 @@ const deactivateAccountSQL = "" + "UPDATE account_accounts SET is_deactivated = TRUE WHERE localpart = $1" const selectAccountByLocalpartSQL = "" + - "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" + "SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1" const selectPasswordHashSQL = "" + "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE" @@ -148,9 +148,10 @@ func (s *accountsStatements) selectAccountByLocalpart( ) (*api.Account, error) { var appserviceIDPtr sql.NullString var acc api.Account + var accType api.AccountType stmt := s.selectAccountByLocalpartStmt - err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr) + err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &accType) if err != nil { if err != sql.ErrNoRows { log.WithError(err).Error("Unable to retrieve user from the db") @@ -163,6 +164,7 @@ func (s *accountsStatements) selectAccountByLocalpart( acc.UserID = userutil.MakeUserID(localpart, s.serverName) acc.ServerName = s.serverName + acc.AccountType = accType return &acc, nil } diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/accounts/sqlite3/accounts_table.go index 229bdf28a..6f63a8023 100644 --- a/userapi/storage/accounts/sqlite3/accounts_table.go +++ b/userapi/storage/accounts/sqlite3/accounts_table.go @@ -58,7 +58,7 @@ const deactivateAccountSQL = "" + "UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1" const selectAccountByLocalpartSQL = "" + - "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" + "SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1" const selectPasswordHashSQL = "" + "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0" @@ -148,9 +148,10 @@ func (s *accountsStatements) selectAccountByLocalpart( ) (*api.Account, error) { var appserviceIDPtr sql.NullString var acc api.Account + var accType api.AccountType stmt := s.selectAccountByLocalpartStmt - err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr) + err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &accType) if err != nil { if err != sql.ErrNoRows { log.WithError(err).Error("Unable to retrieve user from the db") @@ -163,6 +164,7 @@ func (s *accountsStatements) selectAccountByLocalpart( acc.UserID = userutil.MakeUserID(localpart, s.serverName) acc.ServerName = s.serverName + acc.AccountType = accType return &acc, nil }