Add AccoutnType to Device, so it can be verified on requests

This commit is contained in:
Till Faelligen 2022-02-10 14:28:23 +01:00
parent c8e81a3674
commit 5c5a216011
5 changed files with 21 additions and 7 deletions

View file

@ -47,8 +47,7 @@ func GetAdminWhois(
req *http.Request, userAPI api.UserInternalAPI, device *api.Device,
userID string,
) util.JSONResponse {
if userID != device.UserID {
// TODO: Still allow if user is admin
if device.AccountType != api.AccountTypeAdmin && userID != device.UserID {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("userID does not match the current user"),

View file

@ -354,6 +354,7 @@ type Device struct {
// If the device is for an appservice user,
// this is the appservice ID.
AppserviceID string
AccountType AccountType
}
// Account represents a Matrix account on this home server.
@ -362,7 +363,7 @@ type Account struct {
Localpart string
ServerName gomatrixserverlib.ServerName
AppServiceID string
// TODO: Other flags like IsAdmin, IsGuest
AccountType AccountType
// TODO: Associations (e.g. with application services)
}

View file

@ -373,6 +373,15 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
}
return err
}
localPart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
return err
}
acc, err := a.AccountDB.GetAccountByLocalpart(ctx, localPart)
if err != nil {
return err
}
device.AccountType = acc.AccountType
res.Device = device
return nil
}
@ -399,6 +408,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
// AS dummy device has AS's token.
AccessToken: token,
AppserviceID: appService.ID,
AccountType: api.AccountTypeAppService,
}
localpart, err := userutil.ParseUsernameParam(appServiceUserID, &a.ServerName)

View file

@ -60,7 +60,7 @@ const deactivateAccountSQL = "" +
"UPDATE account_accounts SET is_deactivated = TRUE WHERE localpart = $1"
const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
"SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1"
const selectPasswordHashSQL = "" +
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
@ -148,9 +148,10 @@ func (s *accountsStatements) selectAccountByLocalpart(
) (*api.Account, error) {
var appserviceIDPtr sql.NullString
var acc api.Account
var accType api.AccountType
stmt := s.selectAccountByLocalpartStmt
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr)
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &accType)
if err != nil {
if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve user from the db")
@ -163,6 +164,7 @@ func (s *accountsStatements) selectAccountByLocalpart(
acc.UserID = userutil.MakeUserID(localpart, s.serverName)
acc.ServerName = s.serverName
acc.AccountType = accType
return &acc, nil
}

View file

@ -58,7 +58,7 @@ const deactivateAccountSQL = "" +
"UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1"
const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
"SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1"
const selectPasswordHashSQL = "" +
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
@ -148,9 +148,10 @@ func (s *accountsStatements) selectAccountByLocalpart(
) (*api.Account, error) {
var appserviceIDPtr sql.NullString
var acc api.Account
var accType api.AccountType
stmt := s.selectAccountByLocalpartStmt
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr)
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &accType)
if err != nil {
if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve user from the db")
@ -163,6 +164,7 @@ func (s *accountsStatements) selectAccountByLocalpart(
acc.UserID = userutil.MakeUserID(localpart, s.serverName)
acc.ServerName = s.serverName
acc.AccountType = accType
return &acc, nil
}