Add policy_version to insertAccount statement
This commit is contained in:
parent
d19518fca5
commit
f8bebe5e5a
|
@ -247,7 +247,7 @@ type QuerySearchProfilesResponse struct {
|
||||||
type PerformAccountCreationRequest struct {
|
type PerformAccountCreationRequest struct {
|
||||||
AccountType AccountType // Required: whether this is a guest or user account
|
AccountType AccountType // Required: whether this is a guest or user account
|
||||||
Localpart string // Required: The localpart for this account. Ignored if account type is guest.
|
Localpart string // Required: The localpart for this account. Ignored if account type is guest.
|
||||||
|
PolicyVersion string // optional: the privacy policy this account has accepted
|
||||||
AppServiceID string // optional: the application service ID (not user ID) creating this account, if any.
|
AppServiceID string // optional: the application service ID (not user ID) creating this account, if any.
|
||||||
Password string // optional: if missing then this account will be a passwordless account
|
Password string // optional: if missing then this account will be a passwordless account
|
||||||
OnConflict Conflict
|
OnConflict Conflict
|
||||||
|
|
|
@ -67,7 +67,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
|
||||||
res.Account = acc
|
res.Account = acc
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID)
|
acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.PolicyVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
|
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
|
||||||
switch req.OnConflict {
|
switch req.OnConflict {
|
||||||
|
|
|
@ -32,7 +32,7 @@ type Database interface {
|
||||||
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
|
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
|
||||||
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
||||||
// account already exists, it will return nil, ErrUserExists.
|
// account already exists, it will return nil, ErrUserExists.
|
||||||
CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID string) (*api.Account, error)
|
CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID, policyVersion string) (*api.Account, error)
|
||||||
CreateGuestAccount(ctx context.Context) (*api.Account, error)
|
CreateGuestAccount(ctx context.Context) (*api.Account, error)
|
||||||
SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error
|
SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error
|
||||||
GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
|
GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
|
||||||
|
|
|
@ -50,7 +50,7 @@ CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertAccountSQL = "" +
|
const insertAccountSQL = "" +
|
||||||
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)"
|
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, policy_version) VALUES ($1, $2, $3, $4, $5)"
|
||||||
|
|
||||||
const updatePasswordSQL = "" +
|
const updatePasswordSQL = "" +
|
||||||
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
|
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
|
||||||
|
@ -113,16 +113,16 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
|
||||||
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
||||||
// on success.
|
// on success.
|
||||||
func (s *accountsStatements) insertAccount(
|
func (s *accountsStatements) insertAccount(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string,
|
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID, policyVersion string,
|
||||||
) (*api.Account, error) {
|
) (*api.Account, error) {
|
||||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
|
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
if appserviceID == "" {
|
if appserviceID == "" {
|
||||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil)
|
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, "")
|
||||||
} else {
|
} else {
|
||||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
|
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, policyVersion)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -166,7 +166,7 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
localpart := strconv.FormatInt(numLocalpart, 10)
|
localpart := strconv.FormatInt(numLocalpart, 10)
|
||||||
acc, err = d.createAccount(ctx, txn, localpart, "", "")
|
acc, err = d.createAccount(ctx, txn, localpart, "", "", "")
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
return acc, err
|
return acc, err
|
||||||
|
@ -176,17 +176,17 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er
|
||||||
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
||||||
// account already exists, it will return nil, sqlutil.ErrUserExists.
|
// account already exists, it will return nil, sqlutil.ErrUserExists.
|
||||||
func (d *Database) CreateAccount(
|
func (d *Database) CreateAccount(
|
||||||
ctx context.Context, localpart, plaintextPassword, appserviceID string,
|
ctx context.Context, localpart, plaintextPassword, appserviceID, policyVersion string,
|
||||||
) (acc *api.Account, err error) {
|
) (acc *api.Account, err error) {
|
||||||
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID)
|
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, policyVersion)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) createAccount(
|
func (d *Database) createAccount(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string,
|
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID, policyVersion string,
|
||||||
) (*api.Account, error) {
|
) (*api.Account, error) {
|
||||||
var account *api.Account
|
var account *api.Account
|
||||||
var err error
|
var err error
|
||||||
|
@ -198,7 +198,7 @@ func (d *Database) createAccount(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID); err != nil {
|
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID, policyVersion); err != nil {
|
||||||
if sqlutil.IsUniqueConstraintViolationErr(err) {
|
if sqlutil.IsUniqueConstraintViolationErr(err) {
|
||||||
return nil, sqlutil.ErrUserExists
|
return nil, sqlutil.ErrUserExists
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,7 +48,7 @@ CREATE TABLE IF NOT EXISTS account_accounts (
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertAccountSQL = "" +
|
const insertAccountSQL = "" +
|
||||||
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)"
|
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, policy_version) VALUES ($1, $2, $3, $4, $5)"
|
||||||
|
|
||||||
const updatePasswordSQL = "" +
|
const updatePasswordSQL = "" +
|
||||||
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
|
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
|
||||||
|
@ -113,16 +113,16 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
|
||||||
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
||||||
// on success.
|
// on success.
|
||||||
func (s *accountsStatements) insertAccount(
|
func (s *accountsStatements) insertAccount(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string,
|
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID, policyVersion string,
|
||||||
) (*api.Account, error) {
|
) (*api.Account, error) {
|
||||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||||
stmt := s.insertAccountStmt
|
stmt := s.insertAccountStmt
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
if appserviceID == "" {
|
if appserviceID == "" {
|
||||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil)
|
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, "")
|
||||||
} else {
|
} else {
|
||||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
|
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, policyVersion)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -192,7 +192,7 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
localpart := strconv.FormatInt(numLocalpart, 10)
|
localpart := strconv.FormatInt(numLocalpart, 10)
|
||||||
acc, err = d.createAccount(ctx, txn, localpart, "", "")
|
acc, err = d.createAccount(ctx, txn, localpart, "", "", "")
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
return acc, err
|
return acc, err
|
||||||
|
@ -202,7 +202,7 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er
|
||||||
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
||||||
// account already exists, it will return nil, ErrUserExists.
|
// account already exists, it will return nil, ErrUserExists.
|
||||||
func (d *Database) CreateAccount(
|
func (d *Database) CreateAccount(
|
||||||
ctx context.Context, localpart, plaintextPassword, appserviceID string,
|
ctx context.Context, localpart, plaintextPassword, appserviceID, policyVersion string,
|
||||||
) (acc *api.Account, err error) {
|
) (acc *api.Account, err error) {
|
||||||
// Create one account at a time else we can get 'database is locked'.
|
// Create one account at a time else we can get 'database is locked'.
|
||||||
d.profilesMu.Lock()
|
d.profilesMu.Lock()
|
||||||
|
@ -212,7 +212,7 @@ func (d *Database) CreateAccount(
|
||||||
defer d.accountDatasMu.Unlock()
|
defer d.accountDatasMu.Unlock()
|
||||||
defer d.accountsMu.Unlock()
|
defer d.accountsMu.Unlock()
|
||||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID)
|
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, policyVersion)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
|
@ -221,7 +221,7 @@ func (d *Database) CreateAccount(
|
||||||
// WARNING! This function assumes that the relevant mutexes have already
|
// WARNING! This function assumes that the relevant mutexes have already
|
||||||
// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount).
|
// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount).
|
||||||
func (d *Database) createAccount(
|
func (d *Database) createAccount(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string,
|
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID, policyVersion string,
|
||||||
) (*api.Account, error) {
|
) (*api.Account, error) {
|
||||||
var err error
|
var err error
|
||||||
var account *api.Account
|
var account *api.Account
|
||||||
|
@ -233,7 +233,7 @@ func (d *Database) createAccount(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID); err != nil {
|
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID, policyVersion); err != nil {
|
||||||
return nil, sqlutil.ErrUserExists
|
return nil, sqlutil.ErrUserExists
|
||||||
}
|
}
|
||||||
if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil {
|
if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil {
|
||||||
|
|
|
@ -73,7 +73,7 @@ func TestQueryProfile(t *testing.T) {
|
||||||
aliceAvatarURL := "mxc://example.com/alice"
|
aliceAvatarURL := "mxc://example.com/alice"
|
||||||
aliceDisplayName := "Alice"
|
aliceDisplayName := "Alice"
|
||||||
userAPI, accountDB := MustMakeInternalAPI(t, apiTestOpts{})
|
userAPI, accountDB := MustMakeInternalAPI(t, apiTestOpts{})
|
||||||
_, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "")
|
_, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to make account: %s", err)
|
t.Fatalf("failed to make account: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -151,7 +151,7 @@ func TestLoginToken(t *testing.T) {
|
||||||
t.Run("tokenLoginFlow", func(t *testing.T) {
|
t.Run("tokenLoginFlow", func(t *testing.T) {
|
||||||
userAPI, accountDB := MustMakeInternalAPI(t, apiTestOpts{})
|
userAPI, accountDB := MustMakeInternalAPI(t, apiTestOpts{})
|
||||||
|
|
||||||
_, err := accountDB.CreateAccount(ctx, "auser", "apassword", "")
|
_, err := accountDB.CreateAccount(ctx, "auser", "apassword", "", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to make account: %s", err)
|
t.Fatalf("failed to make account: %s", err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue