diff --git a/federationsender/storage/sqlite3/joined_hosts_table.go b/federationsender/storage/sqlite3/joined_hosts_table.go index 4ae980d7c..53736fa16 100644 --- a/federationsender/storage/sqlite3/joined_hosts_table.go +++ b/federationsender/storage/sqlite3/joined_hosts_table.go @@ -18,6 +18,7 @@ package sqlite3 import ( "context" "database/sql" + "strings" "github.com/matrix-org/dendrite/federationsender/types" "github.com/matrix-org/dendrite/internal" @@ -63,13 +64,13 @@ const selectJoinedHostsForRoomsSQL = "" + "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)" type joinedHostsStatements struct { - db *sql.DB - writer *sqlutil.TransactionWriter - insertJoinedHostsStmt *sql.Stmt - deleteJoinedHostsStmt *sql.Stmt - selectJoinedHostsStmt *sql.Stmt - selectAllJoinedHostsStmt *sql.Stmt - selectJoinedHostsForRoomsStmt *sql.Stmt + db *sql.DB + writer *sqlutil.TransactionWriter + insertJoinedHostsStmt *sql.Stmt + deleteJoinedHostsStmt *sql.Stmt + selectJoinedHostsStmt *sql.Stmt + selectAllJoinedHostsStmt *sql.Stmt + // selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic } func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { @@ -93,9 +94,6 @@ func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) if s.selectAllJoinedHostsStmt, err = db.Prepare(selectAllJoinedHostsSQL); err != nil { return } - if s.selectJoinedHostsForRoomsStmt, err = db.Prepare(selectJoinedHostsForRoomsSQL); err != nil { - return - } return } @@ -168,7 +166,8 @@ func (s *joinedHostsStatements) SelectJoinedHostsForRooms( iRoomIDs[i] = roomIDs[i] } - rows, err := s.selectJoinedHostsForRoomsStmt.QueryContext(ctx, iRoomIDs...) + sql := strings.Replace(selectJoinedHostsForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1) + rows, err := s.db.QueryContext(ctx, sql, iRoomIDs...) if err != nil { return nil, err } diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index 9f70885ad..900d1238f 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -20,6 +20,7 @@ import ( "time" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" ) @@ -54,6 +55,7 @@ const selectMaxStreamForUserSQL = "" + type deviceKeysStatements struct { db *sql.DB + writer *sqlutil.TransactionWriter upsertDeviceKeysStmt *sql.Stmt selectDeviceKeysStmt *sql.Stmt selectBatchDeviceKeysStmt *sql.Stmt @@ -62,7 +64,8 @@ type deviceKeysStatements struct { func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { s := &deviceKeysStatements{ - db: db, + db: db, + writer: sqlutil.NewTransactionWriter(), } _, err := db.Exec(deviceKeysSchema) if err != nil { @@ -141,14 +144,16 @@ func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn } func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { - for _, key := range keys { - now := time.Now().Unix() - _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( - ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, - ) - if err != nil { - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + for _, key := range keys { + now := time.Now().Unix() + _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( + ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, + ) + if err != nil { + return err + } } - } - return nil + return nil + }) } diff --git a/keyserver/storage/sqlite3/key_changes_table.go b/keyserver/storage/sqlite3/key_changes_table.go index 32721eaea..02b9d193e 100644 --- a/keyserver/storage/sqlite3/key_changes_table.go +++ b/keyserver/storage/sqlite3/key_changes_table.go @@ -21,6 +21,7 @@ import ( "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/storage/tables" ) @@ -51,13 +52,15 @@ const selectKeyChangesSQL = "" + type keyChangesStatements struct { db *sql.DB + writer *sqlutil.TransactionWriter upsertKeyChangeStmt *sql.Stmt selectKeyChangesStmt *sql.Stmt } func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { s := &keyChangesStatements{ - db: db, + db: db, + writer: sqlutil.NewTransactionWriter(), } _, err := db.Exec(keyChangesSchema) if err != nil { @@ -73,8 +76,10 @@ func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { } func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error { - _, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID) - return err + return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + _, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID) + return err + }) } func (s *keyChangesStatements) SelectKeyChanges( diff --git a/keyserver/storage/sqlite3/one_time_keys_table.go b/keyserver/storage/sqlite3/one_time_keys_table.go index b35407cd4..f910479f5 100644 --- a/keyserver/storage/sqlite3/one_time_keys_table.go +++ b/keyserver/storage/sqlite3/one_time_keys_table.go @@ -60,6 +60,7 @@ const selectKeyByAlgorithmSQL = "" + type oneTimeKeysStatements struct { db *sql.DB + writer *sqlutil.TransactionWriter upsertKeysStmt *sql.Stmt selectKeysStmt *sql.Stmt selectKeysCountStmt *sql.Stmt @@ -69,7 +70,8 @@ type oneTimeKeysStatements struct { func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { s := &oneTimeKeysStatements{ - db: db, + db: db, + writer: sqlutil.NewTransactionWriter(), } _, err := db.Exec(oneTimeKeysSchema) if err != nil { @@ -150,7 +152,7 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api. UserID: keys.UserID, KeyCount: make(map[string]int), } - return counts, sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error { + return counts, s.writer.Do(s.db, nil, func(txn *sql.Tx) error { for keyIDWithAlgo, keyJSON := range keys.KeyJSON { algo, keyID := keys.Split(keyIDWithAlgo) _, err := txn.Stmt(s.upsertKeysStmt).ExecContext( @@ -183,14 +185,17 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( ) (map[string]json.RawMessage, error) { var keyID string var keyJSON string - err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil + err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) + if err != nil { + if err == sql.ErrNoRows { + return nil + } + return err } - return nil, err - } - _, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) + _, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) + return err + }) return map[string]json.RawMessage{ algorithm + ":" + keyID: json.RawMessage(keyJSON), }, err diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go index 4e1c1ff16..4f131baa0 100644 --- a/keyserver/storage/storage_test.go +++ b/keyserver/storage/storage_test.go @@ -2,6 +2,10 @@ package storage import ( "context" + "fmt" + "io/ioutil" + "log" + "os" "reflect" "testing" @@ -12,6 +16,23 @@ import ( var ctx = context.Background() +func MustCreateDatabase(t *testing.T) (Database, func()) { + tmpfile, err := ioutil.TempFile("", "keyserver_storage_test") + if err != nil { + log.Fatal(err) + } + t.Logf("Database %s", tmpfile.Name()) + db, err := NewDatabase(&config.DatabaseOptions{ + ConnectionString: config.DataSource(fmt.Sprintf("file://%s", tmpfile.Name())), + }) + if err != nil { + t.Fatalf("Failed to NewDatabase: %s", err) + } + return db, func() { + os.Remove(tmpfile.Name()) + } +} + func MustNotError(t *testing.T, err error) { t.Helper() if err == nil { @@ -21,12 +42,8 @@ func MustNotError(t *testing.T, err error) { } func TestKeyChanges(t *testing.T) { - db, err := NewDatabase(&config.DatabaseOptions{ - ConnectionString: config.DataSource("file::memory:"), - }) - if err != nil { - t.Fatalf("Failed to NewDatabase: %s", err) - } + db, clean := MustCreateDatabase(t) + defer clean() MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost")) MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@bob:localhost")) MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@charlie:localhost")) @@ -43,12 +60,8 @@ func TestKeyChanges(t *testing.T) { } func TestKeyChangesNoDupes(t *testing.T) { - db, err := NewDatabase(&config.DatabaseOptions{ - ConnectionString: config.DataSource("file::memory:"), - }) - if err != nil { - t.Fatalf("Failed to NewDatabase: %s", err) - } + db, clean := MustCreateDatabase(t) + defer clean() MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost")) MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@alice:localhost")) MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@alice:localhost")) @@ -65,12 +78,8 @@ func TestKeyChangesNoDupes(t *testing.T) { } func TestKeyChangesUpperLimit(t *testing.T) { - db, err := NewDatabase(&config.DatabaseOptions{ - ConnectionString: "file::memory:", - }) - if err != nil { - t.Fatalf("Failed to NewDatabase: %s", err) - } + db, clean := MustCreateDatabase(t) + defer clean() MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost")) MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@bob:localhost")) MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@charlie:localhost")) @@ -89,12 +98,9 @@ func TestKeyChangesUpperLimit(t *testing.T) { // The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user, // and that they are returned correctly when querying for device keys. func TestDeviceKeysStreamIDGeneration(t *testing.T) { - db, err := NewDatabase(&config.DatabaseOptions{ - ConnectionString: "file::memory:", - }) - if err != nil { - t.Fatalf("Failed to NewDatabase: %s", err) - } + var err error + db, clean := MustCreateDatabase(t) + defer clean() alice := "@alice:TestDeviceKeysStreamIDGeneration" bob := "@bob:TestDeviceKeysStreamIDGeneration" msgs := []api.DeviceMessage{