From d9a1d20422dd666d43be73cad5d886a1e73aa7a2 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 14 Jul 2020 13:52:42 +0100 Subject: [PATCH] Protect SQLite calls with mutexes (replaces #1200) --- userapi/storage/accounts/sqlite3/storage.go | 61 +++++++++++++++++++-- 1 file changed, 55 insertions(+), 6 deletions(-) diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index 72b27c8bf..9fe3aa91d 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -40,7 +40,10 @@ type Database struct { threepids threepidStatements serverName gomatrixserverlib.ServerName - createAccountMu sync.Mutex + accountsMu sync.Mutex + profilesMu sync.Mutex + accountDatasMu sync.Mutex + threepidsMu sync.Mutex } // NewDatabase creates a new accounts and profiles database @@ -74,7 +77,15 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) if err = t.prepare(db); err != nil { return nil, err } - return &Database{db, partitions, a, p, ac, t, serverName, sync.Mutex{}}, nil + return &Database{ + db: db, + PartitionOffsetStatements: partitions, + accounts: a, + profiles: p, + accountDatas: ac, + threepids: t, + serverName: serverName, + }, nil } // GetAccountByPassword returns the account associated with the given localpart and password. @@ -82,6 +93,8 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) func (d *Database) GetAccountByPassword( ctx context.Context, localpart, plaintextPassword string, ) (*api.Account, error) { + d.accountsMu.Lock() + defer d.accountsMu.Unlock() hash, err := d.accounts.selectPasswordHash(ctx, localpart) if err != nil { return nil, err @@ -97,6 +110,8 @@ func (d *Database) GetAccountByPassword( func (d *Database) GetProfileByLocalpart( ctx context.Context, localpart string, ) (*authtypes.Profile, error) { + d.profilesMu.Lock() + defer d.profilesMu.Unlock() return d.profiles.selectProfileByLocalpart(ctx, localpart) } @@ -105,6 +120,8 @@ func (d *Database) GetProfileByLocalpart( func (d *Database) SetAvatarURL( ctx context.Context, localpart string, avatarURL string, ) error { + d.profilesMu.Lock() + defer d.profilesMu.Unlock() return d.profiles.setAvatarURL(ctx, localpart, avatarURL) } @@ -113,6 +130,8 @@ func (d *Database) SetAvatarURL( func (d *Database) SetDisplayName( ctx context.Context, localpart string, displayName string, ) error { + d.profilesMu.Lock() + defer d.profilesMu.Unlock() return d.profiles.setDisplayName(ctx, localpart, displayName) } @@ -124,8 +143,12 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er // when the first txn upgrades to a write txn. We also need to lock the account creation else we can // race with CreateAccount // We know we'll be the only process since this is sqlite ;) so a lock here will be all that is needed. - d.createAccountMu.Lock() - defer d.createAccountMu.Unlock() + d.profilesMu.Lock() + d.accountDatasMu.Lock() + d.accountsMu.Lock() + defer d.profilesMu.Unlock() + defer d.accountDatasMu.Unlock() + defer d.accountsMu.Unlock() err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { var numLocalpart int64 numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) @@ -146,8 +169,12 @@ func (d *Database) CreateAccount( ctx context.Context, localpart, plaintextPassword, appserviceID string, ) (acc *api.Account, err error) { // Create one account at a time else we can get 'database is locked'. - d.createAccountMu.Lock() - defer d.createAccountMu.Unlock() + d.profilesMu.Lock() + d.accountDatasMu.Lock() + d.accountsMu.Lock() + defer d.profilesMu.Unlock() + defer d.accountDatasMu.Unlock() + defer d.accountsMu.Unlock() err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID) return err @@ -155,6 +182,8 @@ func (d *Database) CreateAccount( return } +// WARNING! This function assumes that the relevant mutexes have already +// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount). func (d *Database) createAccount( ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, ) (*api.Account, error) { @@ -196,6 +225,8 @@ func (d *Database) createAccount( func (d *Database) SaveAccountData( ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, ) error { + d.accountDatasMu.Lock() + defer d.accountDatasMu.Unlock() return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) }) @@ -209,6 +240,8 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) ( rooms map[string]map[string]json.RawMessage, err error, ) { + d.accountDatasMu.Lock() + defer d.accountDatasMu.Unlock() return d.accountDatas.selectAccountData(ctx, localpart) } @@ -219,6 +252,8 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) ( func (d *Database) GetAccountDataByType( ctx context.Context, localpart, roomID, dataType string, ) (data json.RawMessage, err error) { + d.accountDatasMu.Lock() + defer d.accountDatasMu.Unlock() return d.accountDatas.selectAccountDataByType( ctx, localpart, roomID, dataType, ) @@ -228,6 +263,8 @@ func (d *Database) GetAccountDataByType( func (d *Database) GetNewNumericLocalpart( ctx context.Context, ) (int64, error) { + d.accountsMu.Lock() + defer d.accountsMu.Unlock() return d.accounts.selectNewNumericLocalpart(ctx, nil) } @@ -247,6 +284,8 @@ var Err3PIDInUse = errors.New("This third-party identifier is already in use") func (d *Database) SaveThreePIDAssociation( ctx context.Context, threepid, localpart, medium string, ) (err error) { + d.threepidsMu.Lock() + defer d.threepidsMu.Unlock() return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { user, err := d.threepids.selectLocalpartForThreePID( ctx, txn, threepid, medium, @@ -270,6 +309,8 @@ func (d *Database) SaveThreePIDAssociation( func (d *Database) RemoveThreePIDAssociation( ctx context.Context, threepid string, medium string, ) (err error) { + d.threepidsMu.Lock() + defer d.threepidsMu.Unlock() return d.threepids.deleteThreePID(ctx, threepid, medium) } @@ -281,6 +322,8 @@ func (d *Database) RemoveThreePIDAssociation( func (d *Database) GetLocalpartForThreePID( ctx context.Context, threepid string, medium string, ) (localpart string, err error) { + d.threepidsMu.Lock() + defer d.threepidsMu.Unlock() return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium) } @@ -291,6 +334,8 @@ func (d *Database) GetLocalpartForThreePID( func (d *Database) GetThreePIDsForLocalpart( ctx context.Context, localpart string, ) (threepids []authtypes.ThreePID, err error) { + d.threepidsMu.Lock() + defer d.threepidsMu.Unlock() return d.threepids.selectThreePIDsForLocalpart(ctx, localpart) } @@ -298,6 +343,8 @@ func (d *Database) GetThreePIDsForLocalpart( // in the database. // If the DB returns sql.ErrNoRows the Localpart isn't taken. func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) { + d.accountsMu.Lock() + defer d.accountsMu.Unlock() _, err := d.accounts.selectAccountByLocalpart(ctx, localpart) if err == sql.ErrNoRows { return true, nil @@ -310,5 +357,7 @@ func (d *Database) CheckAccountAvailability(ctx context.Context, localpart strin // Returns sql.ErrNoRows if no account exists which matches the given localpart. func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, ) (*api.Account, error) { + d.accountsMu.Lock() + defer d.accountsMu.Unlock() return d.accounts.selectAccountByLocalpart(ctx, localpart) }