Account tables

This commit is contained in:
Neil Alexander 2022-11-04 12:56:33 +00:00
parent d1c61f5f95
commit a0cc4c806c
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
15 changed files with 126 additions and 115 deletions

View file

@ -86,7 +86,7 @@ func Password(
}
// Get the local part.
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
@ -94,8 +94,9 @@ func Password(
// Ask the user API to perform the password change.
passwordReq := &api.PerformPasswordUpdateRequest{
Localpart: localpart,
Password: r.NewPassword,
Localpart: localpart,
ServerName: domain,
Password: r.NewPassword,
}
passwordRes := &api.PerformPasswordUpdateResponse{}
if err := userAPI.PerformPasswordUpdate(req.Context(), passwordReq, passwordRes); err != nil {

View file

@ -588,12 +588,15 @@ func Register(
}
// Auto generate a numeric username if r.Username is empty
if r.Username == "" {
res := &userapi.QueryNumericLocalpartResponse{}
if err := userAPI.QueryNumericLocalpart(req.Context(), res); err != nil {
nreq := &userapi.QueryNumericLocalpartRequest{
ServerName: cfg.Matrix.ServerName, // TODO: might not be right
}
nres := &userapi.QueryNumericLocalpartResponse{}
if err := userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed")
return jsonerror.InternalServerError()
}
r.Username = strconv.FormatInt(res.ID, 10)
r.Username = strconv.FormatInt(nres.ID, 10)
}
// Is this an appservice registration? It will be if the access

View file

@ -78,7 +78,7 @@ type ClientUserAPI interface {
QueryAcccessTokenAPI
LoginTokenInternalAPI
UserLoginAPI
QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error
QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
@ -335,9 +335,10 @@ type PerformAccountCreationResponse struct {
// PerformAccountCreationRequest is the request for PerformAccountCreation
type PerformPasswordUpdateRequest struct {
Localpart string // Required: The localpart for this account.
Password string // Required: The new password to set.
LogoutDevices bool // Optional: Whether to log out all user devices.
Localpart string // Required: The localpart for this account.
ServerName gomatrixserverlib.ServerName // Required: The domain for this account.
Password string // Required: The new password to set.
LogoutDevices bool // Optional: Whether to log out all user devices.
}
// PerformAccountCreationResponse is the response for PerformAccountCreation
@ -601,12 +602,17 @@ type PerformSetAvatarURLResponse struct {
Changed bool `json:"changed"`
}
type QueryNumericLocalpartRequest struct {
ServerName gomatrixserverlib.ServerName
}
type QueryNumericLocalpartResponse struct {
ID int64
}
type QueryAccountAvailabilityRequest struct {
Localpart string
Localpart string
ServerName gomatrixserverlib.ServerName
}
type QueryAccountAvailabilityResponse struct {
@ -614,7 +620,9 @@ type QueryAccountAvailabilityResponse struct {
}
type QueryAccountByPasswordRequest struct {
Localpart, PlaintextPassword string
Localpart string
ServerName gomatrixserverlib.ServerName
PlaintextPassword string
}
type QueryAccountByPasswordResponse struct {

View file

@ -156,8 +156,8 @@ func (t *UserInternalAPITrace) SetAvatarURL(ctx context.Context, req *PerformSet
return err
}
func (t *UserInternalAPITrace) QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error {
err := t.Impl.QueryNumericLocalpart(ctx, res)
func (t *UserInternalAPITrace) QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error {
err := t.Impl.QueryNumericLocalpart(ctx, req, res)
util.GetLogger(ctx).Infof("QueryNumericLocalpart req= res=%+v", js(res))
return err
}

View file

@ -227,7 +227,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
}
func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
if err := a.DB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
if err := a.DB.SetPassword(ctx, req.Localpart, req.ServerName, req.Password); err != nil {
return err
}
if req.LogoutDevices {
@ -527,7 +527,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
if !a.Config.Matrix.IsLocalServerName(domain) {
return nil
}
acc, err := a.DB.GetAccountByLocalpart(ctx, localPart)
acc, err := a.DB.GetAccountByLocalpart(ctx, localPart, domain)
if err != nil {
return err
}
@ -568,7 +568,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
if localpart != "" { // AS is masquerading as another user
// Verify that the user is registered
account, err := a.DB.GetAccountByLocalpart(ctx, localpart)
account, err := a.DB.GetAccountByLocalpart(ctx, localpart, a.Cfg.Matrix.ServerName) // TODO: which server name here?
// Verify that the account exists and either appServiceID matches or
// it belongs to the appservice user namespaces
if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) {
@ -620,7 +620,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
return err
}
err := a.DB.DeactivateAccount(ctx, req.Localpart)
err := a.DB.DeactivateAccount(ctx, req.Localpart, serverName)
res.AccountDeactivated = err == nil
return err
}
@ -883,8 +883,8 @@ func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetA
return err
}
func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error {
id, err := a.DB.GetNewNumericLocalpart(ctx)
func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, req *api.QueryNumericLocalpartRequest, res *api.QueryNumericLocalpartResponse) error {
id, err := a.DB.GetNewNumericLocalpart(ctx, req.ServerName)
if err != nil {
return err
}
@ -894,12 +894,12 @@ func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.Qu
func (a *UserInternalAPI) QueryAccountAvailability(ctx context.Context, req *api.QueryAccountAvailabilityRequest, res *api.QueryAccountAvailabilityResponse) error {
var err error
res.Available, err = a.DB.CheckAccountAvailability(ctx, req.Localpart)
res.Available, err = a.DB.CheckAccountAvailability(ctx, req.Localpart, req.ServerName)
return err
}
func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error {
acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.PlaintextPassword)
acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.ServerName, req.PlaintextPassword)
switch err {
case sql.ErrNoRows: // user does not exist
return nil

View file

@ -66,7 +66,7 @@ func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLog
if !a.Config.Matrix.IsLocalServerName(domain) {
return fmt.Errorf("cannot return a login token for a remote user (server name %s)", domain)
}
if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil {
if _, err := a.DB.GetAccountByLocalpart(ctx, localpart, domain); err != nil {
res.Data = nil
if err == sql.ErrNoRows {
return nil

View file

@ -355,11 +355,12 @@ func (h *httpUserInternalAPI) SetAvatarURL(
func (h *httpUserInternalAPI) QueryNumericLocalpart(
ctx context.Context,
request *api.QueryNumericLocalpartRequest,
response *api.QueryNumericLocalpartResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryNumericLocalpart", h.apiURL+QueryNumericLocalpartPath,
h.httpClient, ctx, &struct{}{}, response,
h.httpClient, ctx, request, response,
)
}

View file

@ -15,12 +15,9 @@
package inthttp
import (
"net/http"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util"
)
// nolint: gocyclo
@ -152,15 +149,9 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
httputil.MakeInternalRPCAPI("UserAPIPerformSetAvatarURL", s.SetAvatarURL),
)
// TODO: Look at the shape of this
internalAPIMux.Handle(QueryNumericLocalpartPath,
httputil.MakeInternalAPI("UserAPIQueryNumericLocalpart", func(req *http.Request) util.JSONResponse {
response := api.QueryNumericLocalpartResponse{}
if err := s.QueryNumericLocalpart(req.Context(), &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
QueryNumericLocalpartPath,
httputil.MakeInternalRPCAPI("UserAPIQueryNumericLocalpart", s.QueryNumericLocalpart),
)
internalAPIMux.Handle(

View file

@ -40,12 +40,12 @@ type Account interface {
// for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, ErrUserExists.
CreateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error)
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
GetNewNumericLocalpart(ctx context.Context) (int64, error)
CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
DeactivateAccount(ctx context.Context, localpart string) (err error)
SetPassword(ctx context.Context, localpart string, plaintextPassword string) error
GetAccountByPassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string) (*api.Account, error)
GetNewNumericLocalpart(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error)
CheckAccountAvailability(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (bool, error)
GetAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*api.Account, error)
DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error)
SetPassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string) error
}
type AccountData interface {

View file

@ -52,22 +52,22 @@ CREATE TABLE IF NOT EXISTS userapi_accounts (
`
const insertAccountSQL = "" +
"INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
"INSERT INTO userapi_accounts(localpart, server_name, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5, $6)"
const updatePasswordSQL = "" +
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2"
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2 AND server_name = $3"
const deactivateAccountSQL = "" +
"UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1"
"UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1 AND server_name = $2"
const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1"
"SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2"
const selectPasswordHashSQL = "" +
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = FALSE"
const selectNewNumericLocalpartSQL = "" +
"SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$'"
"SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$' AND server_name = $2"
type accountsStatements struct {
insertAccountStmt *sql.Stmt
@ -118,16 +118,18 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam
// this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success.
func (s *accountsStatements) InsertAccount(
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
hash, appserviceID string, accountType api.AccountType,
) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
var err error
if accountType != api.AccountTypeAppService {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType)
_, err = stmt.ExecContext(ctx, localpart, serverName, createdTimeMS, hash, nil, accountType)
} else {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType)
_, err = stmt.ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType)
}
if err != nil {
return nil, err
@ -143,34 +145,35 @@ func (s *accountsStatements) InsertAccount(
}
func (s *accountsStatements) UpdatePassword(
ctx context.Context, localpart, passwordHash string,
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
passwordHash string,
) (err error) {
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart, serverName)
return
}
func (s *accountsStatements) DeactivateAccount(
ctx context.Context, localpart string,
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (err error) {
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart, serverName)
return
}
func (s *accountsStatements) SelectPasswordHash(
ctx context.Context, localpart string,
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (hash string, err error) {
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart, serverName).Scan(&hash)
return
}
func (s *accountsStatements) SelectAccountByLocalpart(
ctx context.Context, localpart string,
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (*api.Account, error) {
var appserviceIDPtr sql.NullString
var acc api.Account
stmt := s.selectAccountByLocalpartStmt
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType)
err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType)
if err != nil {
if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve user from the db")
@ -188,12 +191,12 @@ func (s *accountsStatements) SelectAccountByLocalpart(
}
func (s *accountsStatements) SelectNewNumericLocalpart(
ctx context.Context, txn *sql.Tx,
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (id int64, err error) {
stmt := s.selectNewNumericLocalpartStmt
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
err = stmt.QueryRowContext(ctx).Scan(&id)
err = stmt.QueryRowContext(ctx, serverName).Scan(&id)
return id + 1, err
}

View file

@ -68,9 +68,10 @@ const (
// GetAccountByPassword returns the account associated with the given localpart and password.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByPassword(
ctx context.Context, localpart, plaintextPassword string,
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
plaintextPassword string,
) (*api.Account, error) {
hash, err := d.Accounts.SelectPasswordHash(ctx, localpart)
hash, err := d.Accounts.SelectPasswordHash(ctx, localpart, serverName)
if err != nil {
return nil, err
}
@ -80,7 +81,7 @@ func (d *Database) GetAccountByPassword(
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
return nil, err
}
return d.Accounts.SelectAccountByLocalpart(ctx, localpart)
return d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName)
}
// GetProfileByLocalpart returns the profile associated with the given localpart.
@ -117,14 +118,15 @@ func (d *Database) SetDisplayName(
// SetPassword sets the account password to the given hash.
func (d *Database) SetPassword(
ctx context.Context, localpart, plaintextPassword string,
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
plaintextPassword string,
) error {
hash, err := d.hashPassword(plaintextPassword)
if err != nil {
return err
}
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
return d.Accounts.UpdatePassword(ctx, localpart, hash)
return d.Accounts.UpdatePassword(ctx, localpart, serverName, hash)
})
}
@ -139,7 +141,7 @@ func (d *Database) CreateAccount(
// For guest accounts, we create a new numeric local part
if accountType == api.AccountTypeGuest {
var numLocalpart int64
numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn)
numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn, serverName)
if err != nil {
return err
}
@ -170,13 +172,13 @@ func (d *Database) createAccount(
return nil, err
}
}
if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil {
if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, serverName, hash, appserviceID, accountType); err != nil {
return nil, sqlutil.ErrUserExists
}
if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil {
return nil, err
}
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName)
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName)
prbs, err := json.Marshal(pushRuleSets)
if err != nil {
return nil, err
@ -262,9 +264,9 @@ func (d *Database) GetAccountDataByType(
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
func (d *Database) GetNewNumericLocalpart(
ctx context.Context,
ctx context.Context, serverName gomatrixserverlib.ServerName,
) (int64, error) {
return d.Accounts.SelectNewNumericLocalpart(ctx, nil)
return d.Accounts.SelectNewNumericLocalpart(ctx, nil, serverName)
}
func (d *Database) hashPassword(plaintext string) (hash string, err error) {
@ -335,8 +337,8 @@ func (d *Database) GetThreePIDsForLocalpart(
// CheckAccountAvailability checks if the username/localpart is already present
// in the database.
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
_, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart)
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (bool, error) {
_, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName)
if err == sql.ErrNoRows {
return true, nil
}
@ -346,12 +348,12 @@ func (d *Database) CheckAccountAvailability(ctx context.Context, localpart strin
// GetAccountByLocalpart returns the account associated with the given localpart.
// This function assumes the request is authenticated or the account data is used only internally.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (*api.Account, error) {
// try to get the account with lowercase localpart (majority)
acc, err := d.Accounts.SelectAccountByLocalpart(ctx, strings.ToLower(localpart))
acc, err := d.Accounts.SelectAccountByLocalpart(ctx, strings.ToLower(localpart), serverName)
if err == sql.ErrNoRows {
acc, err = d.Accounts.SelectAccountByLocalpart(ctx, localpart) // try with localpart as passed by the request
acc, err = d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName) // try with localpart as passed by the request
}
return acc, err
}
@ -364,9 +366,9 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi
}
// DeactivateAccount deactivates the user's account, removing all ability for the user to login again.
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
func (d *Database) DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error) {
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
return d.Accounts.DeactivateAccount(ctx, localpart)
return d.Accounts.DeactivateAccount(ctx, localpart, serverName)
})
}

View file

@ -52,22 +52,22 @@ CREATE TABLE IF NOT EXISTS userapi_accounts (
`
const insertAccountSQL = "" +
"INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
"INSERT INTO userapi_accounts(localpart, server_name, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5, $6)"
const updatePasswordSQL = "" +
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2"
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2 AND server_name = $3"
const deactivateAccountSQL = "" +
"UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1"
"UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1 AND server_name = $2"
const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1"
"SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2"
const selectPasswordHashSQL = "" +
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = 0"
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = 0"
const selectNewNumericLocalpartSQL = "" +
"SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0"
"SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0 AND server_name = $1"
type accountsStatements struct {
db *sql.DB
@ -120,16 +120,17 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
// this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success.
func (s *accountsStatements) InsertAccount(
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName,
hash, appserviceID string, accountType api.AccountType,
) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
stmt := s.insertAccountStmt
var err error
if accountType != api.AccountTypeAppService {
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType)
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, serverName, createdTimeMS, hash, nil, accountType)
} else {
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType)
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType)
}
if err != nil {
return nil, err
@ -145,34 +146,35 @@ func (s *accountsStatements) InsertAccount(
}
func (s *accountsStatements) UpdatePassword(
ctx context.Context, localpart, passwordHash string,
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
passwordHash string,
) (err error) {
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart, serverName)
return
}
func (s *accountsStatements) DeactivateAccount(
ctx context.Context, localpart string,
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (err error) {
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart, serverName)
return
}
func (s *accountsStatements) SelectPasswordHash(
ctx context.Context, localpart string,
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (hash string, err error) {
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart, serverName).Scan(&hash)
return
}
func (s *accountsStatements) SelectAccountByLocalpart(
ctx context.Context, localpart string,
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (*api.Account, error) {
var appserviceIDPtr sql.NullString
var acc api.Account
stmt := s.selectAccountByLocalpartStmt
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType)
err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType)
if err != nil {
if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve user from the db")
@ -190,13 +192,13 @@ func (s *accountsStatements) SelectAccountByLocalpart(
}
func (s *accountsStatements) SelectNewNumericLocalpart(
ctx context.Context, txn *sql.Tx,
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (id int64, err error) {
stmt := s.selectNewNumericLocalpartStmt
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
err = stmt.QueryRowContext(ctx).Scan(&id)
err = stmt.QueryRowContext(ctx, serverName).Scan(&id)
if err == sql.ErrNoRows {
return 1, nil
}

View file

@ -88,47 +88,47 @@ func Test_Accounts(t *testing.T) {
assert.NoError(t, err, "failed to create account")
// verify the newly create account is the same as returned by CreateAccount
var accGet *api.Account
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "testing")
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "testing")
assert.NoError(t, err, "failed to get account by password")
assert.Equal(t, accAlice, accGet)
accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart)
accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "failed to get account by localpart")
assert.Equal(t, accAlice, accGet)
// check account availability
available, err := db.CheckAccountAvailability(ctx, aliceLocalpart)
available, err := db.CheckAccountAvailability(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "failed to checkout account availability")
assert.Equal(t, false, available)
available, err = db.CheckAccountAvailability(ctx, "unusedname")
available, err = db.CheckAccountAvailability(ctx, "unusedname", aliceDomain)
assert.NoError(t, err, "failed to checkout account availability")
assert.Equal(t, true, available)
// get guest account numeric aliceLocalpart
first, err := db.GetNewNumericLocalpart(ctx)
first, err := db.GetNewNumericLocalpart(ctx, aliceDomain)
assert.NoError(t, err, "failed to get new numeric localpart")
// Create a new account to verify the numeric localpart is updated
_, err = db.CreateAccount(ctx, "", aliceDomain, "testing", "", api.AccountTypeGuest)
assert.NoError(t, err, "failed to create account")
second, err := db.GetNewNumericLocalpart(ctx)
second, err := db.GetNewNumericLocalpart(ctx, aliceDomain)
assert.NoError(t, err)
assert.Greater(t, second, first)
// update password for alice
err = db.SetPassword(ctx, aliceLocalpart, "newPassword")
err = db.SetPassword(ctx, aliceLocalpart, aliceDomain, "newPassword")
assert.NoError(t, err, "failed to update password")
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword")
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "newPassword")
assert.NoError(t, err, "failed to get account by new password")
assert.Equal(t, accAlice, accGet)
// deactivate account
err = db.DeactivateAccount(ctx, aliceLocalpart)
err = db.DeactivateAccount(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "failed to deactivate account")
// This should fail now, as the account is deactivated
_, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword")
_, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "newPassword")
assert.Error(t, err, "expected an error, got none")
_, err = db.GetAccountByLocalpart(ctx, "unusename")
_, err = db.GetAccountByLocalpart(ctx, "unusename", aliceDomain)
assert.Error(t, err, "expected an error for non existent localpart")
// create an empty localpart; this should never happen, but is required to test getting a numeric localpart

View file

@ -34,12 +34,12 @@ type AccountDataTable interface {
}
type AccountsTable interface {
InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType) (*api.Account, error)
UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error)
DeactivateAccount(ctx context.Context, localpart string) (err error)
SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error)
SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error)
InsertAccount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, hash, appserviceID string, accountType api.AccountType) (*api.Account, error)
UpdatePassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, passwordHash string) (err error)
DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error)
SelectPasswordHash(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (hash string, err error)
SelectAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*api.Account, error)
SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (id int64, err error)
}
type DevicesTable interface {

View file

@ -89,7 +89,7 @@ func mustMakeAccountAndDevice(
appServiceID = util.RandomString(16)
}
_, err := accDB.InsertAccount(ctx, nil, localpart, "", appServiceID, accType)
_, err := accDB.InsertAccount(ctx, nil, localpart, "localhost", "", appServiceID, accType)
if err != nil {
t.Fatalf("unable to create account: %v", err)
}