diff --git a/userapi/storage/accounts/sqlite3/profile_table.go b/userapi/storage/accounts/sqlite3/profile_table.go index 9b5192a02..6f213b7ce 100644 --- a/userapi/storage/accounts/sqlite3/profile_table.go +++ b/userapi/storage/accounts/sqlite3/profile_table.go @@ -93,15 +93,15 @@ func (s *profilesStatements) selectProfileByLocalpart( } func (s *profilesStatements) setAvatarURL( - ctx context.Context, localpart string, avatarURL string, + ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, ) (err error) { - _, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart) + _, err = txn.Stmt(s.setAvatarURLStmt).ExecContext(ctx, avatarURL, localpart) return } func (s *profilesStatements) setDisplayName( - ctx context.Context, localpart string, displayName string, + ctx context.Context, txn *sql.Tx, localpart string, displayName string, ) (err error) { - _, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart) + _, err = txn.Stmt(s.setDisplayNameStmt).ExecContext(ctx, displayName, localpart) return } diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index 72b27c8bf..ab276ff10 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -20,7 +20,6 @@ import ( "encoding/json" "errors" "strconv" - "sync" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -40,7 +39,9 @@ type Database struct { threepids threepidStatements serverName gomatrixserverlib.ServerName - createAccountMu sync.Mutex + accountWriter *sqlutil.TransactionWriter + profileWriter *sqlutil.TransactionWriter + threepidWriter *sqlutil.TransactionWriter } // NewDatabase creates a new accounts and profiles database @@ -74,7 +75,18 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) if err = t.prepare(db); err != nil { return nil, err } - return &Database{db, partitions, a, p, ac, t, serverName, sync.Mutex{}}, nil + return &Database{ + db: db, + PartitionOffsetStatements: partitions, + accounts: a, + profiles: p, + accountDatas: ac, + threepids: t, + serverName: serverName, + accountWriter: sqlutil.NewTransactionWriter(), + profileWriter: sqlutil.NewTransactionWriter(), + threepidWriter: sqlutil.NewTransactionWriter(), + }, nil } // GetAccountByPassword returns the account associated with the given localpart and password. @@ -105,7 +117,9 @@ func (d *Database) GetProfileByLocalpart( func (d *Database) SetAvatarURL( ctx context.Context, localpart string, avatarURL string, ) error { - return d.profiles.setAvatarURL(ctx, localpart, avatarURL) + return d.profileWriter.Do(d.db, func(txn *sql.Tx) error { + return d.profiles.setAvatarURL(ctx, txn, localpart, avatarURL) + }) } // SetDisplayName updates the display name of the profile associated with the given @@ -113,7 +127,9 @@ func (d *Database) SetAvatarURL( func (d *Database) SetDisplayName( ctx context.Context, localpart string, displayName string, ) error { - return d.profiles.setDisplayName(ctx, localpart, displayName) + return d.profileWriter.Do(d.db, func(txn *sql.Tx) error { + return d.profiles.setDisplayName(ctx, txn, localpart, displayName) + }) } // CreateGuestAccount makes a new guest account and creates an empty profile @@ -124,9 +140,7 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er // when the first txn upgrades to a write txn. We also need to lock the account creation else we can // race with CreateAccount // We know we'll be the only process since this is sqlite ;) so a lock here will be all that is needed. - d.createAccountMu.Lock() - defer d.createAccountMu.Unlock() - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + err = d.accountWriter.Do(d.db, func(txn *sql.Tx) error { var numLocalpart int64 numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) if err != nil { @@ -146,9 +160,7 @@ func (d *Database) CreateAccount( ctx context.Context, localpart, plaintextPassword, appserviceID string, ) (acc *api.Account, err error) { // Create one account at a time else we can get 'database is locked'. - d.createAccountMu.Lock() - defer d.createAccountMu.Unlock() - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + err = d.accountWriter.Do(d.db, func(txn *sql.Tx) error { acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID) return err }) @@ -196,7 +208,7 @@ func (d *Database) createAccount( func (d *Database) SaveAccountData( ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, ) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.accountWriter.Do(d.db, func(txn *sql.Tx) error { return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) }) } @@ -247,7 +259,7 @@ var Err3PIDInUse = errors.New("This third-party identifier is already in use") func (d *Database) SaveThreePIDAssociation( ctx context.Context, threepid, localpart, medium string, ) (err error) { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.threepidWriter.Do(d.db, func(txn *sql.Tx) error { user, err := d.threepids.selectLocalpartForThreePID( ctx, txn, threepid, medium, ) @@ -270,7 +282,9 @@ func (d *Database) SaveThreePIDAssociation( func (d *Database) RemoveThreePIDAssociation( ctx context.Context, threepid string, medium string, ) (err error) { - return d.threepids.deleteThreePID(ctx, threepid, medium) + return d.threepidWriter.Do(d.db, func(txn *sql.Tx) error { + return d.threepids.deleteThreePID(ctx, txn, threepid, medium) + }) } // GetLocalpartForThreePID looks up the localpart associated with a given third-party diff --git a/userapi/storage/accounts/sqlite3/threepid_table.go b/userapi/storage/accounts/sqlite3/threepid_table.go index 0200dee7f..ed3702c3b 100644 --- a/userapi/storage/accounts/sqlite3/threepid_table.go +++ b/userapi/storage/accounts/sqlite3/threepid_table.go @@ -124,7 +124,8 @@ func (s *threepidStatements) insertThreePID( } func (s *threepidStatements) deleteThreePID( - ctx context.Context, threepid string, medium string) (err error) { - _, err = s.deleteThreePIDStmt.ExecContext(ctx, threepid, medium) + ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) { + stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt) + _, err = stmt.ExecContext(ctx, threepid, medium) return }