Account tables
This commit is contained in:
parent
d1c61f5f95
commit
a0cc4c806c
|
@ -86,7 +86,7 @@ func Password(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the local part.
|
// Get the local part.
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
|
@ -95,6 +95,7 @@ func Password(
|
||||||
// Ask the user API to perform the password change.
|
// Ask the user API to perform the password change.
|
||||||
passwordReq := &api.PerformPasswordUpdateRequest{
|
passwordReq := &api.PerformPasswordUpdateRequest{
|
||||||
Localpart: localpart,
|
Localpart: localpart,
|
||||||
|
ServerName: domain,
|
||||||
Password: r.NewPassword,
|
Password: r.NewPassword,
|
||||||
}
|
}
|
||||||
passwordRes := &api.PerformPasswordUpdateResponse{}
|
passwordRes := &api.PerformPasswordUpdateResponse{}
|
||||||
|
|
|
@ -588,12 +588,15 @@ func Register(
|
||||||
}
|
}
|
||||||
// Auto generate a numeric username if r.Username is empty
|
// Auto generate a numeric username if r.Username is empty
|
||||||
if r.Username == "" {
|
if r.Username == "" {
|
||||||
res := &userapi.QueryNumericLocalpartResponse{}
|
nreq := &userapi.QueryNumericLocalpartRequest{
|
||||||
if err := userAPI.QueryNumericLocalpart(req.Context(), res); err != nil {
|
ServerName: cfg.Matrix.ServerName, // TODO: might not be right
|
||||||
|
}
|
||||||
|
nres := &userapi.QueryNumericLocalpartResponse{}
|
||||||
|
if err := userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed")
|
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
r.Username = strconv.FormatInt(res.ID, 10)
|
r.Username = strconv.FormatInt(nres.ID, 10)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Is this an appservice registration? It will be if the access
|
// Is this an appservice registration? It will be if the access
|
||||||
|
|
|
@ -78,7 +78,7 @@ type ClientUserAPI interface {
|
||||||
QueryAcccessTokenAPI
|
QueryAcccessTokenAPI
|
||||||
LoginTokenInternalAPI
|
LoginTokenInternalAPI
|
||||||
UserLoginAPI
|
UserLoginAPI
|
||||||
QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error
|
QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error
|
||||||
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
|
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
|
||||||
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
|
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
|
||||||
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
||||||
|
@ -336,6 +336,7 @@ type PerformAccountCreationResponse struct {
|
||||||
// PerformAccountCreationRequest is the request for PerformAccountCreation
|
// PerformAccountCreationRequest is the request for PerformAccountCreation
|
||||||
type PerformPasswordUpdateRequest struct {
|
type PerformPasswordUpdateRequest struct {
|
||||||
Localpart string // Required: The localpart for this account.
|
Localpart string // Required: The localpart for this account.
|
||||||
|
ServerName gomatrixserverlib.ServerName // Required: The domain for this account.
|
||||||
Password string // Required: The new password to set.
|
Password string // Required: The new password to set.
|
||||||
LogoutDevices bool // Optional: Whether to log out all user devices.
|
LogoutDevices bool // Optional: Whether to log out all user devices.
|
||||||
}
|
}
|
||||||
|
@ -601,12 +602,17 @@ type PerformSetAvatarURLResponse struct {
|
||||||
Changed bool `json:"changed"`
|
Changed bool `json:"changed"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type QueryNumericLocalpartRequest struct {
|
||||||
|
ServerName gomatrixserverlib.ServerName
|
||||||
|
}
|
||||||
|
|
||||||
type QueryNumericLocalpartResponse struct {
|
type QueryNumericLocalpartResponse struct {
|
||||||
ID int64
|
ID int64
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueryAccountAvailabilityRequest struct {
|
type QueryAccountAvailabilityRequest struct {
|
||||||
Localpart string
|
Localpart string
|
||||||
|
ServerName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueryAccountAvailabilityResponse struct {
|
type QueryAccountAvailabilityResponse struct {
|
||||||
|
@ -614,7 +620,9 @@ type QueryAccountAvailabilityResponse struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueryAccountByPasswordRequest struct {
|
type QueryAccountByPasswordRequest struct {
|
||||||
Localpart, PlaintextPassword string
|
Localpart string
|
||||||
|
ServerName gomatrixserverlib.ServerName
|
||||||
|
PlaintextPassword string
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueryAccountByPasswordResponse struct {
|
type QueryAccountByPasswordResponse struct {
|
||||||
|
|
|
@ -156,8 +156,8 @@ func (t *UserInternalAPITrace) SetAvatarURL(ctx context.Context, req *PerformSet
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *UserInternalAPITrace) QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error {
|
func (t *UserInternalAPITrace) QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error {
|
||||||
err := t.Impl.QueryNumericLocalpart(ctx, res)
|
err := t.Impl.QueryNumericLocalpart(ctx, req, res)
|
||||||
util.GetLogger(ctx).Infof("QueryNumericLocalpart req= res=%+v", js(res))
|
util.GetLogger(ctx).Infof("QueryNumericLocalpart req= res=%+v", js(res))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -227,7 +227,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
|
func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
|
||||||
if err := a.DB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
|
if err := a.DB.SetPassword(ctx, req.Localpart, req.ServerName, req.Password); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if req.LogoutDevices {
|
if req.LogoutDevices {
|
||||||
|
@ -527,7 +527,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
|
||||||
if !a.Config.Matrix.IsLocalServerName(domain) {
|
if !a.Config.Matrix.IsLocalServerName(domain) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
acc, err := a.DB.GetAccountByLocalpart(ctx, localPart)
|
acc, err := a.DB.GetAccountByLocalpart(ctx, localPart, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -568,7 +568,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
|
||||||
|
|
||||||
if localpart != "" { // AS is masquerading as another user
|
if localpart != "" { // AS is masquerading as another user
|
||||||
// Verify that the user is registered
|
// Verify that the user is registered
|
||||||
account, err := a.DB.GetAccountByLocalpart(ctx, localpart)
|
account, err := a.DB.GetAccountByLocalpart(ctx, localpart, a.Cfg.Matrix.ServerName) // TODO: which server name here?
|
||||||
// Verify that the account exists and either appServiceID matches or
|
// Verify that the account exists and either appServiceID matches or
|
||||||
// it belongs to the appservice user namespaces
|
// it belongs to the appservice user namespaces
|
||||||
if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) {
|
if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) {
|
||||||
|
@ -620,7 +620,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err := a.DB.DeactivateAccount(ctx, req.Localpart)
|
err := a.DB.DeactivateAccount(ctx, req.Localpart, serverName)
|
||||||
res.AccountDeactivated = err == nil
|
res.AccountDeactivated = err == nil
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -883,8 +883,8 @@ func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetA
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error {
|
func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, req *api.QueryNumericLocalpartRequest, res *api.QueryNumericLocalpartResponse) error {
|
||||||
id, err := a.DB.GetNewNumericLocalpart(ctx)
|
id, err := a.DB.GetNewNumericLocalpart(ctx, req.ServerName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -894,12 +894,12 @@ func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.Qu
|
||||||
|
|
||||||
func (a *UserInternalAPI) QueryAccountAvailability(ctx context.Context, req *api.QueryAccountAvailabilityRequest, res *api.QueryAccountAvailabilityResponse) error {
|
func (a *UserInternalAPI) QueryAccountAvailability(ctx context.Context, req *api.QueryAccountAvailabilityRequest, res *api.QueryAccountAvailabilityResponse) error {
|
||||||
var err error
|
var err error
|
||||||
res.Available, err = a.DB.CheckAccountAvailability(ctx, req.Localpart)
|
res.Available, err = a.DB.CheckAccountAvailability(ctx, req.Localpart, req.ServerName)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error {
|
func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error {
|
||||||
acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.PlaintextPassword)
|
acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.ServerName, req.PlaintextPassword)
|
||||||
switch err {
|
switch err {
|
||||||
case sql.ErrNoRows: // user does not exist
|
case sql.ErrNoRows: // user does not exist
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -66,7 +66,7 @@ func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLog
|
||||||
if !a.Config.Matrix.IsLocalServerName(domain) {
|
if !a.Config.Matrix.IsLocalServerName(domain) {
|
||||||
return fmt.Errorf("cannot return a login token for a remote user (server name %s)", domain)
|
return fmt.Errorf("cannot return a login token for a remote user (server name %s)", domain)
|
||||||
}
|
}
|
||||||
if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil {
|
if _, err := a.DB.GetAccountByLocalpart(ctx, localpart, domain); err != nil {
|
||||||
res.Data = nil
|
res.Data = nil
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -355,11 +355,12 @@ func (h *httpUserInternalAPI) SetAvatarURL(
|
||||||
|
|
||||||
func (h *httpUserInternalAPI) QueryNumericLocalpart(
|
func (h *httpUserInternalAPI) QueryNumericLocalpart(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
request *api.QueryNumericLocalpartRequest,
|
||||||
response *api.QueryNumericLocalpartResponse,
|
response *api.QueryNumericLocalpartResponse,
|
||||||
) error {
|
) error {
|
||||||
return httputil.CallInternalRPCAPI(
|
return httputil.CallInternalRPCAPI(
|
||||||
"QueryNumericLocalpart", h.apiURL+QueryNumericLocalpartPath,
|
"QueryNumericLocalpart", h.apiURL+QueryNumericLocalpartPath,
|
||||||
h.httpClient, ctx, &struct{}{}, response,
|
h.httpClient, ctx, request, response,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -15,12 +15,9 @@
|
||||||
package inthttp
|
package inthttp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/matrix-org/dendrite/internal/httputil"
|
"github.com/matrix-org/dendrite/internal/httputil"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/util"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// nolint: gocyclo
|
// nolint: gocyclo
|
||||||
|
@ -152,15 +149,9 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
|
||||||
httputil.MakeInternalRPCAPI("UserAPIPerformSetAvatarURL", s.SetAvatarURL),
|
httputil.MakeInternalRPCAPI("UserAPIPerformSetAvatarURL", s.SetAvatarURL),
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: Look at the shape of this
|
internalAPIMux.Handle(
|
||||||
internalAPIMux.Handle(QueryNumericLocalpartPath,
|
QueryNumericLocalpartPath,
|
||||||
httputil.MakeInternalAPI("UserAPIQueryNumericLocalpart", func(req *http.Request) util.JSONResponse {
|
httputil.MakeInternalRPCAPI("UserAPIQueryNumericLocalpart", s.QueryNumericLocalpart),
|
||||||
response := api.QueryNumericLocalpartResponse{}
|
|
||||||
if err := s.QueryNumericLocalpart(req.Context(), &response); err != nil {
|
|
||||||
return util.ErrorResponse(err)
|
|
||||||
}
|
|
||||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
|
||||||
}),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
internalAPIMux.Handle(
|
internalAPIMux.Handle(
|
||||||
|
|
|
@ -40,12 +40,12 @@ type Account interface {
|
||||||
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
||||||
// account already exists, it will return nil, ErrUserExists.
|
// account already exists, it will return nil, ErrUserExists.
|
||||||
CreateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error)
|
CreateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error)
|
||||||
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
|
GetAccountByPassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string) (*api.Account, error)
|
||||||
GetNewNumericLocalpart(ctx context.Context) (int64, error)
|
GetNewNumericLocalpart(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error)
|
||||||
CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
|
CheckAccountAvailability(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (bool, error)
|
||||||
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
|
GetAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*api.Account, error)
|
||||||
DeactivateAccount(ctx context.Context, localpart string) (err error)
|
DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error)
|
||||||
SetPassword(ctx context.Context, localpart string, plaintextPassword string) error
|
SetPassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type AccountData interface {
|
type AccountData interface {
|
||||||
|
|
|
@ -52,22 +52,22 @@ CREATE TABLE IF NOT EXISTS userapi_accounts (
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertAccountSQL = "" +
|
const insertAccountSQL = "" +
|
||||||
"INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
|
"INSERT INTO userapi_accounts(localpart, server_name, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5, $6)"
|
||||||
|
|
||||||
const updatePasswordSQL = "" +
|
const updatePasswordSQL = "" +
|
||||||
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2"
|
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2 AND server_name = $3"
|
||||||
|
|
||||||
const deactivateAccountSQL = "" +
|
const deactivateAccountSQL = "" +
|
||||||
"UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1"
|
"UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1 AND server_name = $2"
|
||||||
|
|
||||||
const selectAccountByLocalpartSQL = "" +
|
const selectAccountByLocalpartSQL = "" +
|
||||||
"SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1"
|
"SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2"
|
||||||
|
|
||||||
const selectPasswordHashSQL = "" +
|
const selectPasswordHashSQL = "" +
|
||||||
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
|
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = FALSE"
|
||||||
|
|
||||||
const selectNewNumericLocalpartSQL = "" +
|
const selectNewNumericLocalpartSQL = "" +
|
||||||
"SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$'"
|
"SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$' AND server_name = $2"
|
||||||
|
|
||||||
type accountsStatements struct {
|
type accountsStatements struct {
|
||||||
insertAccountStmt *sql.Stmt
|
insertAccountStmt *sql.Stmt
|
||||||
|
@ -118,16 +118,18 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam
|
||||||
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
||||||
// on success.
|
// on success.
|
||||||
func (s *accountsStatements) InsertAccount(
|
func (s *accountsStatements) InsertAccount(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
hash, appserviceID string, accountType api.AccountType,
|
||||||
) (*api.Account, error) {
|
) (*api.Account, error) {
|
||||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
|
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
if accountType != api.AccountTypeAppService {
|
if accountType != api.AccountTypeAppService {
|
||||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType)
|
_, err = stmt.ExecContext(ctx, localpart, serverName, createdTimeMS, hash, nil, accountType)
|
||||||
} else {
|
} else {
|
||||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType)
|
_, err = stmt.ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -143,34 +145,35 @@ func (s *accountsStatements) InsertAccount(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) UpdatePassword(
|
func (s *accountsStatements) UpdatePassword(
|
||||||
ctx context.Context, localpart, passwordHash string,
|
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
passwordHash string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
|
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart, serverName)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) DeactivateAccount(
|
func (s *accountsStatements) DeactivateAccount(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
|
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart, serverName)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) SelectPasswordHash(
|
func (s *accountsStatements) SelectPasswordHash(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (hash string, err error) {
|
) (hash string, err error) {
|
||||||
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
|
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart, serverName).Scan(&hash)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) SelectAccountByLocalpart(
|
func (s *accountsStatements) SelectAccountByLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (*api.Account, error) {
|
) (*api.Account, error) {
|
||||||
var appserviceIDPtr sql.NullString
|
var appserviceIDPtr sql.NullString
|
||||||
var acc api.Account
|
var acc api.Account
|
||||||
|
|
||||||
stmt := s.selectAccountByLocalpartStmt
|
stmt := s.selectAccountByLocalpartStmt
|
||||||
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType)
|
err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err != sql.ErrNoRows {
|
if err != sql.ErrNoRows {
|
||||||
log.WithError(err).Error("Unable to retrieve user from the db")
|
log.WithError(err).Error("Unable to retrieve user from the db")
|
||||||
|
@ -188,12 +191,12 @@ func (s *accountsStatements) SelectAccountByLocalpart(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) SelectNewNumericLocalpart(
|
func (s *accountsStatements) SelectNewNumericLocalpart(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||||
) (id int64, err error) {
|
) (id int64, err error) {
|
||||||
stmt := s.selectNewNumericLocalpartStmt
|
stmt := s.selectNewNumericLocalpartStmt
|
||||||
if txn != nil {
|
if txn != nil {
|
||||||
stmt = sqlutil.TxStmt(txn, stmt)
|
stmt = sqlutil.TxStmt(txn, stmt)
|
||||||
}
|
}
|
||||||
err = stmt.QueryRowContext(ctx).Scan(&id)
|
err = stmt.QueryRowContext(ctx, serverName).Scan(&id)
|
||||||
return id + 1, err
|
return id + 1, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -68,9 +68,10 @@ const (
|
||||||
// GetAccountByPassword returns the account associated with the given localpart and password.
|
// GetAccountByPassword returns the account associated with the given localpart and password.
|
||||||
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
||||||
func (d *Database) GetAccountByPassword(
|
func (d *Database) GetAccountByPassword(
|
||||||
ctx context.Context, localpart, plaintextPassword string,
|
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
plaintextPassword string,
|
||||||
) (*api.Account, error) {
|
) (*api.Account, error) {
|
||||||
hash, err := d.Accounts.SelectPasswordHash(ctx, localpart)
|
hash, err := d.Accounts.SelectPasswordHash(ctx, localpart, serverName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -80,7 +81,7 @@ func (d *Database) GetAccountByPassword(
|
||||||
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
|
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return d.Accounts.SelectAccountByLocalpart(ctx, localpart)
|
return d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetProfileByLocalpart returns the profile associated with the given localpart.
|
// GetProfileByLocalpart returns the profile associated with the given localpart.
|
||||||
|
@ -117,14 +118,15 @@ func (d *Database) SetDisplayName(
|
||||||
|
|
||||||
// SetPassword sets the account password to the given hash.
|
// SetPassword sets the account password to the given hash.
|
||||||
func (d *Database) SetPassword(
|
func (d *Database) SetPassword(
|
||||||
ctx context.Context, localpart, plaintextPassword string,
|
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
plaintextPassword string,
|
||||||
) error {
|
) error {
|
||||||
hash, err := d.hashPassword(plaintextPassword)
|
hash, err := d.hashPassword(plaintextPassword)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
|
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
|
||||||
return d.Accounts.UpdatePassword(ctx, localpart, hash)
|
return d.Accounts.UpdatePassword(ctx, localpart, serverName, hash)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -139,7 +141,7 @@ func (d *Database) CreateAccount(
|
||||||
// For guest accounts, we create a new numeric local part
|
// For guest accounts, we create a new numeric local part
|
||||||
if accountType == api.AccountTypeGuest {
|
if accountType == api.AccountTypeGuest {
|
||||||
var numLocalpart int64
|
var numLocalpart int64
|
||||||
numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn)
|
numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn, serverName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -170,13 +172,13 @@ func (d *Database) createAccount(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil {
|
if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, serverName, hash, appserviceID, accountType); err != nil {
|
||||||
return nil, sqlutil.ErrUserExists
|
return nil, sqlutil.ErrUserExists
|
||||||
}
|
}
|
||||||
if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil {
|
if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName)
|
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName)
|
||||||
prbs, err := json.Marshal(pushRuleSets)
|
prbs, err := json.Marshal(pushRuleSets)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -262,9 +264,9 @@ func (d *Database) GetAccountDataByType(
|
||||||
|
|
||||||
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
|
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
|
||||||
func (d *Database) GetNewNumericLocalpart(
|
func (d *Database) GetNewNumericLocalpart(
|
||||||
ctx context.Context,
|
ctx context.Context, serverName gomatrixserverlib.ServerName,
|
||||||
) (int64, error) {
|
) (int64, error) {
|
||||||
return d.Accounts.SelectNewNumericLocalpart(ctx, nil)
|
return d.Accounts.SelectNewNumericLocalpart(ctx, nil, serverName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) hashPassword(plaintext string) (hash string, err error) {
|
func (d *Database) hashPassword(plaintext string) (hash string, err error) {
|
||||||
|
@ -335,8 +337,8 @@ func (d *Database) GetThreePIDsForLocalpart(
|
||||||
// CheckAccountAvailability checks if the username/localpart is already present
|
// CheckAccountAvailability checks if the username/localpart is already present
|
||||||
// in the database.
|
// in the database.
|
||||||
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
|
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
|
||||||
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
|
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (bool, error) {
|
||||||
_, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart)
|
_, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
@ -346,12 +348,12 @@ func (d *Database) CheckAccountAvailability(ctx context.Context, localpart strin
|
||||||
// GetAccountByLocalpart returns the account associated with the given localpart.
|
// GetAccountByLocalpart returns the account associated with the given localpart.
|
||||||
// This function assumes the request is authenticated or the account data is used only internally.
|
// This function assumes the request is authenticated or the account data is used only internally.
|
||||||
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
||||||
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
|
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (*api.Account, error) {
|
) (*api.Account, error) {
|
||||||
// try to get the account with lowercase localpart (majority)
|
// try to get the account with lowercase localpart (majority)
|
||||||
acc, err := d.Accounts.SelectAccountByLocalpart(ctx, strings.ToLower(localpart))
|
acc, err := d.Accounts.SelectAccountByLocalpart(ctx, strings.ToLower(localpart), serverName)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
acc, err = d.Accounts.SelectAccountByLocalpart(ctx, localpart) // try with localpart as passed by the request
|
acc, err = d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName) // try with localpart as passed by the request
|
||||||
}
|
}
|
||||||
return acc, err
|
return acc, err
|
||||||
}
|
}
|
||||||
|
@ -364,9 +366,9 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeactivateAccount deactivates the user's account, removing all ability for the user to login again.
|
// DeactivateAccount deactivates the user's account, removing all ability for the user to login again.
|
||||||
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
|
func (d *Database) DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error) {
|
||||||
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
|
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
|
||||||
return d.Accounts.DeactivateAccount(ctx, localpart)
|
return d.Accounts.DeactivateAccount(ctx, localpart, serverName)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -52,22 +52,22 @@ CREATE TABLE IF NOT EXISTS userapi_accounts (
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertAccountSQL = "" +
|
const insertAccountSQL = "" +
|
||||||
"INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
|
"INSERT INTO userapi_accounts(localpart, server_name, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5, $6)"
|
||||||
|
|
||||||
const updatePasswordSQL = "" +
|
const updatePasswordSQL = "" +
|
||||||
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2"
|
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2 AND server_name = $3"
|
||||||
|
|
||||||
const deactivateAccountSQL = "" +
|
const deactivateAccountSQL = "" +
|
||||||
"UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1"
|
"UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1 AND server_name = $2"
|
||||||
|
|
||||||
const selectAccountByLocalpartSQL = "" +
|
const selectAccountByLocalpartSQL = "" +
|
||||||
"SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1"
|
"SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2"
|
||||||
|
|
||||||
const selectPasswordHashSQL = "" +
|
const selectPasswordHashSQL = "" +
|
||||||
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = 0"
|
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = 0"
|
||||||
|
|
||||||
const selectNewNumericLocalpartSQL = "" +
|
const selectNewNumericLocalpartSQL = "" +
|
||||||
"SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0"
|
"SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0 AND server_name = $1"
|
||||||
|
|
||||||
type accountsStatements struct {
|
type accountsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
@ -120,16 +120,17 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
|
||||||
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
||||||
// on success.
|
// on success.
|
||||||
func (s *accountsStatements) InsertAccount(
|
func (s *accountsStatements) InsertAccount(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
|
ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
hash, appserviceID string, accountType api.AccountType,
|
||||||
) (*api.Account, error) {
|
) (*api.Account, error) {
|
||||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||||
stmt := s.insertAccountStmt
|
stmt := s.insertAccountStmt
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
if accountType != api.AccountTypeAppService {
|
if accountType != api.AccountTypeAppService {
|
||||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType)
|
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, serverName, createdTimeMS, hash, nil, accountType)
|
||||||
} else {
|
} else {
|
||||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType)
|
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -145,34 +146,35 @@ func (s *accountsStatements) InsertAccount(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) UpdatePassword(
|
func (s *accountsStatements) UpdatePassword(
|
||||||
ctx context.Context, localpart, passwordHash string,
|
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
passwordHash string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
|
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart, serverName)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) DeactivateAccount(
|
func (s *accountsStatements) DeactivateAccount(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
|
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart, serverName)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) SelectPasswordHash(
|
func (s *accountsStatements) SelectPasswordHash(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (hash string, err error) {
|
) (hash string, err error) {
|
||||||
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
|
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart, serverName).Scan(&hash)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) SelectAccountByLocalpart(
|
func (s *accountsStatements) SelectAccountByLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (*api.Account, error) {
|
) (*api.Account, error) {
|
||||||
var appserviceIDPtr sql.NullString
|
var appserviceIDPtr sql.NullString
|
||||||
var acc api.Account
|
var acc api.Account
|
||||||
|
|
||||||
stmt := s.selectAccountByLocalpartStmt
|
stmt := s.selectAccountByLocalpartStmt
|
||||||
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType)
|
err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err != sql.ErrNoRows {
|
if err != sql.ErrNoRows {
|
||||||
log.WithError(err).Error("Unable to retrieve user from the db")
|
log.WithError(err).Error("Unable to retrieve user from the db")
|
||||||
|
@ -190,13 +192,13 @@ func (s *accountsStatements) SelectAccountByLocalpart(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) SelectNewNumericLocalpart(
|
func (s *accountsStatements) SelectNewNumericLocalpart(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||||
) (id int64, err error) {
|
) (id int64, err error) {
|
||||||
stmt := s.selectNewNumericLocalpartStmt
|
stmt := s.selectNewNumericLocalpartStmt
|
||||||
if txn != nil {
|
if txn != nil {
|
||||||
stmt = sqlutil.TxStmt(txn, stmt)
|
stmt = sqlutil.TxStmt(txn, stmt)
|
||||||
}
|
}
|
||||||
err = stmt.QueryRowContext(ctx).Scan(&id)
|
err = stmt.QueryRowContext(ctx, serverName).Scan(&id)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return 1, nil
|
return 1, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -88,47 +88,47 @@ func Test_Accounts(t *testing.T) {
|
||||||
assert.NoError(t, err, "failed to create account")
|
assert.NoError(t, err, "failed to create account")
|
||||||
// verify the newly create account is the same as returned by CreateAccount
|
// verify the newly create account is the same as returned by CreateAccount
|
||||||
var accGet *api.Account
|
var accGet *api.Account
|
||||||
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "testing")
|
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "testing")
|
||||||
assert.NoError(t, err, "failed to get account by password")
|
assert.NoError(t, err, "failed to get account by password")
|
||||||
assert.Equal(t, accAlice, accGet)
|
assert.Equal(t, accAlice, accGet)
|
||||||
accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart)
|
accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart, aliceDomain)
|
||||||
assert.NoError(t, err, "failed to get account by localpart")
|
assert.NoError(t, err, "failed to get account by localpart")
|
||||||
assert.Equal(t, accAlice, accGet)
|
assert.Equal(t, accAlice, accGet)
|
||||||
|
|
||||||
// check account availability
|
// check account availability
|
||||||
available, err := db.CheckAccountAvailability(ctx, aliceLocalpart)
|
available, err := db.CheckAccountAvailability(ctx, aliceLocalpart, aliceDomain)
|
||||||
assert.NoError(t, err, "failed to checkout account availability")
|
assert.NoError(t, err, "failed to checkout account availability")
|
||||||
assert.Equal(t, false, available)
|
assert.Equal(t, false, available)
|
||||||
|
|
||||||
available, err = db.CheckAccountAvailability(ctx, "unusedname")
|
available, err = db.CheckAccountAvailability(ctx, "unusedname", aliceDomain)
|
||||||
assert.NoError(t, err, "failed to checkout account availability")
|
assert.NoError(t, err, "failed to checkout account availability")
|
||||||
assert.Equal(t, true, available)
|
assert.Equal(t, true, available)
|
||||||
|
|
||||||
// get guest account numeric aliceLocalpart
|
// get guest account numeric aliceLocalpart
|
||||||
first, err := db.GetNewNumericLocalpart(ctx)
|
first, err := db.GetNewNumericLocalpart(ctx, aliceDomain)
|
||||||
assert.NoError(t, err, "failed to get new numeric localpart")
|
assert.NoError(t, err, "failed to get new numeric localpart")
|
||||||
// Create a new account to verify the numeric localpart is updated
|
// Create a new account to verify the numeric localpart is updated
|
||||||
_, err = db.CreateAccount(ctx, "", aliceDomain, "testing", "", api.AccountTypeGuest)
|
_, err = db.CreateAccount(ctx, "", aliceDomain, "testing", "", api.AccountTypeGuest)
|
||||||
assert.NoError(t, err, "failed to create account")
|
assert.NoError(t, err, "failed to create account")
|
||||||
second, err := db.GetNewNumericLocalpart(ctx)
|
second, err := db.GetNewNumericLocalpart(ctx, aliceDomain)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Greater(t, second, first)
|
assert.Greater(t, second, first)
|
||||||
|
|
||||||
// update password for alice
|
// update password for alice
|
||||||
err = db.SetPassword(ctx, aliceLocalpart, "newPassword")
|
err = db.SetPassword(ctx, aliceLocalpart, aliceDomain, "newPassword")
|
||||||
assert.NoError(t, err, "failed to update password")
|
assert.NoError(t, err, "failed to update password")
|
||||||
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword")
|
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "newPassword")
|
||||||
assert.NoError(t, err, "failed to get account by new password")
|
assert.NoError(t, err, "failed to get account by new password")
|
||||||
assert.Equal(t, accAlice, accGet)
|
assert.Equal(t, accAlice, accGet)
|
||||||
|
|
||||||
// deactivate account
|
// deactivate account
|
||||||
err = db.DeactivateAccount(ctx, aliceLocalpart)
|
err = db.DeactivateAccount(ctx, aliceLocalpart, aliceDomain)
|
||||||
assert.NoError(t, err, "failed to deactivate account")
|
assert.NoError(t, err, "failed to deactivate account")
|
||||||
// This should fail now, as the account is deactivated
|
// This should fail now, as the account is deactivated
|
||||||
_, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword")
|
_, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "newPassword")
|
||||||
assert.Error(t, err, "expected an error, got none")
|
assert.Error(t, err, "expected an error, got none")
|
||||||
|
|
||||||
_, err = db.GetAccountByLocalpart(ctx, "unusename")
|
_, err = db.GetAccountByLocalpart(ctx, "unusename", aliceDomain)
|
||||||
assert.Error(t, err, "expected an error for non existent localpart")
|
assert.Error(t, err, "expected an error for non existent localpart")
|
||||||
|
|
||||||
// create an empty localpart; this should never happen, but is required to test getting a numeric localpart
|
// create an empty localpart; this should never happen, but is required to test getting a numeric localpart
|
||||||
|
|
|
@ -34,12 +34,12 @@ type AccountDataTable interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type AccountsTable interface {
|
type AccountsTable interface {
|
||||||
InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType) (*api.Account, error)
|
InsertAccount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, hash, appserviceID string, accountType api.AccountType) (*api.Account, error)
|
||||||
UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error)
|
UpdatePassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, passwordHash string) (err error)
|
||||||
DeactivateAccount(ctx context.Context, localpart string) (err error)
|
DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error)
|
||||||
SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error)
|
SelectPasswordHash(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (hash string, err error)
|
||||||
SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
|
SelectAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*api.Account, error)
|
||||||
SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error)
|
SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (id int64, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type DevicesTable interface {
|
type DevicesTable interface {
|
||||||
|
|
|
@ -89,7 +89,7 @@ func mustMakeAccountAndDevice(
|
||||||
appServiceID = util.RandomString(16)
|
appServiceID = util.RandomString(16)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := accDB.InsertAccount(ctx, nil, localpart, "", appServiceID, accType)
|
_, err := accDB.InsertAccount(ctx, nil, localpart, "localhost", "", appServiceID, accType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to create account: %v", err)
|
t.Fatalf("unable to create account: %v", err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue