Add policy_version to insertAccount statement

This commit is contained in:
Till Faelligen 2022-02-15 14:10:50 +01:00
parent d19518fca5
commit f8bebe5e5a
8 changed files with 28 additions and 28 deletions

View file

@ -245,12 +245,12 @@ type QuerySearchProfilesResponse struct {
// PerformAccountCreationRequest is the request for PerformAccountCreation // PerformAccountCreationRequest is the request for PerformAccountCreation
type PerformAccountCreationRequest struct { type PerformAccountCreationRequest struct {
AccountType AccountType // Required: whether this is a guest or user account AccountType AccountType // Required: whether this is a guest or user account
Localpart string // Required: The localpart for this account. Ignored if account type is guest. Localpart string // Required: The localpart for this account. Ignored if account type is guest.
PolicyVersion string // optional: the privacy policy this account has accepted
AppServiceID string // optional: the application service ID (not user ID) creating this account, if any. AppServiceID string // optional: the application service ID (not user ID) creating this account, if any.
Password string // optional: if missing then this account will be a passwordless account Password string // optional: if missing then this account will be a passwordless account
OnConflict Conflict OnConflict Conflict
} }
// PerformAccountCreationResponse is the response for PerformAccountCreation // PerformAccountCreationResponse is the response for PerformAccountCreation

View file

@ -67,7 +67,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
res.Account = acc res.Account = acc
return nil return nil
} }
acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID) acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.PolicyVersion)
if err != nil { if err != nil {
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
switch req.OnConflict { switch req.OnConflict {

View file

@ -32,7 +32,7 @@ type Database interface {
// CreateAccount makes a new account with the given login name and password, and creates an empty profile // CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the // for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, ErrUserExists. // account already exists, it will return nil, ErrUserExists.
CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID string) (*api.Account, error) CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID, policyVersion string) (*api.Account, error)
CreateGuestAccount(ctx context.Context) (*api.Account, error) CreateGuestAccount(ctx context.Context) (*api.Account, error)
SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error
GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error) GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)

View file

@ -50,7 +50,7 @@ CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
` `
const insertAccountSQL = "" + const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)" "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, policy_version) VALUES ($1, $2, $3, $4, $5)"
const updatePasswordSQL = "" + const updatePasswordSQL = "" +
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
@ -113,16 +113,16 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
// this account will be passwordless. Returns an error if this account already exists. Returns the account // this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success. // on success.
func (s *accountsStatements) insertAccount( func (s *accountsStatements) insertAccount(
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID, policyVersion string,
) (*api.Account, error) { ) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt) stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
var err error var err error
if appserviceID == "" { if appserviceID == "" {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil) _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, "")
} else { } else {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, policyVersion)
} }
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -166,7 +166,7 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er
return err return err
} }
localpart := strconv.FormatInt(numLocalpart, 10) localpart := strconv.FormatInt(numLocalpart, 10)
acc, err = d.createAccount(ctx, txn, localpart, "", "") acc, err = d.createAccount(ctx, txn, localpart, "", "", "")
return err return err
}) })
return acc, err return acc, err
@ -176,17 +176,17 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er
// for this account. If no password is supplied, the account will be a passwordless account. If the // for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, sqlutil.ErrUserExists. // account already exists, it will return nil, sqlutil.ErrUserExists.
func (d *Database) CreateAccount( func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string, ctx context.Context, localpart, plaintextPassword, appserviceID, policyVersion string,
) (acc *api.Account, err error) { ) (acc *api.Account, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID) acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, policyVersion)
return err return err
}) })
return return
} }
func (d *Database) createAccount( func (d *Database) createAccount(
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID, policyVersion string,
) (*api.Account, error) { ) (*api.Account, error) {
var account *api.Account var account *api.Account
var err error var err error
@ -198,7 +198,7 @@ func (d *Database) createAccount(
return nil, err return nil, err
} }
} }
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID); err != nil { if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID, policyVersion); err != nil {
if sqlutil.IsUniqueConstraintViolationErr(err) { if sqlutil.IsUniqueConstraintViolationErr(err) {
return nil, sqlutil.ErrUserExists return nil, sqlutil.ErrUserExists
} }

View file

@ -48,7 +48,7 @@ CREATE TABLE IF NOT EXISTS account_accounts (
` `
const insertAccountSQL = "" + const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)" "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, policy_version) VALUES ($1, $2, $3, $4, $5)"
const updatePasswordSQL = "" + const updatePasswordSQL = "" +
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
@ -113,16 +113,16 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
// this account will be passwordless. Returns an error if this account already exists. Returns the account // this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success. // on success.
func (s *accountsStatements) insertAccount( func (s *accountsStatements) insertAccount(
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID, policyVersion string,
) (*api.Account, error) { ) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
stmt := s.insertAccountStmt stmt := s.insertAccountStmt
var err error var err error
if appserviceID == "" { if appserviceID == "" {
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil) _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, "")
} else { } else {
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, policyVersion)
} }
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -192,7 +192,7 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er
return err return err
} }
localpart := strconv.FormatInt(numLocalpart, 10) localpart := strconv.FormatInt(numLocalpart, 10)
acc, err = d.createAccount(ctx, txn, localpart, "", "") acc, err = d.createAccount(ctx, txn, localpart, "", "", "")
return err return err
}) })
return acc, err return acc, err
@ -202,7 +202,7 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er
// for this account. If no password is supplied, the account will be a passwordless account. If the // for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, ErrUserExists. // account already exists, it will return nil, ErrUserExists.
func (d *Database) CreateAccount( func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string, ctx context.Context, localpart, plaintextPassword, appserviceID, policyVersion string,
) (acc *api.Account, err error) { ) (acc *api.Account, err error) {
// Create one account at a time else we can get 'database is locked'. // Create one account at a time else we can get 'database is locked'.
d.profilesMu.Lock() d.profilesMu.Lock()
@ -212,7 +212,7 @@ func (d *Database) CreateAccount(
defer d.accountDatasMu.Unlock() defer d.accountDatasMu.Unlock()
defer d.accountsMu.Unlock() defer d.accountsMu.Unlock()
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID) acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, policyVersion)
return err return err
}) })
return return
@ -221,7 +221,7 @@ func (d *Database) CreateAccount(
// WARNING! This function assumes that the relevant mutexes have already // WARNING! This function assumes that the relevant mutexes have already
// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount). // been taken out by the caller (e.g. CreateAccount or CreateGuestAccount).
func (d *Database) createAccount( func (d *Database) createAccount(
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID, policyVersion string,
) (*api.Account, error) { ) (*api.Account, error) {
var err error var err error
var account *api.Account var account *api.Account
@ -233,7 +233,7 @@ func (d *Database) createAccount(
return nil, err return nil, err
} }
} }
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID); err != nil { if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID, policyVersion); err != nil {
return nil, sqlutil.ErrUserExists return nil, sqlutil.ErrUserExists
} }
if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil { if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil {

View file

@ -73,7 +73,7 @@ func TestQueryProfile(t *testing.T) {
aliceAvatarURL := "mxc://example.com/alice" aliceAvatarURL := "mxc://example.com/alice"
aliceDisplayName := "Alice" aliceDisplayName := "Alice"
userAPI, accountDB := MustMakeInternalAPI(t, apiTestOpts{}) userAPI, accountDB := MustMakeInternalAPI(t, apiTestOpts{})
_, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "") _, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "", "")
if err != nil { if err != nil {
t.Fatalf("failed to make account: %s", err) t.Fatalf("failed to make account: %s", err)
} }
@ -151,7 +151,7 @@ func TestLoginToken(t *testing.T) {
t.Run("tokenLoginFlow", func(t *testing.T) { t.Run("tokenLoginFlow", func(t *testing.T) {
userAPI, accountDB := MustMakeInternalAPI(t, apiTestOpts{}) userAPI, accountDB := MustMakeInternalAPI(t, apiTestOpts{})
_, err := accountDB.CreateAccount(ctx, "auser", "apassword", "") _, err := accountDB.CreateAccount(ctx, "auser", "apassword", "", "")
if err != nil { if err != nil {
t.Fatalf("failed to make account: %s", err) t.Fatalf("failed to make account: %s", err)
} }