Fix createAccount and friends

This commit is contained in:
Neil Alexander 2020-07-14 13:42:09 +01:00
parent 88199c8ee0
commit fe4723faa6
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944

View file

@ -39,9 +39,10 @@ type Database struct {
threepids threepidStatements threepids threepidStatements
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
accountWriter *sqlutil.TransactionWriter accountWriter *sqlutil.TransactionWriter
profileWriter *sqlutil.TransactionWriter profileWriter *sqlutil.TransactionWriter
threepidWriter *sqlutil.TransactionWriter accountDataWriter *sqlutil.TransactionWriter
threepidWriter *sqlutil.TransactionWriter
} }
// NewDatabase creates a new accounts and profiles database // NewDatabase creates a new accounts and profiles database
@ -85,6 +86,7 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName)
serverName: serverName, serverName: serverName,
accountWriter: sqlutil.NewTransactionWriter(), accountWriter: sqlutil.NewTransactionWriter(),
profileWriter: sqlutil.NewTransactionWriter(), profileWriter: sqlutil.NewTransactionWriter(),
accountDataWriter: sqlutil.NewTransactionWriter(),
threepidWriter: sqlutil.NewTransactionWriter(), threepidWriter: sqlutil.NewTransactionWriter(),
}, nil }, nil
} }
@ -148,9 +150,9 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er
} }
localpart := strconv.FormatInt(numLocalpart, 10) localpart := strconv.FormatInt(numLocalpart, 10)
acc, err = d.createAccount(ctx, txn, localpart, "", "") acc, err = d.createAccount(ctx, txn, localpart, "", "")
return err return nil
}) })
return acc, err return
} }
// CreateAccount makes a new account with the given login name and password, and creates an empty profile // CreateAccount makes a new account with the given login name and password, and creates an empty profile
@ -160,9 +162,9 @@ func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string, ctx context.Context, localpart, plaintextPassword, appserviceID string,
) (acc *api.Account, err error) { ) (acc *api.Account, err error) {
// Create one account at a time else we can get 'database is locked'. // Create one account at a time else we can get 'database is locked'.
err = d.accountWriter.Do(d.db, func(txn *sql.Tx) error { _ = d.accountWriter.Do(d.db, func(txn *sql.Tx) error {
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID) acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID)
return err return nil
}) })
return return
} }
@ -179,24 +181,37 @@ func (d *Database) createAccount(
return nil, err return nil, err
} }
} }
if err := d.profiles.insertProfile(ctx, txn, localpart); err != nil {
if isConstraintError(err) { err = d.profileWriter.Do(d.db, func(txn *sql.Tx) error {
return nil, sqlutil.ErrUserExists if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil {
if isConstraintError(err) {
return sqlutil.ErrUserExists
}
return err
} }
return nil
})
if err != nil {
_ = txn.Rollback()
return nil, err return nil, err
} }
if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ err = d.accountDataWriter.Do(d.db, func(txn *sql.Tx) error {
"global": { return d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
"content": [], "global": {
"override": [], "content": [],
"room": [], "override": [],
"sender": [], "room": [],
"underride": [] "sender": [],
} "underride": []
}`)); err != nil { }
}`))
})
if err != nil {
_ = txn.Rollback()
return nil, err return nil, err
} }
return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID) return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID)
} }
@ -208,7 +223,7 @@ func (d *Database) createAccount(
func (d *Database) SaveAccountData( func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
) error { ) error {
return d.accountWriter.Do(d.db, func(txn *sql.Tx) error { return d.accountDataWriter.Do(d.db, func(txn *sql.Tx) error {
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
}) })
} }