Account tables
This commit is contained in:
parent
d1c61f5f95
commit
a0cc4c806c
|
@ -86,7 +86,7 @@ func Password(
|
|||
}
|
||||
|
||||
// Get the local part.
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
||||
return jsonerror.InternalServerError()
|
||||
|
@ -95,6 +95,7 @@ func Password(
|
|||
// Ask the user API to perform the password change.
|
||||
passwordReq := &api.PerformPasswordUpdateRequest{
|
||||
Localpart: localpart,
|
||||
ServerName: domain,
|
||||
Password: r.NewPassword,
|
||||
}
|
||||
passwordRes := &api.PerformPasswordUpdateResponse{}
|
||||
|
|
|
@ -588,12 +588,15 @@ func Register(
|
|||
}
|
||||
// Auto generate a numeric username if r.Username is empty
|
||||
if r.Username == "" {
|
||||
res := &userapi.QueryNumericLocalpartResponse{}
|
||||
if err := userAPI.QueryNumericLocalpart(req.Context(), res); err != nil {
|
||||
nreq := &userapi.QueryNumericLocalpartRequest{
|
||||
ServerName: cfg.Matrix.ServerName, // TODO: might not be right
|
||||
}
|
||||
nres := &userapi.QueryNumericLocalpartResponse{}
|
||||
if err := userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
r.Username = strconv.FormatInt(res.ID, 10)
|
||||
r.Username = strconv.FormatInt(nres.ID, 10)
|
||||
}
|
||||
|
||||
// Is this an appservice registration? It will be if the access
|
||||
|
|
|
@ -78,7 +78,7 @@ type ClientUserAPI interface {
|
|||
QueryAcccessTokenAPI
|
||||
LoginTokenInternalAPI
|
||||
UserLoginAPI
|
||||
QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error
|
||||
QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error
|
||||
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
|
||||
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
|
||||
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
||||
|
@ -336,6 +336,7 @@ type PerformAccountCreationResponse struct {
|
|||
// PerformAccountCreationRequest is the request for PerformAccountCreation
|
||||
type PerformPasswordUpdateRequest struct {
|
||||
Localpart string // Required: The localpart for this account.
|
||||
ServerName gomatrixserverlib.ServerName // Required: The domain for this account.
|
||||
Password string // Required: The new password to set.
|
||||
LogoutDevices bool // Optional: Whether to log out all user devices.
|
||||
}
|
||||
|
@ -601,12 +602,17 @@ type PerformSetAvatarURLResponse struct {
|
|||
Changed bool `json:"changed"`
|
||||
}
|
||||
|
||||
type QueryNumericLocalpartRequest struct {
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
type QueryNumericLocalpartResponse struct {
|
||||
ID int64
|
||||
}
|
||||
|
||||
type QueryAccountAvailabilityRequest struct {
|
||||
Localpart string
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
type QueryAccountAvailabilityResponse struct {
|
||||
|
@ -614,7 +620,9 @@ type QueryAccountAvailabilityResponse struct {
|
|||
}
|
||||
|
||||
type QueryAccountByPasswordRequest struct {
|
||||
Localpart, PlaintextPassword string
|
||||
Localpart string
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
PlaintextPassword string
|
||||
}
|
||||
|
||||
type QueryAccountByPasswordResponse struct {
|
||||
|
|
|
@ -156,8 +156,8 @@ func (t *UserInternalAPITrace) SetAvatarURL(ctx context.Context, req *PerformSet
|
|||
return err
|
||||
}
|
||||
|
||||
func (t *UserInternalAPITrace) QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error {
|
||||
err := t.Impl.QueryNumericLocalpart(ctx, res)
|
||||
func (t *UserInternalAPITrace) QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error {
|
||||
err := t.Impl.QueryNumericLocalpart(ctx, req, res)
|
||||
util.GetLogger(ctx).Infof("QueryNumericLocalpart req= res=%+v", js(res))
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -227,7 +227,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
|
|||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
|
||||
if err := a.DB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
|
||||
if err := a.DB.SetPassword(ctx, req.Localpart, req.ServerName, req.Password); err != nil {
|
||||
return err
|
||||
}
|
||||
if req.LogoutDevices {
|
||||
|
@ -527,7 +527,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
|
|||
if !a.Config.Matrix.IsLocalServerName(domain) {
|
||||
return nil
|
||||
}
|
||||
acc, err := a.DB.GetAccountByLocalpart(ctx, localPart)
|
||||
acc, err := a.DB.GetAccountByLocalpart(ctx, localPart, domain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -568,7 +568,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
|
|||
|
||||
if localpart != "" { // AS is masquerading as another user
|
||||
// Verify that the user is registered
|
||||
account, err := a.DB.GetAccountByLocalpart(ctx, localpart)
|
||||
account, err := a.DB.GetAccountByLocalpart(ctx, localpart, a.Cfg.Matrix.ServerName) // TODO: which server name here?
|
||||
// Verify that the account exists and either appServiceID matches or
|
||||
// it belongs to the appservice user namespaces
|
||||
if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) {
|
||||
|
@ -620,7 +620,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
|
|||
return err
|
||||
}
|
||||
|
||||
err := a.DB.DeactivateAccount(ctx, req.Localpart)
|
||||
err := a.DB.DeactivateAccount(ctx, req.Localpart, serverName)
|
||||
res.AccountDeactivated = err == nil
|
||||
return err
|
||||
}
|
||||
|
@ -883,8 +883,8 @@ func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetA
|
|||
return err
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error {
|
||||
id, err := a.DB.GetNewNumericLocalpart(ctx)
|
||||
func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, req *api.QueryNumericLocalpartRequest, res *api.QueryNumericLocalpartResponse) error {
|
||||
id, err := a.DB.GetNewNumericLocalpart(ctx, req.ServerName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -894,12 +894,12 @@ func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.Qu
|
|||
|
||||
func (a *UserInternalAPI) QueryAccountAvailability(ctx context.Context, req *api.QueryAccountAvailabilityRequest, res *api.QueryAccountAvailabilityResponse) error {
|
||||
var err error
|
||||
res.Available, err = a.DB.CheckAccountAvailability(ctx, req.Localpart)
|
||||
res.Available, err = a.DB.CheckAccountAvailability(ctx, req.Localpart, req.ServerName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error {
|
||||
acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.PlaintextPassword)
|
||||
acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.ServerName, req.PlaintextPassword)
|
||||
switch err {
|
||||
case sql.ErrNoRows: // user does not exist
|
||||
return nil
|
||||
|
|
|
@ -66,7 +66,7 @@ func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLog
|
|||
if !a.Config.Matrix.IsLocalServerName(domain) {
|
||||
return fmt.Errorf("cannot return a login token for a remote user (server name %s)", domain)
|
||||
}
|
||||
if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil {
|
||||
if _, err := a.DB.GetAccountByLocalpart(ctx, localpart, domain); err != nil {
|
||||
res.Data = nil
|
||||
if err == sql.ErrNoRows {
|
||||
return nil
|
||||
|
|
|
@ -355,11 +355,12 @@ func (h *httpUserInternalAPI) SetAvatarURL(
|
|||
|
||||
func (h *httpUserInternalAPI) QueryNumericLocalpart(
|
||||
ctx context.Context,
|
||||
request *api.QueryNumericLocalpartRequest,
|
||||
response *api.QueryNumericLocalpartResponse,
|
||||
) error {
|
||||
return httputil.CallInternalRPCAPI(
|
||||
"QueryNumericLocalpart", h.apiURL+QueryNumericLocalpartPath,
|
||||
h.httpClient, ctx, &struct{}{}, response,
|
||||
h.httpClient, ctx, request, response,
|
||||
)
|
||||
}
|
||||
|
||||
|
|
|
@ -15,12 +15,9 @@
|
|||
package inthttp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/matrix-org/dendrite/internal/httputil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
// nolint: gocyclo
|
||||
|
@ -152,15 +149,9 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
|
|||
httputil.MakeInternalRPCAPI("UserAPIPerformSetAvatarURL", s.SetAvatarURL),
|
||||
)
|
||||
|
||||
// TODO: Look at the shape of this
|
||||
internalAPIMux.Handle(QueryNumericLocalpartPath,
|
||||
httputil.MakeInternalAPI("UserAPIQueryNumericLocalpart", func(req *http.Request) util.JSONResponse {
|
||||
response := api.QueryNumericLocalpartResponse{}
|
||||
if err := s.QueryNumericLocalpart(req.Context(), &response); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
internalAPIMux.Handle(
|
||||
QueryNumericLocalpartPath,
|
||||
httputil.MakeInternalRPCAPI("UserAPIQueryNumericLocalpart", s.QueryNumericLocalpart),
|
||||
)
|
||||
|
||||
internalAPIMux.Handle(
|
||||
|
|
|
@ -40,12 +40,12 @@ type Account interface {
|
|||
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
||||
// account already exists, it will return nil, ErrUserExists.
|
||||
CreateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error)
|
||||
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
|
||||
GetNewNumericLocalpart(ctx context.Context) (int64, error)
|
||||
CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
|
||||
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
|
||||
DeactivateAccount(ctx context.Context, localpart string) (err error)
|
||||
SetPassword(ctx context.Context, localpart string, plaintextPassword string) error
|
||||
GetAccountByPassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string) (*api.Account, error)
|
||||
GetNewNumericLocalpart(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error)
|
||||
CheckAccountAvailability(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (bool, error)
|
||||
GetAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*api.Account, error)
|
||||
DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error)
|
||||
SetPassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string) error
|
||||
}
|
||||
|
||||
type AccountData interface {
|
||||
|
|
|
@ -52,22 +52,22 @@ CREATE TABLE IF NOT EXISTS userapi_accounts (
|
|||
`
|
||||
|
||||
const insertAccountSQL = "" +
|
||||
"INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
|
||||
"INSERT INTO userapi_accounts(localpart, server_name, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5, $6)"
|
||||
|
||||
const updatePasswordSQL = "" +
|
||||
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2"
|
||||
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2 AND server_name = $3"
|
||||
|
||||
const deactivateAccountSQL = "" +
|
||||
"UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1"
|
||||
"UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const selectAccountByLocalpartSQL = "" +
|
||||
"SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1"
|
||||
"SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const selectPasswordHashSQL = "" +
|
||||
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
|
||||
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = FALSE"
|
||||
|
||||
const selectNewNumericLocalpartSQL = "" +
|
||||
"SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$'"
|
||||
"SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$' AND server_name = $2"
|
||||
|
||||
type accountsStatements struct {
|
||||
insertAccountStmt *sql.Stmt
|
||||
|
@ -118,16 +118,18 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam
|
|||
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
||||
// on success.
|
||||
func (s *accountsStatements) InsertAccount(
|
||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
hash, appserviceID string, accountType api.AccountType,
|
||||
) (*api.Account, error) {
|
||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
|
||||
|
||||
var err error
|
||||
if accountType != api.AccountTypeAppService {
|
||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType)
|
||||
_, err = stmt.ExecContext(ctx, localpart, serverName, createdTimeMS, hash, nil, accountType)
|
||||
} else {
|
||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType)
|
||||
_, err = stmt.ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -143,34 +145,35 @@ func (s *accountsStatements) InsertAccount(
|
|||
}
|
||||
|
||||
func (s *accountsStatements) UpdatePassword(
|
||||
ctx context.Context, localpart, passwordHash string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
passwordHash string,
|
||||
) (err error) {
|
||||
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
|
||||
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart, serverName)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) DeactivateAccount(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (err error) {
|
||||
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
|
||||
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart, serverName)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) SelectPasswordHash(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (hash string, err error) {
|
||||
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
|
||||
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart, serverName).Scan(&hash)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) SelectAccountByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (*api.Account, error) {
|
||||
var appserviceIDPtr sql.NullString
|
||||
var acc api.Account
|
||||
|
||||
stmt := s.selectAccountByLocalpartStmt
|
||||
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType)
|
||||
err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType)
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
log.WithError(err).Error("Unable to retrieve user from the db")
|
||||
|
@ -188,12 +191,12 @@ func (s *accountsStatements) SelectAccountByLocalpart(
|
|||
}
|
||||
|
||||
func (s *accountsStatements) SelectNewNumericLocalpart(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) (id int64, err error) {
|
||||
stmt := s.selectNewNumericLocalpartStmt
|
||||
if txn != nil {
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
}
|
||||
err = stmt.QueryRowContext(ctx).Scan(&id)
|
||||
err = stmt.QueryRowContext(ctx, serverName).Scan(&id)
|
||||
return id + 1, err
|
||||
}
|
||||
|
|
|
@ -68,9 +68,10 @@ const (
|
|||
// GetAccountByPassword returns the account associated with the given localpart and password.
|
||||
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
||||
func (d *Database) GetAccountByPassword(
|
||||
ctx context.Context, localpart, plaintextPassword string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
plaintextPassword string,
|
||||
) (*api.Account, error) {
|
||||
hash, err := d.Accounts.SelectPasswordHash(ctx, localpart)
|
||||
hash, err := d.Accounts.SelectPasswordHash(ctx, localpart, serverName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -80,7 +81,7 @@ func (d *Database) GetAccountByPassword(
|
|||
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d.Accounts.SelectAccountByLocalpart(ctx, localpart)
|
||||
return d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName)
|
||||
}
|
||||
|
||||
// GetProfileByLocalpart returns the profile associated with the given localpart.
|
||||
|
@ -117,14 +118,15 @@ func (d *Database) SetDisplayName(
|
|||
|
||||
// SetPassword sets the account password to the given hash.
|
||||
func (d *Database) SetPassword(
|
||||
ctx context.Context, localpart, plaintextPassword string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
plaintextPassword string,
|
||||
) error {
|
||||
hash, err := d.hashPassword(plaintextPassword)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
|
||||
return d.Accounts.UpdatePassword(ctx, localpart, hash)
|
||||
return d.Accounts.UpdatePassword(ctx, localpart, serverName, hash)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -139,7 +141,7 @@ func (d *Database) CreateAccount(
|
|||
// For guest accounts, we create a new numeric local part
|
||||
if accountType == api.AccountTypeGuest {
|
||||
var numLocalpart int64
|
||||
numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn)
|
||||
numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn, serverName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -170,13 +172,13 @@ func (d *Database) createAccount(
|
|||
return nil, err
|
||||
}
|
||||
}
|
||||
if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil {
|
||||
if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, serverName, hash, appserviceID, accountType); err != nil {
|
||||
return nil, sqlutil.ErrUserExists
|
||||
}
|
||||
if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName)
|
||||
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName)
|
||||
prbs, err := json.Marshal(pushRuleSets)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -262,9 +264,9 @@ func (d *Database) GetAccountDataByType(
|
|||
|
||||
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
|
||||
func (d *Database) GetNewNumericLocalpart(
|
||||
ctx context.Context,
|
||||
ctx context.Context, serverName gomatrixserverlib.ServerName,
|
||||
) (int64, error) {
|
||||
return d.Accounts.SelectNewNumericLocalpart(ctx, nil)
|
||||
return d.Accounts.SelectNewNumericLocalpart(ctx, nil, serverName)
|
||||
}
|
||||
|
||||
func (d *Database) hashPassword(plaintext string) (hash string, err error) {
|
||||
|
@ -335,8 +337,8 @@ func (d *Database) GetThreePIDsForLocalpart(
|
|||
// CheckAccountAvailability checks if the username/localpart is already present
|
||||
// in the database.
|
||||
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
|
||||
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
|
||||
_, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart)
|
||||
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (bool, error) {
|
||||
_, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName)
|
||||
if err == sql.ErrNoRows {
|
||||
return true, nil
|
||||
}
|
||||
|
@ -346,12 +348,12 @@ func (d *Database) CheckAccountAvailability(ctx context.Context, localpart strin
|
|||
// GetAccountByLocalpart returns the account associated with the given localpart.
|
||||
// This function assumes the request is authenticated or the account data is used only internally.
|
||||
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
||||
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
|
||||
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (*api.Account, error) {
|
||||
// try to get the account with lowercase localpart (majority)
|
||||
acc, err := d.Accounts.SelectAccountByLocalpart(ctx, strings.ToLower(localpart))
|
||||
acc, err := d.Accounts.SelectAccountByLocalpart(ctx, strings.ToLower(localpart), serverName)
|
||||
if err == sql.ErrNoRows {
|
||||
acc, err = d.Accounts.SelectAccountByLocalpart(ctx, localpart) // try with localpart as passed by the request
|
||||
acc, err = d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName) // try with localpart as passed by the request
|
||||
}
|
||||
return acc, err
|
||||
}
|
||||
|
@ -364,9 +366,9 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi
|
|||
}
|
||||
|
||||
// DeactivateAccount deactivates the user's account, removing all ability for the user to login again.
|
||||
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
|
||||
func (d *Database) DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error) {
|
||||
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
|
||||
return d.Accounts.DeactivateAccount(ctx, localpart)
|
||||
return d.Accounts.DeactivateAccount(ctx, localpart, serverName)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -52,22 +52,22 @@ CREATE TABLE IF NOT EXISTS userapi_accounts (
|
|||
`
|
||||
|
||||
const insertAccountSQL = "" +
|
||||
"INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
|
||||
"INSERT INTO userapi_accounts(localpart, server_name, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5, $6)"
|
||||
|
||||
const updatePasswordSQL = "" +
|
||||
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2"
|
||||
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2 AND server_name = $3"
|
||||
|
||||
const deactivateAccountSQL = "" +
|
||||
"UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1"
|
||||
"UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const selectAccountByLocalpartSQL = "" +
|
||||
"SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1"
|
||||
"SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const selectPasswordHashSQL = "" +
|
||||
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = 0"
|
||||
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = 0"
|
||||
|
||||
const selectNewNumericLocalpartSQL = "" +
|
||||
"SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0"
|
||||
"SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0 AND server_name = $1"
|
||||
|
||||
type accountsStatements struct {
|
||||
db *sql.DB
|
||||
|
@ -120,16 +120,17 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
|
|||
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
||||
// on success.
|
||||
func (s *accountsStatements) InsertAccount(
|
||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
|
||||
ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
hash, appserviceID string, accountType api.AccountType,
|
||||
) (*api.Account, error) {
|
||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
stmt := s.insertAccountStmt
|
||||
|
||||
var err error
|
||||
if accountType != api.AccountTypeAppService {
|
||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType)
|
||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, serverName, createdTimeMS, hash, nil, accountType)
|
||||
} else {
|
||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType)
|
||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -145,34 +146,35 @@ func (s *accountsStatements) InsertAccount(
|
|||
}
|
||||
|
||||
func (s *accountsStatements) UpdatePassword(
|
||||
ctx context.Context, localpart, passwordHash string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
passwordHash string,
|
||||
) (err error) {
|
||||
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
|
||||
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart, serverName)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) DeactivateAccount(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (err error) {
|
||||
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
|
||||
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart, serverName)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) SelectPasswordHash(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (hash string, err error) {
|
||||
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
|
||||
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart, serverName).Scan(&hash)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) SelectAccountByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (*api.Account, error) {
|
||||
var appserviceIDPtr sql.NullString
|
||||
var acc api.Account
|
||||
|
||||
stmt := s.selectAccountByLocalpartStmt
|
||||
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType)
|
||||
err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType)
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
log.WithError(err).Error("Unable to retrieve user from the db")
|
||||
|
@ -190,13 +192,13 @@ func (s *accountsStatements) SelectAccountByLocalpart(
|
|||
}
|
||||
|
||||
func (s *accountsStatements) SelectNewNumericLocalpart(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) (id int64, err error) {
|
||||
stmt := s.selectNewNumericLocalpartStmt
|
||||
if txn != nil {
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
}
|
||||
err = stmt.QueryRowContext(ctx).Scan(&id)
|
||||
err = stmt.QueryRowContext(ctx, serverName).Scan(&id)
|
||||
if err == sql.ErrNoRows {
|
||||
return 1, nil
|
||||
}
|
||||
|
|
|
@ -88,47 +88,47 @@ func Test_Accounts(t *testing.T) {
|
|||
assert.NoError(t, err, "failed to create account")
|
||||
// verify the newly create account is the same as returned by CreateAccount
|
||||
var accGet *api.Account
|
||||
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "testing")
|
||||
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "testing")
|
||||
assert.NoError(t, err, "failed to get account by password")
|
||||
assert.Equal(t, accAlice, accGet)
|
||||
accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart)
|
||||
accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart, aliceDomain)
|
||||
assert.NoError(t, err, "failed to get account by localpart")
|
||||
assert.Equal(t, accAlice, accGet)
|
||||
|
||||
// check account availability
|
||||
available, err := db.CheckAccountAvailability(ctx, aliceLocalpart)
|
||||
available, err := db.CheckAccountAvailability(ctx, aliceLocalpart, aliceDomain)
|
||||
assert.NoError(t, err, "failed to checkout account availability")
|
||||
assert.Equal(t, false, available)
|
||||
|
||||
available, err = db.CheckAccountAvailability(ctx, "unusedname")
|
||||
available, err = db.CheckAccountAvailability(ctx, "unusedname", aliceDomain)
|
||||
assert.NoError(t, err, "failed to checkout account availability")
|
||||
assert.Equal(t, true, available)
|
||||
|
||||
// get guest account numeric aliceLocalpart
|
||||
first, err := db.GetNewNumericLocalpart(ctx)
|
||||
first, err := db.GetNewNumericLocalpart(ctx, aliceDomain)
|
||||
assert.NoError(t, err, "failed to get new numeric localpart")
|
||||
// Create a new account to verify the numeric localpart is updated
|
||||
_, err = db.CreateAccount(ctx, "", aliceDomain, "testing", "", api.AccountTypeGuest)
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
second, err := db.GetNewNumericLocalpart(ctx)
|
||||
second, err := db.GetNewNumericLocalpart(ctx, aliceDomain)
|
||||
assert.NoError(t, err)
|
||||
assert.Greater(t, second, first)
|
||||
|
||||
// update password for alice
|
||||
err = db.SetPassword(ctx, aliceLocalpart, "newPassword")
|
||||
err = db.SetPassword(ctx, aliceLocalpart, aliceDomain, "newPassword")
|
||||
assert.NoError(t, err, "failed to update password")
|
||||
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword")
|
||||
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "newPassword")
|
||||
assert.NoError(t, err, "failed to get account by new password")
|
||||
assert.Equal(t, accAlice, accGet)
|
||||
|
||||
// deactivate account
|
||||
err = db.DeactivateAccount(ctx, aliceLocalpart)
|
||||
err = db.DeactivateAccount(ctx, aliceLocalpart, aliceDomain)
|
||||
assert.NoError(t, err, "failed to deactivate account")
|
||||
// This should fail now, as the account is deactivated
|
||||
_, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword")
|
||||
_, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "newPassword")
|
||||
assert.Error(t, err, "expected an error, got none")
|
||||
|
||||
_, err = db.GetAccountByLocalpart(ctx, "unusename")
|
||||
_, err = db.GetAccountByLocalpart(ctx, "unusename", aliceDomain)
|
||||
assert.Error(t, err, "expected an error for non existent localpart")
|
||||
|
||||
// create an empty localpart; this should never happen, but is required to test getting a numeric localpart
|
||||
|
|
|
@ -34,12 +34,12 @@ type AccountDataTable interface {
|
|||
}
|
||||
|
||||
type AccountsTable interface {
|
||||
InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType) (*api.Account, error)
|
||||
UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error)
|
||||
DeactivateAccount(ctx context.Context, localpart string) (err error)
|
||||
SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error)
|
||||
SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
|
||||
SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error)
|
||||
InsertAccount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, hash, appserviceID string, accountType api.AccountType) (*api.Account, error)
|
||||
UpdatePassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, passwordHash string) (err error)
|
||||
DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error)
|
||||
SelectPasswordHash(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (hash string, err error)
|
||||
SelectAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*api.Account, error)
|
||||
SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (id int64, err error)
|
||||
}
|
||||
|
||||
type DevicesTable interface {
|
||||
|
|
|
@ -89,7 +89,7 @@ func mustMakeAccountAndDevice(
|
|||
appServiceID = util.RandomString(16)
|
||||
}
|
||||
|
||||
_, err := accDB.InsertAccount(ctx, nil, localpart, "", appServiceID, accType)
|
||||
_, err := accDB.InsertAccount(ctx, nil, localpart, "localhost", "", appServiceID, accType)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create account: %v", err)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue