From e708cb73aa26d5e840f1398ce1350c7a4e0b78fe Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 21 Aug 2020 09:55:46 +0100 Subject: [PATCH] Move writers up a layer in sync API --- internal/sqlutil/writer_exclusive.go | 2 +- syncapi/storage/postgres/syncserver.go | 2 +- syncapi/storage/shared/syncserver.go | 31 ++++++---- syncapi/storage/sqlite3/account_data_table.go | 20 +++---- .../sqlite3/backwards_extremities_table.go | 19 ++---- .../sqlite3/current_room_state_table.go | 42 ++++++-------- syncapi/storage/sqlite3/filter_table.go | 58 +++++++++---------- syncapi/storage/sqlite3/invites_table.go | 39 +++++-------- .../sqlite3/output_room_events_table.go | 47 +++++++-------- .../output_room_events_topology_table.go | 18 +++--- .../storage/sqlite3/send_to_device_table.go | 24 +++----- syncapi/storage/sqlite3/stream_id_table.go | 17 ++---- syncapi/storage/sqlite3/syncserver.go | 20 +++---- 13 files changed, 144 insertions(+), 195 deletions(-) diff --git a/internal/sqlutil/writer_exclusive.go b/internal/sqlutil/writer_exclusive.go index eff73713d..75eeba3d8 100644 --- a/internal/sqlutil/writer_exclusive.go +++ b/internal/sqlutil/writer_exclusive.go @@ -68,7 +68,7 @@ func (w *ExclusiveWriter) run() { return task.f(txn) }) } else { - panic("expected database or transaction but got neither") + task.wait <- task.f(nil) } close(task.wait) } diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go index 5365674b0..36e8de67f 100644 --- a/syncapi/storage/postgres/syncserver.go +++ b/syncapi/storage/postgres/syncserver.go @@ -80,6 +80,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e } d.Database = shared.Database{ DB: d.db, + Writer: sqlutil.NewDummyWriter(), Invites: invites, AccountData: accountData, OutputEvents: events, @@ -88,7 +89,6 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e BackwardExtremities: backwardExtremities, Filter: filter, SendToDevice: sendToDevice, - SendToDeviceWriter: sqlutil.NewExclusiveWriter(), EDUCache: cache.New(), } return &d, nil diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 28eadfea1..699a66472 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -37,6 +37,7 @@ import ( // For now this contains the shared functions type Database struct { DB *sql.DB + Writer sqlutil.Writer Invites tables.Invites AccountData tables.AccountData OutputEvents tables.Events @@ -45,7 +46,6 @@ type Database struct { BackwardExtremities tables.BackwardsExtremities SendToDevice tables.SendToDevice Filter tables.Filter - SendToDeviceWriter sqlutil.Writer EDUCache *cache.EDUCache } @@ -129,10 +129,7 @@ func (d *Database) GetStateEvent( func (d *Database) GetStateEventsForRoom( ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter, ) (stateEvents []gomatrixserverlib.HeaderedEvent, err error) { - err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { - stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, txn, roomID, stateFilter) - return err - }) + stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, nil, roomID, stateFilter) return } @@ -171,15 +168,23 @@ func (d *Database) SyncStreamPosition(ctx context.Context) (types.StreamPosition func (d *Database) AddInviteEvent( ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent, ) (sp types.StreamPosition, err error) { - return d.Invites.InsertInviteEvent(ctx, nil, inviteEvent) + _ = d.Writer.Do(nil, nil, func(_ *sql.Tx) error { + sp, err = d.Invites.InsertInviteEvent(ctx, nil, inviteEvent) + return nil + }) + return } // RetireInviteEvent removes an old invite event from the database. // Returns an error if there was a problem communicating with the database. func (d *Database) RetireInviteEvent( ctx context.Context, inviteEventID string, -) (types.StreamPosition, error) { - return d.Invites.DeleteInviteEvent(ctx, inviteEventID) +) (sp types.StreamPosition, err error) { + _ = d.Writer.Do(nil, nil, func(_ *sql.Tx) error { + sp, err = d.Invites.DeleteInviteEvent(ctx, inviteEventID) + return nil + }) + return } // GetAccountDataInRange returns all account data for a given user inserted or @@ -203,7 +208,7 @@ func (d *Database) GetAccountDataInRange( func (d *Database) UpsertAccountData( ctx context.Context, userID, roomID, dataType string, ) (sp types.StreamPosition, err error) { - err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { sp, err = d.AccountData.InsertAccountData(ctx, txn, userID, roomID, dataType) return err }) @@ -233,6 +238,7 @@ func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.Strea // handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of // the events listed in the event's 'prev_events'. This function also updates the backwards extremities table // to account for the fact that the given event is no longer a backwards extremity, but may be marked as such. +// This function should always be called within a sqlutil.Writer for safety in SQLite. func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, ev *gomatrixserverlib.HeaderedEvent) error { if err := d.BackwardExtremities.DeleteBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID()); err != nil { return err @@ -271,7 +277,7 @@ func (d *Database) WriteEvent( addStateEventIDs, removeStateEventIDs []string, transactionID *api.TransactionID, excludeFromSync bool, ) (pduPosition types.StreamPosition, returnErr error) { - returnErr = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { var err error pos, err := d.OutputEvents.InsertEvent( ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, @@ -300,6 +306,7 @@ func (d *Database) WriteEvent( return pduPosition, returnErr } +// This function should always be called within a sqlutil.Writer for safety in SQLite. func (d *Database) updateRoomState( ctx context.Context, txn *sql.Tx, removedEventIDs []string, @@ -1110,7 +1117,7 @@ func (d *Database) StoreNewSendForDeviceMessage( } // Delegate the database write task to the SendToDeviceWriter. It'll guarantee // that we don't lock the table for writes in more than one place. - err = d.SendToDeviceWriter.Do(d.DB, nil, func(txn *sql.Tx) error { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.AddSendToDeviceEvent( ctx, txn, userID, deviceID, string(j), ) @@ -1175,7 +1182,7 @@ func (d *Database) CleanSendToDeviceUpdates( // If we need to write to the database then we'll ask the SendToDeviceWriter to // do that for us. It'll guarantee that we don't lock the table for writes in // more than one place. - err = d.SendToDeviceWriter.Do(d.DB, nil, func(txn *sql.Tx) error { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { // Delete any send-to-device messages marked for deletion. if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil { return fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e) diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index 81004455f..72c46e48d 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -20,7 +20,6 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -51,17 +50,15 @@ const selectMaxAccountDataIDSQL = "" + type accountDataStatements struct { db *sql.DB - writer sqlutil.Writer streamIDStatements *streamIDStatements insertAccountDataStmt *sql.Stmt selectMaxAccountDataIDStmt *sql.Stmt selectAccountDataInRangeStmt *sql.Stmt } -func NewSqliteAccountDataTable(db *sql.DB, writer sqlutil.Writer, streamID *streamIDStatements) (tables.AccountData, error) { +func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) { s := &accountDataStatements{ db: db, - writer: writer, streamIDStatements: streamID, } _, err := db.Exec(accountDataSchema) @@ -84,15 +81,12 @@ func (s *accountDataStatements) InsertAccountData( ctx context.Context, txn *sql.Tx, userID, roomID, dataType string, ) (pos types.StreamPosition, err error) { - return pos, s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - var err error - pos, err = s.streamIDStatements.nextStreamID(ctx, txn) - if err != nil { - return err - } - _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType) - return err - }) + pos, err = s.streamIDStatements.nextStreamID(ctx, txn) + if err != nil { + return + } + _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType) + return } func (s *accountDataStatements) SelectAccountDataInRange( diff --git a/syncapi/storage/sqlite3/backwards_extremities_table.go b/syncapi/storage/sqlite3/backwards_extremities_table.go index 123739e08..116c33dc4 100644 --- a/syncapi/storage/sqlite3/backwards_extremities_table.go +++ b/syncapi/storage/sqlite3/backwards_extremities_table.go @@ -19,7 +19,6 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" ) @@ -49,16 +48,14 @@ const deleteBackwardExtremitySQL = "" + type backwardExtremitiesStatements struct { db *sql.DB - writer sqlutil.Writer insertBackwardExtremityStmt *sql.Stmt selectBackwardExtremitiesForRoomStmt *sql.Stmt deleteBackwardExtremityStmt *sql.Stmt } -func NewSqliteBackwardsExtremitiesTable(db *sql.DB, writer sqlutil.Writer) (tables.BackwardsExtremities, error) { +func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) { s := &backwardExtremitiesStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(backwardExtremitiesSchema) if err != nil { @@ -79,10 +76,8 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB, writer sqlutil.Writer) (tabl func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string, ) (err error) { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - _, err := txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID) - return err - }) + _, err = txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID) + return err } func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom( @@ -110,8 +105,6 @@ func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom( func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( ctx context.Context, txn *sql.Tx, roomID, knownEventID string, ) (err error) { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - _, err := txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) - return err - }) + _, err = txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) + return err } diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index 1bbaaf234..94557aa16 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -85,7 +85,6 @@ const selectEventsWithEventIDsSQL = "" + type currentRoomStateStatements struct { db *sql.DB - writer sqlutil.Writer streamIDStatements *streamIDStatements upsertRoomStateStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt @@ -95,10 +94,9 @@ type currentRoomStateStatements struct { selectStateEventStmt *sql.Stmt } -func NewSqliteCurrentRoomStateTable(db *sql.DB, writer sqlutil.Writer, streamID *streamIDStatements) (tables.CurrentRoomState, error) { +func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) { s := ¤tRoomStateStatements{ db: db, - writer: writer, streamIDStatements: streamID, } _, err := db.Exec(currentRoomStateSchema) @@ -200,11 +198,9 @@ func (s *currentRoomStateStatements) SelectCurrentState( func (s *currentRoomStateStatements) DeleteRoomStateByEventID( ctx context.Context, txn *sql.Tx, eventID string, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt) - _, err := stmt.ExecContext(ctx, eventID) - return err - }) + stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt) + _, err := stmt.ExecContext(ctx, eventID) + return err } func (s *currentRoomStateStatements) UpsertRoomState( @@ -225,22 +221,20 @@ func (s *currentRoomStateStatements) UpsertRoomState( } // upsert state event - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt) - _, err := stmt.ExecContext( - ctx, - event.RoomID(), - event.EventID(), - event.Type(), - event.Sender(), - containsURL, - *event.StateKey(), - headeredJSON, - membership, - addedAt, - ) - return err - }) + stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt) + _, err = stmt.ExecContext( + ctx, + event.RoomID(), + event.EventID(), + event.Type(), + event.Sender(), + containsURL, + *event.StateKey(), + headeredJSON, + membership, + addedAt, + ) + return err } func minOfInts(a, b int) int { diff --git a/syncapi/storage/sqlite3/filter_table.go b/syncapi/storage/sqlite3/filter_table.go index edabbef61..3092bcd7d 100644 --- a/syncapi/storage/sqlite3/filter_table.go +++ b/syncapi/storage/sqlite3/filter_table.go @@ -20,7 +20,6 @@ import ( "encoding/json" "fmt" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/gomatrixserverlib" ) @@ -52,20 +51,18 @@ const insertFilterSQL = "" + type filterStatements struct { db *sql.DB - writer sqlutil.Writer selectFilterStmt *sql.Stmt selectFilterIDByContentStmt *sql.Stmt insertFilterStmt *sql.Stmt } -func NewSqliteFilterTable(db *sql.DB, writer sqlutil.Writer) (tables.Filter, error) { +func NewSqliteFilterTable(db *sql.DB) (tables.Filter, error) { _, err := db.Exec(filterSchema) if err != nil { return nil, err } s := &filterStatements{ - db: db, - writer: writer, + db: db, } if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil { return nil, err @@ -114,33 +111,30 @@ func (s *filterStatements) InsertFilter( return "", err } - err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - // Check if filter already exists in the database using its localpart and content - // - // This can result in a race condition when two clients try to insert the - // same filter and localpart at the same time, however this is not a - // problem as both calls will result in the same filterID - err = s.selectFilterIDByContentStmt.QueryRowContext(ctx, - localpart, filterJSON).Scan(&existingFilterID) - if err != nil && err != sql.ErrNoRows { - return err - } - // If it does, return the existing ID - if existingFilterID != "" { - return nil - } + // Check if filter already exists in the database using its localpart and content + // + // This can result in a race condition when two clients try to insert the + // same filter and localpart at the same time, however this is not a + // problem as both calls will result in the same filterID + err = s.selectFilterIDByContentStmt.QueryRowContext(ctx, + localpart, filterJSON).Scan(&existingFilterID) + if err != nil && err != sql.ErrNoRows { + return "", err + } + // If it does, return the existing ID + if existingFilterID != "" { + return existingFilterID, nil + } - // Otherwise insert the filter and return the new ID - res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart) - if err != nil { - return err - } - rowid, err := res.LastInsertId() - if err != nil { - return err - } - filterID = fmt.Sprintf("%d", rowid) - return nil - }) + // Otherwise insert the filter and return the new ID + res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart) + if err != nil { + return "", err + } + rowid, err := res.LastInsertId() + if err != nil { + return "", err + } + filterID = fmt.Sprintf("%d", rowid) return } diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go index 72a19528b..45862efbb 100644 --- a/syncapi/storage/sqlite3/invites_table.go +++ b/syncapi/storage/sqlite3/invites_table.go @@ -59,7 +59,6 @@ const selectMaxInviteIDSQL = "" + type inviteEventsStatements struct { db *sql.DB - writer sqlutil.Writer streamIDStatements *streamIDStatements insertInviteEventStmt *sql.Stmt selectInviteEventsInRangeStmt *sql.Stmt @@ -67,10 +66,9 @@ type inviteEventsStatements struct { selectMaxInviteIDStmt *sql.Stmt } -func NewSqliteInvitesTable(db *sql.DB, writer sqlutil.Writer, streamID *streamIDStatements) (tables.Invites, error) { +func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) { s := &inviteEventsStatements{ db: db, - writer: writer, streamIDStatements: streamID, } _, err := db.Exec(inviteEventsSchema) @@ -100,23 +98,21 @@ func (s *inviteEventsStatements) InsertInviteEvent( return } - err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - var headeredJSON []byte - headeredJSON, err = json.Marshal(inviteEvent) - if err != nil { - return err - } + var headeredJSON []byte + headeredJSON, err = json.Marshal(inviteEvent) + if err != nil { + return + } - _, err = txn.Stmt(s.insertInviteEventStmt).ExecContext( - ctx, - streamPos, - inviteEvent.RoomID(), - inviteEvent.EventID(), - *inviteEvent.StateKey(), - headeredJSON, - ) - return err - }) + stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt) + _, err = stmt.ExecContext( + ctx, + streamPos, + inviteEvent.RoomID(), + inviteEvent.EventID(), + *inviteEvent.StateKey(), + headeredJSON, + ) return } @@ -127,10 +123,7 @@ func (s *inviteEventsStatements) DeleteInviteEvent( if err != nil { return streamPos, err } - err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - _, err = s.deleteInviteEventStmt.ExecContext(ctx, streamPos, inviteEventID) - return err - }) + _, err = s.deleteInviteEventStmt.ExecContext(ctx, streamPos, inviteEventID) return streamPos, err } diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index 2e6f37848..f10d01066 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -105,7 +105,6 @@ const selectStateInRangeSQL = "" + type outputRoomEventsStatements struct { db *sql.DB - writer sqlutil.Writer streamIDStatements *streamIDStatements insertEventStmt *sql.Stmt selectEventsStmt *sql.Stmt @@ -117,10 +116,9 @@ type outputRoomEventsStatements struct { updateEventJSONStmt *sql.Stmt } -func NewSqliteEventsTable(db *sql.DB, writer sqlutil.Writer, streamID *streamIDStatements) (tables.Events, error) { +func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) { s := &outputRoomEventsStatements{ db: db, - writer: writer, streamIDStatements: streamID, } _, err := db.Exec(outputRoomEventsSchema) @@ -159,10 +157,8 @@ func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event if err != nil { return err } - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - _, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID()) - return err - }) + _, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID()) + return err } // selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos. @@ -308,26 +304,23 @@ func (s *outputRoomEventsStatements) InsertEvent( if err != nil { return 0, err } - err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) - _, ierr := insertStmt.ExecContext( - ctx, - streamPos, - event.RoomID(), - event.EventID(), - headeredJSON, - event.Type(), - event.Sender(), - containsURL, - string(addStateJSON), - string(removeStateJSON), - sessionID, - txnID, - excludeFromSync, - excludeFromSync, - ) - return ierr - }) + insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) + _, err = insertStmt.ExecContext( + ctx, + streamPos, + event.RoomID(), + event.EventID(), + headeredJSON, + event.Type(), + event.Sender(), + containsURL, + string(addStateJSON), + string(removeStateJSON), + sessionID, + txnID, + excludeFromSync, + excludeFromSync, + ) return streamPos, err } diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go index 8aeb0041f..d8c97b7e3 100644 --- a/syncapi/storage/sqlite3/output_room_events_topology_table.go +++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go @@ -67,7 +67,6 @@ const selectMaxPositionInTopologySQL = "" + type outputRoomEventsTopologyStatements struct { db *sql.DB - writer sqlutil.Writer insertEventInTopologyStmt *sql.Stmt selectEventIDsInRangeASCStmt *sql.Stmt selectEventIDsInRangeDESCStmt *sql.Stmt @@ -75,10 +74,9 @@ type outputRoomEventsTopologyStatements struct { selectMaxPositionInTopologyStmt *sql.Stmt } -func NewSqliteTopologyTable(db *sql.DB, writer sqlutil.Writer) (tables.Topology, error) { +func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { s := &outputRoomEventsTopologyStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(outputRoomEventsTopologySchema) if err != nil { @@ -107,13 +105,11 @@ func NewSqliteTopologyTable(db *sql.DB, writer sqlutil.Writer) (tables.Topology, func (s *outputRoomEventsTopologyStatements) InsertEventInTopology( ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition, ) (err error) { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt) - _, err := stmt.ExecContext( - ctx, event.EventID(), event.Depth(), event.RoomID(), pos, - ) - return err - }) + stmt := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt) + _, err = stmt.ExecContext( + ctx, event.EventID(), event.Depth(), event.RoomID(), pos, + ) + return } func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange( diff --git a/syncapi/storage/sqlite3/send_to_device_table.go b/syncapi/storage/sqlite3/send_to_device_table.go index cbe067f49..fbc759b12 100644 --- a/syncapi/storage/sqlite3/send_to_device_table.go +++ b/syncapi/storage/sqlite3/send_to_device_table.go @@ -73,16 +73,14 @@ const deleteSendToDeviceMessagesSQL = ` type sendToDeviceStatements struct { db *sql.DB - writer sqlutil.Writer insertSendToDeviceMessageStmt *sql.Stmt selectSendToDeviceMessagesStmt *sql.Stmt countSendToDeviceMessagesStmt *sql.Stmt } -func NewSqliteSendToDeviceTable(db *sql.DB, writer sqlutil.Writer) (tables.SendToDevice, error) { +func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { s := &sendToDeviceStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(sendToDeviceSchema) if err != nil { @@ -103,10 +101,8 @@ func NewSqliteSendToDeviceTable(db *sql.DB, writer sqlutil.Writer) (tables.SendT func (s *sendToDeviceStatements) InsertSendToDeviceMessage( ctx context.Context, txn *sql.Tx, userID, deviceID, content string, ) (err error) { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - _, err := sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content) - return err - }) + _, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content) + return } func (s *sendToDeviceStatements) CountSendToDeviceMessages( @@ -163,10 +159,8 @@ func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages( for k, v := range nids { params[k+1] = v } - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - _, err := txn.ExecContext(ctx, query, params...) - return err - }) + _, err = txn.ExecContext(ctx, query, params...) + return } func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( @@ -177,8 +171,6 @@ func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( for k, v := range nids { params[k] = v } - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - _, err := txn.ExecContext(ctx, query, params...) - return err - }) + _, err = txn.ExecContext(ctx, query, params...) + return } diff --git a/syncapi/storage/sqlite3/stream_id_table.go b/syncapi/storage/sqlite3/stream_id_table.go index be2c0387f..e6bdc4fcb 100644 --- a/syncapi/storage/sqlite3/stream_id_table.go +++ b/syncapi/storage/sqlite3/stream_id_table.go @@ -28,14 +28,12 @@ const selectStreamIDStmt = "" + type streamIDStatements struct { db *sql.DB - writer sqlutil.Writer increaseStreamIDStmt *sql.Stmt selectStreamIDStmt *sql.Stmt } -func (s *streamIDStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { +func (s *streamIDStatements) prepare(db *sql.DB) (err error) { s.db = db - s.writer = writer _, err = db.Exec(streamIDTableSchema) if err != nil { return @@ -52,14 +50,9 @@ func (s *streamIDStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err err func (s *streamIDStatements) nextStreamID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt) - err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - if _, ierr := increaseStmt.ExecContext(ctx, "global"); err != nil { - return ierr - } - if serr := selectStmt.QueryRowContext(ctx, "global").Scan(&pos); err != nil { - return serr - } - return nil - }) + if _, err = increaseStmt.ExecContext(ctx, "global"); err != nil { + return + } + err = selectStmt.QueryRowContext(ctx, "global").Scan(&pos) return } diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index 7d8dbf8ac..81197bb76 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -56,43 +56,44 @@ func (d *SyncServerDatasource) prepare() (err error) { if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "syncapi"); err != nil { return err } - if err = d.streamID.prepare(d.db, d.writer); err != nil { + if err = d.streamID.prepare(d.db); err != nil { return err } - accountData, err := NewSqliteAccountDataTable(d.db, d.writer, &d.streamID) + accountData, err := NewSqliteAccountDataTable(d.db, &d.streamID) if err != nil { return err } - events, err := NewSqliteEventsTable(d.db, d.writer, &d.streamID) + events, err := NewSqliteEventsTable(d.db, &d.streamID) if err != nil { return err } - roomState, err := NewSqliteCurrentRoomStateTable(d.db, d.writer, &d.streamID) + roomState, err := NewSqliteCurrentRoomStateTable(d.db, &d.streamID) if err != nil { return err } - invites, err := NewSqliteInvitesTable(d.db, d.writer, &d.streamID) + invites, err := NewSqliteInvitesTable(d.db, &d.streamID) if err != nil { return err } - topology, err := NewSqliteTopologyTable(d.db, d.writer) + topology, err := NewSqliteTopologyTable(d.db) if err != nil { return err } - bwExtrem, err := NewSqliteBackwardsExtremitiesTable(d.db, d.writer) + bwExtrem, err := NewSqliteBackwardsExtremitiesTable(d.db) if err != nil { return err } - sendToDevice, err := NewSqliteSendToDeviceTable(d.db, d.writer) + sendToDevice, err := NewSqliteSendToDeviceTable(d.db) if err != nil { return err } - filter, err := NewSqliteFilterTable(d.db, d.writer) + filter, err := NewSqliteFilterTable(d.db) if err != nil { return err } d.Database = shared.Database{ DB: d.db, + Writer: sqlutil.NewExclusiveWriter(), Invites: invites, AccountData: accountData, OutputEvents: events, @@ -101,7 +102,6 @@ func (d *SyncServerDatasource) prepare() (err error) { Topology: topology, Filter: filter, SendToDevice: sendToDevice, - SendToDeviceWriter: sqlutil.NewExclusiveWriter(), EDUCache: cache.New(), } return nil