Add policy_version to insertAccount statement
This commit is contained in:
parent
d19518fca5
commit
f8bebe5e5a
|
@ -247,7 +247,7 @@ type QuerySearchProfilesResponse struct {
|
|||
type PerformAccountCreationRequest struct {
|
||||
AccountType AccountType // Required: whether this is a guest or user account
|
||||
Localpart string // Required: The localpart for this account. Ignored if account type is guest.
|
||||
|
||||
PolicyVersion string // optional: the privacy policy this account has accepted
|
||||
AppServiceID string // optional: the application service ID (not user ID) creating this account, if any.
|
||||
Password string // optional: if missing then this account will be a passwordless account
|
||||
OnConflict Conflict
|
||||
|
|
|
@ -67,7 +67,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
|
|||
res.Account = acc
|
||||
return nil
|
||||
}
|
||||
acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID)
|
||||
acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.PolicyVersion)
|
||||
if err != nil {
|
||||
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
|
||||
switch req.OnConflict {
|
||||
|
|
|
@ -32,7 +32,7 @@ type Database interface {
|
|||
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
|
||||
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
||||
// account already exists, it will return nil, ErrUserExists.
|
||||
CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID string) (*api.Account, error)
|
||||
CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID, policyVersion string) (*api.Account, error)
|
||||
CreateGuestAccount(ctx context.Context) (*api.Account, error)
|
||||
SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error
|
||||
GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
|
||||
|
|
|
@ -50,7 +50,7 @@ CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
|
|||
`
|
||||
|
||||
const insertAccountSQL = "" +
|
||||
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)"
|
||||
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, policy_version) VALUES ($1, $2, $3, $4, $5)"
|
||||
|
||||
const updatePasswordSQL = "" +
|
||||
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
|
||||
|
@ -113,16 +113,16 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
|
|||
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
||||
// on success.
|
||||
func (s *accountsStatements) insertAccount(
|
||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string,
|
||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID, policyVersion string,
|
||||
) (*api.Account, error) {
|
||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
|
||||
|
||||
var err error
|
||||
if appserviceID == "" {
|
||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil)
|
||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, "")
|
||||
} else {
|
||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
|
||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, policyVersion)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -166,7 +166,7 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er
|
|||
return err
|
||||
}
|
||||
localpart := strconv.FormatInt(numLocalpart, 10)
|
||||
acc, err = d.createAccount(ctx, txn, localpart, "", "")
|
||||
acc, err = d.createAccount(ctx, txn, localpart, "", "", "")
|
||||
return err
|
||||
})
|
||||
return acc, err
|
||||
|
@ -176,17 +176,17 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er
|
|||
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
||||
// account already exists, it will return nil, sqlutil.ErrUserExists.
|
||||
func (d *Database) CreateAccount(
|
||||
ctx context.Context, localpart, plaintextPassword, appserviceID string,
|
||||
ctx context.Context, localpart, plaintextPassword, appserviceID, policyVersion string,
|
||||
) (acc *api.Account, err error) {
|
||||
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID)
|
||||
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, policyVersion)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Database) createAccount(
|
||||
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string,
|
||||
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID, policyVersion string,
|
||||
) (*api.Account, error) {
|
||||
var account *api.Account
|
||||
var err error
|
||||
|
@ -198,7 +198,7 @@ func (d *Database) createAccount(
|
|||
return nil, err
|
||||
}
|
||||
}
|
||||
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID); err != nil {
|
||||
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID, policyVersion); err != nil {
|
||||
if sqlutil.IsUniqueConstraintViolationErr(err) {
|
||||
return nil, sqlutil.ErrUserExists
|
||||
}
|
||||
|
|
|
@ -48,7 +48,7 @@ CREATE TABLE IF NOT EXISTS account_accounts (
|
|||
`
|
||||
|
||||
const insertAccountSQL = "" +
|
||||
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)"
|
||||
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, policy_version) VALUES ($1, $2, $3, $4, $5)"
|
||||
|
||||
const updatePasswordSQL = "" +
|
||||
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
|
||||
|
@ -113,16 +113,16 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
|
|||
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
||||
// on success.
|
||||
func (s *accountsStatements) insertAccount(
|
||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string,
|
||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID, policyVersion string,
|
||||
) (*api.Account, error) {
|
||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
stmt := s.insertAccountStmt
|
||||
|
||||
var err error
|
||||
if appserviceID == "" {
|
||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil)
|
||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, "")
|
||||
} else {
|
||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
|
||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, policyVersion)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -192,7 +192,7 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er
|
|||
return err
|
||||
}
|
||||
localpart := strconv.FormatInt(numLocalpart, 10)
|
||||
acc, err = d.createAccount(ctx, txn, localpart, "", "")
|
||||
acc, err = d.createAccount(ctx, txn, localpart, "", "", "")
|
||||
return err
|
||||
})
|
||||
return acc, err
|
||||
|
@ -202,7 +202,7 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er
|
|||
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
||||
// account already exists, it will return nil, ErrUserExists.
|
||||
func (d *Database) CreateAccount(
|
||||
ctx context.Context, localpart, plaintextPassword, appserviceID string,
|
||||
ctx context.Context, localpart, plaintextPassword, appserviceID, policyVersion string,
|
||||
) (acc *api.Account, err error) {
|
||||
// Create one account at a time else we can get 'database is locked'.
|
||||
d.profilesMu.Lock()
|
||||
|
@ -212,7 +212,7 @@ func (d *Database) CreateAccount(
|
|||
defer d.accountDatasMu.Unlock()
|
||||
defer d.accountsMu.Unlock()
|
||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID)
|
||||
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, policyVersion)
|
||||
return err
|
||||
})
|
||||
return
|
||||
|
@ -221,7 +221,7 @@ func (d *Database) CreateAccount(
|
|||
// WARNING! This function assumes that the relevant mutexes have already
|
||||
// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount).
|
||||
func (d *Database) createAccount(
|
||||
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string,
|
||||
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID, policyVersion string,
|
||||
) (*api.Account, error) {
|
||||
var err error
|
||||
var account *api.Account
|
||||
|
@ -233,7 +233,7 @@ func (d *Database) createAccount(
|
|||
return nil, err
|
||||
}
|
||||
}
|
||||
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID); err != nil {
|
||||
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID, policyVersion); err != nil {
|
||||
return nil, sqlutil.ErrUserExists
|
||||
}
|
||||
if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil {
|
||||
|
|
|
@ -73,7 +73,7 @@ func TestQueryProfile(t *testing.T) {
|
|||
aliceAvatarURL := "mxc://example.com/alice"
|
||||
aliceDisplayName := "Alice"
|
||||
userAPI, accountDB := MustMakeInternalAPI(t, apiTestOpts{})
|
||||
_, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "")
|
||||
_, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to make account: %s", err)
|
||||
}
|
||||
|
@ -151,7 +151,7 @@ func TestLoginToken(t *testing.T) {
|
|||
t.Run("tokenLoginFlow", func(t *testing.T) {
|
||||
userAPI, accountDB := MustMakeInternalAPI(t, apiTestOpts{})
|
||||
|
||||
_, err := accountDB.CreateAccount(ctx, "auser", "apassword", "")
|
||||
_, err := accountDB.CreateAccount(ctx, "auser", "apassword", "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to make account: %s", err)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue