diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index 9f70885ad..900d1238f 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -20,6 +20,7 @@ import ( "time" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" ) @@ -54,6 +55,7 @@ const selectMaxStreamForUserSQL = "" + type deviceKeysStatements struct { db *sql.DB + writer *sqlutil.TransactionWriter upsertDeviceKeysStmt *sql.Stmt selectDeviceKeysStmt *sql.Stmt selectBatchDeviceKeysStmt *sql.Stmt @@ -62,7 +64,8 @@ type deviceKeysStatements struct { func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { s := &deviceKeysStatements{ - db: db, + db: db, + writer: sqlutil.NewTransactionWriter(), } _, err := db.Exec(deviceKeysSchema) if err != nil { @@ -141,14 +144,16 @@ func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn } func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { - for _, key := range keys { - now := time.Now().Unix() - _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( - ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, - ) - if err != nil { - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + for _, key := range keys { + now := time.Now().Unix() + _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( + ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, + ) + if err != nil { + return err + } } - } - return nil + return nil + }) } diff --git a/keyserver/storage/sqlite3/key_changes_table.go b/keyserver/storage/sqlite3/key_changes_table.go index 32721eaea..02b9d193e 100644 --- a/keyserver/storage/sqlite3/key_changes_table.go +++ b/keyserver/storage/sqlite3/key_changes_table.go @@ -21,6 +21,7 @@ import ( "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/storage/tables" ) @@ -51,13 +52,15 @@ const selectKeyChangesSQL = "" + type keyChangesStatements struct { db *sql.DB + writer *sqlutil.TransactionWriter upsertKeyChangeStmt *sql.Stmt selectKeyChangesStmt *sql.Stmt } func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { s := &keyChangesStatements{ - db: db, + db: db, + writer: sqlutil.NewTransactionWriter(), } _, err := db.Exec(keyChangesSchema) if err != nil { @@ -73,8 +76,10 @@ func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { } func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error { - _, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID) - return err + return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + _, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID) + return err + }) } func (s *keyChangesStatements) SelectKeyChanges( diff --git a/keyserver/storage/sqlite3/one_time_keys_table.go b/keyserver/storage/sqlite3/one_time_keys_table.go index b35407cd4..f910479f5 100644 --- a/keyserver/storage/sqlite3/one_time_keys_table.go +++ b/keyserver/storage/sqlite3/one_time_keys_table.go @@ -60,6 +60,7 @@ const selectKeyByAlgorithmSQL = "" + type oneTimeKeysStatements struct { db *sql.DB + writer *sqlutil.TransactionWriter upsertKeysStmt *sql.Stmt selectKeysStmt *sql.Stmt selectKeysCountStmt *sql.Stmt @@ -69,7 +70,8 @@ type oneTimeKeysStatements struct { func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { s := &oneTimeKeysStatements{ - db: db, + db: db, + writer: sqlutil.NewTransactionWriter(), } _, err := db.Exec(oneTimeKeysSchema) if err != nil { @@ -150,7 +152,7 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api. UserID: keys.UserID, KeyCount: make(map[string]int), } - return counts, sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error { + return counts, s.writer.Do(s.db, nil, func(txn *sql.Tx) error { for keyIDWithAlgo, keyJSON := range keys.KeyJSON { algo, keyID := keys.Split(keyIDWithAlgo) _, err := txn.Stmt(s.upsertKeysStmt).ExecContext( @@ -183,14 +185,17 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( ) (map[string]json.RawMessage, error) { var keyID string var keyJSON string - err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil + err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) + if err != nil { + if err == sql.ErrNoRows { + return nil + } + return err } - return nil, err - } - _, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) + _, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) + return err + }) return map[string]json.RawMessage{ algorithm + ":" + keyID: json.RawMessage(keyJSON), }, err