diff --git a/keyserver/storage/postgres/one_time_keys_table.go b/keyserver/storage/postgres/one_time_keys_table.go index df215d5a8..a299861df 100644 --- a/keyserver/storage/postgres/one_time_keys_table.go +++ b/keyserver/storage/postgres/one_time_keys_table.go @@ -21,7 +21,6 @@ import ( "time" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" ) @@ -143,39 +142,37 @@ func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, de return counts, nil } -func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) { +func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) { now := time.Now().Unix() counts := &api.OneTimeKeysCount{ DeviceID: keys.DeviceID, UserID: keys.UserID, KeyCount: make(map[string]int), } - return counts, sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error { - for keyIDWithAlgo, keyJSON := range keys.KeyJSON { - algo, keyID := keys.Split(keyIDWithAlgo) - _, err := txn.Stmt(s.upsertKeysStmt).ExecContext( - ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON), - ) - if err != nil { - return err - } - } - rows, err := txn.Stmt(s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID) + for keyIDWithAlgo, keyJSON := range keys.KeyJSON { + algo, keyID := keys.Split(keyIDWithAlgo) + _, err := txn.Stmt(s.upsertKeysStmt).ExecContext( + ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON), + ) if err != nil { - return err + return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") - for rows.Next() { - var algorithm string - var count int - if err = rows.Scan(&algorithm, &count); err != nil { - return err - } - counts.KeyCount[algorithm] = count + } + rows, err := txn.Stmt(s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") + for rows.Next() { + var algorithm string + var count int + if err = rows.Scan(&algorithm, &count); err != nil { + return nil, err } + counts.KeyCount[algorithm] = count + } - return rows.Err() - }) + return counts, rows.Err() } func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( diff --git a/keyserver/storage/postgres/storage.go b/keyserver/storage/postgres/storage.go index 1c693f5b2..783303c0e 100644 --- a/keyserver/storage/postgres/storage.go +++ b/keyserver/storage/postgres/storage.go @@ -45,6 +45,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) } return &shared.Database{ DB: db, + Writer: sqlutil.NewDummyWriter(), OneTimeKeysTable: otk, DeviceKeysTable: dk, KeyChangesTable: kc, diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index a4c35a4bd..d4915afc1 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -27,6 +27,7 @@ import ( type Database struct { DB *sql.DB + Writer sqlutil.Writer OneTimeKeysTable tables.OneTimeKeys DeviceKeysTable tables.DeviceKeys KeyChangesTable tables.KeyChanges @@ -37,8 +38,12 @@ func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID str return d.OneTimeKeysTable.SelectOneTimeKeys(ctx, userID, deviceID, keyIDsWithAlgorithms) } -func (d *Database) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) { - return d.OneTimeKeysTable.InsertOneTimeKeys(ctx, keys) +func (d *Database) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (counts *api.OneTimeKeysCount, err error) { + _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + counts, err = d.OneTimeKeysTable.InsertOneTimeKeys(ctx, txn, keys) + return nil + }) + return } func (d *Database) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { @@ -62,7 +67,7 @@ func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []i } func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error { - return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { for _, userID := range clearUserIDs { err := d.DeviceKeysTable.DeleteAllDeviceKeys(ctx, txn, userID) if err != nil { @@ -79,7 +84,7 @@ func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMe for _, k := range keys { userIDToStreamID[k.UserID] = 0 } - return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { for userID := range userIDToStreamID { streamID, err := d.DeviceKeysTable.SelectMaxStreamIDForUser(ctx, txn, userID) if err != nil { @@ -104,7 +109,7 @@ func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceI func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) { var result []api.OneTimeKeys - err := sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { for userID, deviceToAlgo := range userToDeviceToAlgorithm { for deviceID, algo := range deviceToAlgo { keyJSON, err := d.OneTimeKeysTable.SelectAndDeleteOneTimeKey(ctx, txn, userID, deviceID, algo) @@ -126,7 +131,9 @@ func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[st } func (d *Database) StoreKeyChange(ctx context.Context, partition int32, offset int64, userID string) error { - return d.KeyChangesTable.InsertKeyChange(ctx, partition, offset, userID) + return d.Writer.Do(nil, nil, func(_ *sql.Tx) error { + return d.KeyChangesTable.InsertKeyChange(ctx, partition, offset, userID) + }) } func (d *Database) KeyChanges(ctx context.Context, partition int32, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) { @@ -141,5 +148,7 @@ func (d *Database) StaleDeviceLists(ctx context.Context, domains []gomatrixserve // MarkDeviceListStale sets the stale bit for this user to isStale. func (d *Database) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error { - return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale) + return d.Writer.Do(nil, nil, func(_ *sql.Tx) error { + return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale) + }) } diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index 2af337613..195429f08 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -63,7 +63,6 @@ const deleteAllDeviceKeysSQL = "" + type deviceKeysStatements struct { db *sql.DB - writer sqlutil.Writer upsertDeviceKeysStmt *sql.Stmt selectDeviceKeysStmt *sql.Stmt selectBatchDeviceKeysStmt *sql.Stmt @@ -71,10 +70,9 @@ type deviceKeysStatements struct { deleteAllDeviceKeysStmt *sql.Stmt } -func NewSqliteDeviceKeysTable(db *sql.DB, writer sqlutil.Writer) (tables.DeviceKeys, error) { +func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { s := &deviceKeysStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(deviceKeysSchema) if err != nil { @@ -188,16 +186,14 @@ func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID } func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - for _, key := range keys { - now := time.Now().Unix() - _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( - ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName, - ) - if err != nil { - return err - } + for _, key := range keys { + now := time.Now().Unix() + _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( + ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName, + ) + if err != nil { + return err } - return nil - }) + } + return nil } diff --git a/keyserver/storage/sqlite3/key_changes_table.go b/keyserver/storage/sqlite3/key_changes_table.go index cd1784130..32721eaea 100644 --- a/keyserver/storage/sqlite3/key_changes_table.go +++ b/keyserver/storage/sqlite3/key_changes_table.go @@ -21,7 +21,6 @@ import ( "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/storage/tables" ) @@ -52,15 +51,13 @@ const selectKeyChangesSQL = "" + type keyChangesStatements struct { db *sql.DB - writer sqlutil.Writer upsertKeyChangeStmt *sql.Stmt selectKeyChangesStmt *sql.Stmt } -func NewSqliteKeyChangesTable(db *sql.DB, writer sqlutil.Writer) (tables.KeyChanges, error) { +func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { s := &keyChangesStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(keyChangesSchema) if err != nil { @@ -76,10 +73,8 @@ func NewSqliteKeyChangesTable(db *sql.DB, writer sqlutil.Writer) (tables.KeyChan } func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error { - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - _, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID) - return err - }) + _, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID) + return err } func (s *keyChangesStatements) SelectKeyChanges( diff --git a/keyserver/storage/sqlite3/one_time_keys_table.go b/keyserver/storage/sqlite3/one_time_keys_table.go index d788f6768..1b6a74d6b 100644 --- a/keyserver/storage/sqlite3/one_time_keys_table.go +++ b/keyserver/storage/sqlite3/one_time_keys_table.go @@ -21,7 +21,6 @@ import ( "time" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" ) @@ -60,7 +59,6 @@ const selectKeyByAlgorithmSQL = "" + type oneTimeKeysStatements struct { db *sql.DB - writer sqlutil.Writer upsertKeysStmt *sql.Stmt selectKeysStmt *sql.Stmt selectKeysCountStmt *sql.Stmt @@ -68,10 +66,9 @@ type oneTimeKeysStatements struct { deleteOneTimeKeyStmt *sql.Stmt } -func NewSqliteOneTimeKeysTable(db *sql.DB, writer sqlutil.Writer) (tables.OneTimeKeys, error) { +func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { s := &oneTimeKeysStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(oneTimeKeysSchema) if err != nil { @@ -145,39 +142,39 @@ func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, de return counts, nil } -func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) { +func (s *oneTimeKeysStatements) InsertOneTimeKeys( + ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys, +) (*api.OneTimeKeysCount, error) { now := time.Now().Unix() counts := &api.OneTimeKeysCount{ DeviceID: keys.DeviceID, UserID: keys.UserID, KeyCount: make(map[string]int), } - return counts, s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - for keyIDWithAlgo, keyJSON := range keys.KeyJSON { - algo, keyID := keys.Split(keyIDWithAlgo) - _, err := txn.Stmt(s.upsertKeysStmt).ExecContext( - ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON), - ) - if err != nil { - return err - } - } - rows, err := txn.Stmt(s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID) + for keyIDWithAlgo, keyJSON := range keys.KeyJSON { + algo, keyID := keys.Split(keyIDWithAlgo) + _, err := txn.Stmt(s.upsertKeysStmt).ExecContext( + ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON), + ) if err != nil { - return err + return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") - for rows.Next() { - var algorithm string - var count int - if err = rows.Scan(&algorithm, &count); err != nil { - return err - } - counts.KeyCount[algorithm] = count + } + rows, err := txn.Stmt(s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") + for rows.Next() { + var algorithm string + var count int + if err = rows.Scan(&algorithm, &count); err != nil { + return nil, err } + counts.KeyCount[algorithm] = count + } - return rows.Err() - }) + return counts, rows.Err() } func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( @@ -185,17 +182,17 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( ) (map[string]json.RawMessage, error) { var keyID string var keyJSON string - err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) - if err != nil { - if err == sql.ErrNoRows { - return nil - } - return err + err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil } - _, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) - return err - }) + return nil, err + } + _, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) + if err != nil { + return nil, err + } if keyJSON == "" { return nil, nil } diff --git a/keyserver/storage/sqlite3/stale_device_lists.go b/keyserver/storage/sqlite3/stale_device_lists.go index 8b6f88135..fc2cc37c4 100644 --- a/keyserver/storage/sqlite3/stale_device_lists.go +++ b/keyserver/storage/sqlite3/stale_device_lists.go @@ -20,7 +20,6 @@ import ( "time" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/gomatrixserverlib" ) @@ -51,16 +50,14 @@ const selectStaleDeviceListsSQL = "" + type staleDeviceListsStatements struct { db *sql.DB - writer sqlutil.Writer upsertStaleDeviceListStmt *sql.Stmt selectStaleDeviceListsWithDomainsStmt *sql.Stmt selectStaleDeviceListsStmt *sql.Stmt } -func NewSqliteStaleDeviceListsTable(db *sql.DB, writer sqlutil.Writer) (tables.StaleDeviceLists, error) { +func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { s := &staleDeviceListsStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(staleDeviceListsSchema) if err != nil { @@ -83,11 +80,8 @@ func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, if err != nil { return err } - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.upsertStaleDeviceListStmt) - _, err = stmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix()) - return err - }) + _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix()) + return err } func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { diff --git a/keyserver/storage/sqlite3/storage.go b/keyserver/storage/sqlite3/storage.go index 1a2a237f6..1d5382c06 100644 --- a/keyserver/storage/sqlite3/storage.go +++ b/keyserver/storage/sqlite3/storage.go @@ -25,25 +25,25 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) if err != nil { return nil, err } - writer := sqlutil.NewExclusiveWriter() - otk, err := NewSqliteOneTimeKeysTable(db, writer) + otk, err := NewSqliteOneTimeKeysTable(db) if err != nil { return nil, err } - dk, err := NewSqliteDeviceKeysTable(db, writer) + dk, err := NewSqliteDeviceKeysTable(db) if err != nil { return nil, err } - kc, err := NewSqliteKeyChangesTable(db, writer) + kc, err := NewSqliteKeyChangesTable(db) if err != nil { return nil, err } - sdl, err := NewSqliteStaleDeviceListsTable(db, writer) + sdl, err := NewSqliteStaleDeviceListsTable(db) if err != nil { return nil, err } return &shared.Database{ DB: db, + Writer: sqlutil.NewExclusiveWriter(), OneTimeKeysTable: otk, DeviceKeysTable: dk, KeyChangesTable: kc, diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index f97e871f6..b70c9bce6 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -26,7 +26,7 @@ import ( type OneTimeKeys interface { SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) - InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) + InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) // SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON. // Returns an empty map if the key does not exist. SelectAndDeleteOneTimeKey(ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string) (map[string]json.RawMessage, error)