Fix createAccount and friends

This commit is contained in:
Neil Alexander 2020-07-14 13:42:09 +01:00
parent 88199c8ee0
commit fe4723faa6
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944

View file

@ -39,9 +39,10 @@ type Database struct {
threepids threepidStatements
serverName gomatrixserverlib.ServerName
accountWriter *sqlutil.TransactionWriter
profileWriter *sqlutil.TransactionWriter
threepidWriter *sqlutil.TransactionWriter
accountWriter *sqlutil.TransactionWriter
profileWriter *sqlutil.TransactionWriter
accountDataWriter *sqlutil.TransactionWriter
threepidWriter *sqlutil.TransactionWriter
}
// NewDatabase creates a new accounts and profiles database
@ -85,6 +86,7 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName)
serverName: serverName,
accountWriter: sqlutil.NewTransactionWriter(),
profileWriter: sqlutil.NewTransactionWriter(),
accountDataWriter: sqlutil.NewTransactionWriter(),
threepidWriter: sqlutil.NewTransactionWriter(),
}, nil
}
@ -148,9 +150,9 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er
}
localpart := strconv.FormatInt(numLocalpart, 10)
acc, err = d.createAccount(ctx, txn, localpart, "", "")
return err
return nil
})
return acc, err
return
}
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
@ -160,9 +162,9 @@ func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string,
) (acc *api.Account, err error) {
// Create one account at a time else we can get 'database is locked'.
err = d.accountWriter.Do(d.db, func(txn *sql.Tx) error {
_ = d.accountWriter.Do(d.db, func(txn *sql.Tx) error {
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID)
return err
return nil
})
return
}
@ -179,24 +181,37 @@ func (d *Database) createAccount(
return nil, err
}
}
if err := d.profiles.insertProfile(ctx, txn, localpart); err != nil {
if isConstraintError(err) {
return nil, sqlutil.ErrUserExists
err = d.profileWriter.Do(d.db, func(txn *sql.Tx) error {
if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil {
if isConstraintError(err) {
return sqlutil.ErrUserExists
}
return err
}
return nil
})
if err != nil {
_ = txn.Rollback()
return nil, err
}
if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
"global": {
"content": [],
"override": [],
"room": [],
"sender": [],
"underride": []
}
}`)); err != nil {
err = d.accountDataWriter.Do(d.db, func(txn *sql.Tx) error {
return d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
"global": {
"content": [],
"override": [],
"room": [],
"sender": [],
"underride": []
}
}`))
})
if err != nil {
_ = txn.Rollback()
return nil, err
}
return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID)
}
@ -208,7 +223,7 @@ func (d *Database) createAccount(
func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
) error {
return d.accountWriter.Do(d.db, func(txn *sql.Tx) error {
return d.accountDataWriter.Do(d.db, func(txn *sql.Tx) error {
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
})
}