From a0cc4c806ccd45bec45a593af4bc90fd66c02149 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 4 Nov 2022 12:56:33 +0000 Subject: [PATCH] Account tables --- clientapi/routing/password.go | 7 ++-- clientapi/routing/register.go | 9 +++-- userapi/api/api.go | 20 +++++++---- userapi/api/api_trace.go | 4 +-- userapi/internal/api.go | 16 ++++----- userapi/internal/api_logintoken.go | 2 +- userapi/inthttp/client.go | 3 +- userapi/inthttp/server.go | 15 ++------ userapi/storage/interface.go | 12 +++---- userapi/storage/postgres/accounts_table.go | 41 ++++++++++++---------- userapi/storage/shared/storage.go | 36 ++++++++++--------- userapi/storage/sqlite3/accounts_table.go | 40 +++++++++++---------- userapi/storage/storage_test.go | 22 ++++++------ userapi/storage/tables/interface.go | 12 +++---- userapi/storage/tables/stats_table_test.go | 2 +- 15 files changed, 126 insertions(+), 115 deletions(-) diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go index 6dc9af508..81a85cf9c 100644 --- a/clientapi/routing/password.go +++ b/clientapi/routing/password.go @@ -86,7 +86,7 @@ func Password( } // Get the local part. - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") return jsonerror.InternalServerError() @@ -94,8 +94,9 @@ func Password( // Ask the user API to perform the password change. passwordReq := &api.PerformPasswordUpdateRequest{ - Localpart: localpart, - Password: r.NewPassword, + Localpart: localpart, + ServerName: domain, + Password: r.NewPassword, } passwordRes := &api.PerformPasswordUpdateResponse{} if err := userAPI.PerformPasswordUpdate(req.Context(), passwordReq, passwordRes); err != nil { diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index b9ebb0518..7ae46e4ed 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -588,12 +588,15 @@ func Register( } // Auto generate a numeric username if r.Username is empty if r.Username == "" { - res := &userapi.QueryNumericLocalpartResponse{} - if err := userAPI.QueryNumericLocalpart(req.Context(), res); err != nil { + nreq := &userapi.QueryNumericLocalpartRequest{ + ServerName: cfg.Matrix.ServerName, // TODO: might not be right + } + nres := &userapi.QueryNumericLocalpartResponse{} + if err := userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil { util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed") return jsonerror.InternalServerError() } - r.Username = strconv.FormatInt(res.ID, 10) + r.Username = strconv.FormatInt(nres.ID, 10) } // Is this an appservice registration? It will be if the access diff --git a/userapi/api/api.go b/userapi/api/api.go index 8d7f783de..53d888b47 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -78,7 +78,7 @@ type ClientUserAPI interface { QueryAcccessTokenAPI LoginTokenInternalAPI UserLoginAPI - QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error + QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error @@ -335,9 +335,10 @@ type PerformAccountCreationResponse struct { // PerformAccountCreationRequest is the request for PerformAccountCreation type PerformPasswordUpdateRequest struct { - Localpart string // Required: The localpart for this account. - Password string // Required: The new password to set. - LogoutDevices bool // Optional: Whether to log out all user devices. + Localpart string // Required: The localpart for this account. + ServerName gomatrixserverlib.ServerName // Required: The domain for this account. + Password string // Required: The new password to set. + LogoutDevices bool // Optional: Whether to log out all user devices. } // PerformAccountCreationResponse is the response for PerformAccountCreation @@ -601,12 +602,17 @@ type PerformSetAvatarURLResponse struct { Changed bool `json:"changed"` } +type QueryNumericLocalpartRequest struct { + ServerName gomatrixserverlib.ServerName +} + type QueryNumericLocalpartResponse struct { ID int64 } type QueryAccountAvailabilityRequest struct { - Localpart string + Localpart string + ServerName gomatrixserverlib.ServerName } type QueryAccountAvailabilityResponse struct { @@ -614,7 +620,9 @@ type QueryAccountAvailabilityResponse struct { } type QueryAccountByPasswordRequest struct { - Localpart, PlaintextPassword string + Localpart string + ServerName gomatrixserverlib.ServerName + PlaintextPassword string } type QueryAccountByPasswordResponse struct { diff --git a/userapi/api/api_trace.go b/userapi/api/api_trace.go index 90834f7e3..ce661770f 100644 --- a/userapi/api/api_trace.go +++ b/userapi/api/api_trace.go @@ -156,8 +156,8 @@ func (t *UserInternalAPITrace) SetAvatarURL(ctx context.Context, req *PerformSet return err } -func (t *UserInternalAPITrace) QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error { - err := t.Impl.QueryNumericLocalpart(ctx, res) +func (t *UserInternalAPITrace) QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error { + err := t.Impl.QueryNumericLocalpart(ctx, req, res) util.GetLogger(ctx).Infof("QueryNumericLocalpart req= res=%+v", js(res)) return err } diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 9a0de9c12..1c0de2f0c 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -227,7 +227,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P } func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error { - if err := a.DB.SetPassword(ctx, req.Localpart, req.Password); err != nil { + if err := a.DB.SetPassword(ctx, req.Localpart, req.ServerName, req.Password); err != nil { return err } if req.LogoutDevices { @@ -527,7 +527,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc if !a.Config.Matrix.IsLocalServerName(domain) { return nil } - acc, err := a.DB.GetAccountByLocalpart(ctx, localPart) + acc, err := a.DB.GetAccountByLocalpart(ctx, localPart, domain) if err != nil { return err } @@ -568,7 +568,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe if localpart != "" { // AS is masquerading as another user // Verify that the user is registered - account, err := a.DB.GetAccountByLocalpart(ctx, localpart) + account, err := a.DB.GetAccountByLocalpart(ctx, localpart, a.Cfg.Matrix.ServerName) // TODO: which server name here? // Verify that the account exists and either appServiceID matches or // it belongs to the appservice user namespaces if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) { @@ -620,7 +620,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a return err } - err := a.DB.DeactivateAccount(ctx, req.Localpart) + err := a.DB.DeactivateAccount(ctx, req.Localpart, serverName) res.AccountDeactivated = err == nil return err } @@ -883,8 +883,8 @@ func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetA return err } -func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error { - id, err := a.DB.GetNewNumericLocalpart(ctx) +func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, req *api.QueryNumericLocalpartRequest, res *api.QueryNumericLocalpartResponse) error { + id, err := a.DB.GetNewNumericLocalpart(ctx, req.ServerName) if err != nil { return err } @@ -894,12 +894,12 @@ func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.Qu func (a *UserInternalAPI) QueryAccountAvailability(ctx context.Context, req *api.QueryAccountAvailabilityRequest, res *api.QueryAccountAvailabilityResponse) error { var err error - res.Available, err = a.DB.CheckAccountAvailability(ctx, req.Localpart) + res.Available, err = a.DB.CheckAccountAvailability(ctx, req.Localpart, req.ServerName) return err } func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error { - acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.PlaintextPassword) + acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.ServerName, req.PlaintextPassword) switch err { case sql.ErrNoRows: // user does not exist return nil diff --git a/userapi/internal/api_logintoken.go b/userapi/internal/api_logintoken.go index 87f25e5e2..3b211db5b 100644 --- a/userapi/internal/api_logintoken.go +++ b/userapi/internal/api_logintoken.go @@ -66,7 +66,7 @@ func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLog if !a.Config.Matrix.IsLocalServerName(domain) { return fmt.Errorf("cannot return a login token for a remote user (server name %s)", domain) } - if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil { + if _, err := a.DB.GetAccountByLocalpart(ctx, localpart, domain); err != nil { res.Data = nil if err == sql.ErrNoRows { return nil diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index aa5d46d9f..87ae058c2 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -355,11 +355,12 @@ func (h *httpUserInternalAPI) SetAvatarURL( func (h *httpUserInternalAPI) QueryNumericLocalpart( ctx context.Context, + request *api.QueryNumericLocalpartRequest, response *api.QueryNumericLocalpartResponse, ) error { return httputil.CallInternalRPCAPI( "QueryNumericLocalpart", h.apiURL+QueryNumericLocalpartPath, - h.httpClient, ctx, &struct{}{}, response, + h.httpClient, ctx, request, response, ) } diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index 99148b760..661fecfae 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -15,12 +15,9 @@ package inthttp import ( - "net/http" - "github.com/gorilla/mux" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/util" ) // nolint: gocyclo @@ -152,15 +149,9 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { httputil.MakeInternalRPCAPI("UserAPIPerformSetAvatarURL", s.SetAvatarURL), ) - // TODO: Look at the shape of this - internalAPIMux.Handle(QueryNumericLocalpartPath, - httputil.MakeInternalAPI("UserAPIQueryNumericLocalpart", func(req *http.Request) util.JSONResponse { - response := api.QueryNumericLocalpartResponse{} - if err := s.QueryNumericLocalpart(req.Context(), &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + internalAPIMux.Handle( + QueryNumericLocalpartPath, + httputil.MakeInternalRPCAPI("UserAPIQueryNumericLocalpart", s.QueryNumericLocalpart), ) internalAPIMux.Handle( diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 4c12c42e6..6413f88fb 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -40,12 +40,12 @@ type Account interface { // for this account. If no password is supplied, the account will be a passwordless account. If the // account already exists, it will return nil, ErrUserExists. CreateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error) - GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) - GetNewNumericLocalpart(ctx context.Context) (int64, error) - CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) - GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) - DeactivateAccount(ctx context.Context, localpart string) (err error) - SetPassword(ctx context.Context, localpart string, plaintextPassword string) error + GetAccountByPassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string) (*api.Account, error) + GetNewNumericLocalpart(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) + CheckAccountAvailability(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (bool, error) + GetAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*api.Account, error) + DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error) + SetPassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string) error } type AccountData interface { diff --git a/userapi/storage/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go index 1448c888b..754880a06 100644 --- a/userapi/storage/postgres/accounts_table.go +++ b/userapi/storage/postgres/accounts_table.go @@ -52,22 +52,22 @@ CREATE TABLE IF NOT EXISTS userapi_accounts ( ` const insertAccountSQL = "" + - "INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" + "INSERT INTO userapi_accounts(localpart, server_name, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5, $6)" const updatePasswordSQL = "" + - "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2" + "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2 AND server_name = $3" const deactivateAccountSQL = "" + - "UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1" + "UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1 AND server_name = $2" const selectAccountByLocalpartSQL = "" + - "SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1" + "SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2" const selectPasswordHashSQL = "" + - "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = FALSE" + "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = FALSE" const selectNewNumericLocalpartSQL = "" + - "SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$'" + "SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$' AND server_name = $2" type accountsStatements struct { insertAccountStmt *sql.Stmt @@ -118,16 +118,18 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam // this account will be passwordless. Returns an error if this account already exists. Returns the account // on success. func (s *accountsStatements) InsertAccount( - ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + hash, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.insertAccountStmt) var err error if accountType != api.AccountTypeAppService { - _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType) + _, err = stmt.ExecContext(ctx, localpart, serverName, createdTimeMS, hash, nil, accountType) } else { - _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType) + _, err = stmt.ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType) } if err != nil { return nil, err @@ -143,34 +145,35 @@ func (s *accountsStatements) InsertAccount( } func (s *accountsStatements) UpdatePassword( - ctx context.Context, localpart, passwordHash string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + passwordHash string, ) (err error) { - _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) + _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart, serverName) return } func (s *accountsStatements) DeactivateAccount( - ctx context.Context, localpart string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) (err error) { - _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart) + _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart, serverName) return } func (s *accountsStatements) SelectPasswordHash( - ctx context.Context, localpart string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) (hash string, err error) { - err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash) + err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart, serverName).Scan(&hash) return } func (s *accountsStatements) SelectAccountByLocalpart( - ctx context.Context, localpart string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) (*api.Account, error) { var appserviceIDPtr sql.NullString var acc api.Account stmt := s.selectAccountByLocalpartStmt - err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType) + err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType) if err != nil { if err != sql.ErrNoRows { log.WithError(err).Error("Unable to retrieve user from the db") @@ -188,12 +191,12 @@ func (s *accountsStatements) SelectAccountByLocalpart( } func (s *accountsStatements) SelectNewNumericLocalpart( - ctx context.Context, txn *sql.Tx, + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ) (id int64, err error) { stmt := s.selectNewNumericLocalpartStmt if txn != nil { stmt = sqlutil.TxStmt(txn, stmt) } - err = stmt.QueryRowContext(ctx).Scan(&id) + err = stmt.QueryRowContext(ctx, serverName).Scan(&id) return id + 1, err } diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index dfef128b5..f2cfe021d 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -68,9 +68,10 @@ const ( // GetAccountByPassword returns the account associated with the given localpart and password. // Returns sql.ErrNoRows if no account exists which matches the given localpart. func (d *Database) GetAccountByPassword( - ctx context.Context, localpart, plaintextPassword string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + plaintextPassword string, ) (*api.Account, error) { - hash, err := d.Accounts.SelectPasswordHash(ctx, localpart) + hash, err := d.Accounts.SelectPasswordHash(ctx, localpart, serverName) if err != nil { return nil, err } @@ -80,7 +81,7 @@ func (d *Database) GetAccountByPassword( if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil { return nil, err } - return d.Accounts.SelectAccountByLocalpart(ctx, localpart) + return d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName) } // GetProfileByLocalpart returns the profile associated with the given localpart. @@ -117,14 +118,15 @@ func (d *Database) SetDisplayName( // SetPassword sets the account password to the given hash. func (d *Database) SetPassword( - ctx context.Context, localpart, plaintextPassword string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + plaintextPassword string, ) error { hash, err := d.hashPassword(plaintextPassword) if err != nil { return err } return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { - return d.Accounts.UpdatePassword(ctx, localpart, hash) + return d.Accounts.UpdatePassword(ctx, localpart, serverName, hash) }) } @@ -139,7 +141,7 @@ func (d *Database) CreateAccount( // For guest accounts, we create a new numeric local part if accountType == api.AccountTypeGuest { var numLocalpart int64 - numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn) + numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn, serverName) if err != nil { return err } @@ -170,13 +172,13 @@ func (d *Database) createAccount( return nil, err } } - if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil { + if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, serverName, hash, appserviceID, accountType); err != nil { return nil, sqlutil.ErrUserExists } if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil { return nil, err } - pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName) + pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName) prbs, err := json.Marshal(pushRuleSets) if err != nil { return nil, err @@ -262,9 +264,9 @@ func (d *Database) GetAccountDataByType( // GetNewNumericLocalpart generates and returns a new unused numeric localpart func (d *Database) GetNewNumericLocalpart( - ctx context.Context, + ctx context.Context, serverName gomatrixserverlib.ServerName, ) (int64, error) { - return d.Accounts.SelectNewNumericLocalpart(ctx, nil) + return d.Accounts.SelectNewNumericLocalpart(ctx, nil, serverName) } func (d *Database) hashPassword(plaintext string) (hash string, err error) { @@ -335,8 +337,8 @@ func (d *Database) GetThreePIDsForLocalpart( // CheckAccountAvailability checks if the username/localpart is already present // in the database. // If the DB returns sql.ErrNoRows the Localpart isn't taken. -func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) { - _, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart) +func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (bool, error) { + _, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName) if err == sql.ErrNoRows { return true, nil } @@ -346,12 +348,12 @@ func (d *Database) CheckAccountAvailability(ctx context.Context, localpart strin // GetAccountByLocalpart returns the account associated with the given localpart. // This function assumes the request is authenticated or the account data is used only internally. // Returns sql.ErrNoRows if no account exists which matches the given localpart. -func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, +func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) (*api.Account, error) { // try to get the account with lowercase localpart (majority) - acc, err := d.Accounts.SelectAccountByLocalpart(ctx, strings.ToLower(localpart)) + acc, err := d.Accounts.SelectAccountByLocalpart(ctx, strings.ToLower(localpart), serverName) if err == sql.ErrNoRows { - acc, err = d.Accounts.SelectAccountByLocalpart(ctx, localpart) // try with localpart as passed by the request + acc, err = d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName) // try with localpart as passed by the request } return acc, err } @@ -364,9 +366,9 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi } // DeactivateAccount deactivates the user's account, removing all ability for the user to login again. -func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) { +func (d *Database) DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error) { return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { - return d.Accounts.DeactivateAccount(ctx, localpart) + return d.Accounts.DeactivateAccount(ctx, localpart, serverName) }) } diff --git a/userapi/storage/sqlite3/accounts_table.go b/userapi/storage/sqlite3/accounts_table.go index ae82503de..cfe6c98fa 100644 --- a/userapi/storage/sqlite3/accounts_table.go +++ b/userapi/storage/sqlite3/accounts_table.go @@ -52,22 +52,22 @@ CREATE TABLE IF NOT EXISTS userapi_accounts ( ` const insertAccountSQL = "" + - "INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" + "INSERT INTO userapi_accounts(localpart, server_name, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5, $6)" const updatePasswordSQL = "" + - "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2" + "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2 AND server_name = $3" const deactivateAccountSQL = "" + - "UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1" + "UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1 AND server_name = $2" const selectAccountByLocalpartSQL = "" + - "SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1" + "SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2" const selectPasswordHashSQL = "" + - "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = 0" + "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = 0" const selectNewNumericLocalpartSQL = "" + - "SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0" + "SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0 AND server_name = $1" type accountsStatements struct { db *sql.DB @@ -120,16 +120,17 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) // this account will be passwordless. Returns an error if this account already exists. Returns the account // on success. func (s *accountsStatements) InsertAccount( - ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType, + ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, + hash, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 stmt := s.insertAccountStmt var err error if accountType != api.AccountTypeAppService { - _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType) + _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, serverName, createdTimeMS, hash, nil, accountType) } else { - _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType) + _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType) } if err != nil { return nil, err @@ -145,34 +146,35 @@ func (s *accountsStatements) InsertAccount( } func (s *accountsStatements) UpdatePassword( - ctx context.Context, localpart, passwordHash string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + passwordHash string, ) (err error) { - _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) + _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart, serverName) return } func (s *accountsStatements) DeactivateAccount( - ctx context.Context, localpart string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) (err error) { - _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart) + _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart, serverName) return } func (s *accountsStatements) SelectPasswordHash( - ctx context.Context, localpart string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) (hash string, err error) { - err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash) + err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart, serverName).Scan(&hash) return } func (s *accountsStatements) SelectAccountByLocalpart( - ctx context.Context, localpart string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) (*api.Account, error) { var appserviceIDPtr sql.NullString var acc api.Account stmt := s.selectAccountByLocalpartStmt - err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType) + err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType) if err != nil { if err != sql.ErrNoRows { log.WithError(err).Error("Unable to retrieve user from the db") @@ -190,13 +192,13 @@ func (s *accountsStatements) SelectAccountByLocalpart( } func (s *accountsStatements) SelectNewNumericLocalpart( - ctx context.Context, txn *sql.Tx, + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ) (id int64, err error) { stmt := s.selectNewNumericLocalpartStmt if txn != nil { stmt = sqlutil.TxStmt(txn, stmt) } - err = stmt.QueryRowContext(ctx).Scan(&id) + err = stmt.QueryRowContext(ctx, serverName).Scan(&id) if err == sql.ErrNoRows { return 1, nil } diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index 23f354083..ada077ce1 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -88,47 +88,47 @@ func Test_Accounts(t *testing.T) { assert.NoError(t, err, "failed to create account") // verify the newly create account is the same as returned by CreateAccount var accGet *api.Account - accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "testing") + accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "testing") assert.NoError(t, err, "failed to get account by password") assert.Equal(t, accAlice, accGet) - accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart) + accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "failed to get account by localpart") assert.Equal(t, accAlice, accGet) // check account availability - available, err := db.CheckAccountAvailability(ctx, aliceLocalpart) + available, err := db.CheckAccountAvailability(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "failed to checkout account availability") assert.Equal(t, false, available) - available, err = db.CheckAccountAvailability(ctx, "unusedname") + available, err = db.CheckAccountAvailability(ctx, "unusedname", aliceDomain) assert.NoError(t, err, "failed to checkout account availability") assert.Equal(t, true, available) // get guest account numeric aliceLocalpart - first, err := db.GetNewNumericLocalpart(ctx) + first, err := db.GetNewNumericLocalpart(ctx, aliceDomain) assert.NoError(t, err, "failed to get new numeric localpart") // Create a new account to verify the numeric localpart is updated _, err = db.CreateAccount(ctx, "", aliceDomain, "testing", "", api.AccountTypeGuest) assert.NoError(t, err, "failed to create account") - second, err := db.GetNewNumericLocalpart(ctx) + second, err := db.GetNewNumericLocalpart(ctx, aliceDomain) assert.NoError(t, err) assert.Greater(t, second, first) // update password for alice - err = db.SetPassword(ctx, aliceLocalpart, "newPassword") + err = db.SetPassword(ctx, aliceLocalpart, aliceDomain, "newPassword") assert.NoError(t, err, "failed to update password") - accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword") + accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "newPassword") assert.NoError(t, err, "failed to get account by new password") assert.Equal(t, accAlice, accGet) // deactivate account - err = db.DeactivateAccount(ctx, aliceLocalpart) + err = db.DeactivateAccount(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "failed to deactivate account") // This should fail now, as the account is deactivated - _, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword") + _, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "newPassword") assert.Error(t, err, "expected an error, got none") - _, err = db.GetAccountByLocalpart(ctx, "unusename") + _, err = db.GetAccountByLocalpart(ctx, "unusename", aliceDomain) assert.Error(t, err, "expected an error for non existent localpart") // create an empty localpart; this should never happen, but is required to test getting a numeric localpart diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 77216f00e..32b7c30c6 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -34,12 +34,12 @@ type AccountDataTable interface { } type AccountsTable interface { - InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType) (*api.Account, error) - UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error) - DeactivateAccount(ctx context.Context, localpart string) (err error) - SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error) - SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) - SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error) + InsertAccount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, hash, appserviceID string, accountType api.AccountType) (*api.Account, error) + UpdatePassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, passwordHash string) (err error) + DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error) + SelectPasswordHash(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (hash string, err error) + SelectAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*api.Account, error) + SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (id int64, err error) } type DevicesTable interface { diff --git a/userapi/storage/tables/stats_table_test.go b/userapi/storage/tables/stats_table_test.go index a547423bc..ae2915bc4 100644 --- a/userapi/storage/tables/stats_table_test.go +++ b/userapi/storage/tables/stats_table_test.go @@ -89,7 +89,7 @@ func mustMakeAccountAndDevice( appServiceID = util.RandomString(16) } - _, err := accDB.InsertAccount(ctx, nil, localpart, "", appServiceID, accType) + _, err := accDB.InsertAccount(ctx, nil, localpart, "localhost", "", appServiceID, accType) if err != nil { t.Fatalf("unable to create account: %v", err) }