Refactor TransactionWriter in user API

This commit is contained in:
Neil Alexander 2020-08-20 17:41:29 +01:00
parent 66d0134e2a
commit 2a9bb642ef
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
7 changed files with 20 additions and 18 deletions

View file

@ -57,9 +57,9 @@ type accountDataStatements struct {
selectAccountDataByTypeStmt *sql.Stmt
}
func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
func (s *accountDataStatements) prepare(db *sql.DB, writer sqlutil.TransactionWriter) (err error) {
s.db = db
s.writer = sqlutil.NewTransactionWriter()
s.writer = writer
_, err = db.Exec(accountDataSchema)
if err != nil {
return

View file

@ -67,9 +67,9 @@ type accountsStatements struct {
serverName gomatrixserverlib.ServerName
}
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
func (s *accountsStatements) prepare(db *sql.DB, writer sqlutil.TransactionWriter, server gomatrixserverlib.ServerName) (err error) {
s.db = db
s.writer = sqlutil.NewTransactionWriter()
s.writer = writer
_, err = db.Exec(accountsSchema)
if err != nil {
return

View file

@ -61,9 +61,9 @@ type profilesStatements struct {
selectProfilesBySearchStmt *sql.Stmt
}
func (s *profilesStatements) prepare(db *sql.DB) (err error) {
func (s *profilesStatements) prepare(db *sql.DB, writer sqlutil.TransactionWriter) (err error) {
s.db = db
s.writer = sqlutil.NewTransactionWriter()
s.writer = writer
_, err = db.Exec(profilesSchema)
if err != nil {
return

View file

@ -64,16 +64,16 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err = partitions.Prepare(db, d.writer, "account"); err != nil {
return nil, err
}
if err = d.accounts.prepare(db, serverName); err != nil {
if err = d.accounts.prepare(db, d.writer, serverName); err != nil {
return nil, err
}
if err = d.profiles.prepare(db); err != nil {
if err = d.profiles.prepare(db, d.writer); err != nil {
return nil, err
}
if err = d.accountDatas.prepare(db); err != nil {
if err = d.accountDatas.prepare(db, d.writer); err != nil {
return nil, err
}
if err = d.threepids.prepare(db); err != nil {
if err = d.threepids.prepare(db, d.writer); err != nil {
return nil, err
}
return d, nil

View file

@ -61,9 +61,9 @@ type threepidStatements struct {
deleteThreePIDStmt *sql.Stmt
}
func (s *threepidStatements) prepare(db *sql.DB) (err error) {
func (s *threepidStatements) prepare(db *sql.DB, writer sqlutil.TransactionWriter) (err error) {
s.db = db
s.writer = sqlutil.NewTransactionWriter()
s.writer = writer
_, err = db.Exec(threepidSchema)
if err != nil {
return

View file

@ -91,9 +91,9 @@ type devicesStatements struct {
serverName gomatrixserverlib.ServerName
}
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.TransactionWriter, server gomatrixserverlib.ServerName) (err error) {
s.db = db
s.writer = sqlutil.NewTransactionWriter()
s.writer = writer
_, err = db.Exec(devicesSchema)
if err != nil {
return

View file

@ -34,6 +34,7 @@ var deviceIDByteLength = 6
// Database represents a device database.
type Database struct {
db *sql.DB
writer sqlutil.TransactionWriter
devices devicesStatements
}
@ -43,11 +44,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err != nil {
return nil, err
}
writer := sqlutil.NewTransactionWriter()
d := devicesStatements{}
if err = d.prepare(db, serverName); err != nil {
if err = d.prepare(db, writer, serverName); err != nil {
return nil, err
}
return &Database{db, d}, nil
return &Database{db, writer, d}, nil
}
// GetDeviceByAccessToken returns the device matching the given access token.
@ -88,7 +90,7 @@ func (d *Database) CreateDevice(
displayName *string,
) (dev *api.Device, returnErr error) {
if deviceID != nil {
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
@ -108,7 +110,7 @@ func (d *Database) CreateDevice(
return
}
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName)
return err