diff --git a/clientapi/auth/storage/accounts/postgres/accounts_table.go b/clientapi/auth/storage/accounts/postgres/accounts_table.go index 6b8ed3728..89acee1f5 100644 --- a/clientapi/auth/storage/accounts/postgres/accounts_table.go +++ b/clientapi/auth/storage/accounts/postgres/accounts_table.go @@ -146,8 +146,12 @@ func (s *accountsStatements) selectAccountByLocalpart( } func (s *accountsStatements) selectNewNumericLocalpart( - ctx context.Context, + ctx context.Context, txn *sql.Tx, ) (id int64, err error) { - err = s.selectNewNumericLocalpartStmt.QueryRowContext(ctx).Scan(&id) + stmt := s.selectNewNumericLocalpartStmt + if txn != nil { + stmt = txn.Stmt(stmt) + } + err = stmt.QueryRowContext(ctx).Scan(&id) return } diff --git a/clientapi/auth/storage/accounts/postgres/storage.go b/clientapi/auth/storage/accounts/postgres/storage.go index eb1f2236d..7623d4e18 100644 --- a/clientapi/auth/storage/accounts/postgres/storage.go +++ b/clientapi/auth/storage/accounts/postgres/storage.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "errors" + "strconv" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/common" @@ -118,8 +119,19 @@ func (d *Database) SetDisplayName( return d.profiles.setDisplayName(ctx, localpart, displayName) } -func (d *Database) CreateGuestAccount(ctx context.Context) (*authtypes.Account, error) { - return nil, nil +// CreateGuestAccount makes a new guest account and creates an empty profile +// for this account. +func (d *Database) CreateGuestAccount(ctx context.Context) (acc *authtypes.Account, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + numLocalpart, err := d.accounts.selectNewNumericLocalpart(ctx, txn) + if err != nil { + return err + } + localpart := strconv.FormatInt(numLocalpart, 10) + acc, err = d.createAccount(ctx, txn, localpart, "", "") + return err + }) + return acc, err } // CreateAccount makes a new account with the given login name and password, and creates an empty profile @@ -127,6 +139,16 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (*authtypes.Account, // account already exists, it will return nil, nil. func (d *Database) CreateAccount( ctx context.Context, localpart, plaintextPassword, appserviceID string, +) (acc *authtypes.Account, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID) + return err + }) + return +} + +func (d *Database) createAccount( + ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, ) (*authtypes.Account, error) { var err error @@ -292,7 +314,7 @@ func (d *Database) GetAccountDataByType( func (d *Database) GetNewNumericLocalpart( ctx context.Context, ) (int64, error) { - return d.accounts.selectNewNumericLocalpart(ctx) + return d.accounts.selectNewNumericLocalpart(ctx, nil) } func hashPassword(plaintext string) (hash string, err error) { diff --git a/clientapi/auth/storage/accounts/sqlite3/storage.go b/clientapi/auth/storage/accounts/sqlite3/storage.go index a980071ce..3cb1f9fb2 100644 --- a/clientapi/auth/storage/accounts/sqlite3/storage.go +++ b/clientapi/auth/storage/accounts/sqlite3/storage.go @@ -119,6 +119,8 @@ func (d *Database) SetDisplayName( return d.profiles.setDisplayName(ctx, localpart, displayName) } +// CreateGuestAccount makes a new guest account and creates an empty profile +// for this account. func (d *Database) CreateGuestAccount(ctx context.Context) (acc *authtypes.Account, err error) { err = common.WithTransaction(d.db, func(txn *sql.Tx) error { numLocalpart, err := d.accounts.selectNewNumericLocalpart(ctx, txn)