From e0f8c58bba060044f113af98c0ac616dc29f9ab7 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Fri, 6 Mar 2020 17:04:14 +0000 Subject: [PATCH] Actually use the txn when creating accounts --- .../accounts/postgres/account_data_table.go | 4 ++-- .../storage/accounts/postgres/accounts_table.go | 4 ++-- .../storage/accounts/postgres/profile_table.go | 4 ++-- .../auth/storage/accounts/postgres/storage.go | 14 +++++++++----- clientapi/auth/storage/accounts/sqlite3/storage.go | 3 ++- 5 files changed, 17 insertions(+), 12 deletions(-) diff --git a/clientapi/auth/storage/accounts/postgres/account_data_table.go b/clientapi/auth/storage/accounts/postgres/account_data_table.go index d0cfcc0cf..4573999b4 100644 --- a/clientapi/auth/storage/accounts/postgres/account_data_table.go +++ b/clientapi/auth/storage/accounts/postgres/account_data_table.go @@ -72,9 +72,9 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) { } func (s *accountDataStatements) insertAccountData( - ctx context.Context, localpart, roomID, dataType, content string, + ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string, ) (err error) { - stmt := s.insertAccountDataStmt + stmt := txn.Stmt(s.insertAccountDataStmt) _, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content) return } diff --git a/clientapi/auth/storage/accounts/postgres/accounts_table.go b/clientapi/auth/storage/accounts/postgres/accounts_table.go index 89acee1f5..85c1938a1 100644 --- a/clientapi/auth/storage/accounts/postgres/accounts_table.go +++ b/clientapi/auth/storage/accounts/postgres/accounts_table.go @@ -91,10 +91,10 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server // this account will be passwordless. Returns an error if this account already exists. Returns the account // on success. func (s *accountsStatements) insertAccount( - ctx context.Context, localpart, hash, appserviceID string, + ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, ) (*authtypes.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 - stmt := s.insertAccountStmt + stmt := txn.Stmt(s.insertAccountStmt) var err error if appserviceID == "" { diff --git a/clientapi/auth/storage/accounts/postgres/profile_table.go b/clientapi/auth/storage/accounts/postgres/profile_table.go index 38c76c40f..d2cbeb8e6 100644 --- a/clientapi/auth/storage/accounts/postgres/profile_table.go +++ b/clientapi/auth/storage/accounts/postgres/profile_table.go @@ -73,9 +73,9 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) { } func (s *profilesStatements) insertProfile( - ctx context.Context, localpart string, + ctx context.Context, txn *sql.Tx, localpart string, ) (err error) { - _, err = s.insertProfileStmt.ExecContext(ctx, localpart, "", "") + _, err = txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "") return } diff --git a/clientapi/auth/storage/accounts/postgres/storage.go b/clientapi/auth/storage/accounts/postgres/storage.go index 7623d4e18..8115dca43 100644 --- a/clientapi/auth/storage/accounts/postgres/storage.go +++ b/clientapi/auth/storage/accounts/postgres/storage.go @@ -123,7 +123,8 @@ func (d *Database) SetDisplayName( // for this account. func (d *Database) CreateGuestAccount(ctx context.Context) (acc *authtypes.Account, err error) { err = common.WithTransaction(d.db, func(txn *sql.Tx) error { - numLocalpart, err := d.accounts.selectNewNumericLocalpart(ctx, txn) + var numLocalpart int64 + numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) if err != nil { return err } @@ -160,13 +161,14 @@ func (d *Database) createAccount( return nil, err } } - if err := d.profiles.insertProfile(ctx, localpart); err != nil { + if err := d.profiles.insertProfile(ctx, txn, localpart); err != nil { if common.IsUniqueConstraintViolationErr(err) { return nil, nil } return nil, err } - if err := d.SaveAccountData(ctx, localpart, "", "m.push_rules", `{ + + if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{ "global": { "content": [], "override": [], @@ -177,7 +179,7 @@ func (d *Database) createAccount( }`); err != nil { return nil, err } - return d.accounts.insertAccount(ctx, localpart, hash, appserviceID) + return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID) } // SaveMembership saves the user matching a given localpart as a member of a given @@ -284,7 +286,9 @@ func (d *Database) newMembership( func (d *Database) SaveAccountData( ctx context.Context, localpart, roomID, dataType, content string, ) error { - return d.accountDatas.insertAccountData(ctx, localpart, roomID, dataType, content) + return common.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) + }) } // GetAccountData returns account data related to a given localpart diff --git a/clientapi/auth/storage/accounts/sqlite3/storage.go b/clientapi/auth/storage/accounts/sqlite3/storage.go index 63d47f939..4b685a08b 100644 --- a/clientapi/auth/storage/accounts/sqlite3/storage.go +++ b/clientapi/auth/storage/accounts/sqlite3/storage.go @@ -123,7 +123,8 @@ func (d *Database) SetDisplayName( // for this account. func (d *Database) CreateGuestAccount(ctx context.Context) (acc *authtypes.Account, err error) { err = common.WithTransaction(d.db, func(txn *sql.Tx) error { - numLocalpart, err := d.accounts.selectNewNumericLocalpart(ctx, txn) + var numLocalpart int64 + numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) if err != nil { return err }