From fe4723faa6afa347fb83bc2d140259904c27e83d Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 14 Jul 2020 13:42:09 +0100 Subject: [PATCH] Fix createAccount and friends --- userapi/storage/accounts/sqlite3/storage.go | 55 +++++++++++++-------- 1 file changed, 35 insertions(+), 20 deletions(-) diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index ab276ff10..8c6095fee 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -39,9 +39,10 @@ type Database struct { threepids threepidStatements serverName gomatrixserverlib.ServerName - accountWriter *sqlutil.TransactionWriter - profileWriter *sqlutil.TransactionWriter - threepidWriter *sqlutil.TransactionWriter + accountWriter *sqlutil.TransactionWriter + profileWriter *sqlutil.TransactionWriter + accountDataWriter *sqlutil.TransactionWriter + threepidWriter *sqlutil.TransactionWriter } // NewDatabase creates a new accounts and profiles database @@ -85,6 +86,7 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) serverName: serverName, accountWriter: sqlutil.NewTransactionWriter(), profileWriter: sqlutil.NewTransactionWriter(), + accountDataWriter: sqlutil.NewTransactionWriter(), threepidWriter: sqlutil.NewTransactionWriter(), }, nil } @@ -148,9 +150,9 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er } localpart := strconv.FormatInt(numLocalpart, 10) acc, err = d.createAccount(ctx, txn, localpart, "", "") - return err + return nil }) - return acc, err + return } // CreateAccount makes a new account with the given login name and password, and creates an empty profile @@ -160,9 +162,9 @@ func (d *Database) CreateAccount( ctx context.Context, localpart, plaintextPassword, appserviceID string, ) (acc *api.Account, err error) { // Create one account at a time else we can get 'database is locked'. - err = d.accountWriter.Do(d.db, func(txn *sql.Tx) error { + _ = d.accountWriter.Do(d.db, func(txn *sql.Tx) error { acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID) - return err + return nil }) return } @@ -179,24 +181,37 @@ func (d *Database) createAccount( return nil, err } } - if err := d.profiles.insertProfile(ctx, txn, localpart); err != nil { - if isConstraintError(err) { - return nil, sqlutil.ErrUserExists + + err = d.profileWriter.Do(d.db, func(txn *sql.Tx) error { + if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil { + if isConstraintError(err) { + return sqlutil.ErrUserExists + } + return err } + return nil + }) + if err != nil { + _ = txn.Rollback() return nil, err } - if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ - "global": { - "content": [], - "override": [], - "room": [], - "sender": [], - "underride": [] - } - }`)); err != nil { + err = d.accountDataWriter.Do(d.db, func(txn *sql.Tx) error { + return d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ + "global": { + "content": [], + "override": [], + "room": [], + "sender": [], + "underride": [] + } + }`)) + }) + if err != nil { + _ = txn.Rollback() return nil, err } + return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID) } @@ -208,7 +223,7 @@ func (d *Database) createAccount( func (d *Database) SaveAccountData( ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, ) error { - return d.accountWriter.Do(d.db, func(txn *sql.Tx) error { + return d.accountDataWriter.Do(d.db, func(txn *sql.Tx) error { return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) }) }