Account tables

This commit is contained in:
Neil Alexander 2022-11-04 12:56:33 +00:00
parent d1c61f5f95
commit a0cc4c806c
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
15 changed files with 126 additions and 115 deletions

View file

@ -86,7 +86,7 @@ func Password(
} }
// Get the local part. // Get the local part.
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -95,6 +95,7 @@ func Password(
// Ask the user API to perform the password change. // Ask the user API to perform the password change.
passwordReq := &api.PerformPasswordUpdateRequest{ passwordReq := &api.PerformPasswordUpdateRequest{
Localpart: localpart, Localpart: localpart,
ServerName: domain,
Password: r.NewPassword, Password: r.NewPassword,
} }
passwordRes := &api.PerformPasswordUpdateResponse{} passwordRes := &api.PerformPasswordUpdateResponse{}

View file

@ -588,12 +588,15 @@ func Register(
} }
// Auto generate a numeric username if r.Username is empty // Auto generate a numeric username if r.Username is empty
if r.Username == "" { if r.Username == "" {
res := &userapi.QueryNumericLocalpartResponse{} nreq := &userapi.QueryNumericLocalpartRequest{
if err := userAPI.QueryNumericLocalpart(req.Context(), res); err != nil { ServerName: cfg.Matrix.ServerName, // TODO: might not be right
}
nres := &userapi.QueryNumericLocalpartResponse{}
if err := userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed") util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
r.Username = strconv.FormatInt(res.ID, 10) r.Username = strconv.FormatInt(nres.ID, 10)
} }
// Is this an appservice registration? It will be if the access // Is this an appservice registration? It will be if the access

View file

@ -78,7 +78,7 @@ type ClientUserAPI interface {
QueryAcccessTokenAPI QueryAcccessTokenAPI
LoginTokenInternalAPI LoginTokenInternalAPI
UserLoginAPI UserLoginAPI
QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
@ -336,6 +336,7 @@ type PerformAccountCreationResponse struct {
// PerformAccountCreationRequest is the request for PerformAccountCreation // PerformAccountCreationRequest is the request for PerformAccountCreation
type PerformPasswordUpdateRequest struct { type PerformPasswordUpdateRequest struct {
Localpart string // Required: The localpart for this account. Localpart string // Required: The localpart for this account.
ServerName gomatrixserverlib.ServerName // Required: The domain for this account.
Password string // Required: The new password to set. Password string // Required: The new password to set.
LogoutDevices bool // Optional: Whether to log out all user devices. LogoutDevices bool // Optional: Whether to log out all user devices.
} }
@ -601,12 +602,17 @@ type PerformSetAvatarURLResponse struct {
Changed bool `json:"changed"` Changed bool `json:"changed"`
} }
type QueryNumericLocalpartRequest struct {
ServerName gomatrixserverlib.ServerName
}
type QueryNumericLocalpartResponse struct { type QueryNumericLocalpartResponse struct {
ID int64 ID int64
} }
type QueryAccountAvailabilityRequest struct { type QueryAccountAvailabilityRequest struct {
Localpart string Localpart string
ServerName gomatrixserverlib.ServerName
} }
type QueryAccountAvailabilityResponse struct { type QueryAccountAvailabilityResponse struct {
@ -614,7 +620,9 @@ type QueryAccountAvailabilityResponse struct {
} }
type QueryAccountByPasswordRequest struct { type QueryAccountByPasswordRequest struct {
Localpart, PlaintextPassword string Localpart string
ServerName gomatrixserverlib.ServerName
PlaintextPassword string
} }
type QueryAccountByPasswordResponse struct { type QueryAccountByPasswordResponse struct {

View file

@ -156,8 +156,8 @@ func (t *UserInternalAPITrace) SetAvatarURL(ctx context.Context, req *PerformSet
return err return err
} }
func (t *UserInternalAPITrace) QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error { func (t *UserInternalAPITrace) QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error {
err := t.Impl.QueryNumericLocalpart(ctx, res) err := t.Impl.QueryNumericLocalpart(ctx, req, res)
util.GetLogger(ctx).Infof("QueryNumericLocalpart req= res=%+v", js(res)) util.GetLogger(ctx).Infof("QueryNumericLocalpart req= res=%+v", js(res))
return err return err
} }

View file

@ -227,7 +227,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
} }
func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error { func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
if err := a.DB.SetPassword(ctx, req.Localpart, req.Password); err != nil { if err := a.DB.SetPassword(ctx, req.Localpart, req.ServerName, req.Password); err != nil {
return err return err
} }
if req.LogoutDevices { if req.LogoutDevices {
@ -527,7 +527,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
if !a.Config.Matrix.IsLocalServerName(domain) { if !a.Config.Matrix.IsLocalServerName(domain) {
return nil return nil
} }
acc, err := a.DB.GetAccountByLocalpart(ctx, localPart) acc, err := a.DB.GetAccountByLocalpart(ctx, localPart, domain)
if err != nil { if err != nil {
return err return err
} }
@ -568,7 +568,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
if localpart != "" { // AS is masquerading as another user if localpart != "" { // AS is masquerading as another user
// Verify that the user is registered // Verify that the user is registered
account, err := a.DB.GetAccountByLocalpart(ctx, localpart) account, err := a.DB.GetAccountByLocalpart(ctx, localpart, a.Cfg.Matrix.ServerName) // TODO: which server name here?
// Verify that the account exists and either appServiceID matches or // Verify that the account exists and either appServiceID matches or
// it belongs to the appservice user namespaces // it belongs to the appservice user namespaces
if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) { if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) {
@ -620,7 +620,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
return err return err
} }
err := a.DB.DeactivateAccount(ctx, req.Localpart) err := a.DB.DeactivateAccount(ctx, req.Localpart, serverName)
res.AccountDeactivated = err == nil res.AccountDeactivated = err == nil
return err return err
} }
@ -883,8 +883,8 @@ func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetA
return err return err
} }
func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error { func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, req *api.QueryNumericLocalpartRequest, res *api.QueryNumericLocalpartResponse) error {
id, err := a.DB.GetNewNumericLocalpart(ctx) id, err := a.DB.GetNewNumericLocalpart(ctx, req.ServerName)
if err != nil { if err != nil {
return err return err
} }
@ -894,12 +894,12 @@ func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.Qu
func (a *UserInternalAPI) QueryAccountAvailability(ctx context.Context, req *api.QueryAccountAvailabilityRequest, res *api.QueryAccountAvailabilityResponse) error { func (a *UserInternalAPI) QueryAccountAvailability(ctx context.Context, req *api.QueryAccountAvailabilityRequest, res *api.QueryAccountAvailabilityResponse) error {
var err error var err error
res.Available, err = a.DB.CheckAccountAvailability(ctx, req.Localpart) res.Available, err = a.DB.CheckAccountAvailability(ctx, req.Localpart, req.ServerName)
return err return err
} }
func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error { func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error {
acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.PlaintextPassword) acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.ServerName, req.PlaintextPassword)
switch err { switch err {
case sql.ErrNoRows: // user does not exist case sql.ErrNoRows: // user does not exist
return nil return nil

View file

@ -66,7 +66,7 @@ func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLog
if !a.Config.Matrix.IsLocalServerName(domain) { if !a.Config.Matrix.IsLocalServerName(domain) {
return fmt.Errorf("cannot return a login token for a remote user (server name %s)", domain) return fmt.Errorf("cannot return a login token for a remote user (server name %s)", domain)
} }
if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil { if _, err := a.DB.GetAccountByLocalpart(ctx, localpart, domain); err != nil {
res.Data = nil res.Data = nil
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil return nil

View file

@ -355,11 +355,12 @@ func (h *httpUserInternalAPI) SetAvatarURL(
func (h *httpUserInternalAPI) QueryNumericLocalpart( func (h *httpUserInternalAPI) QueryNumericLocalpart(
ctx context.Context, ctx context.Context,
request *api.QueryNumericLocalpartRequest,
response *api.QueryNumericLocalpartResponse, response *api.QueryNumericLocalpartResponse,
) error { ) error {
return httputil.CallInternalRPCAPI( return httputil.CallInternalRPCAPI(
"QueryNumericLocalpart", h.apiURL+QueryNumericLocalpartPath, "QueryNumericLocalpart", h.apiURL+QueryNumericLocalpartPath,
h.httpClient, ctx, &struct{}{}, response, h.httpClient, ctx, request, response,
) )
} }

View file

@ -15,12 +15,9 @@
package inthttp package inthttp
import ( import (
"net/http"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util"
) )
// nolint: gocyclo // nolint: gocyclo
@ -152,15 +149,9 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
httputil.MakeInternalRPCAPI("UserAPIPerformSetAvatarURL", s.SetAvatarURL), httputil.MakeInternalRPCAPI("UserAPIPerformSetAvatarURL", s.SetAvatarURL),
) )
// TODO: Look at the shape of this internalAPIMux.Handle(
internalAPIMux.Handle(QueryNumericLocalpartPath, QueryNumericLocalpartPath,
httputil.MakeInternalAPI("UserAPIQueryNumericLocalpart", func(req *http.Request) util.JSONResponse { httputil.MakeInternalRPCAPI("UserAPIQueryNumericLocalpart", s.QueryNumericLocalpart),
response := api.QueryNumericLocalpartResponse{}
if err := s.QueryNumericLocalpart(req.Context(), &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle( internalAPIMux.Handle(

View file

@ -40,12 +40,12 @@ type Account interface {
// for this account. If no password is supplied, the account will be a passwordless account. If the // for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, ErrUserExists. // account already exists, it will return nil, ErrUserExists.
CreateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error) CreateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error)
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) GetAccountByPassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string) (*api.Account, error)
GetNewNumericLocalpart(ctx context.Context) (int64, error) GetNewNumericLocalpart(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error)
CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) CheckAccountAvailability(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (bool, error)
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) GetAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*api.Account, error)
DeactivateAccount(ctx context.Context, localpart string) (err error) DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error)
SetPassword(ctx context.Context, localpart string, plaintextPassword string) error SetPassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string) error
} }
type AccountData interface { type AccountData interface {

View file

@ -52,22 +52,22 @@ CREATE TABLE IF NOT EXISTS userapi_accounts (
` `
const insertAccountSQL = "" + const insertAccountSQL = "" +
"INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" "INSERT INTO userapi_accounts(localpart, server_name, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5, $6)"
const updatePasswordSQL = "" + const updatePasswordSQL = "" +
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2" "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2 AND server_name = $3"
const deactivateAccountSQL = "" + const deactivateAccountSQL = "" +
"UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1" "UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1 AND server_name = $2"
const selectAccountByLocalpartSQL = "" + const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1" "SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2"
const selectPasswordHashSQL = "" + const selectPasswordHashSQL = "" +
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = FALSE" "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = FALSE"
const selectNewNumericLocalpartSQL = "" + const selectNewNumericLocalpartSQL = "" +
"SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$'" "SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$' AND server_name = $2"
type accountsStatements struct { type accountsStatements struct {
insertAccountStmt *sql.Stmt insertAccountStmt *sql.Stmt
@ -118,16 +118,18 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam
// this account will be passwordless. Returns an error if this account already exists. Returns the account // this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success. // on success.
func (s *accountsStatements) InsertAccount( func (s *accountsStatements) InsertAccount(
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType, ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
hash, appserviceID string, accountType api.AccountType,
) (*api.Account, error) { ) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt) stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
var err error var err error
if accountType != api.AccountTypeAppService { if accountType != api.AccountTypeAppService {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType) _, err = stmt.ExecContext(ctx, localpart, serverName, createdTimeMS, hash, nil, accountType)
} else { } else {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType) _, err = stmt.ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType)
} }
if err != nil { if err != nil {
return nil, err return nil, err
@ -143,34 +145,35 @@ func (s *accountsStatements) InsertAccount(
} }
func (s *accountsStatements) UpdatePassword( func (s *accountsStatements) UpdatePassword(
ctx context.Context, localpart, passwordHash string, ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
passwordHash string,
) (err error) { ) (err error) {
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart, serverName)
return return
} }
func (s *accountsStatements) DeactivateAccount( func (s *accountsStatements) DeactivateAccount(
ctx context.Context, localpart string, ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (err error) { ) (err error) {
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart) _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart, serverName)
return return
} }
func (s *accountsStatements) SelectPasswordHash( func (s *accountsStatements) SelectPasswordHash(
ctx context.Context, localpart string, ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (hash string, err error) { ) (hash string, err error) {
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash) err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart, serverName).Scan(&hash)
return return
} }
func (s *accountsStatements) SelectAccountByLocalpart( func (s *accountsStatements) SelectAccountByLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (*api.Account, error) { ) (*api.Account, error) {
var appserviceIDPtr sql.NullString var appserviceIDPtr sql.NullString
var acc api.Account var acc api.Account
stmt := s.selectAccountByLocalpartStmt stmt := s.selectAccountByLocalpartStmt
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType) err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType)
if err != nil { if err != nil {
if err != sql.ErrNoRows { if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve user from the db") log.WithError(err).Error("Unable to retrieve user from the db")
@ -188,12 +191,12 @@ func (s *accountsStatements) SelectAccountByLocalpart(
} }
func (s *accountsStatements) SelectNewNumericLocalpart( func (s *accountsStatements) SelectNewNumericLocalpart(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (id int64, err error) { ) (id int64, err error) {
stmt := s.selectNewNumericLocalpartStmt stmt := s.selectNewNumericLocalpartStmt
if txn != nil { if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt) stmt = sqlutil.TxStmt(txn, stmt)
} }
err = stmt.QueryRowContext(ctx).Scan(&id) err = stmt.QueryRowContext(ctx, serverName).Scan(&id)
return id + 1, err return id + 1, err
} }

View file

@ -68,9 +68,10 @@ const (
// GetAccountByPassword returns the account associated with the given localpart and password. // GetAccountByPassword returns the account associated with the given localpart and password.
// Returns sql.ErrNoRows if no account exists which matches the given localpart. // Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByPassword( func (d *Database) GetAccountByPassword(
ctx context.Context, localpart, plaintextPassword string, ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
plaintextPassword string,
) (*api.Account, error) { ) (*api.Account, error) {
hash, err := d.Accounts.SelectPasswordHash(ctx, localpart) hash, err := d.Accounts.SelectPasswordHash(ctx, localpart, serverName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -80,7 +81,7 @@ func (d *Database) GetAccountByPassword(
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil { if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
return nil, err return nil, err
} }
return d.Accounts.SelectAccountByLocalpart(ctx, localpart) return d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName)
} }
// GetProfileByLocalpart returns the profile associated with the given localpart. // GetProfileByLocalpart returns the profile associated with the given localpart.
@ -117,14 +118,15 @@ func (d *Database) SetDisplayName(
// SetPassword sets the account password to the given hash. // SetPassword sets the account password to the given hash.
func (d *Database) SetPassword( func (d *Database) SetPassword(
ctx context.Context, localpart, plaintextPassword string, ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
plaintextPassword string,
) error { ) error {
hash, err := d.hashPassword(plaintextPassword) hash, err := d.hashPassword(plaintextPassword)
if err != nil { if err != nil {
return err return err
} }
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
return d.Accounts.UpdatePassword(ctx, localpart, hash) return d.Accounts.UpdatePassword(ctx, localpart, serverName, hash)
}) })
} }
@ -139,7 +141,7 @@ func (d *Database) CreateAccount(
// For guest accounts, we create a new numeric local part // For guest accounts, we create a new numeric local part
if accountType == api.AccountTypeGuest { if accountType == api.AccountTypeGuest {
var numLocalpart int64 var numLocalpart int64
numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn) numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn, serverName)
if err != nil { if err != nil {
return err return err
} }
@ -170,13 +172,13 @@ func (d *Database) createAccount(
return nil, err return nil, err
} }
} }
if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil { if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, serverName, hash, appserviceID, accountType); err != nil {
return nil, sqlutil.ErrUserExists return nil, sqlutil.ErrUserExists
} }
if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil { if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil {
return nil, err return nil, err
} }
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName) pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName)
prbs, err := json.Marshal(pushRuleSets) prbs, err := json.Marshal(pushRuleSets)
if err != nil { if err != nil {
return nil, err return nil, err
@ -262,9 +264,9 @@ func (d *Database) GetAccountDataByType(
// GetNewNumericLocalpart generates and returns a new unused numeric localpart // GetNewNumericLocalpart generates and returns a new unused numeric localpart
func (d *Database) GetNewNumericLocalpart( func (d *Database) GetNewNumericLocalpart(
ctx context.Context, ctx context.Context, serverName gomatrixserverlib.ServerName,
) (int64, error) { ) (int64, error) {
return d.Accounts.SelectNewNumericLocalpart(ctx, nil) return d.Accounts.SelectNewNumericLocalpart(ctx, nil, serverName)
} }
func (d *Database) hashPassword(plaintext string) (hash string, err error) { func (d *Database) hashPassword(plaintext string) (hash string, err error) {
@ -335,8 +337,8 @@ func (d *Database) GetThreePIDsForLocalpart(
// CheckAccountAvailability checks if the username/localpart is already present // CheckAccountAvailability checks if the username/localpart is already present
// in the database. // in the database.
// If the DB returns sql.ErrNoRows the Localpart isn't taken. // If the DB returns sql.ErrNoRows the Localpart isn't taken.
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) { func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (bool, error) {
_, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart) _, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return true, nil return true, nil
} }
@ -346,12 +348,12 @@ func (d *Database) CheckAccountAvailability(ctx context.Context, localpart strin
// GetAccountByLocalpart returns the account associated with the given localpart. // GetAccountByLocalpart returns the account associated with the given localpart.
// This function assumes the request is authenticated or the account data is used only internally. // This function assumes the request is authenticated or the account data is used only internally.
// Returns sql.ErrNoRows if no account exists which matches the given localpart. // Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (*api.Account, error) { ) (*api.Account, error) {
// try to get the account with lowercase localpart (majority) // try to get the account with lowercase localpart (majority)
acc, err := d.Accounts.SelectAccountByLocalpart(ctx, strings.ToLower(localpart)) acc, err := d.Accounts.SelectAccountByLocalpart(ctx, strings.ToLower(localpart), serverName)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
acc, err = d.Accounts.SelectAccountByLocalpart(ctx, localpart) // try with localpart as passed by the request acc, err = d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName) // try with localpart as passed by the request
} }
return acc, err return acc, err
} }
@ -364,9 +366,9 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi
} }
// DeactivateAccount deactivates the user's account, removing all ability for the user to login again. // DeactivateAccount deactivates the user's account, removing all ability for the user to login again.
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) { func (d *Database) DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error) {
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
return d.Accounts.DeactivateAccount(ctx, localpart) return d.Accounts.DeactivateAccount(ctx, localpart, serverName)
}) })
} }

View file

@ -52,22 +52,22 @@ CREATE TABLE IF NOT EXISTS userapi_accounts (
` `
const insertAccountSQL = "" + const insertAccountSQL = "" +
"INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" "INSERT INTO userapi_accounts(localpart, server_name, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5, $6)"
const updatePasswordSQL = "" + const updatePasswordSQL = "" +
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2" "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2 AND server_name = $3"
const deactivateAccountSQL = "" + const deactivateAccountSQL = "" +
"UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1" "UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1 AND server_name = $2"
const selectAccountByLocalpartSQL = "" + const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1" "SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2"
const selectPasswordHashSQL = "" + const selectPasswordHashSQL = "" +
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = 0" "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = 0"
const selectNewNumericLocalpartSQL = "" + const selectNewNumericLocalpartSQL = "" +
"SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0" "SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0 AND server_name = $1"
type accountsStatements struct { type accountsStatements struct {
db *sql.DB db *sql.DB
@ -120,16 +120,17 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
// this account will be passwordless. Returns an error if this account already exists. Returns the account // this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success. // on success.
func (s *accountsStatements) InsertAccount( func (s *accountsStatements) InsertAccount(
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType, ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName,
hash, appserviceID string, accountType api.AccountType,
) (*api.Account, error) { ) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
stmt := s.insertAccountStmt stmt := s.insertAccountStmt
var err error var err error
if accountType != api.AccountTypeAppService { if accountType != api.AccountTypeAppService {
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType) _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, serverName, createdTimeMS, hash, nil, accountType)
} else { } else {
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType) _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType)
} }
if err != nil { if err != nil {
return nil, err return nil, err
@ -145,34 +146,35 @@ func (s *accountsStatements) InsertAccount(
} }
func (s *accountsStatements) UpdatePassword( func (s *accountsStatements) UpdatePassword(
ctx context.Context, localpart, passwordHash string, ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
passwordHash string,
) (err error) { ) (err error) {
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart, serverName)
return return
} }
func (s *accountsStatements) DeactivateAccount( func (s *accountsStatements) DeactivateAccount(
ctx context.Context, localpart string, ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (err error) { ) (err error) {
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart) _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart, serverName)
return return
} }
func (s *accountsStatements) SelectPasswordHash( func (s *accountsStatements) SelectPasswordHash(
ctx context.Context, localpart string, ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (hash string, err error) { ) (hash string, err error) {
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash) err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart, serverName).Scan(&hash)
return return
} }
func (s *accountsStatements) SelectAccountByLocalpart( func (s *accountsStatements) SelectAccountByLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (*api.Account, error) { ) (*api.Account, error) {
var appserviceIDPtr sql.NullString var appserviceIDPtr sql.NullString
var acc api.Account var acc api.Account
stmt := s.selectAccountByLocalpartStmt stmt := s.selectAccountByLocalpartStmt
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType) err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType)
if err != nil { if err != nil {
if err != sql.ErrNoRows { if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve user from the db") log.WithError(err).Error("Unable to retrieve user from the db")
@ -190,13 +192,13 @@ func (s *accountsStatements) SelectAccountByLocalpart(
} }
func (s *accountsStatements) SelectNewNumericLocalpart( func (s *accountsStatements) SelectNewNumericLocalpart(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (id int64, err error) { ) (id int64, err error) {
stmt := s.selectNewNumericLocalpartStmt stmt := s.selectNewNumericLocalpartStmt
if txn != nil { if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt) stmt = sqlutil.TxStmt(txn, stmt)
} }
err = stmt.QueryRowContext(ctx).Scan(&id) err = stmt.QueryRowContext(ctx, serverName).Scan(&id)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return 1, nil return 1, nil
} }

View file

@ -88,47 +88,47 @@ func Test_Accounts(t *testing.T) {
assert.NoError(t, err, "failed to create account") assert.NoError(t, err, "failed to create account")
// verify the newly create account is the same as returned by CreateAccount // verify the newly create account is the same as returned by CreateAccount
var accGet *api.Account var accGet *api.Account
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "testing") accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "testing")
assert.NoError(t, err, "failed to get account by password") assert.NoError(t, err, "failed to get account by password")
assert.Equal(t, accAlice, accGet) assert.Equal(t, accAlice, accGet)
accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart) accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "failed to get account by localpart") assert.NoError(t, err, "failed to get account by localpart")
assert.Equal(t, accAlice, accGet) assert.Equal(t, accAlice, accGet)
// check account availability // check account availability
available, err := db.CheckAccountAvailability(ctx, aliceLocalpart) available, err := db.CheckAccountAvailability(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "failed to checkout account availability") assert.NoError(t, err, "failed to checkout account availability")
assert.Equal(t, false, available) assert.Equal(t, false, available)
available, err = db.CheckAccountAvailability(ctx, "unusedname") available, err = db.CheckAccountAvailability(ctx, "unusedname", aliceDomain)
assert.NoError(t, err, "failed to checkout account availability") assert.NoError(t, err, "failed to checkout account availability")
assert.Equal(t, true, available) assert.Equal(t, true, available)
// get guest account numeric aliceLocalpart // get guest account numeric aliceLocalpart
first, err := db.GetNewNumericLocalpart(ctx) first, err := db.GetNewNumericLocalpart(ctx, aliceDomain)
assert.NoError(t, err, "failed to get new numeric localpart") assert.NoError(t, err, "failed to get new numeric localpart")
// Create a new account to verify the numeric localpart is updated // Create a new account to verify the numeric localpart is updated
_, err = db.CreateAccount(ctx, "", aliceDomain, "testing", "", api.AccountTypeGuest) _, err = db.CreateAccount(ctx, "", aliceDomain, "testing", "", api.AccountTypeGuest)
assert.NoError(t, err, "failed to create account") assert.NoError(t, err, "failed to create account")
second, err := db.GetNewNumericLocalpart(ctx) second, err := db.GetNewNumericLocalpart(ctx, aliceDomain)
assert.NoError(t, err) assert.NoError(t, err)
assert.Greater(t, second, first) assert.Greater(t, second, first)
// update password for alice // update password for alice
err = db.SetPassword(ctx, aliceLocalpart, "newPassword") err = db.SetPassword(ctx, aliceLocalpart, aliceDomain, "newPassword")
assert.NoError(t, err, "failed to update password") assert.NoError(t, err, "failed to update password")
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword") accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "newPassword")
assert.NoError(t, err, "failed to get account by new password") assert.NoError(t, err, "failed to get account by new password")
assert.Equal(t, accAlice, accGet) assert.Equal(t, accAlice, accGet)
// deactivate account // deactivate account
err = db.DeactivateAccount(ctx, aliceLocalpart) err = db.DeactivateAccount(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "failed to deactivate account") assert.NoError(t, err, "failed to deactivate account")
// This should fail now, as the account is deactivated // This should fail now, as the account is deactivated
_, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword") _, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "newPassword")
assert.Error(t, err, "expected an error, got none") assert.Error(t, err, "expected an error, got none")
_, err = db.GetAccountByLocalpart(ctx, "unusename") _, err = db.GetAccountByLocalpart(ctx, "unusename", aliceDomain)
assert.Error(t, err, "expected an error for non existent localpart") assert.Error(t, err, "expected an error for non existent localpart")
// create an empty localpart; this should never happen, but is required to test getting a numeric localpart // create an empty localpart; this should never happen, but is required to test getting a numeric localpart

View file

@ -34,12 +34,12 @@ type AccountDataTable interface {
} }
type AccountsTable interface { type AccountsTable interface {
InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType) (*api.Account, error) InsertAccount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, hash, appserviceID string, accountType api.AccountType) (*api.Account, error)
UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error) UpdatePassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, passwordHash string) (err error)
DeactivateAccount(ctx context.Context, localpart string) (err error) DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error)
SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error) SelectPasswordHash(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (hash string, err error)
SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) SelectAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*api.Account, error)
SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error) SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (id int64, err error)
} }
type DevicesTable interface { type DevicesTable interface {

View file

@ -89,7 +89,7 @@ func mustMakeAccountAndDevice(
appServiceID = util.RandomString(16) appServiceID = util.RandomString(16)
} }
_, err := accDB.InsertAccount(ctx, nil, localpart, "", appServiceID, accType) _, err := accDB.InsertAccount(ctx, nil, localpart, "localhost", "", appServiceID, accType)
if err != nil { if err != nil {
t.Fatalf("unable to create account: %v", err) t.Fatalf("unable to create account: %v", err)
} }