From b24747b305a0770fdd746655e702aa1c1c049765 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 19 Aug 2020 15:38:27 +0100 Subject: [PATCH] Transaction writer changes, move roomserver writers (#1285) * Updated TransactionWriters, moved locks in roomserver, various other tweaks * Fix redaction deadlocks * Fix lint issue * Rename SQLiteTransactionWriter to ExclusiveTransactionWriter * Fix us not sending transactions through in latest events updater --- .../sqlite3/appservice_events_table.go | 2 +- .../storage/sqlite3/txn_id_counter_table.go | 2 +- .../sqlite3/current_room_state_table.go | 2 +- .../storage/postgres/blacklist_table.go | 20 ++--- .../storage/sqlite3/blacklist_table.go | 2 +- .../storage/sqlite3/joined_hosts_table.go | 2 +- .../storage/sqlite3/queue_edus_table.go | 2 +- .../storage/sqlite3/queue_json_table.go | 2 +- .../storage/sqlite3/queue_pdus_table.go | 2 +- .../storage/sqlite3/room_table.go | 2 +- internal/sqlutil/sql.go | 71 +----------------- internal/sqlutil/writer_dummy.go | 22 ++++++ internal/sqlutil/writer_exclusive.go | 75 +++++++++++++++++++ .../storage/sqlite3/device_keys_table.go | 2 +- .../storage/sqlite3/key_changes_table.go | 2 +- .../storage/sqlite3/one_time_keys_table.go | 2 +- .../storage/sqlite3/media_repository_table.go | 2 +- roomserver/internal/input_latest_events.go | 36 +++++---- roomserver/state/state.go | 22 ++++-- roomserver/storage/postgres/storage.go | 1 + .../storage/shared/latest_events_updater.go | 26 +++++-- .../storage/shared/membership_updater.go | 34 +++++---- roomserver/storage/shared/storage.go | 43 +++++++---- .../storage/sqlite3/event_json_table.go | 12 +-- .../storage/sqlite3/event_state_keys_table.go | 28 +++---- .../storage/sqlite3/event_types_table.go | 27 ++++--- roomserver/storage/sqlite3/events_table.go | 54 ++++++------- roomserver/storage/sqlite3/invite_table.go | 66 +++++++--------- .../storage/sqlite3/membership_table.go | 26 +++---- .../storage/sqlite3/previous_events_table.go | 18 ++--- roomserver/storage/sqlite3/published_table.go | 16 ++-- .../storage/sqlite3/redactions_table.go | 22 ++---- .../storage/sqlite3/room_aliases_table.go | 25 ++----- roomserver/storage/sqlite3/rooms_table.go | 46 +++++------- .../storage/sqlite3/state_block_table.go | 37 ++++----- .../storage/sqlite3/state_snapshot_table.go | 29 +++---- roomserver/storage/sqlite3/storage.go | 32 ++++---- .../storage/sqlite3/transactions_table.go | 20 ++--- .../storage/sqlite3/server_key_table.go | 2 +- syncapi/storage/shared/syncserver.go | 2 +- syncapi/storage/sqlite3/account_data_table.go | 2 +- .../sqlite3/backwards_extremities_table.go | 2 +- .../sqlite3/current_room_state_table.go | 2 +- syncapi/storage/sqlite3/filter_table.go | 2 +- syncapi/storage/sqlite3/invites_table.go | 2 +- .../sqlite3/output_room_events_table.go | 2 +- .../output_room_events_topology_table.go | 2 +- .../storage/sqlite3/send_to_device_table.go | 2 +- syncapi/storage/sqlite3/stream_id_table.go | 2 +- .../accounts/sqlite3/account_data_table.go | 2 +- .../accounts/sqlite3/accounts_table.go | 2 +- .../storage/accounts/sqlite3/profile_table.go | 2 +- .../accounts/sqlite3/threepid_table.go | 2 +- .../storage/devices/sqlite3/devices_table.go | 2 +- 54 files changed, 432 insertions(+), 434 deletions(-) create mode 100644 internal/sqlutil/writer_dummy.go create mode 100644 internal/sqlutil/writer_exclusive.go diff --git a/appservice/storage/sqlite3/appservice_events_table.go b/appservice/storage/sqlite3/appservice_events_table.go index da31f2359..5cc07ed34 100644 --- a/appservice/storage/sqlite3/appservice_events_table.go +++ b/appservice/storage/sqlite3/appservice_events_table.go @@ -67,7 +67,7 @@ const ( type eventsStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter selectEventsByApplicationServiceIDStmt *sql.Stmt countEventsByApplicationServiceIDStmt *sql.Stmt insertEventStmt *sql.Stmt diff --git a/appservice/storage/sqlite3/txn_id_counter_table.go b/appservice/storage/sqlite3/txn_id_counter_table.go index 501ab5aa7..0ae0feeea 100644 --- a/appservice/storage/sqlite3/txn_id_counter_table.go +++ b/appservice/storage/sqlite3/txn_id_counter_table.go @@ -38,7 +38,7 @@ const selectTxnIDSQL = ` type txnStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter selectTxnIDStmt *sql.Stmt } diff --git a/currentstateserver/storage/sqlite3/current_room_state_table.go b/currentstateserver/storage/sqlite3/current_room_state_table.go index 5c7e8b0a7..9d2fe6e04 100644 --- a/currentstateserver/storage/sqlite3/current_room_state_table.go +++ b/currentstateserver/storage/sqlite3/current_room_state_table.go @@ -83,7 +83,7 @@ const selectKnownUsersSQL = "" + type currentRoomStateStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter upsertRoomStateStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt selectRoomIDsWithMembershipStmt *sql.Stmt diff --git a/federationsender/storage/postgres/blacklist_table.go b/federationsender/storage/postgres/blacklist_table.go index 8de6feec3..f92c59e54 100644 --- a/federationsender/storage/postgres/blacklist_table.go +++ b/federationsender/storage/postgres/blacklist_table.go @@ -42,7 +42,6 @@ const deleteBlacklistSQL = "" + type blacklistStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertBlacklistStmt *sql.Stmt selectBlacklistStmt *sql.Stmt deleteBlacklistStmt *sql.Stmt @@ -50,8 +49,7 @@ type blacklistStatements struct { func NewPostgresBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) { s = &blacklistStatements{ - db: db, - writer: sqlutil.NewTransactionWriter(), + db: db, } _, err = db.Exec(blacklistSchema) if err != nil { @@ -75,11 +73,9 @@ func NewPostgresBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) { func (s *blacklistStatements) InsertBlacklist( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt) - _, err := stmt.ExecContext(ctx, serverName) - return err - }) + stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err } // selectRoomForUpdate locks the row for the room and returns the last_event_id. @@ -105,9 +101,7 @@ func (s *blacklistStatements) SelectBlacklist( func (s *blacklistStatements) DeleteBlacklist( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt) - _, err := stmt.ExecContext(ctx, serverName) - return err - }) + stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err } diff --git a/federationsender/storage/sqlite3/blacklist_table.go b/federationsender/storage/sqlite3/blacklist_table.go index a14fe0c40..b23bfcba4 100644 --- a/federationsender/storage/sqlite3/blacklist_table.go +++ b/federationsender/storage/sqlite3/blacklist_table.go @@ -42,7 +42,7 @@ const deleteBlacklistSQL = "" + type blacklistStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter insertBlacklistStmt *sql.Stmt selectBlacklistStmt *sql.Stmt deleteBlacklistStmt *sql.Stmt diff --git a/federationsender/storage/sqlite3/joined_hosts_table.go b/federationsender/storage/sqlite3/joined_hosts_table.go index 53736fa16..5dc18f4ec 100644 --- a/federationsender/storage/sqlite3/joined_hosts_table.go +++ b/federationsender/storage/sqlite3/joined_hosts_table.go @@ -65,7 +65,7 @@ const selectJoinedHostsForRoomsSQL = "" + type joinedHostsStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter insertJoinedHostsStmt *sql.Stmt deleteJoinedHostsStmt *sql.Stmt selectJoinedHostsStmt *sql.Stmt diff --git a/federationsender/storage/sqlite3/queue_edus_table.go b/federationsender/storage/sqlite3/queue_edus_table.go index cd11a0ea8..2abcc105d 100644 --- a/federationsender/storage/sqlite3/queue_edus_table.go +++ b/federationsender/storage/sqlite3/queue_edus_table.go @@ -64,7 +64,7 @@ const selectQueueServerNamesSQL = "" + type queueEDUsStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter insertQueueEDUStmt *sql.Stmt selectQueueEDUStmt *sql.Stmt selectQueueEDUReferenceJSONCountStmt *sql.Stmt diff --git a/federationsender/storage/sqlite3/queue_json_table.go b/federationsender/storage/sqlite3/queue_json_table.go index 46dfd9ab1..867ffd44b 100644 --- a/federationsender/storage/sqlite3/queue_json_table.go +++ b/federationsender/storage/sqlite3/queue_json_table.go @@ -50,7 +50,7 @@ const selectJSONSQL = "" + type queueJSONStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter insertJSONStmt *sql.Stmt //deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic //selectJSONStmt *sql.Stmt - prepared at runtime due to variadic diff --git a/federationsender/storage/sqlite3/queue_pdus_table.go b/federationsender/storage/sqlite3/queue_pdus_table.go index 1474bfc02..538ba3db8 100644 --- a/federationsender/storage/sqlite3/queue_pdus_table.go +++ b/federationsender/storage/sqlite3/queue_pdus_table.go @@ -71,7 +71,7 @@ const selectQueuePDUsServerNamesSQL = "" + type queuePDUsStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter insertQueuePDUStmt *sql.Stmt selectQueueNextTransactionIDStmt *sql.Stmt selectQueuePDUsByTransactionStmt *sql.Stmt diff --git a/federationsender/storage/sqlite3/room_table.go b/federationsender/storage/sqlite3/room_table.go index 517938745..9a439fada 100644 --- a/federationsender/storage/sqlite3/room_table.go +++ b/federationsender/storage/sqlite3/room_table.go @@ -44,7 +44,7 @@ const updateRoomSQL = "" + type roomStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter insertRoomStmt *sql.Stmt selectRoomForUpdateStmt *sql.Stmt updateRoomStmt *sql.Stmt diff --git a/internal/sqlutil/sql.go b/internal/sqlutil/sql.go index 95467c636..002d77183 100644 --- a/internal/sqlutil/sql.go +++ b/internal/sqlutil/sql.go @@ -19,8 +19,6 @@ import ( "errors" "fmt" "runtime" - - "go.uber.org/atomic" ) // ErrUserExists is returned if a username already exists in the database. @@ -52,7 +50,7 @@ func EndTransaction(txn Transaction, succeeded *bool) error { func WithTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) { txn, err := db.Begin() if err != nil { - return + return fmt.Errorf("sqlutil.WithTransaction.Begin: %w", err) } succeeded := false defer func() { @@ -106,69 +104,6 @@ func SQLiteDriverName() string { return "sqlite3" } -// TransactionWriter allows queuing database writes so that you don't -// contend on database locks in, e.g. SQLite. Only one task will run -// at a time on a given TransactionWriter. -type TransactionWriter struct { - running atomic.Bool - todo chan transactionWriterTask -} - -func NewTransactionWriter() *TransactionWriter { - return &TransactionWriter{ - todo: make(chan transactionWriterTask), - } -} - -// transactionWriterTask represents a specific task. -type transactionWriterTask struct { - db *sql.DB - txn *sql.Tx - f func(txn *sql.Tx) error - wait chan error -} - -// Do queues a task to be run by a TransactionWriter. The function -// provided will be ran within a transaction as supplied by the -// txn parameter if one is supplied, and if not, will take out a -// new transaction from the database supplied in the database -// parameter. Either way, this will block until the task is done. -func (w *TransactionWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error { - if w.todo == nil { - return errors.New("not initialised") - } - if !w.running.Load() { - go w.run() - } - task := transactionWriterTask{ - db: db, - txn: txn, - f: f, - wait: make(chan error, 1), - } - w.todo <- task - return <-task.wait -} - -// run processes the tasks for a given transaction writer. Only one -// of these goroutines will run at a time. A transaction will be -// opened using the database object from the task and then this will -// be passed as a parameter to the task function. -func (w *TransactionWriter) run() { - if !w.running.CAS(false, true) { - return - } - defer w.running.Store(false) - for task := range w.todo { - if task.txn != nil { - task.wait <- task.f(task.txn) - } else if task.db != nil { - task.wait <- WithTransaction(task.db, func(txn *sql.Tx) error { - return task.f(txn) - }) - } else { - panic("expected database or transaction but got neither") - } - close(task.wait) - } +type TransactionWriter interface { + Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error } diff --git a/internal/sqlutil/writer_dummy.go b/internal/sqlutil/writer_dummy.go new file mode 100644 index 000000000..e6ab81f68 --- /dev/null +++ b/internal/sqlutil/writer_dummy.go @@ -0,0 +1,22 @@ +package sqlutil + +import ( + "database/sql" +) + +type DummyTransactionWriter struct { +} + +func NewDummyTransactionWriter() TransactionWriter { + return &DummyTransactionWriter{} +} + +func (w *DummyTransactionWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error { + if txn == nil { + return WithTransaction(db, func(txn *sql.Tx) error { + return f(txn) + }) + } else { + return f(txn) + } +} diff --git a/internal/sqlutil/writer_exclusive.go b/internal/sqlutil/writer_exclusive.go new file mode 100644 index 000000000..2e3666aec --- /dev/null +++ b/internal/sqlutil/writer_exclusive.go @@ -0,0 +1,75 @@ +package sqlutil + +import ( + "database/sql" + "errors" + + "go.uber.org/atomic" +) + +// ExclusiveTransactionWriter allows queuing database writes so that you don't +// contend on database locks in, e.g. SQLite. Only one task will run +// at a time on a given ExclusiveTransactionWriter. +type ExclusiveTransactionWriter struct { + running atomic.Bool + todo chan transactionWriterTask +} + +func NewTransactionWriter() TransactionWriter { + return &ExclusiveTransactionWriter{ + todo: make(chan transactionWriterTask), + } +} + +// transactionWriterTask represents a specific task. +type transactionWriterTask struct { + db *sql.DB + txn *sql.Tx + f func(txn *sql.Tx) error + wait chan error +} + +// Do queues a task to be run by a TransactionWriter. The function +// provided will be ran within a transaction as supplied by the +// txn parameter if one is supplied, and if not, will take out a +// new transaction from the database supplied in the database +// parameter. Either way, this will block until the task is done. +func (w *ExclusiveTransactionWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error { + if w.todo == nil { + return errors.New("not initialised") + } + if !w.running.Load() { + go w.run() + } + task := transactionWriterTask{ + db: db, + txn: txn, + f: f, + wait: make(chan error, 1), + } + w.todo <- task + return <-task.wait +} + +// run processes the tasks for a given transaction writer. Only one +// of these goroutines will run at a time. A transaction will be +// opened using the database object from the task and then this will +// be passed as a parameter to the task function. +func (w *ExclusiveTransactionWriter) run() { + if !w.running.CAS(false, true) { + return + } + defer w.running.Store(false) + for task := range w.todo { + if task.txn != nil { + task.wait <- task.f(task.txn) + } else if task.db != nil { + task.wait <- WithTransaction(task.db, func(txn *sql.Tx) error { + return task.f(txn) + }) + } else { + panic("expected database or transaction but got neither") + } + close(task.wait) + } +} diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index a4d71fe13..c95790be7 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -63,7 +63,7 @@ const deleteAllDeviceKeysSQL = "" + type deviceKeysStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter upsertDeviceKeysStmt *sql.Stmt selectDeviceKeysStmt *sql.Stmt selectBatchDeviceKeysStmt *sql.Stmt diff --git a/keyserver/storage/sqlite3/key_changes_table.go b/keyserver/storage/sqlite3/key_changes_table.go index 02b9d193e..f451d657b 100644 --- a/keyserver/storage/sqlite3/key_changes_table.go +++ b/keyserver/storage/sqlite3/key_changes_table.go @@ -52,7 +52,7 @@ const selectKeyChangesSQL = "" + type keyChangesStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter upsertKeyChangeStmt *sql.Stmt selectKeyChangesStmt *sql.Stmt } diff --git a/keyserver/storage/sqlite3/one_time_keys_table.go b/keyserver/storage/sqlite3/one_time_keys_table.go index 907966a7a..c71cc47d1 100644 --- a/keyserver/storage/sqlite3/one_time_keys_table.go +++ b/keyserver/storage/sqlite3/one_time_keys_table.go @@ -60,7 +60,7 @@ const selectKeyByAlgorithmSQL = "" + type oneTimeKeysStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter upsertKeysStmt *sql.Stmt selectKeysStmt *sql.Stmt selectKeysCountStmt *sql.Stmt diff --git a/mediaapi/storage/sqlite3/media_repository_table.go b/mediaapi/storage/sqlite3/media_repository_table.go index f53f164d4..ff6ddf3da 100644 --- a/mediaapi/storage/sqlite3/media_repository_table.go +++ b/mediaapi/storage/sqlite3/media_repository_table.go @@ -62,7 +62,7 @@ SELECT content_type, file_size_bytes, creation_ts, upload_name, base64hash, user type mediaStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter insertMediaStmt *sql.Stmt selectMediaStmt *sql.Stmt } diff --git a/roomserver/internal/input_latest_events.go b/roomserver/internal/input_latest_events.go index 0158c8f7f..3be5218d5 100644 --- a/roomserver/internal/input_latest_events.go +++ b/roomserver/internal/input_latest_events.go @@ -57,7 +57,7 @@ func (r *RoomserverInternalAPI) updateLatestEvents( ) (err error) { updater, err := r.DB.GetLatestEventsForUpdate(ctx, roomNID) if err != nil { - return + return fmt.Errorf("r.DB.GetLatestEventsForUpdate: %w", err) } succeeded := false defer func() { @@ -79,7 +79,7 @@ func (r *RoomserverInternalAPI) updateLatestEvents( } if err = u.doUpdateLatestEvents(); err != nil { - return err + return fmt.Errorf("u.doUpdateLatestEvents: %w", err) } succeeded = true @@ -137,7 +137,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { // don't need to do anything, as we've handled it already. hasBeenSent, err := u.updater.HasEventBeenSent(u.stateAtEvent.EventNID) if err != nil { - return err + return fmt.Errorf("u.updater.HasEventBeenSent: %w", err) } else if hasBeenSent { return nil } @@ -145,7 +145,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { // Update the roomserver_previous_events table with references. This // is effectively tracking the structure of the DAG. if err = u.updater.StorePreviousEvents(u.stateAtEvent.EventNID, prevEvents); err != nil { - return err + return fmt.Errorf("u.updater.StorePreviousEvents: %w", err) } // Get the event reference for our new event. This will be used when @@ -156,7 +156,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { // in the room. If it is then it isn't a latest event. alreadyReferenced, err := u.updater.IsReferenced(eventReference) if err != nil { - return err + return fmt.Errorf("u.updater.IsReferenced: %w", err) } // Work out what the latest events are. @@ -173,19 +173,19 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { // Now that we know what the latest events are, it's time to get the // latest state. if err = u.latestState(); err != nil { - return err + return fmt.Errorf("u.latestState: %w", err) } // If we need to generate any output events then here's where we do it. // TODO: Move this! updates, err := u.api.updateMemberships(u.ctx, u.updater, u.removed, u.added) if err != nil { - return err + return fmt.Errorf("u.api.updateMemberships: %w", err) } update, err := u.makeOutputNewRoomEvent() if err != nil { - return err + return fmt.Errorf("u.makeOutputNewRoomEvent: %w", err) } updates = append(updates, *update) @@ -198,14 +198,18 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { // the correct order, 2) that pending writes are resent across restarts. In order to avoid writing all the // necessary bookkeeping we'll keep the event sending synchronous for now. if err = u.api.WriteOutputEvents(u.event.RoomID(), updates); err != nil { - return err + return fmt.Errorf("u.api.WriteOutputEvents: %w", err) } if err = u.updater.SetLatestEvents(u.roomNID, u.latest, u.stateAtEvent.EventNID, u.newStateNID); err != nil { - return err + return fmt.Errorf("u.updater.SetLatestEvents: %w", err) } - return u.updater.MarkEventAsSent(u.stateAtEvent.EventNID) + if err = u.updater.MarkEventAsSent(u.stateAtEvent.EventNID); err != nil { + return fmt.Errorf("u.updater.MarkEventAsSent: %w", err) + } + + return nil } func (u *latestEventsUpdater) latestState() error { @@ -225,7 +229,7 @@ func (u *latestEventsUpdater) latestState() error { u.ctx, u.roomNID, latestStateAtEvents, ) if err != nil { - return err + return fmt.Errorf("roomState.CalculateAndStoreStateAfterEvents: %w", err) } // If we are overwriting the state then we should make sure that we @@ -244,7 +248,7 @@ func (u *latestEventsUpdater) latestState() error { u.ctx, u.oldStateNID, u.newStateNID, ) if err != nil { - return err + return fmt.Errorf("roomState.DifferenceBetweenStateSnapshots: %w", err) } // Also work out the state before the event removes and the event @@ -252,7 +256,11 @@ func (u *latestEventsUpdater) latestState() error { u.stateBeforeEventRemoves, u.stateBeforeEventAdds, err = roomState.DifferenceBetweeenStateSnapshots( u.ctx, u.newStateNID, u.stateAtEvent.BeforeStateSnapshotNID, ) - return err + if err != nil { + return fmt.Errorf("roomState.DifferenceBetweeenStateSnapshots: %w", err) + } + + return nil } func calculateLatest( diff --git a/roomserver/state/state.go b/roomserver/state/state.go index d5be4a901..b9ad4a504 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -558,7 +558,11 @@ func (v StateResolution) CalculateAndStoreStateAfterEvents( // 2) There weren't any prev_events for this event so the state is // empty. metrics.algorithm = "empty_state" - return metrics.stop(v.db.AddState(ctx, roomNID, nil, nil)) + stateNID, err := v.db.AddState(ctx, roomNID, nil, nil) + if err != nil { + err = fmt.Errorf("v.db.AddState: %w", err) + } + return metrics.stop(stateNID, err) } if len(prevStates) == 1 { @@ -578,22 +582,30 @@ func (v StateResolution) CalculateAndStoreStateAfterEvents( ) if err != nil { metrics.algorithm = "_load_state_blocks" - return metrics.stop(0, err) + return metrics.stop(0, fmt.Errorf("v.db.StateBlockNIDs: %w", err)) } stateBlockNIDs := stateBlockNIDLists[0].StateBlockNIDs if len(stateBlockNIDs) < maxStateBlockNIDs { // 4) The number of state data blocks is small enough that we can just // add the state event as a block of size one to the end of the blocks. metrics.algorithm = "single_delta" - return metrics.stop(v.db.AddState( + stateNID, err := v.db.AddState( ctx, roomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry}, - )) + ) + if err != nil { + err = fmt.Errorf("v.db.AddState: %w", err) + } + return metrics.stop(stateNID, err) } // If there are too many deltas then we need to calculate the full state // So fall through to calculateAndStoreStateAfterManyEvents } - return v.calculateAndStoreStateAfterManyEvents(ctx, roomNID, prevStates, metrics) + stateNID, err := v.calculateAndStoreStateAfterManyEvents(ctx, roomNID, prevStates, metrics) + if err != nil { + return 0, fmt.Errorf("v.calculateAndStoreStateAfterManyEvents: %w", err) + } + return stateNID, nil } // maxStateBlockNIDs is the maximum number of state data blocks to use to encode a snapshot of room state. diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 52ff479ba..0b7ed225a 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -98,6 +98,7 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) { } d.Database = shared.Database{ DB: db, + Writer: sqlutil.NewDummyTransactionWriter(), EventTypesTable: eventTypes, EventStateKeysTable: eventStateKeys, EventJSONTable: eventJSON, diff --git a/roomserver/storage/shared/latest_events_updater.go b/roomserver/storage/shared/latest_events_updater.go index 21b168a4f..e9a0f6982 100644 --- a/roomserver/storage/shared/latest_events_updater.go +++ b/roomserver/storage/shared/latest_events_updater.go @@ -3,6 +3,7 @@ package shared import ( "context" "database/sql" + "fmt" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" @@ -65,12 +66,14 @@ func (u *LatestEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { // StorePreviousEvents implements types.RoomRecentEventsUpdater func (u *LatestEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { - for _, ref := range previousEventReferences { - if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { - return err + return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { + for _, ref := range previousEventReferences { + if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { + return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err) + } } - } - return nil + return nil + }) } // IsReferenced implements types.RoomRecentEventsUpdater @@ -82,7 +85,7 @@ func (u *LatestEventsUpdater) IsReferenced(eventReference gomatrixserverlib.Even if err == sql.ErrNoRows { return false, nil } - return false, err + return false, fmt.Errorf("u.d.PrevEventsTable.SelectPreviousEventExists: %w", err) } // SetLatestEvents implements types.RoomRecentEventsUpdater @@ -94,7 +97,12 @@ func (u *LatestEventsUpdater) SetLatestEvents( for i := range latest { eventNIDs[i] = latest[i].EventNID } - return u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, u.txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID) + return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { + if err := u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID); err != nil { + return fmt.Errorf("u.d.RoomsTable.updateLatestEventNIDs: %w", err) + } + return nil + }) } // HasEventBeenSent implements types.RoomRecentEventsUpdater @@ -104,7 +112,9 @@ func (u *LatestEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, e // MarkEventAsSent implements types.RoomRecentEventsUpdater func (u *LatestEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { - return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, u.txn, eventNID) + return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { + return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, txn, eventNID) + }) } func (u *LatestEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) { diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go index 5955844f9..329813bfc 100644 --- a/roomserver/storage/shared/membership_updater.go +++ b/roomserver/storage/shared/membership_updater.go @@ -3,6 +3,7 @@ package shared import ( "context" "database/sql" + "fmt" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" @@ -41,9 +42,14 @@ func (d *Database) membershipUpdaterTxn( targetUserNID types.EventStateKeyNID, targetLocal bool, ) (*MembershipUpdater, error) { - - if err := d.MembershipTable.InsertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil { - return nil, err + err := d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { + if err := d.MembershipTable.InsertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil { + return fmt.Errorf("d.MembershipTable.InsertMembership: %w", err) + } + return nil + }) + if err != nil { + return nil, fmt.Errorf("u.d.Writer.Do: %w", err) } membership, err := d.MembershipTable.SelectMembershipForUpdate(ctx, txn, roomNID, targetUserNID) @@ -75,19 +81,19 @@ func (u *MembershipUpdater) IsLeave() bool { func (u *MembershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) { senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender()) if err != nil { - return false, err + return false, fmt.Errorf("u.d.AssignStateKeyNID: %w", err) } inserted, err := u.d.InvitesTable.InsertInviteEvent( u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(), ) if err != nil { - return false, err + return false, fmt.Errorf("u.d.InvitesTable.InsertInviteEvent: %w", err) } if u.membership != tables.MembershipStateInvite { if err = u.d.MembershipTable.UpdateMembership( u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0, ); err != nil { - return false, err + return false, fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) } } return inserted, nil @@ -99,7 +105,7 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) if err != nil { - return nil, err + return nil, fmt.Errorf("u.d.AssignStateKeyNID: %w", err) } // If this is a join event update, there is no invite to update @@ -108,14 +114,14 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd u.ctx, u.txn, u.roomNID, u.targetUserNID, ) if err != nil { - return nil, err + return nil, fmt.Errorf("u.d.InvitesTables.UpdateInviteRetired: %w", err) } } // Look up the NID of the new join event nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) if err != nil { - return nil, err + return nil, fmt.Errorf("u.d.EventNIDs: %w", err) } if u.membership != tables.MembershipStateJoin || isUpdate { @@ -123,7 +129,7 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateJoin, nIDs[eventID], ); err != nil { - return nil, err + return nil, fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) } } @@ -134,19 +140,19 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) { senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) if err != nil { - return nil, err + return nil, fmt.Errorf("u.d.AssignStateKeyNID: %w", err) } inviteEventIDs, err := u.d.InvitesTable.UpdateInviteRetired( u.ctx, u.txn, u.roomNID, u.targetUserNID, ) if err != nil { - return nil, err + return nil, fmt.Errorf("u.d.InvitesTable.updateInviteRetired: %w", err) } // Look up the NID of the new leave event nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) if err != nil { - return nil, err + return nil, fmt.Errorf("u.d.EventNIDs: %w", err) } if u.membership != tables.MembershipStateLeaveOrBan { @@ -154,7 +160,7 @@ func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateLeaveOrBan, nIDs[eventID], ); err != nil { - return nil, err + return nil, fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) } } return inviteEventIDs, nil diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 00179e336..45020d551 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -27,6 +27,7 @@ const redactionsArePermanent = false type Database struct { DB *sql.DB + Writer sqlutil.TransactionWriter EventsTable tables.Events EventJSONTable tables.EventJSON EventTypesTable tables.EventTypes @@ -83,20 +84,23 @@ func (d *Database) AddState( stateBlockNIDs []types.StateBlockNID, state []types.StateEntry, ) (stateNID types.StateSnapshotNID, err error) { - err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { if len(state) > 0 { var stateBlockNID types.StateBlockNID stateBlockNID, err = d.StateBlockTable.BulkInsertStateData(ctx, txn, state) if err != nil { - return err + return fmt.Errorf("d.StateBlockTable.BulkInsertStateData: %w", err) } stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID) } stateNID, err = d.StateSnapshotTable.InsertState(ctx, txn, roomNID, stateBlockNIDs) - return err + if err != nil { + return fmt.Errorf("d.StateSnapshotTable.InsertState: %w", err) + } + return nil }) if err != nil { - return 0, err + return 0, fmt.Errorf("d.Writer.Do: %w", err) } return } @@ -110,7 +114,9 @@ func (d *Database) EventNIDs( func (d *Database) SetState( ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, ) error { - return d.EventsTable.UpdateEventState(ctx, eventNID, stateNID) + return d.Writer.Do(d.DB, nil, func(_ *sql.Tx) error { + return d.EventsTable.UpdateEventState(ctx, eventNID, stateNID) + }) } func (d *Database) StateAtEventIDs( @@ -221,7 +227,9 @@ func (d *Database) GetRoomVersionForRoomNID( } func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error { - return d.RoomAliasesTable.InsertRoomAlias(ctx, alias, roomID, creatorUserID) + return d.Writer.Do(d.DB, nil, func(_ *sql.Tx) error { + return d.RoomAliasesTable.InsertRoomAlias(ctx, alias, roomID, creatorUserID) + }) } func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) { @@ -239,15 +247,21 @@ func (d *Database) GetCreatorIDForAlias( } func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { - return d.RoomAliasesTable.DeleteRoomAlias(ctx, alias) + return d.Writer.Do(d.DB, nil, func(_ *sql.Tx) error { + return d.RoomAliasesTable.DeleteRoomAlias(ctx, alias) + }) } func (d *Database) GetMembership( ctx context.Context, roomNID types.RoomNID, requestSenderUserID string, ) (membershipEventNID types.EventNID, stillInRoom bool, err error) { - requestSenderUserNID, err := d.assignStateKeyNID(ctx, nil, requestSenderUserID) + var requestSenderUserNID types.EventStateKeyNID + err = d.Writer.Do(d.DB, nil, func(_ *sql.Tx) error { + requestSenderUserNID, err = d.assignStateKeyNID(ctx, nil, requestSenderUserID) + return err + }) if err != nil { - return + return 0, false, fmt.Errorf("d.assignStateKeyNID: %w", err) } senderMembershipEventNID, senderMembership, err := @@ -350,6 +364,7 @@ func (d *Database) GetLatestEventsForUpdate( return NewLatestEventsUpdater(ctx, d, txn, roomNID) } +// nolint:gocyclo func (d *Database) StoreEvent( ctx context.Context, event gomatrixserverlib.Event, txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, @@ -365,10 +380,10 @@ func (d *Database) StoreEvent( err error ) - err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { if txnAndSessionID != nil { if err = d.TransactionsTable.InsertTransaction( - ctx, txn, txnAndSessionID.TransactionID, + ctx, nil, txnAndSessionID.TransactionID, txnAndSessionID.SessionID, event.Sender(), event.EventID(), ); err != nil { return fmt.Errorf("d.TransactionsTable.InsertTransaction: %w", err) @@ -433,7 +448,7 @@ func (d *Database) StoreEvent( return nil }) if err != nil { - return 0, types.StateAtEvent{}, nil, "", err + return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.Writer.Do: %w", err) } return roomNID, types.StateAtEvent{ @@ -449,7 +464,9 @@ func (d *Database) StoreEvent( } func (d *Database) PublishRoom(ctx context.Context, roomID string, publish bool) error { - return d.PublishedTable.UpsertRoomPublished(ctx, roomID, publish) + return d.Writer.Do(d.DB, nil, func(_ *sql.Tx) error { + return d.PublishedTable.UpsertRoomPublished(ctx, roomID, publish) + }) } func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) { diff --git a/roomserver/storage/sqlite3/event_json_table.go b/roomserver/storage/sqlite3/event_json_table.go index e8118ad76..3cd44b1dc 100644 --- a/roomserver/storage/sqlite3/event_json_table.go +++ b/roomserver/storage/sqlite3/event_json_table.go @@ -49,15 +49,13 @@ const bulkSelectEventJSONSQL = ` type eventJSONStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertEventJSONStmt *sql.Stmt bulkSelectEventJSONStmt *sql.Stmt } -func NewSqliteEventJSONTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.EventJSON, error) { +func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) { s := &eventJSONStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(eventJSONSchema) if err != nil { @@ -72,10 +70,8 @@ func NewSqliteEventJSONTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tab func (s *eventJSONStatements) InsertEventJSON( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - _, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON) - return err - }) + _, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON) + return err } func (s *eventJSONStatements) BulkSelectEventJSON( diff --git a/roomserver/storage/sqlite3/event_state_keys_table.go b/roomserver/storage/sqlite3/event_state_keys_table.go index c8ad052bf..345df8c62 100644 --- a/roomserver/storage/sqlite3/event_state_keys_table.go +++ b/roomserver/storage/sqlite3/event_state_keys_table.go @@ -64,17 +64,15 @@ const bulkSelectEventStateKeyNIDSQL = ` type eventStateKeyStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertEventStateKeyNIDStmt *sql.Stmt selectEventStateKeyNIDStmt *sql.Stmt bulkSelectEventStateKeyNIDStmt *sql.Stmt bulkSelectEventStateKeyStmt *sql.Stmt } -func NewSqliteEventStateKeysTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.EventStateKeys, error) { +func NewSqliteEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) { s := &eventStateKeyStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(eventStateKeysSchema) if err != nil { @@ -91,19 +89,15 @@ func NewSqliteEventStateKeysTable(db *sql.DB, writer *sqlutil.TransactionWriter) func (s *eventStateKeyStatements) InsertEventStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKey string, ) (types.EventStateKeyNID, error) { - var eventStateKeyNID int64 - err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - insertStmt := sqlutil.TxStmt(txn, s.insertEventStateKeyNIDStmt) - res, err := insertStmt.ExecContext(ctx, eventStateKey) - if err != nil { - return err - } - eventStateKeyNID, err = res.LastInsertId() - if err != nil { - return err - } - return nil - }) + insertStmt := sqlutil.TxStmt(txn, s.insertEventStateKeyNIDStmt) + res, err := insertStmt.ExecContext(ctx, eventStateKey) + if err != nil { + return 0, err + } + eventStateKeyNID, err := res.LastInsertId() + if err != nil { + return 0, err + } return types.EventStateKeyNID(eventStateKeyNID), err } diff --git a/roomserver/storage/sqlite3/event_types_table.go b/roomserver/storage/sqlite3/event_types_table.go index 4a645789d..26e2bf843 100644 --- a/roomserver/storage/sqlite3/event_types_table.go +++ b/roomserver/storage/sqlite3/event_types_table.go @@ -18,6 +18,7 @@ package sqlite3 import ( "context" "database/sql" + "fmt" "strings" "github.com/matrix-org/dendrite/internal" @@ -78,17 +79,15 @@ const bulkSelectEventTypeNIDSQL = ` type eventTypeStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertEventTypeNIDStmt *sql.Stmt insertEventTypeNIDResultStmt *sql.Stmt selectEventTypeNIDStmt *sql.Stmt bulkSelectEventTypeNIDStmt *sql.Stmt } -func NewSqliteEventTypesTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.EventTypes, error) { +func NewSqliteEventTypesTable(db *sql.DB) (tables.EventTypes, error) { s := &eventTypeStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(eventTypesSchema) if err != nil { @@ -104,18 +103,18 @@ func NewSqliteEventTypesTable(db *sql.DB, writer *sqlutil.TransactionWriter) (ta } func (s *eventTypeStatements) InsertEventTypeNID( - ctx context.Context, tx *sql.Tx, eventType string, + ctx context.Context, txn *sql.Tx, eventType string, ) (types.EventTypeNID, error) { var eventTypeNID int64 - err := s.writer.Do(s.db, tx, func(tx *sql.Tx) error { - insertStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDStmt) - resultStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDResultStmt) - _, err := insertStmt.ExecContext(ctx, eventType) - if err != nil { - return err - } - return resultStmt.QueryRowContext(ctx).Scan(&eventTypeNID) - }) + insertStmt := sqlutil.TxStmt(txn, s.insertEventTypeNIDStmt) + resultStmt := sqlutil.TxStmt(txn, s.insertEventTypeNIDResultStmt) + _, err := insertStmt.ExecContext(ctx, eventType) + if err != nil { + return 0, fmt.Errorf("insertStmt.ExecContext: %w", err) + } + if err = resultStmt.QueryRowContext(ctx).Scan(&eventTypeNID); err != nil { + return 0, fmt.Errorf("resultStmt.QueryRowContext.Scan: %w", err) + } return types.EventTypeNID(eventTypeNID), err } diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index 0e39755cb..26ea1d415 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -99,7 +99,6 @@ const selectRoomNIDForEventNIDSQL = "" + type eventStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertEventStmt *sql.Stmt selectEventStmt *sql.Stmt bulkSelectStateEventByIDStmt *sql.Stmt @@ -115,10 +114,9 @@ type eventStatements struct { selectRoomNIDForEventNIDStmt *sql.Stmt } -func NewSqliteEventsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Events, error) { +func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) { s := &eventStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(eventsSchema) if err != nil { @@ -155,22 +153,19 @@ func (s *eventStatements) InsertEvent( ) (types.EventNID, types.StateSnapshotNID, error) { // attempt to insert: the last_row_id is the event NID var eventNID int64 - err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) - result, err := insertStmt.ExecContext( - ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), - eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, - ) - if err != nil { - return err - } - modified, err := result.RowsAffected() - if modified == 0 && err == nil { - return sql.ErrNoRows - } - eventNID, err = result.LastInsertId() - return err - }) + insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) + result, err := insertStmt.ExecContext( + ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), + eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, + ) + if err != nil { + return 0, 0, err + } + modified, err := result.RowsAffected() + if modified == 0 && err == nil { + return 0, 0, sql.ErrNoRows + } + eventNID, err = result.LastInsertId() return types.EventNID(eventNID), 0, err } @@ -286,11 +281,8 @@ func (s *eventStatements) BulkSelectStateAtEventByID( func (s *eventStatements) UpdateEventState( ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, ) error { - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.updateEventStateStmt) - _, err := stmt.ExecContext(ctx, int64(stateNID), int64(eventNID)) - return err - }) + _, err := s.updateEventStateStmt.ExecContext(ctx, int64(stateNID), int64(eventNID)) + return err } func (s *eventStatements) SelectEventSentToOutput( @@ -302,11 +294,9 @@ func (s *eventStatements) SelectEventSentToOutput( } func (s *eventStatements) UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - updateStmt := sqlutil.TxStmt(txn, s.updateEventSentToOutputStmt) - _, err := updateStmt.ExecContext(ctx, int64(eventNID)) - return err - }) + updateStmt := sqlutil.TxStmt(txn, s.updateEventSentToOutputStmt) + _, err := updateStmt.ExecContext(ctx, int64(eventNID)) + return err } func (s *eventStatements) SelectEventID( @@ -334,7 +324,7 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference( rows, err := sqlutil.TxStmt(txn, selectPrep).QueryContext(ctx, iEventNIDs...) if err != nil { - return nil, err + return nil, fmt.Errorf("sqlutil.TxStmt.QueryContext: %w", err) } defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateAtEventAndReference: rows.close() failed") results := make([]types.StateAtEventAndReference, len(eventNIDs)) @@ -481,7 +471,7 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, } err = sqlutil.TxStmt(txn, sqlPrep).QueryRowContext(ctx, iEventIDs...).Scan(&result) if err != nil { - return 0, err + return 0, fmt.Errorf("sqlutil.TxStmt.QueryRowContext: %w", err) } return result, nil } diff --git a/roomserver/storage/sqlite3/invite_table.go b/roomserver/storage/sqlite3/invite_table.go index 1305f4a8a..327be6a03 100644 --- a/roomserver/storage/sqlite3/invite_table.go +++ b/roomserver/storage/sqlite3/invite_table.go @@ -64,17 +64,15 @@ SELECT invite_event_id FROM roomserver_invites WHERE room_nid = $1 AND target_ni type inviteStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertInviteEventStmt *sql.Stmt selectInviteActiveForUserInRoomStmt *sql.Stmt updateInviteRetiredStmt *sql.Stmt selectInvitesAboutToRetireStmt *sql.Stmt } -func NewSqliteInvitesTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Invites, error) { +func NewSqliteInvitesTable(db *sql.DB) (tables.Invites, error) { s := &inviteStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(inviteSchema) if err != nil { @@ -96,20 +94,17 @@ func (s *inviteStatements) InsertInviteEvent( inviteEventJSON []byte, ) (bool, error) { var count int64 - err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt) - result, err := stmt.ExecContext( - ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON, - ) - if err != nil { - return err - } - count, err = result.RowsAffected() - if err != nil { - return err - } - return nil - }) + stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt) + result, err := stmt.ExecContext( + ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON, + ) + if err != nil { + return false, err + } + count, err = result.RowsAffected() + if err != nil { + return false, err + } return count != 0, err } @@ -117,26 +112,23 @@ func (s *inviteStatements) UpdateInviteRetired( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (eventIDs []string, err error) { - err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - // gather all the event IDs we will retire - stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt) - rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID) - if err != nil { - return err + // gather all the event IDs we will retire + stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt) + rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID) + if err != nil { + return + } + defer internal.CloseAndLogIfError(ctx, rows, "UpdateInviteRetired: rows.close() failed") + for rows.Next() { + var inviteEventID string + if err = rows.Scan(&inviteEventID); err != nil { + return } - defer internal.CloseAndLogIfError(ctx, rows, "UpdateInviteRetired: rows.close() failed") - for rows.Next() { - var inviteEventID string - if err = rows.Scan(&inviteEventID); err != nil { - return err - } - eventIDs = append(eventIDs, inviteEventID) - } - // now retire the invites - stmt = sqlutil.TxStmt(txn, s.updateInviteRetiredStmt) - _, err = stmt.ExecContext(ctx, roomNID, targetUserNID) - return err - }) + eventIDs = append(eventIDs, inviteEventID) + } + // now retire the invites + stmt = sqlutil.TxStmt(txn, s.updateInviteRetiredStmt) + _, err = stmt.ExecContext(ctx, roomNID, targetUserNID) return } diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 7b69cee32..b3ee69c00 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -77,7 +77,6 @@ const updateMembershipSQL = "" + type membershipStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertMembershipStmt *sql.Stmt selectMembershipForUpdateStmt *sql.Stmt selectMembershipFromRoomAndTargetStmt *sql.Stmt @@ -88,10 +87,9 @@ type membershipStatements struct { updateMembershipStmt *sql.Stmt } -func NewSqliteMembershipTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Membership, error) { +func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) { s := &membershipStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(membershipSchema) if err != nil { @@ -115,11 +113,9 @@ func (s *membershipStatements) InsertMembership( roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt) - _, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget) - return err - }) + stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt) + _, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget) + return err } func (s *membershipStatements) SelectMembershipForUpdate( @@ -201,11 +197,9 @@ func (s *membershipStatements) UpdateMembership( senderUserNID types.EventStateKeyNID, membership tables.MembershipState, eventNID types.EventNID, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt) - _, err := stmt.ExecContext( - ctx, senderUserNID, membership, eventNID, roomNID, targetUserNID, - ) - return err - }) + stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt) + _, err := stmt.ExecContext( + ctx, senderUserNID, membership, eventNID, roomNID, targetUserNID, + ) + return err } diff --git a/roomserver/storage/sqlite3/previous_events_table.go b/roomserver/storage/sqlite3/previous_events_table.go index ff804861c..d28a42c69 100644 --- a/roomserver/storage/sqlite3/previous_events_table.go +++ b/roomserver/storage/sqlite3/previous_events_table.go @@ -54,15 +54,13 @@ const selectPreviousEventExistsSQL = ` type previousEventStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertPreviousEventStmt *sql.Stmt selectPreviousEventExistsStmt *sql.Stmt } -func NewSqlitePrevEventsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.PreviousEvents, error) { +func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) { s := &previousEventStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(previousEventSchema) if err != nil { @@ -82,13 +80,11 @@ func (s *previousEventStatements) InsertPreviousEvent( previousEventReferenceSHA256 []byte, eventNID types.EventNID, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt) - _, err := stmt.ExecContext( - ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID), - ) - return err - }) + stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt) + _, err := stmt.ExecContext( + ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID), + ) + return err } // Check if the event reference exists diff --git a/roomserver/storage/sqlite3/published_table.go b/roomserver/storage/sqlite3/published_table.go index a4a47aec9..1d6ccd561 100644 --- a/roomserver/storage/sqlite3/published_table.go +++ b/roomserver/storage/sqlite3/published_table.go @@ -19,7 +19,6 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) @@ -45,16 +44,14 @@ const selectPublishedSQL = "" + type publishedStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter upsertPublishedStmt *sql.Stmt selectAllPublishedStmt *sql.Stmt selectPublishedStmt *sql.Stmt } -func NewSqlitePublishedTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Published, error) { +func NewSqlitePublishedTable(db *sql.DB) (tables.Published, error) { s := &publishedStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(publishedSchema) if err != nil { @@ -69,12 +66,9 @@ func NewSqlitePublishedTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tab func (s *publishedStatements) UpsertRoomPublished( ctx context.Context, roomID string, published bool, -) (err error) { - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.upsertPublishedStmt) - _, err := stmt.ExecContext(ctx, roomID, published) - return err - }) +) error { + _, err := s.upsertPublishedStmt.ExecContext(ctx, roomID, published) + return err } func (s *publishedStatements) SelectPublishedFromRoomID( diff --git a/roomserver/storage/sqlite3/redactions_table.go b/roomserver/storage/sqlite3/redactions_table.go index ad900a4ec..a2179357c 100644 --- a/roomserver/storage/sqlite3/redactions_table.go +++ b/roomserver/storage/sqlite3/redactions_table.go @@ -53,17 +53,15 @@ const markRedactionValidatedSQL = "" + type redactionStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertRedactionStmt *sql.Stmt selectRedactionInfoByRedactionEventIDStmt *sql.Stmt selectRedactionInfoByEventBeingRedactedStmt *sql.Stmt markRedactionValidatedStmt *sql.Stmt } -func NewSqliteRedactionsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Redactions, error) { +func NewSqliteRedactionsTable(db *sql.DB) (tables.Redactions, error) { s := &redactionStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(redactionsSchema) if err != nil { @@ -81,11 +79,9 @@ func NewSqliteRedactionsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (ta func (s *redactionStatements) InsertRedaction( ctx context.Context, txn *sql.Tx, info tables.RedactionInfo, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.insertRedactionStmt) - _, err := stmt.ExecContext(ctx, info.RedactionEventID, info.RedactsEventID, info.Validated) - return err - }) + stmt := sqlutil.TxStmt(txn, s.insertRedactionStmt) + _, err := stmt.ExecContext(ctx, info.RedactionEventID, info.RedactsEventID, info.Validated) + return err } func (s *redactionStatements) SelectRedactionInfoByRedactionEventID( @@ -121,9 +117,7 @@ func (s *redactionStatements) SelectRedactionInfoByEventBeingRedacted( func (s *redactionStatements) MarkRedactionValidated( ctx context.Context, txn *sql.Tx, redactionEventID string, validated bool, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.markRedactionValidatedStmt) - _, err := stmt.ExecContext(ctx, redactionEventID, validated) - return err - }) + stmt := sqlutil.TxStmt(txn, s.markRedactionValidatedStmt) + _, err := stmt.ExecContext(ctx, redactionEventID, validated) + return err } diff --git a/roomserver/storage/sqlite3/room_aliases_table.go b/roomserver/storage/sqlite3/room_aliases_table.go index deba3ff55..a16e97aa5 100644 --- a/roomserver/storage/sqlite3/room_aliases_table.go +++ b/roomserver/storage/sqlite3/room_aliases_table.go @@ -20,7 +20,6 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) @@ -57,7 +56,6 @@ const deleteRoomAliasSQL = ` type roomAliasesStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertRoomAliasStmt *sql.Stmt selectRoomIDFromAliasStmt *sql.Stmt selectAliasesFromRoomIDStmt *sql.Stmt @@ -65,10 +63,9 @@ type roomAliasesStatements struct { deleteRoomAliasStmt *sql.Stmt } -func NewSqliteRoomAliasesTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.RoomAliases, error) { +func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { s := &roomAliasesStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(roomAliasesSchema) if err != nil { @@ -85,12 +82,9 @@ func NewSqliteRoomAliasesTable(db *sql.DB, writer *sqlutil.TransactionWriter) (t func (s *roomAliasesStatements) InsertRoomAlias( ctx context.Context, alias string, roomID string, creatorUserID string, -) (err error) { - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.insertRoomAliasStmt) - _, err := stmt.ExecContext(ctx, alias, roomID, creatorUserID) - return err - }) +) error { + _, err := s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID) + return err } func (s *roomAliasesStatements) SelectRoomIDFromAlias( @@ -138,10 +132,7 @@ func (s *roomAliasesStatements) SelectCreatorIDFromAlias( func (s *roomAliasesStatements) DeleteRoomAlias( ctx context.Context, alias string, -) (err error) { - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.deleteRoomAliasStmt) - _, err := stmt.ExecContext(ctx, alias) - return err - }) +) error { + _, err := s.deleteRoomAliasStmt.ExecContext(ctx, alias) + return err } diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index 8bbec5080..6541cc0cb 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -66,7 +66,6 @@ const selectRoomVersionForRoomNIDSQL = "" + type roomStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertRoomNIDStmt *sql.Stmt selectRoomNIDStmt *sql.Stmt selectLatestEventNIDsStmt *sql.Stmt @@ -76,10 +75,9 @@ type roomStatements struct { selectRoomVersionForRoomNIDStmt *sql.Stmt } -func NewSqliteRoomsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Rooms, error) { +func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) { s := &roomStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(roomsSchema) if err != nil { @@ -100,20 +98,14 @@ func (s *roomStatements) InsertRoomNID( ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion, ) (roomNID types.RoomNID, err error) { - err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt) - _, err = insertStmt.ExecContext(ctx, roomID, roomVersion) - if err != nil { - return fmt.Errorf("insertStmt.ExecContext: %w", err) - } - roomNID, err = s.SelectRoomNID(ctx, txn, roomID) - if err != nil { - return fmt.Errorf("s.SelectRoomNID: %w", err) - } - return nil - }) + insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt) + _, err = insertStmt.ExecContext(ctx, roomID, roomVersion) if err != nil { - return types.RoomNID(0), err + return 0, fmt.Errorf("insertStmt.ExecContext: %w", err) + } + roomNID, err = s.SelectRoomNID(ctx, txn, roomID) + if err != nil { + return 0, fmt.Errorf("s.SelectRoomNID: %w", err) } return } @@ -170,17 +162,15 @@ func (s *roomStatements) UpdateLatestEventNIDs( lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.updateLatestEventNIDsStmt) - _, err := stmt.ExecContext( - ctx, - eventNIDsAsArray(eventNIDs), - int64(lastEventSentNID), - int64(stateSnapshotNID), - roomNID, - ) - return err - }) + stmt := sqlutil.TxStmt(txn, s.updateLatestEventNIDsStmt) + _, err := stmt.ExecContext( + ctx, + eventNIDsAsArray(eventNIDs), + int64(lastEventSentNID), + int64(stateSnapshotNID), + roomNID, + ) + return err } func (s *roomStatements) SelectRoomVersionForRoomID( diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index 3e28e450b..8033903f5 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -74,17 +74,15 @@ const bulkSelectFilteredStateBlockEntriesSQL = "" + type stateBlockStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertStateDataStmt *sql.Stmt selectNextStateBlockNIDStmt *sql.Stmt bulkSelectStateBlockEntriesStmt *sql.Stmt bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt } -func NewSqliteStateBlockTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.StateBlock, error) { +func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) { s := &stateBlockStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(stateDataSchema) if err != nil { @@ -107,25 +105,22 @@ func (s *stateBlockStatements) BulkInsertStateData( return 0, nil } var stateBlockNID types.StateBlockNID - err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - err := txn.Stmt(s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID) + err := txn.Stmt(s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID) + if err != nil { + return 0, err + } + for _, entry := range entries { + _, err = txn.Stmt(s.insertStateDataStmt).ExecContext( + ctx, + int64(stateBlockNID), + int64(entry.EventTypeNID), + int64(entry.EventStateKeyNID), + int64(entry.EventNID), + ) if err != nil { - return err + return 0, err } - for _, entry := range entries { - _, err := txn.Stmt(s.insertStateDataStmt).ExecContext( - ctx, - int64(stateBlockNID), - int64(entry.EventTypeNID), - int64(entry.EventStateKeyNID), - int64(entry.EventNID), - ) - if err != nil { - return err - } - } - return nil - }) + } return stateBlockNID, err } diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index 799904ff6..392c2a671 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -50,15 +50,13 @@ const bulkSelectStateBlockNIDsSQL = "" + type stateSnapshotStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertStateStmt *sql.Stmt bulkSelectStateBlockNIDsStmt *sql.Stmt } -func NewSqliteStateSnapshotTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.StateSnapshot, error) { +func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { s := &stateSnapshotStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(stateSnapshotSchema) if err != nil { @@ -78,19 +76,16 @@ func (s *stateSnapshotStatements) InsertState( if err != nil { return } - err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - insertStmt := txn.Stmt(s.insertStateStmt) - res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON)) - if err != nil { - return err - } - lastRowID, err := res.LastInsertId() - if err != nil { - return err - } - stateNID = types.StateSnapshotNID(lastRowID) - return nil - }) + insertStmt := txn.Stmt(s.insertStateStmt) + res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON)) + if err != nil { + return 0, err + } + lastRowID, err := res.LastInsertId() + if err != nil { + return 0, err + } + stateNID = types.StateSnapshotNID(lastRowID) return } diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 724316373..8e3af6b7a 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -41,6 +41,7 @@ type Database struct { invites tables.Invites membership tables.Membership db *sql.DB + writer sqlutil.TransactionWriter } // Open a sqlite database. @@ -51,7 +52,7 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) { if d.db, err = sqlutil.Open(dbProperties); err != nil { return nil, err } - writer := sqlutil.NewTransactionWriter() + d.writer = sqlutil.NewTransactionWriter() //d.db.Exec("PRAGMA journal_mode=WAL;") //d.db.Exec("PRAGMA read_uncommitted = true;") @@ -61,64 +62,65 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) { // which it will never obtain. d.db.SetMaxOpenConns(20) - d.eventStateKeys, err = NewSqliteEventStateKeysTable(d.db, writer) + d.eventStateKeys, err = NewSqliteEventStateKeysTable(d.db) if err != nil { return nil, err } - d.eventTypes, err = NewSqliteEventTypesTable(d.db, writer) + d.eventTypes, err = NewSqliteEventTypesTable(d.db) if err != nil { return nil, err } - d.eventJSON, err = NewSqliteEventJSONTable(d.db, writer) + d.eventJSON, err = NewSqliteEventJSONTable(d.db) if err != nil { return nil, err } - d.events, err = NewSqliteEventsTable(d.db, writer) + d.events, err = NewSqliteEventsTable(d.db) if err != nil { return nil, err } - d.rooms, err = NewSqliteRoomsTable(d.db, writer) + d.rooms, err = NewSqliteRoomsTable(d.db) if err != nil { return nil, err } - d.transactions, err = NewSqliteTransactionsTable(d.db, writer) + d.transactions, err = NewSqliteTransactionsTable(d.db) if err != nil { return nil, err } - stateBlock, err := NewSqliteStateBlockTable(d.db, writer) + stateBlock, err := NewSqliteStateBlockTable(d.db) if err != nil { return nil, err } - stateSnapshot, err := NewSqliteStateSnapshotTable(d.db, writer) + stateSnapshot, err := NewSqliteStateSnapshotTable(d.db) if err != nil { return nil, err } - d.prevEvents, err = NewSqlitePrevEventsTable(d.db, writer) + d.prevEvents, err = NewSqlitePrevEventsTable(d.db) if err != nil { return nil, err } - roomAliases, err := NewSqliteRoomAliasesTable(d.db, writer) + roomAliases, err := NewSqliteRoomAliasesTable(d.db) if err != nil { return nil, err } - d.invites, err = NewSqliteInvitesTable(d.db, writer) + d.invites, err = NewSqliteInvitesTable(d.db) if err != nil { return nil, err } - d.membership, err = NewSqliteMembershipTable(d.db, writer) + d.membership, err = NewSqliteMembershipTable(d.db) if err != nil { return nil, err } - published, err := NewSqlitePublishedTable(d.db, writer) + published, err := NewSqlitePublishedTable(d.db) if err != nil { return nil, err } - redactions, err := NewSqliteRedactionsTable(d.db, writer) + redactions, err := NewSqliteRedactionsTable(d.db) if err != nil { return nil, err } d.Database = shared.Database{ DB: d.db, + Writer: sqlutil.NewTransactionWriter(), EventsTable: d.events, EventTypesTable: d.eventTypes, EventStateKeysTable: d.eventStateKeys, diff --git a/roomserver/storage/sqlite3/transactions_table.go b/roomserver/storage/sqlite3/transactions_table.go index 65c18a8a9..029122c5e 100644 --- a/roomserver/storage/sqlite3/transactions_table.go +++ b/roomserver/storage/sqlite3/transactions_table.go @@ -45,15 +45,13 @@ const selectTransactionEventIDSQL = ` type transactionStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertTransactionStmt *sql.Stmt selectTransactionEventIDStmt *sql.Stmt } -func NewSqliteTransactionsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Transactions, error) { +func NewSqliteTransactionsTable(db *sql.DB) (tables.Transactions, error) { s := &transactionStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(transactionsSchema) if err != nil { @@ -72,14 +70,12 @@ func (s *transactionStatements) InsertTransaction( sessionID int64, userID string, eventID string, -) (err error) { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.insertTransactionStmt) - _, err := stmt.ExecContext( - ctx, transactionID, sessionID, userID, eventID, - ) - return err - }) +) error { + stmt := sqlutil.TxStmt(txn, s.insertTransactionStmt) + _, err := stmt.ExecContext( + ctx, transactionID, sessionID, userID, eventID, + ) + return err } func (s *transactionStatements) SelectTransactionEventID( diff --git a/serverkeyapi/storage/sqlite3/server_key_table.go b/serverkeyapi/storage/sqlite3/server_key_table.go index 423292a54..b829eae74 100644 --- a/serverkeyapi/storage/sqlite3/server_key_table.go +++ b/serverkeyapi/storage/sqlite3/server_key_table.go @@ -63,7 +63,7 @@ const upsertServerKeysSQL = "" + type serverKeyStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter bulkSelectServerKeysStmt *sql.Stmt upsertServerKeysStmt *sql.Stmt } diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index dd5b838ce..fdbf6758d 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -45,7 +45,7 @@ type Database struct { BackwardExtremities tables.BackwardsExtremities SendToDevice tables.SendToDevice Filter tables.Filter - SendToDeviceWriter *sqlutil.TransactionWriter + SendToDeviceWriter sqlutil.TransactionWriter EDUCache *cache.EDUCache } diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index 609cef141..248ec9267 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -51,7 +51,7 @@ const selectMaxAccountDataIDSQL = "" + type accountDataStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter streamIDStatements *streamIDStatements insertAccountDataStmt *sql.Stmt selectMaxAccountDataIDStmt *sql.Stmt diff --git a/syncapi/storage/sqlite3/backwards_extremities_table.go b/syncapi/storage/sqlite3/backwards_extremities_table.go index 1aeb041f4..d96f2fe57 100644 --- a/syncapi/storage/sqlite3/backwards_extremities_table.go +++ b/syncapi/storage/sqlite3/backwards_extremities_table.go @@ -49,7 +49,7 @@ const deleteBackwardExtremitySQL = "" + type backwardExtremitiesStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter insertBackwardExtremityStmt *sql.Stmt selectBackwardExtremitiesForRoomStmt *sql.Stmt deleteBackwardExtremityStmt *sql.Stmt diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index 6edc99aa0..77a21543f 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -85,7 +85,7 @@ const selectEventsWithEventIDsSQL = "" + type currentRoomStateStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter streamIDStatements *streamIDStatements upsertRoomStateStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt diff --git a/syncapi/storage/sqlite3/filter_table.go b/syncapi/storage/sqlite3/filter_table.go index 3e8a46551..338b0b500 100644 --- a/syncapi/storage/sqlite3/filter_table.go +++ b/syncapi/storage/sqlite3/filter_table.go @@ -52,7 +52,7 @@ const insertFilterSQL = "" + type filterStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter selectFilterStmt *sql.Stmt selectFilterIDByContentStmt *sql.Stmt insertFilterStmt *sql.Stmt diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go index 19e7a7c68..0bbd79f77 100644 --- a/syncapi/storage/sqlite3/invites_table.go +++ b/syncapi/storage/sqlite3/invites_table.go @@ -59,7 +59,7 @@ const selectMaxInviteIDSQL = "" + type inviteEventsStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter streamIDStatements *streamIDStatements insertInviteEventStmt *sql.Stmt selectInviteEventsInRangeStmt *sql.Stmt diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index 12b4dbabe..0d1546507 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -105,7 +105,7 @@ const selectStateInRangeSQL = "" + type outputRoomEventsStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter streamIDStatements *streamIDStatements insertEventStmt *sql.Stmt selectEventsStmt *sql.Stmt diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go index 2e71e8f33..5c4ab005f 100644 --- a/syncapi/storage/sqlite3/output_room_events_topology_table.go +++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go @@ -67,7 +67,7 @@ const selectMaxPositionInTopologySQL = "" + type outputRoomEventsTopologyStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter insertEventInTopologyStmt *sql.Stmt selectEventIDsInRangeASCStmt *sql.Stmt selectEventIDsInRangeDESCStmt *sql.Stmt diff --git a/syncapi/storage/sqlite3/send_to_device_table.go b/syncapi/storage/sqlite3/send_to_device_table.go index 88b319fb3..53786589c 100644 --- a/syncapi/storage/sqlite3/send_to_device_table.go +++ b/syncapi/storage/sqlite3/send_to_device_table.go @@ -73,7 +73,7 @@ const deleteSendToDeviceMessagesSQL = ` type sendToDeviceStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter insertSendToDeviceMessageStmt *sql.Stmt selectSendToDeviceMessagesStmt *sql.Stmt countSendToDeviceMessagesStmt *sql.Stmt diff --git a/syncapi/storage/sqlite3/stream_id_table.go b/syncapi/storage/sqlite3/stream_id_table.go index cf3eed5ba..1971e7f3b 100644 --- a/syncapi/storage/sqlite3/stream_id_table.go +++ b/syncapi/storage/sqlite3/stream_id_table.go @@ -28,7 +28,7 @@ const selectStreamIDStmt = "" + type streamIDStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter increaseStreamIDStmt *sql.Stmt selectStreamIDStmt *sql.Stmt } diff --git a/userapi/storage/accounts/sqlite3/account_data_table.go b/userapi/storage/accounts/sqlite3/account_data_table.go index cb54412ab..9b40e6579 100644 --- a/userapi/storage/accounts/sqlite3/account_data_table.go +++ b/userapi/storage/accounts/sqlite3/account_data_table.go @@ -51,7 +51,7 @@ const selectAccountDataByTypeSQL = "" + type accountDataStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter insertAccountDataStmt *sql.Stmt selectAccountDataStmt *sql.Stmt selectAccountDataByTypeStmt *sql.Stmt diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/accounts/sqlite3/accounts_table.go index 27c3d845a..586bcab91 100644 --- a/userapi/storage/accounts/sqlite3/accounts_table.go +++ b/userapi/storage/accounts/sqlite3/accounts_table.go @@ -59,7 +59,7 @@ const selectNewNumericLocalpartSQL = "" + type accountsStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter insertAccountStmt *sql.Stmt selectAccountByLocalpartStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt diff --git a/userapi/storage/accounts/sqlite3/profile_table.go b/userapi/storage/accounts/sqlite3/profile_table.go index d4c404ca3..cd35d2982 100644 --- a/userapi/storage/accounts/sqlite3/profile_table.go +++ b/userapi/storage/accounts/sqlite3/profile_table.go @@ -53,7 +53,7 @@ const selectProfilesBySearchSQL = "" + type profilesStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter insertProfileStmt *sql.Stmt selectProfileByLocalpartStmt *sql.Stmt setAvatarURLStmt *sql.Stmt diff --git a/userapi/storage/accounts/sqlite3/threepid_table.go b/userapi/storage/accounts/sqlite3/threepid_table.go index 0104e8346..3000d7c43 100644 --- a/userapi/storage/accounts/sqlite3/threepid_table.go +++ b/userapi/storage/accounts/sqlite3/threepid_table.go @@ -54,7 +54,7 @@ const deleteThreePIDSQL = "" + type threepidStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter selectLocalpartForThreePIDStmt *sql.Stmt selectThreePIDsForLocalpartStmt *sql.Stmt insertThreePIDStmt *sql.Stmt diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go index 9b535aab9..962e63b03 100644 --- a/userapi/storage/devices/sqlite3/devices_table.go +++ b/userapi/storage/devices/sqlite3/devices_table.go @@ -78,7 +78,7 @@ const selectDevicesByIDSQL = "" + type devicesStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter + writer sqlutil.TransactionWriter insertDeviceStmt *sql.Stmt selectDevicesCountStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt