From 3c81c569a7f468650bbc8c865d40663b966397ff Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Fri, 6 Mar 2020 17:38:47 +0000 Subject: [PATCH] bugfix for database is locked on guest reg --- clientapi/auth/storage/accounts/sqlite3/storage.go | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/clientapi/auth/storage/accounts/sqlite3/storage.go b/clientapi/auth/storage/accounts/sqlite3/storage.go index 4b685a08b..9124640c6 100644 --- a/clientapi/auth/storage/accounts/sqlite3/storage.go +++ b/clientapi/auth/storage/accounts/sqlite3/storage.go @@ -19,6 +19,7 @@ import ( "database/sql" "errors" "strconv" + "sync" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/common" @@ -40,6 +41,8 @@ type Database struct { threepids threepidStatements filter filterStatements serverName gomatrixserverlib.ServerName + + createGuestAccountMu sync.Mutex } // NewDatabase creates a new accounts and profiles database @@ -77,7 +80,7 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) if err = f.prepare(db); err != nil { return nil, err } - return &Database{db, partitions, a, p, m, ac, t, f, serverName}, nil + return &Database{db, partitions, a, p, m, ac, t, f, serverName, sync.Mutex{}}, nil } // GetAccountByPassword returns the account associated with the given localpart and password. @@ -123,6 +126,13 @@ func (d *Database) SetDisplayName( // for this account. func (d *Database) CreateGuestAccount(ctx context.Context) (acc *authtypes.Account, err error) { err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + // We need to lock so we sequentially create numeric localparts. If we don't, two calls to + // this function will cause the same number to be selected and one will fail with 'database is locked' + // when the first txn upgrades to a write txn. + // We know we'll be the only process since this is sqlite ;) so a lock here will be all that is needed. + d.createGuestAccountMu.Lock() + defer d.createGuestAccountMu.Unlock() + var numLocalpart int64 numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) if err != nil {