diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index 9f70885ad..900d1238f 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -20,6 +20,7 @@ import ( "time" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" ) @@ -54,6 +55,7 @@ const selectMaxStreamForUserSQL = "" + type deviceKeysStatements struct { db *sql.DB + writer *sqlutil.TransactionWriter upsertDeviceKeysStmt *sql.Stmt selectDeviceKeysStmt *sql.Stmt selectBatchDeviceKeysStmt *sql.Stmt @@ -62,7 +64,8 @@ type deviceKeysStatements struct { func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { s := &deviceKeysStatements{ - db: db, + db: db, + writer: sqlutil.NewTransactionWriter(), } _, err := db.Exec(deviceKeysSchema) if err != nil { @@ -141,14 +144,16 @@ func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn } func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { - for _, key := range keys { - now := time.Now().Unix() - _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( - ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, - ) - if err != nil { - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + for _, key := range keys { + now := time.Now().Unix() + _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( + ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, + ) + if err != nil { + return err + } } - } - return nil + return nil + }) } diff --git a/keyserver/storage/sqlite3/key_changes_table.go b/keyserver/storage/sqlite3/key_changes_table.go index 32721eaea..02b9d193e 100644 --- a/keyserver/storage/sqlite3/key_changes_table.go +++ b/keyserver/storage/sqlite3/key_changes_table.go @@ -21,6 +21,7 @@ import ( "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/storage/tables" ) @@ -51,13 +52,15 @@ const selectKeyChangesSQL = "" + type keyChangesStatements struct { db *sql.DB + writer *sqlutil.TransactionWriter upsertKeyChangeStmt *sql.Stmt selectKeyChangesStmt *sql.Stmt } func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { s := &keyChangesStatements{ - db: db, + db: db, + writer: sqlutil.NewTransactionWriter(), } _, err := db.Exec(keyChangesSchema) if err != nil { @@ -73,8 +76,10 @@ func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { } func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error { - _, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID) - return err + return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + _, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID) + return err + }) } func (s *keyChangesStatements) SelectKeyChanges( diff --git a/keyserver/storage/sqlite3/one_time_keys_table.go b/keyserver/storage/sqlite3/one_time_keys_table.go index b35407cd4..f910479f5 100644 --- a/keyserver/storage/sqlite3/one_time_keys_table.go +++ b/keyserver/storage/sqlite3/one_time_keys_table.go @@ -60,6 +60,7 @@ const selectKeyByAlgorithmSQL = "" + type oneTimeKeysStatements struct { db *sql.DB + writer *sqlutil.TransactionWriter upsertKeysStmt *sql.Stmt selectKeysStmt *sql.Stmt selectKeysCountStmt *sql.Stmt @@ -69,7 +70,8 @@ type oneTimeKeysStatements struct { func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { s := &oneTimeKeysStatements{ - db: db, + db: db, + writer: sqlutil.NewTransactionWriter(), } _, err := db.Exec(oneTimeKeysSchema) if err != nil { @@ -150,7 +152,7 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api. UserID: keys.UserID, KeyCount: make(map[string]int), } - return counts, sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error { + return counts, s.writer.Do(s.db, nil, func(txn *sql.Tx) error { for keyIDWithAlgo, keyJSON := range keys.KeyJSON { algo, keyID := keys.Split(keyIDWithAlgo) _, err := txn.Stmt(s.upsertKeysStmt).ExecContext( @@ -183,14 +185,17 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( ) (map[string]json.RawMessage, error) { var keyID string var keyJSON string - err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil + err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) + if err != nil { + if err == sql.ErrNoRows { + return nil + } + return err } - return nil, err - } - _, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) + _, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) + return err + }) return map[string]json.RawMessage{ algorithm + ":" + keyID: json.RawMessage(keyJSON), }, err diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go index b3e45e6cc..949d9dd6c 100644 --- a/keyserver/storage/storage_test.go +++ b/keyserver/storage/storage_test.go @@ -2,6 +2,10 @@ package storage import ( "context" + "fmt" + "io/ioutil" + "log" + "os" "reflect" "testing" @@ -11,6 +15,21 @@ import ( var ctx = context.Background() +func MustCreateDatabase(t *testing.T) (Database, func()) { + tmpfile, err := ioutil.TempFile("", "keyserver_storage_test") + if err != nil { + log.Fatal(err) + } + t.Logf("Database %s", tmpfile.Name()) + db, err := NewDatabase(fmt.Sprintf("file://%s", tmpfile.Name()), nil) + if err != nil { + t.Fatalf("Failed to NewDatabase: %s", err) + } + return db, func() { + os.Remove(tmpfile.Name()) + } +} + func MustNotError(t *testing.T, err error) { t.Helper() if err == nil { @@ -20,10 +39,8 @@ func MustNotError(t *testing.T, err error) { } func TestKeyChanges(t *testing.T) { - db, err := NewDatabase("file::memory:", nil) - if err != nil { - t.Fatalf("Failed to NewDatabase: %s", err) - } + db, clean := MustCreateDatabase(t) + defer clean() MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost")) MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@bob:localhost")) MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@charlie:localhost")) @@ -40,10 +57,8 @@ func TestKeyChanges(t *testing.T) { } func TestKeyChangesNoDupes(t *testing.T) { - db, err := NewDatabase("file::memory:", nil) - if err != nil { - t.Fatalf("Failed to NewDatabase: %s", err) - } + db, clean := MustCreateDatabase(t) + defer clean() MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost")) MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@alice:localhost")) MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@alice:localhost")) @@ -60,10 +75,8 @@ func TestKeyChangesNoDupes(t *testing.T) { } func TestKeyChangesUpperLimit(t *testing.T) { - db, err := NewDatabase("file::memory:", nil) - if err != nil { - t.Fatalf("Failed to NewDatabase: %s", err) - } + db, clean := MustCreateDatabase(t) + defer clean() MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost")) MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@bob:localhost")) MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@charlie:localhost")) @@ -82,10 +95,9 @@ func TestKeyChangesUpperLimit(t *testing.T) { // The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user, // and that they are returned correctly when querying for device keys. func TestDeviceKeysStreamIDGeneration(t *testing.T) { - db, err := NewDatabase("file::memory:", nil) - if err != nil { - t.Fatalf("Failed to NewDatabase: %s", err) - } + var err error + db, clean := MustCreateDatabase(t) + defer clean() alice := "@alice:TestDeviceKeysStreamIDGeneration" bob := "@bob:TestDeviceKeysStreamIDGeneration" msgs := []api.DeviceMessage{