mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-31 10:43:10 -06:00
Add AccoutnType to Device, so it can be verified on requests
This commit is contained in:
parent
c8e81a3674
commit
5c5a216011
|
|
@ -47,8 +47,7 @@ func GetAdminWhois(
|
||||||
req *http.Request, userAPI api.UserInternalAPI, device *api.Device,
|
req *http.Request, userAPI api.UserInternalAPI, device *api.Device,
|
||||||
userID string,
|
userID string,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
if userID != device.UserID {
|
if device.AccountType != api.AccountTypeAdmin && userID != device.UserID {
|
||||||
// TODO: Still allow if user is admin
|
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusForbidden,
|
Code: http.StatusForbidden,
|
||||||
JSON: jsonerror.Forbidden("userID does not match the current user"),
|
JSON: jsonerror.Forbidden("userID does not match the current user"),
|
||||||
|
|
|
||||||
|
|
@ -354,6 +354,7 @@ type Device struct {
|
||||||
// If the device is for an appservice user,
|
// If the device is for an appservice user,
|
||||||
// this is the appservice ID.
|
// this is the appservice ID.
|
||||||
AppserviceID string
|
AppserviceID string
|
||||||
|
AccountType AccountType
|
||||||
}
|
}
|
||||||
|
|
||||||
// Account represents a Matrix account on this home server.
|
// Account represents a Matrix account on this home server.
|
||||||
|
|
@ -362,7 +363,7 @@ type Account struct {
|
||||||
Localpart string
|
Localpart string
|
||||||
ServerName gomatrixserverlib.ServerName
|
ServerName gomatrixserverlib.ServerName
|
||||||
AppServiceID string
|
AppServiceID string
|
||||||
// TODO: Other flags like IsAdmin, IsGuest
|
AccountType AccountType
|
||||||
// TODO: Associations (e.g. with application services)
|
// TODO: Associations (e.g. with application services)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -373,6 +373,15 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
localPart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
acc, err := a.AccountDB.GetAccountByLocalpart(ctx, localPart)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
device.AccountType = acc.AccountType
|
||||||
res.Device = device
|
res.Device = device
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
@ -399,6 +408,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
|
||||||
// AS dummy device has AS's token.
|
// AS dummy device has AS's token.
|
||||||
AccessToken: token,
|
AccessToken: token,
|
||||||
AppserviceID: appService.ID,
|
AppserviceID: appService.ID,
|
||||||
|
AccountType: api.AccountTypeAppService,
|
||||||
}
|
}
|
||||||
|
|
||||||
localpart, err := userutil.ParseUsernameParam(appServiceUserID, &a.ServerName)
|
localpart, err := userutil.ParseUsernameParam(appServiceUserID, &a.ServerName)
|
||||||
|
|
|
||||||
|
|
@ -60,7 +60,7 @@ const deactivateAccountSQL = "" +
|
||||||
"UPDATE account_accounts SET is_deactivated = TRUE WHERE localpart = $1"
|
"UPDATE account_accounts SET is_deactivated = TRUE WHERE localpart = $1"
|
||||||
|
|
||||||
const selectAccountByLocalpartSQL = "" +
|
const selectAccountByLocalpartSQL = "" +
|
||||||
"SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
|
"SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1"
|
||||||
|
|
||||||
const selectPasswordHashSQL = "" +
|
const selectPasswordHashSQL = "" +
|
||||||
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
|
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
|
||||||
|
|
@ -148,9 +148,10 @@ func (s *accountsStatements) selectAccountByLocalpart(
|
||||||
) (*api.Account, error) {
|
) (*api.Account, error) {
|
||||||
var appserviceIDPtr sql.NullString
|
var appserviceIDPtr sql.NullString
|
||||||
var acc api.Account
|
var acc api.Account
|
||||||
|
var accType api.AccountType
|
||||||
|
|
||||||
stmt := s.selectAccountByLocalpartStmt
|
stmt := s.selectAccountByLocalpartStmt
|
||||||
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr)
|
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &accType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err != sql.ErrNoRows {
|
if err != sql.ErrNoRows {
|
||||||
log.WithError(err).Error("Unable to retrieve user from the db")
|
log.WithError(err).Error("Unable to retrieve user from the db")
|
||||||
|
|
@ -163,6 +164,7 @@ func (s *accountsStatements) selectAccountByLocalpart(
|
||||||
|
|
||||||
acc.UserID = userutil.MakeUserID(localpart, s.serverName)
|
acc.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||||
acc.ServerName = s.serverName
|
acc.ServerName = s.serverName
|
||||||
|
acc.AccountType = accType
|
||||||
|
|
||||||
return &acc, nil
|
return &acc, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -58,7 +58,7 @@ const deactivateAccountSQL = "" +
|
||||||
"UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1"
|
"UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1"
|
||||||
|
|
||||||
const selectAccountByLocalpartSQL = "" +
|
const selectAccountByLocalpartSQL = "" +
|
||||||
"SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
|
"SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1"
|
||||||
|
|
||||||
const selectPasswordHashSQL = "" +
|
const selectPasswordHashSQL = "" +
|
||||||
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
|
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
|
||||||
|
|
@ -148,9 +148,10 @@ func (s *accountsStatements) selectAccountByLocalpart(
|
||||||
) (*api.Account, error) {
|
) (*api.Account, error) {
|
||||||
var appserviceIDPtr sql.NullString
|
var appserviceIDPtr sql.NullString
|
||||||
var acc api.Account
|
var acc api.Account
|
||||||
|
var accType api.AccountType
|
||||||
|
|
||||||
stmt := s.selectAccountByLocalpartStmt
|
stmt := s.selectAccountByLocalpartStmt
|
||||||
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr)
|
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &accType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err != sql.ErrNoRows {
|
if err != sql.ErrNoRows {
|
||||||
log.WithError(err).Error("Unable to retrieve user from the db")
|
log.WithError(err).Error("Unable to retrieve user from the db")
|
||||||
|
|
@ -163,6 +164,7 @@ func (s *accountsStatements) selectAccountByLocalpart(
|
||||||
|
|
||||||
acc.UserID = userutil.MakeUserID(localpart, s.serverName)
|
acc.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||||
acc.ServerName = s.serverName
|
acc.ServerName = s.serverName
|
||||||
|
acc.AccountType = accType
|
||||||
|
|
||||||
return &acc, nil
|
return &acc, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue