From 2a9bb642ef6e6c1bb0ff67f36293f228595531f4 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 20 Aug 2020 17:41:29 +0100 Subject: [PATCH] Refactor TransactionWriter in user API --- userapi/storage/accounts/sqlite3/account_data_table.go | 4 ++-- userapi/storage/accounts/sqlite3/accounts_table.go | 4 ++-- userapi/storage/accounts/sqlite3/profile_table.go | 4 ++-- userapi/storage/accounts/sqlite3/storage.go | 8 ++++---- userapi/storage/accounts/sqlite3/threepid_table.go | 4 ++-- userapi/storage/devices/sqlite3/devices_table.go | 4 ++-- userapi/storage/devices/sqlite3/storage.go | 10 ++++++---- 7 files changed, 20 insertions(+), 18 deletions(-) diff --git a/userapi/storage/accounts/sqlite3/account_data_table.go b/userapi/storage/accounts/sqlite3/account_data_table.go index 9b40e6579..69498c65c 100644 --- a/userapi/storage/accounts/sqlite3/account_data_table.go +++ b/userapi/storage/accounts/sqlite3/account_data_table.go @@ -57,9 +57,9 @@ type accountDataStatements struct { selectAccountDataByTypeStmt *sql.Stmt } -func (s *accountDataStatements) prepare(db *sql.DB) (err error) { +func (s *accountDataStatements) prepare(db *sql.DB, writer sqlutil.TransactionWriter) (err error) { s.db = db - s.writer = sqlutil.NewTransactionWriter() + s.writer = writer _, err = db.Exec(accountDataSchema) if err != nil { return diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/accounts/sqlite3/accounts_table.go index 586bcab91..c7281406e 100644 --- a/userapi/storage/accounts/sqlite3/accounts_table.go +++ b/userapi/storage/accounts/sqlite3/accounts_table.go @@ -67,9 +67,9 @@ type accountsStatements struct { serverName gomatrixserverlib.ServerName } -func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { +func (s *accountsStatements) prepare(db *sql.DB, writer sqlutil.TransactionWriter, server gomatrixserverlib.ServerName) (err error) { s.db = db - s.writer = sqlutil.NewTransactionWriter() + s.writer = writer _, err = db.Exec(accountsSchema) if err != nil { return diff --git a/userapi/storage/accounts/sqlite3/profile_table.go b/userapi/storage/accounts/sqlite3/profile_table.go index cd35d2982..bff573667 100644 --- a/userapi/storage/accounts/sqlite3/profile_table.go +++ b/userapi/storage/accounts/sqlite3/profile_table.go @@ -61,9 +61,9 @@ type profilesStatements struct { selectProfilesBySearchStmt *sql.Stmt } -func (s *profilesStatements) prepare(db *sql.DB) (err error) { +func (s *profilesStatements) prepare(db *sql.DB, writer sqlutil.TransactionWriter) (err error) { s.db = db - s.writer = sqlutil.NewTransactionWriter() + s.writer = writer _, err = db.Exec(profilesSchema) if err != nil { return diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index 6e4dfaf3d..aa3d3c137 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -64,16 +64,16 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = partitions.Prepare(db, d.writer, "account"); err != nil { return nil, err } - if err = d.accounts.prepare(db, serverName); err != nil { + if err = d.accounts.prepare(db, d.writer, serverName); err != nil { return nil, err } - if err = d.profiles.prepare(db); err != nil { + if err = d.profiles.prepare(db, d.writer); err != nil { return nil, err } - if err = d.accountDatas.prepare(db); err != nil { + if err = d.accountDatas.prepare(db, d.writer); err != nil { return nil, err } - if err = d.threepids.prepare(db); err != nil { + if err = d.threepids.prepare(db, d.writer); err != nil { return nil, err } return d, nil diff --git a/userapi/storage/accounts/sqlite3/threepid_table.go b/userapi/storage/accounts/sqlite3/threepid_table.go index 3000d7c43..362196eaf 100644 --- a/userapi/storage/accounts/sqlite3/threepid_table.go +++ b/userapi/storage/accounts/sqlite3/threepid_table.go @@ -61,9 +61,9 @@ type threepidStatements struct { deleteThreePIDStmt *sql.Stmt } -func (s *threepidStatements) prepare(db *sql.DB) (err error) { +func (s *threepidStatements) prepare(db *sql.DB, writer sqlutil.TransactionWriter) (err error) { s.db = db - s.writer = sqlutil.NewTransactionWriter() + s.writer = writer _, err = db.Exec(threepidSchema) if err != nil { return diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go index 962e63b03..c902adc76 100644 --- a/userapi/storage/devices/sqlite3/devices_table.go +++ b/userapi/storage/devices/sqlite3/devices_table.go @@ -91,9 +91,9 @@ type devicesStatements struct { serverName gomatrixserverlib.ServerName } -func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { +func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.TransactionWriter, server gomatrixserverlib.ServerName) (err error) { s.db = db - s.writer = sqlutil.NewTransactionWriter() + s.writer = writer _, err = db.Exec(devicesSchema) if err != nil { return diff --git a/userapi/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go index 1f2b59f30..99d4b1fb9 100644 --- a/userapi/storage/devices/sqlite3/storage.go +++ b/userapi/storage/devices/sqlite3/storage.go @@ -34,6 +34,7 @@ var deviceIDByteLength = 6 // Database represents a device database. type Database struct { db *sql.DB + writer sqlutil.TransactionWriter devices devicesStatements } @@ -43,11 +44,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err != nil { return nil, err } + writer := sqlutil.NewTransactionWriter() d := devicesStatements{} - if err = d.prepare(db, serverName); err != nil { + if err = d.prepare(db, writer, serverName); err != nil { return nil, err } - return &Database{db, d}, nil + return &Database{db, writer, d}, nil } // GetDeviceByAccessToken returns the device matching the given access token. @@ -88,7 +90,7 @@ func (d *Database) CreateDevice( displayName *string, ) (dev *api.Device, returnErr error) { if deviceID != nil { - returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { var err error // Revoke existing tokens for this device if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { @@ -108,7 +110,7 @@ func (d *Database) CreateDevice( return } - returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { var err error dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName) return err