From 7f89298615888003cd71322e00313a572a2cd771 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 20 Aug 2020 17:25:32 +0100 Subject: [PATCH] Refactor TransactionWriter in key server --- keyserver/storage/sqlite3/device_keys_table.go | 4 ++-- keyserver/storage/sqlite3/key_changes_table.go | 4 ++-- .../storage/sqlite3/one_time_keys_table.go | 4 ++-- keyserver/storage/sqlite3/stale_device_lists.go | 17 +++++++++++++---- keyserver/storage/sqlite3/storage.go | 9 +++++---- 5 files changed, 24 insertions(+), 14 deletions(-) diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index c95790be7..80648e330 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -71,10 +71,10 @@ type deviceKeysStatements struct { deleteAllDeviceKeysStmt *sql.Stmt } -func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { +func NewSqliteDeviceKeysTable(db *sql.DB, writer sqlutil.TransactionWriter) (tables.DeviceKeys, error) { s := &deviceKeysStatements{ db: db, - writer: sqlutil.NewTransactionWriter(), + writer: writer, } _, err := db.Exec(deviceKeysSchema) if err != nil { diff --git a/keyserver/storage/sqlite3/key_changes_table.go b/keyserver/storage/sqlite3/key_changes_table.go index f451d657b..92c865414 100644 --- a/keyserver/storage/sqlite3/key_changes_table.go +++ b/keyserver/storage/sqlite3/key_changes_table.go @@ -57,10 +57,10 @@ type keyChangesStatements struct { selectKeyChangesStmt *sql.Stmt } -func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { +func NewSqliteKeyChangesTable(db *sql.DB, writer sqlutil.TransactionWriter) (tables.KeyChanges, error) { s := &keyChangesStatements{ db: db, - writer: sqlutil.NewTransactionWriter(), + writer: writer, } _, err := db.Exec(keyChangesSchema) if err != nil { diff --git a/keyserver/storage/sqlite3/one_time_keys_table.go b/keyserver/storage/sqlite3/one_time_keys_table.go index c71cc47d1..e720e131e 100644 --- a/keyserver/storage/sqlite3/one_time_keys_table.go +++ b/keyserver/storage/sqlite3/one_time_keys_table.go @@ -68,10 +68,10 @@ type oneTimeKeysStatements struct { deleteOneTimeKeyStmt *sql.Stmt } -func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { +func NewSqliteOneTimeKeysTable(db *sql.DB, writer sqlutil.TransactionWriter) (tables.OneTimeKeys, error) { s := &oneTimeKeysStatements{ db: db, - writer: sqlutil.NewTransactionWriter(), + writer: writer, } _, err := db.Exec(oneTimeKeysSchema) if err != nil { diff --git a/keyserver/storage/sqlite3/stale_device_lists.go b/keyserver/storage/sqlite3/stale_device_lists.go index a989476d1..d34f2f858 100644 --- a/keyserver/storage/sqlite3/stale_device_lists.go +++ b/keyserver/storage/sqlite3/stale_device_lists.go @@ -20,6 +20,7 @@ import ( "time" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/gomatrixserverlib" ) @@ -49,13 +50,18 @@ const selectStaleDeviceListsSQL = "" + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1" type staleDeviceListsStatements struct { + db *sql.DB + writer sqlutil.TransactionWriter upsertStaleDeviceListStmt *sql.Stmt selectStaleDeviceListsWithDomainsStmt *sql.Stmt selectStaleDeviceListsStmt *sql.Stmt } -func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { - s := &staleDeviceListsStatements{} +func NewSqliteStaleDeviceListsTable(db *sql.DB, writer sqlutil.TransactionWriter) (tables.StaleDeviceLists, error) { + s := &staleDeviceListsStatements{ + db: db, + writer: writer, + } _, err := db.Exec(staleDeviceListsSchema) if err != nil { return nil, err @@ -77,8 +83,11 @@ func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, if err != nil { return err } - _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix()) - return err + return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.upsertStaleDeviceListStmt) + _, err = stmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix()) + return err + }) } func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { diff --git a/keyserver/storage/sqlite3/storage.go b/keyserver/storage/sqlite3/storage.go index bb2935582..950bd0cfc 100644 --- a/keyserver/storage/sqlite3/storage.go +++ b/keyserver/storage/sqlite3/storage.go @@ -25,19 +25,20 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) if err != nil { return nil, err } - otk, err := NewSqliteOneTimeKeysTable(db) + writer := sqlutil.NewTransactionWriter() + otk, err := NewSqliteOneTimeKeysTable(db, writer) if err != nil { return nil, err } - dk, err := NewSqliteDeviceKeysTable(db) + dk, err := NewSqliteDeviceKeysTable(db, writer) if err != nil { return nil, err } - kc, err := NewSqliteKeyChangesTable(db) + kc, err := NewSqliteKeyChangesTable(db, writer) if err != nil { return nil, err } - sdl, err := NewSqliteStaleDeviceListsTable(db) + sdl, err := NewSqliteStaleDeviceListsTable(db, writer) if err != nil { return nil, err }