Refactor TransactionWriter in user API

This commit is contained in:
Neil Alexander 2020-08-20 17:41:29 +01:00
parent 66d0134e2a
commit 2a9bb642ef
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
7 changed files with 20 additions and 18 deletions

View file

@ -57,9 +57,9 @@ type accountDataStatements struct {
selectAccountDataByTypeStmt *sql.Stmt selectAccountDataByTypeStmt *sql.Stmt
} }
func (s *accountDataStatements) prepare(db *sql.DB) (err error) { func (s *accountDataStatements) prepare(db *sql.DB, writer sqlutil.TransactionWriter) (err error) {
s.db = db s.db = db
s.writer = sqlutil.NewTransactionWriter() s.writer = writer
_, err = db.Exec(accountDataSchema) _, err = db.Exec(accountDataSchema)
if err != nil { if err != nil {
return return

View file

@ -67,9 +67,9 @@ type accountsStatements struct {
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { func (s *accountsStatements) prepare(db *sql.DB, writer sqlutil.TransactionWriter, server gomatrixserverlib.ServerName) (err error) {
s.db = db s.db = db
s.writer = sqlutil.NewTransactionWriter() s.writer = writer
_, err = db.Exec(accountsSchema) _, err = db.Exec(accountsSchema)
if err != nil { if err != nil {
return return

View file

@ -61,9 +61,9 @@ type profilesStatements struct {
selectProfilesBySearchStmt *sql.Stmt selectProfilesBySearchStmt *sql.Stmt
} }
func (s *profilesStatements) prepare(db *sql.DB) (err error) { func (s *profilesStatements) prepare(db *sql.DB, writer sqlutil.TransactionWriter) (err error) {
s.db = db s.db = db
s.writer = sqlutil.NewTransactionWriter() s.writer = writer
_, err = db.Exec(profilesSchema) _, err = db.Exec(profilesSchema)
if err != nil { if err != nil {
return return

View file

@ -64,16 +64,16 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err = partitions.Prepare(db, d.writer, "account"); err != nil { if err = partitions.Prepare(db, d.writer, "account"); err != nil {
return nil, err return nil, err
} }
if err = d.accounts.prepare(db, serverName); err != nil { if err = d.accounts.prepare(db, d.writer, serverName); err != nil {
return nil, err return nil, err
} }
if err = d.profiles.prepare(db); err != nil { if err = d.profiles.prepare(db, d.writer); err != nil {
return nil, err return nil, err
} }
if err = d.accountDatas.prepare(db); err != nil { if err = d.accountDatas.prepare(db, d.writer); err != nil {
return nil, err return nil, err
} }
if err = d.threepids.prepare(db); err != nil { if err = d.threepids.prepare(db, d.writer); err != nil {
return nil, err return nil, err
} }
return d, nil return d, nil

View file

@ -61,9 +61,9 @@ type threepidStatements struct {
deleteThreePIDStmt *sql.Stmt deleteThreePIDStmt *sql.Stmt
} }
func (s *threepidStatements) prepare(db *sql.DB) (err error) { func (s *threepidStatements) prepare(db *sql.DB, writer sqlutil.TransactionWriter) (err error) {
s.db = db s.db = db
s.writer = sqlutil.NewTransactionWriter() s.writer = writer
_, err = db.Exec(threepidSchema) _, err = db.Exec(threepidSchema)
if err != nil { if err != nil {
return return

View file

@ -91,9 +91,9 @@ type devicesStatements struct {
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.TransactionWriter, server gomatrixserverlib.ServerName) (err error) {
s.db = db s.db = db
s.writer = sqlutil.NewTransactionWriter() s.writer = writer
_, err = db.Exec(devicesSchema) _, err = db.Exec(devicesSchema)
if err != nil { if err != nil {
return return

View file

@ -34,6 +34,7 @@ var deviceIDByteLength = 6
// Database represents a device database. // Database represents a device database.
type Database struct { type Database struct {
db *sql.DB db *sql.DB
writer sqlutil.TransactionWriter
devices devicesStatements devices devicesStatements
} }
@ -43,11 +44,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err != nil { if err != nil {
return nil, err return nil, err
} }
writer := sqlutil.NewTransactionWriter()
d := devicesStatements{} d := devicesStatements{}
if err = d.prepare(db, serverName); err != nil { if err = d.prepare(db, writer, serverName); err != nil {
return nil, err return nil, err
} }
return &Database{db, d}, nil return &Database{db, writer, d}, nil
} }
// GetDeviceByAccessToken returns the device matching the given access token. // GetDeviceByAccessToken returns the device matching the given access token.
@ -88,7 +90,7 @@ func (d *Database) CreateDevice(
displayName *string, displayName *string,
) (dev *api.Device, returnErr error) { ) (dev *api.Device, returnErr error) {
if deviceID != nil { if deviceID != nil {
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
var err error var err error
// Revoke existing tokens for this device // Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
@ -108,7 +110,7 @@ func (d *Database) CreateDevice(
return return
} }
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
var err error var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName) dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName)
return err return err