diff --git a/appservice/storage/sqlite3/appservice_events_table.go b/appservice/storage/sqlite3/appservice_events_table.go index 479f2213c..da31f2359 100644 --- a/appservice/storage/sqlite3/appservice_events_table.go +++ b/appservice/storage/sqlite3/appservice_events_table.go @@ -21,6 +21,7 @@ import ( "encoding/json" "time" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" ) @@ -65,6 +66,8 @@ const ( ) type eventsStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter selectEventsByApplicationServiceIDStmt *sql.Stmt countEventsByApplicationServiceIDStmt *sql.Stmt insertEventStmt *sql.Stmt @@ -73,6 +76,8 @@ type eventsStatements struct { } func (s *eventsStatements) prepare(db *sql.DB) (err error) { + s.db = db + s.writer = sqlutil.NewTransactionWriter() _, err = db.Exec(appserviceEventsSchema) if err != nil { return @@ -217,13 +222,15 @@ func (s *eventsStatements) insertEvent( return err } - _, err = s.insertEventStmt.ExecContext( - ctx, - appServiceID, - eventJSON, - -1, // No transaction ID yet - ) - return + return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + _, err := s.insertEventStmt.ExecContext( + ctx, + appServiceID, + eventJSON, + -1, // No transaction ID yet + ) + return err + }) } // updateTxnIDForEvents sets the transactionID for a collection of events. Done @@ -234,8 +241,10 @@ func (s *eventsStatements) updateTxnIDForEvents( appserviceID string, maxID, txnID int, ) (err error) { - _, err = s.updateTxnIDForEventsStmt.ExecContext(ctx, txnID, appserviceID, maxID) - return + return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + _, err := s.updateTxnIDForEventsStmt.ExecContext(ctx, txnID, appserviceID, maxID) + return err + }) } // deleteEventsBeforeAndIncludingID removes events matching given IDs from the database. @@ -244,6 +253,8 @@ func (s *eventsStatements) deleteEventsBeforeAndIncludingID( appserviceID string, eventTableID int, ) (err error) { - _, err = s.deleteEventsBeforeAndIncludingIDStmt.ExecContext(ctx, appserviceID, eventTableID) - return + return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + _, err := s.deleteEventsBeforeAndIncludingIDStmt.ExecContext(ctx, appserviceID, eventTableID) + return err + }) } diff --git a/appservice/storage/sqlite3/txn_id_counter_table.go b/appservice/storage/sqlite3/txn_id_counter_table.go index b1ee60766..501ab5aa7 100644 --- a/appservice/storage/sqlite3/txn_id_counter_table.go +++ b/appservice/storage/sqlite3/txn_id_counter_table.go @@ -18,6 +18,8 @@ package sqlite3 import ( "context" "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" ) const txnIDSchema = ` @@ -35,10 +37,14 @@ const selectTxnIDSQL = ` ` type txnStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter selectTxnIDStmt *sql.Stmt } func (s *txnStatements) prepare(db *sql.DB) (err error) { + s.db = db + s.writer = sqlutil.NewTransactionWriter() _, err = db.Exec(txnIDSchema) if err != nil { return @@ -55,6 +61,9 @@ func (s *txnStatements) prepare(db *sql.DB) (err error) { func (s *txnStatements) selectTxnID( ctx context.Context, ) (txnID int, err error) { - err = s.selectTxnIDStmt.QueryRowContext(ctx).Scan(&txnID) + err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + err := s.selectTxnIDStmt.QueryRowContext(ctx).Scan(&txnID) + return err + }) return } diff --git a/currentstateserver/storage/sqlite3/current_room_state_table.go b/currentstateserver/storage/sqlite3/current_room_state_table.go index 8fac4f352..b95fb4350 100644 --- a/currentstateserver/storage/sqlite3/current_room_state_table.go +++ b/currentstateserver/storage/sqlite3/current_room_state_table.go @@ -68,6 +68,7 @@ const selectBulkStateContentWildSQL = "" + type currentRoomStateStatements struct { db *sql.DB + writer *sqlutil.TransactionWriter upsertRoomStateStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt selectRoomIDsWithMembershipStmt *sql.Stmt @@ -76,7 +77,8 @@ type currentRoomStateStatements struct { func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) { s := ¤tRoomStateStatements{ - db: db, + db: db, + writer: sqlutil.NewTransactionWriter(), } _, err := db.Exec(currentRoomStateSchema) if err != nil { @@ -125,9 +127,11 @@ func (s *currentRoomStateStatements) SelectRoomIDsWithMembership( func (s *currentRoomStateStatements) DeleteRoomStateByEventID( ctx context.Context, txn *sql.Tx, eventID string, ) error { - stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt) - _, err := stmt.ExecContext(ctx, eventID) - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt) + _, err := stmt.ExecContext(ctx, eventID) + return err + }) } func (s *currentRoomStateStatements) UpsertRoomState( @@ -140,18 +144,20 @@ func (s *currentRoomStateStatements) UpsertRoomState( } // upsert state event - stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt) - _, err = stmt.ExecContext( - ctx, - event.RoomID(), - event.EventID(), - event.Type(), - event.Sender(), - *event.StateKey(), - headeredJSON, - contentVal, - ) - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt) + _, err = stmt.ExecContext( + ctx, + event.RoomID(), + event.EventID(), + event.Type(), + event.Sender(), + *event.StateKey(), + headeredJSON, + contentVal, + ) + return err + }) } func (s *currentRoomStateStatements) SelectEventsWithEventIDs( diff --git a/mediaapi/storage/sqlite3/media_repository_table.go b/mediaapi/storage/sqlite3/media_repository_table.go index 8e2e6236a..f53f164d4 100644 --- a/mediaapi/storage/sqlite3/media_repository_table.go +++ b/mediaapi/storage/sqlite3/media_repository_table.go @@ -20,6 +20,7 @@ import ( "database/sql" "time" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -60,11 +61,16 @@ SELECT content_type, file_size_bytes, creation_ts, upload_name, base64hash, user ` type mediaStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter insertMediaStmt *sql.Stmt selectMediaStmt *sql.Stmt } func (s *mediaStatements) prepare(db *sql.DB) (err error) { + s.db = db + s.writer = sqlutil.NewTransactionWriter() + _, err = db.Exec(mediaSchema) if err != nil { return @@ -80,18 +86,21 @@ func (s *mediaStatements) insertMedia( ctx context.Context, mediaMetadata *types.MediaMetadata, ) error { mediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000) - _, err := s.insertMediaStmt.ExecContext( - ctx, - mediaMetadata.MediaID, - mediaMetadata.Origin, - mediaMetadata.ContentType, - mediaMetadata.FileSizeBytes, - mediaMetadata.CreationTimestamp, - mediaMetadata.UploadName, - mediaMetadata.Base64Hash, - mediaMetadata.UserID, - ) - return err + return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.insertMediaStmt) + _, err := stmt.ExecContext( + ctx, + mediaMetadata.MediaID, + mediaMetadata.Origin, + mediaMetadata.ContentType, + mediaMetadata.FileSizeBytes, + mediaMetadata.CreationTimestamp, + mediaMetadata.UploadName, + mediaMetadata.Base64Hash, + mediaMetadata.UserID, + ) + return err + }) } func (s *mediaStatements) selectMedia( diff --git a/roomserver/internal/input_events.go b/roomserver/internal/input_events.go index 04538cf69..a63082990 100644 --- a/roomserver/internal/input_events.go +++ b/roomserver/internal/input_events.go @@ -18,6 +18,7 @@ package internal import ( "context" + "fmt" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" @@ -65,13 +66,13 @@ func (r *RoomserverInternalAPI) processRoomEvent( // Store the event. roomNID, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, input.TransactionID, authEventNIDs) if err != nil { - return + return "", fmt.Errorf("r.DB.StoreEvent: %w", err) } // if storing this event results in it being redacted then do so. if redactedEventID == event.EventID() { r, rerr := eventutil.RedactEvent(redactionEvent, &event) if rerr != nil { - return "", rerr + return "", fmt.Errorf("eventutil.RedactEvent: %w", rerr) } event = *r } @@ -93,7 +94,7 @@ func (r *RoomserverInternalAPI) processRoomEvent( // Lets calculate one. err = r.calculateAndSetState(ctx, input, roomNID, &stateAtEvent, event) if err != nil { - return + return "", fmt.Errorf("r.calculateAndSetState: %w", err) } } @@ -105,7 +106,7 @@ func (r *RoomserverInternalAPI) processRoomEvent( input.SendAsServer, // send as server input.TransactionID, // transaction ID ); err != nil { - return + return "", fmt.Errorf("r.updateLatestEvents: %w", err) } // processing this event resulted in an event (which may not be the one we're processing) @@ -123,7 +124,7 @@ func (r *RoomserverInternalAPI) processRoomEvent( }, }) if err != nil { - return + return "", fmt.Errorf("r.WriteOutputEvents: %w", err) } } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index e2e5daf95..e858a9b00 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "encoding/json" + "fmt" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" @@ -362,7 +363,7 @@ func (d *Database) StoreEvent( ctx, txn, txnAndSessionID.TransactionID, txnAndSessionID.SessionID, event.Sender(), event.EventID(), ); err != nil { - return err + return fmt.Errorf("d.TransactionsTable.InsertTransaction: %w", err) } } @@ -377,15 +378,15 @@ func (d *Database) StoreEvent( // room. var roomVersion gomatrixserverlib.RoomVersion if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil { - return err + return fmt.Errorf("extractRoomVersionFromCreateEvent: %w", err) } if roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion); err != nil { - return err + return fmt.Errorf("d.assignRoomNID: %w", err) } if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, event.Type()); err != nil { - return err + return fmt.Errorf("d.assignEventTypeNID: %w", err) } eventStateKey := event.StateKey() @@ -393,7 +394,7 @@ func (d *Database) StoreEvent( // Otherwise set the numeric ID for the state_key to 0. if eventStateKey != nil { if eventStateKeyNID, err = d.assignStateKeyNID(ctx, txn, *eventStateKey); err != nil { - return err + return fmt.Errorf("d.assignStateKeyNID: %w", err) } } @@ -411,17 +412,20 @@ func (d *Database) StoreEvent( if err == sql.ErrNoRows { // We've already inserted the event so select the numeric event ID eventNID, stateNID, err = d.EventsTable.SelectEvent(ctx, txn, event.EventID()) + if err != nil { + return fmt.Errorf("d.EventsTable.SelectEvent: %w", err) + } } if err != nil { - return err + return fmt.Errorf("d.EventsTable.InsertEvent: %w", err) } } if err = d.EventJSONTable.InsertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil { - return err + return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err) } redactionEvent, redactedEventID, err = d.handleRedactions(ctx, txn, eventNID, event) - return err + return nil }) if err != nil { return 0, types.StateAtEvent{}, nil, "", err diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index 378441c3a..b3cfee07e 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -287,7 +287,8 @@ func (s *eventStatements) UpdateEventState( ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, ) error { return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - _, err := s.updateEventStateStmt.ExecContext(ctx, int64(stateNID), int64(eventNID)) + stmt := sqlutil.TxStmt(txn, s.updateEventStateStmt) + _, err := stmt.ExecContext(ctx, int64(stateNID), int64(eventNID)) return err }) } diff --git a/roomserver/storage/sqlite3/published_table.go b/roomserver/storage/sqlite3/published_table.go index 965752419..85f1e0a49 100644 --- a/roomserver/storage/sqlite3/published_table.go +++ b/roomserver/storage/sqlite3/published_table.go @@ -71,7 +71,8 @@ func (s *publishedStatements) UpsertRoomPublished( ctx context.Context, roomID string, published bool, ) (err error) { return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - _, err := s.upsertPublishedStmt.ExecContext(ctx, roomID, published) + stmt := sqlutil.TxStmt(txn, s.upsertPublishedStmt) + _, err := stmt.ExecContext(ctx, roomID, published) return err }) } diff --git a/roomserver/storage/sqlite3/room_aliases_table.go b/roomserver/storage/sqlite3/room_aliases_table.go index 096b73f98..4a5357776 100644 --- a/roomserver/storage/sqlite3/room_aliases_table.go +++ b/roomserver/storage/sqlite3/room_aliases_table.go @@ -87,7 +87,8 @@ func (s *roomAliasesStatements) InsertRoomAlias( ctx context.Context, alias string, roomID string, creatorUserID string, ) (err error) { return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - _, err := s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID) + stmt := sqlutil.TxStmt(txn, s.insertRoomAliasStmt) + _, err := stmt.ExecContext(ctx, alias, roomID, creatorUserID) return err }) } @@ -139,7 +140,8 @@ func (s *roomAliasesStatements) DeleteRoomAlias( ctx context.Context, alias string, ) (err error) { return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - _, err := s.deleteRoomAliasStmt.ExecContext(ctx, alias) + stmt := sqlutil.TxStmt(txn, s.deleteRoomAliasStmt) + _, err := stmt.ExecContext(ctx, alias) return err }) } diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index 9eeadea94..bb30a63b3 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -20,6 +20,7 @@ import ( "database/sql" "encoding/json" "errors" + "fmt" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" @@ -98,17 +99,23 @@ func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) { func (s *roomStatements) InsertRoomNID( ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion, -) (types.RoomNID, error) { - err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error { +) (roomNID types.RoomNID, err error) { + err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error { insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt) - _, err := insertStmt.ExecContext(ctx, roomID, roomVersion) - return err + _, err = insertStmt.ExecContext(ctx, roomID, roomVersion) + if err != nil { + return fmt.Errorf("insertStmt.ExecContext: %w", err) + } + roomNID, err = s.SelectRoomNID(ctx, txn, roomID) + if err != nil { + return fmt.Errorf("s.SelectRoomNID: %w", err) + } + return nil }) - if err == nil { - return s.SelectRoomNID(ctx, txn, roomID) - } else { + if err != nil { return types.RoomNID(0), err } + return } func (s *roomStatements) SelectRoomNID( diff --git a/serverkeyapi/storage/sqlite3/server_key_table.go b/serverkeyapi/storage/sqlite3/server_key_table.go index 4f03dccbb..423292a54 100644 --- a/serverkeyapi/storage/sqlite3/server_key_table.go +++ b/serverkeyapi/storage/sqlite3/server_key_table.go @@ -63,12 +63,14 @@ const upsertServerKeysSQL = "" + type serverKeyStatements struct { db *sql.DB + writer *sqlutil.TransactionWriter bulkSelectServerKeysStmt *sql.Stmt upsertServerKeysStmt *sql.Stmt } func (s *serverKeyStatements) prepare(db *sql.DB) (err error) { s.db = db + s.writer = sqlutil.NewTransactionWriter() _, err = db.Exec(serverKeysSchema) if err != nil { return @@ -136,16 +138,19 @@ func (s *serverKeyStatements) upsertServerKeys( request gomatrixserverlib.PublicKeyLookupRequest, key gomatrixserverlib.PublicKeyLookupResult, ) error { - _, err := s.upsertServerKeysStmt.ExecContext( - ctx, - string(request.ServerName), - string(request.KeyID), - nameAndKeyID(request), - key.ValidUntilTS, - key.ExpiredTS, - key.Key.Encode(), - ) - return err + return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.upsertServerKeysStmt) + _, err := stmt.ExecContext( + ctx, + string(request.ServerName), + string(request.KeyID), + nameAndKeyID(request), + key.ValidUntilTS, + key.ExpiredTS, + key.Key.Encode(), + ) + return err + }) } func nameAndKeyID(request gomatrixserverlib.PublicKeyLookupRequest) string { diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 320792914..e1312671b 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -281,16 +281,16 @@ func (d *Database) WriteEvent( ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, ) if err != nil { - return err + return fmt.Errorf("d.OutputEvents.InsertEvent: %w", err) } pduPosition = pos if err = d.Topology.InsertEventInTopology(ctx, txn, ev, pos); err != nil { - return err + return fmt.Errorf("d.Topology.InsertEventInTopology: %w", err) } if err = d.handleBackwardExtremities(ctx, txn, ev); err != nil { - return err + return fmt.Errorf("d.handleBackwardExtremities: %w", err) } if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 { @@ -313,7 +313,7 @@ func (d *Database) updateRoomState( // remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add. for _, eventID := range removedEventIDs { if err := d.CurrentRoomState.DeleteRoomStateByEventID(ctx, txn, eventID); err != nil { - return err + return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateByEventID: %w", err) } } @@ -326,13 +326,13 @@ func (d *Database) updateRoomState( if event.Type() == "m.room.member" { value, err := event.Membership() if err != nil { - return err + return fmt.Errorf("event.Membership: %w", err) } membership = &value } if err := d.CurrentRoomState.UpsertRoomState(ctx, txn, event, membership, pduPosition); err != nil { - return err + return fmt.Errorf("d.CurrentRoomState.UpsertRoomState: %w", err) } } diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index ae5caa4e5..609cef141 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -20,6 +20,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -49,6 +50,8 @@ const selectMaxAccountDataIDSQL = "" + "SELECT MAX(id) FROM syncapi_account_data_type" type accountDataStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter streamIDStatements *streamIDStatements insertAccountDataStmt *sql.Stmt selectMaxAccountDataIDStmt *sql.Stmt @@ -57,6 +60,8 @@ type accountDataStatements struct { func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) { s := &accountDataStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), streamIDStatements: streamID, } _, err := db.Exec(accountDataSchema) @@ -79,12 +84,15 @@ func (s *accountDataStatements) InsertAccountData( ctx context.Context, txn *sql.Tx, userID, roomID, dataType string, ) (pos types.StreamPosition, err error) { - pos, err = s.streamIDStatements.nextStreamID(ctx, txn) - if err != nil { - return - } - _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType) - return + return pos, s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + var err error + pos, err = s.streamIDStatements.nextStreamID(ctx, txn) + if err != nil { + return err + } + _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType) + return err + }) } func (s *accountDataStatements) SelectAccountDataInRange( diff --git a/syncapi/storage/sqlite3/backwards_extremities_table.go b/syncapi/storage/sqlite3/backwards_extremities_table.go index e16e54a6f..1aeb041f4 100644 --- a/syncapi/storage/sqlite3/backwards_extremities_table.go +++ b/syncapi/storage/sqlite3/backwards_extremities_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" ) @@ -47,13 +48,18 @@ const deleteBackwardExtremitySQL = "" + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2" type backwardExtremitiesStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter insertBackwardExtremityStmt *sql.Stmt selectBackwardExtremitiesForRoomStmt *sql.Stmt deleteBackwardExtremityStmt *sql.Stmt } func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) { - s := &backwardExtremitiesStatements{} + s := &backwardExtremitiesStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), + } _, err := db.Exec(backwardExtremitiesSchema) if err != nil { return nil, err @@ -73,8 +79,10 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string, ) (err error) { - _, err = txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID) - return + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + _, err := txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID) + return err + }) } func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom( @@ -102,6 +110,8 @@ func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom( func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( ctx context.Context, txn *sql.Tx, roomID, knownEventID string, ) (err error) { - _, err = txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) - return + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + _, err := txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) + return err + }) } diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index 85f212ad8..08b42f5b4 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -84,6 +84,8 @@ const selectEventsWithEventIDsSQL = "" + " FROM syncapi_current_room_state WHERE event_id IN ($1)" type currentRoomStateStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter streamIDStatements *streamIDStatements upsertRoomStateStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt @@ -95,6 +97,8 @@ type currentRoomStateStatements struct { func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) { s := ¤tRoomStateStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), streamIDStatements: streamID, } _, err := db.Exec(currentRoomStateSchema) @@ -196,9 +200,11 @@ func (s *currentRoomStateStatements) SelectCurrentState( func (s *currentRoomStateStatements) DeleteRoomStateByEventID( ctx context.Context, txn *sql.Tx, eventID string, ) error { - stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt) - _, err := stmt.ExecContext(ctx, eventID) - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt) + _, err := stmt.ExecContext(ctx, eventID) + return err + }) } func (s *currentRoomStateStatements) UpsertRoomState( @@ -219,20 +225,22 @@ func (s *currentRoomStateStatements) UpsertRoomState( } // upsert state event - stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt) - _, err = stmt.ExecContext( - ctx, - event.RoomID(), - event.EventID(), - event.Type(), - event.Sender(), - containsURL, - *event.StateKey(), - headeredJSON, - membership, - addedAt, - ) - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt) + _, err := stmt.ExecContext( + ctx, + event.RoomID(), + event.EventID(), + event.Type(), + event.Sender(), + containsURL, + *event.StateKey(), + headeredJSON, + membership, + addedAt, + ) + return err + }) } func (s *currentRoomStateStatements) SelectEventsWithEventIDs( diff --git a/syncapi/storage/sqlite3/filter_table.go b/syncapi/storage/sqlite3/filter_table.go index 8b26759dc..3e8a46551 100644 --- a/syncapi/storage/sqlite3/filter_table.go +++ b/syncapi/storage/sqlite3/filter_table.go @@ -20,6 +20,7 @@ import ( "encoding/json" "fmt" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/gomatrixserverlib" ) @@ -50,6 +51,8 @@ const insertFilterSQL = "" + "INSERT INTO syncapi_filter (filter, localpart) VALUES ($1, $2)" type filterStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter selectFilterStmt *sql.Stmt selectFilterIDByContentStmt *sql.Stmt insertFilterStmt *sql.Stmt @@ -60,7 +63,10 @@ func NewSqliteFilterTable(db *sql.DB) (tables.Filter, error) { if err != nil { return nil, err } - s := &filterStatements{} + s := &filterStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), + } if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil { return nil, err } @@ -108,30 +114,33 @@ func (s *filterStatements) InsertFilter( return "", err } - // Check if filter already exists in the database using its localpart and content - // - // This can result in a race condition when two clients try to insert the - // same filter and localpart at the same time, however this is not a - // problem as both calls will result in the same filterID - err = s.selectFilterIDByContentStmt.QueryRowContext(ctx, - localpart, filterJSON).Scan(&existingFilterID) - if err != nil && err != sql.ErrNoRows { - return "", err - } - // If it does, return the existing ID - if existingFilterID != "" { - return existingFilterID, err - } + err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + // Check if filter already exists in the database using its localpart and content + // + // This can result in a race condition when two clients try to insert the + // same filter and localpart at the same time, however this is not a + // problem as both calls will result in the same filterID + err = s.selectFilterIDByContentStmt.QueryRowContext(ctx, + localpart, filterJSON).Scan(&existingFilterID) + if err != nil && err != sql.ErrNoRows { + return err + } + // If it does, return the existing ID + if existingFilterID != "" { + return nil + } - // Otherwise insert the filter and return the new ID - res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart) - if err != nil { - return "", err - } - rowid, err := res.LastInsertId() - if err != nil { - return "", err - } - filterID = fmt.Sprintf("%d", rowid) + // Otherwise insert the filter and return the new ID + res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart) + if err != nil { + return err + } + rowid, err := res.LastInsertId() + if err != nil { + return err + } + filterID = fmt.Sprintf("%d", rowid) + return nil + }) return } diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go index aa0513888..19e7a7c68 100644 --- a/syncapi/storage/sqlite3/invites_table.go +++ b/syncapi/storage/sqlite3/invites_table.go @@ -58,6 +58,8 @@ const selectMaxInviteIDSQL = "" + "SELECT MAX(id) FROM syncapi_invite_events" type inviteEventsStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter streamIDStatements *streamIDStatements insertInviteEventStmt *sql.Stmt selectInviteEventsInRangeStmt *sql.Stmt @@ -67,6 +69,8 @@ type inviteEventsStatements struct { func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) { s := &inviteEventsStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), streamIDStatements: streamID, } _, err := db.Exec(inviteEventsSchema) @@ -91,36 +95,45 @@ func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Inv func (s *inviteEventsStatements) InsertInviteEvent( ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent, ) (streamPos types.StreamPosition, err error) { - streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn) - if err != nil { - return - } + err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + var err error + streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn) + if err != nil { + return err + } - var headeredJSON []byte - headeredJSON, err = json.Marshal(inviteEvent) - if err != nil { - return - } + var headeredJSON []byte + headeredJSON, err = json.Marshal(inviteEvent) + if err != nil { + return err + } - _, err = txn.Stmt(s.insertInviteEventStmt).ExecContext( - ctx, - streamPos, - inviteEvent.RoomID(), - inviteEvent.EventID(), - *inviteEvent.StateKey(), - headeredJSON, - ) + _, err = txn.Stmt(s.insertInviteEventStmt).ExecContext( + ctx, + streamPos, + inviteEvent.RoomID(), + inviteEvent.EventID(), + *inviteEvent.StateKey(), + headeredJSON, + ) + return err + }) return } func (s *inviteEventsStatements) DeleteInviteEvent( ctx context.Context, inviteEventID string, ) (types.StreamPosition, error) { - streamPos, err := s.streamIDStatements.nextStreamID(ctx, nil) - if err != nil { - return streamPos, err - } - _, err = s.deleteInviteEventStmt.ExecContext(ctx, streamPos, inviteEventID) + var streamPos types.StreamPosition + err := s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + var err error + streamPos, err = s.streamIDStatements.nextStreamID(ctx, nil) + if err != nil { + return err + } + _, err = s.deleteInviteEventStmt.ExecContext(ctx, streamPos, inviteEventID) + return err + }) return streamPos, err } diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index da2ea3f69..12b4dbabe 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -104,6 +104,8 @@ const selectStateInRangeSQL = "" + " LIMIT $8" // limit type outputRoomEventsStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter streamIDStatements *streamIDStatements insertEventStmt *sql.Stmt selectEventsStmt *sql.Stmt @@ -117,6 +119,8 @@ type outputRoomEventsStatements struct { func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) { s := &outputRoomEventsStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), streamIDStatements: streamID, } _, err := db.Exec(outputRoomEventsSchema) @@ -155,8 +159,10 @@ func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event if err != nil { return err } - _, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID()) - return err + return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + _, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID()) + return err + }) } // selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos. @@ -267,7 +273,7 @@ func (s *outputRoomEventsStatements) InsertEvent( ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, addState, removeState []string, transactionID *api.TransactionID, excludeFromSync bool, -) (streamPos types.StreamPosition, err error) { +) (types.StreamPosition, error) { var txnID *string var sessionID *int64 if transactionID != nil { @@ -284,43 +290,47 @@ func (s *outputRoomEventsStatements) InsertEvent( } var headeredJSON []byte - headeredJSON, err = json.Marshal(event) + headeredJSON, err := json.Marshal(event) if err != nil { - return - } - - streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn) - if err != nil { - return + return 0, err } addStateJSON, err := json.Marshal(addState) if err != nil { - return + return 0, err } removeStateJSON, err := json.Marshal(removeState) if err != nil { - return + return 0, err } - insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) - _, err = insertStmt.ExecContext( - ctx, - streamPos, - event.RoomID(), - event.EventID(), - headeredJSON, - event.Type(), - event.Sender(), - containsURL, - string(addStateJSON), - string(removeStateJSON), - sessionID, - txnID, - excludeFromSync, - excludeFromSync, - ) - return + var streamPos types.StreamPosition + err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn) + if err != nil { + return err + } + + insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) + _, ierr := insertStmt.ExecContext( + ctx, + streamPos, + event.RoomID(), + event.EventID(), + headeredJSON, + event.Type(), + event.Sender(), + containsURL, + string(addStateJSON), + string(removeStateJSON), + sessionID, + txnID, + excludeFromSync, + excludeFromSync, + ) + return ierr + }) + return streamPos, err } func (s *outputRoomEventsStatements) SelectRecentEvents( diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go index 811dfa4f3..2e71e8f33 100644 --- a/syncapi/storage/sqlite3/output_room_events_topology_table.go +++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go @@ -66,6 +66,8 @@ const selectMaxPositionInTopologySQL = "" + " WHERE room_id = $1 ORDER BY stream_position DESC" type outputRoomEventsTopologyStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter insertEventInTopologyStmt *sql.Stmt selectEventIDsInRangeASCStmt *sql.Stmt selectEventIDsInRangeDESCStmt *sql.Stmt @@ -74,7 +76,10 @@ type outputRoomEventsTopologyStatements struct { } func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { - s := &outputRoomEventsTopologyStatements{} + s := &outputRoomEventsTopologyStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), + } _, err := db.Exec(outputRoomEventsTopologySchema) if err != nil { return nil, err @@ -102,11 +107,13 @@ func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { func (s *outputRoomEventsTopologyStatements) InsertEventInTopology( ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition, ) (err error) { - stmt := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt) - _, err = stmt.ExecContext( - ctx, event.EventID(), event.Depth(), event.RoomID(), pos, - ) - return + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt) + _, err := stmt.ExecContext( + ctx, event.EventID(), event.Depth(), event.RoomID(), pos, + ) + return err + }) } func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange( diff --git a/syncapi/storage/sqlite3/send_to_device_table.go b/syncapi/storage/sqlite3/send_to_device_table.go index 42bd3c19a..88b319fb3 100644 --- a/syncapi/storage/sqlite3/send_to_device_table.go +++ b/syncapi/storage/sqlite3/send_to_device_table.go @@ -72,13 +72,18 @@ const deleteSendToDeviceMessagesSQL = ` ` type sendToDeviceStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter insertSendToDeviceMessageStmt *sql.Stmt selectSendToDeviceMessagesStmt *sql.Stmt countSendToDeviceMessagesStmt *sql.Stmt } func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { - s := &sendToDeviceStatements{} + s := &sendToDeviceStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), + } _, err := db.Exec(sendToDeviceSchema) if err != nil { return nil, err @@ -98,8 +103,10 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { func (s *sendToDeviceStatements) InsertSendToDeviceMessage( ctx context.Context, txn *sql.Tx, userID, deviceID, content string, ) (err error) { - _, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content) - return + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + _, err := sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content) + return err + }) } func (s *sendToDeviceStatements) CountSendToDeviceMessages( @@ -156,8 +163,10 @@ func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages( for k, v := range nids { params[k+1] = v } - _, err = txn.ExecContext(ctx, query, params...) - return + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + _, err := txn.ExecContext(ctx, query, params...) + return err + }) } func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( @@ -168,6 +177,8 @@ func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( for k, v := range nids { params[k] = v } - _, err = txn.ExecContext(ctx, query, params...) - return + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + _, err := txn.ExecContext(ctx, query, params...) + return err + }) } diff --git a/syncapi/storage/sqlite3/stream_id_table.go b/syncapi/storage/sqlite3/stream_id_table.go index 57abd9c44..cf3eed5ba 100644 --- a/syncapi/storage/sqlite3/stream_id_table.go +++ b/syncapi/storage/sqlite3/stream_id_table.go @@ -27,11 +27,15 @@ const selectStreamIDStmt = "" + "SELECT stream_id FROM syncapi_stream_id WHERE stream_name = $1" type streamIDStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter increaseStreamIDStmt *sql.Stmt selectStreamIDStmt *sql.Stmt } func (s *streamIDStatements) prepare(db *sql.DB) (err error) { + s.db = db + s.writer = sqlutil.NewTransactionWriter() _, err = db.Exec(streamIDTableSchema) if err != nil { return @@ -48,11 +52,14 @@ func (s *streamIDStatements) prepare(db *sql.DB) (err error) { func (s *streamIDStatements) nextStreamID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt) - if _, err = increaseStmt.ExecContext(ctx, "global"); err != nil { - return - } - if err = selectStmt.QueryRowContext(ctx, "global").Scan(&pos); err != nil { - return - } + err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + if _, ierr := increaseStmt.ExecContext(ctx, "global"); err != nil { + return ierr + } + if serr := selectStmt.QueryRowContext(ctx, "global").Scan(&pos); err != nil { + return serr + } + return nil + }) return } diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index feacbc18c..474d3222b 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -5,6 +5,7 @@ import ( "crypto/ed25519" "encoding/json" "fmt" + "os" "testing" "time" @@ -52,7 +53,13 @@ func MustCreateEvent(t *testing.T, roomID string, prevs []gomatrixserverlib.Head } func MustCreateDatabase(t *testing.T) storage.Database { - db, err := sqlite3.NewDatabase("file::memory:") + dbname := fmt.Sprintf("test_%s.db", t.Name()) + if _, err := os.Stat(dbname); err == nil { + if err = os.Remove(dbname); err != nil { + t.Fatalf("tried to delete stale test database but failed: %s", err) + } + } + db, err := sqlite3.NewDatabase(fmt.Sprintf("file:%s", dbname)) if err != nil { t.Fatalf("NewSyncServerDatasource returned %s", err) } diff --git a/userapi/storage/accounts/sqlite3/account_data_table.go b/userapi/storage/accounts/sqlite3/account_data_table.go index d048dbd19..cb54412ab 100644 --- a/userapi/storage/accounts/sqlite3/account_data_table.go +++ b/userapi/storage/accounts/sqlite3/account_data_table.go @@ -18,6 +18,8 @@ import ( "context" "database/sql" "encoding/json" + + "github.com/matrix-org/dendrite/internal/sqlutil" ) const accountDataSchema = ` @@ -48,12 +50,16 @@ const selectAccountDataByTypeSQL = "" + "SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3" type accountDataStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter insertAccountDataStmt *sql.Stmt selectAccountDataStmt *sql.Stmt selectAccountDataByTypeStmt *sql.Stmt } func (s *accountDataStatements) prepare(db *sql.DB) (err error) { + s.db = db + s.writer = sqlutil.NewTransactionWriter() _, err = db.Exec(accountDataSchema) if err != nil { return @@ -73,8 +79,10 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) { func (s *accountDataStatements) insertAccountData( ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, ) (err error) { - _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content) - return + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + _, err := txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content) + return err + }) } func (s *accountDataStatements) selectAccountData( diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/accounts/sqlite3/accounts_table.go index 768f536dd..27c3d845a 100644 --- a/userapi/storage/accounts/sqlite3/accounts_table.go +++ b/userapi/storage/accounts/sqlite3/accounts_table.go @@ -20,6 +20,7 @@ import ( "time" "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" @@ -57,6 +58,8 @@ const selectNewNumericLocalpartSQL = "" + // TODO: Update password type accountsStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter insertAccountStmt *sql.Stmt selectAccountByLocalpartStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt @@ -65,6 +68,8 @@ type accountsStatements struct { } func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { + s.db = db + s.writer = sqlutil.NewTransactionWriter() _, err = db.Exec(accountsSchema) if err != nil { return @@ -94,12 +99,15 @@ func (s *accountsStatements) insertAccount( createdTimeMS := time.Now().UnixNano() / 1000000 stmt := s.insertAccountStmt - var err error - if appserviceID == "" { - _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil) - } else { - _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) - } + err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + var err error + if appserviceID == "" { + _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil) + } else { + _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) + } + return err + }) if err != nil { return nil, err } diff --git a/userapi/storage/accounts/sqlite3/profile_table.go b/userapi/storage/accounts/sqlite3/profile_table.go index 9b5192a02..68cea516d 100644 --- a/userapi/storage/accounts/sqlite3/profile_table.go +++ b/userapi/storage/accounts/sqlite3/profile_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/internal/sqlutil" ) const profilesSchema = ` @@ -46,6 +47,8 @@ const setDisplayNameSQL = "" + "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2" type profilesStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter insertProfileStmt *sql.Stmt selectProfileByLocalpartStmt *sql.Stmt setAvatarURLStmt *sql.Stmt @@ -53,6 +56,8 @@ type profilesStatements struct { } func (s *profilesStatements) prepare(db *sql.DB) (err error) { + s.db = db + s.writer = sqlutil.NewTransactionWriter() _, err = db.Exec(profilesSchema) if err != nil { return @@ -75,8 +80,10 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) { func (s *profilesStatements) insertProfile( ctx context.Context, txn *sql.Tx, localpart string, ) (err error) { - _, err = txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "") - return + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + _, err := txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "") + return err + }) } func (s *profilesStatements) selectProfileByLocalpart( diff --git a/userapi/storage/accounts/sqlite3/threepid_table.go b/userapi/storage/accounts/sqlite3/threepid_table.go index 0200dee7f..0104e8346 100644 --- a/userapi/storage/accounts/sqlite3/threepid_table.go +++ b/userapi/storage/accounts/sqlite3/threepid_table.go @@ -53,6 +53,8 @@ const deleteThreePIDSQL = "" + "DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2" type threepidStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter selectLocalpartForThreePIDStmt *sql.Stmt selectThreePIDsForLocalpartStmt *sql.Stmt insertThreePIDStmt *sql.Stmt @@ -60,6 +62,8 @@ type threepidStatements struct { } func (s *threepidStatements) prepare(db *sql.DB) (err error) { + s.db = db + s.writer = sqlutil.NewTransactionWriter() _, err = db.Exec(threepidSchema) if err != nil { return @@ -118,13 +122,18 @@ func (s *threepidStatements) selectThreePIDsForLocalpart( func (s *threepidStatements) insertThreePID( ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, ) (err error) { - stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) - _, err = stmt.ExecContext(ctx, threepid, medium, localpart) - return + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) + _, err := stmt.ExecContext(ctx, threepid, medium, localpart) + return err + }) } func (s *threepidStatements) deleteThreePID( ctx context.Context, threepid string, medium string) (err error) { - _, err = s.deleteThreePIDStmt.ExecContext(ctx, threepid, medium) - return + return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt) + _, err := stmt.ExecContext(ctx, threepid, medium) + return err + }) } diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go index 07ea5dca3..ec52c64bc 100644 --- a/userapi/storage/devices/sqlite3/devices_table.go +++ b/userapi/storage/devices/sqlite3/devices_table.go @@ -74,6 +74,7 @@ const deleteDevicesSQL = "" + type devicesStatements struct { db *sql.DB + writer *sqlutil.TransactionWriter insertDeviceStmt *sql.Stmt selectDevicesCountStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt @@ -87,6 +88,7 @@ type devicesStatements struct { func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { s.db = db + s.writer = sqlutil.NewTransactionWriter() _, err = db.Exec(devicesSchema) if err != nil { return @@ -128,13 +130,19 @@ func (s *devicesStatements) insertDevice( ) (*api.Device, error) { createdTimeMS := time.Now().UnixNano() / 1000000 var sessionID int64 - countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt) - insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt) - if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil { - return nil, err - } - sessionID++ - if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil { + err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt) + insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt) + if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil { + return err + } + sessionID++ + if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil { + return err + } + return nil + }) + if err != nil { return nil, err } return &api.Device{ @@ -148,9 +156,11 @@ func (s *devicesStatements) insertDevice( func (s *devicesStatements) deleteDevice( ctx context.Context, txn *sql.Tx, id, localpart string, ) error { - stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) - _, err := stmt.ExecContext(ctx, id, localpart) - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) + _, err := stmt.ExecContext(ctx, id, localpart) + return err + }) } func (s *devicesStatements) deleteDevices( @@ -161,31 +171,37 @@ func (s *devicesStatements) deleteDevices( if err != nil { return err } - stmt := sqlutil.TxStmt(txn, prep) - params := make([]interface{}, len(devices)+1) - params[0] = localpart - for i, v := range devices { - params[i+1] = v - } - params = append(params, params...) - _, err = stmt.ExecContext(ctx, params...) - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, prep) + params := make([]interface{}, len(devices)+1) + params[0] = localpart + for i, v := range devices { + params[i+1] = v + } + params = append(params, params...) + _, err = stmt.ExecContext(ctx, params...) + return err + }) } func (s *devicesStatements) deleteDevicesByLocalpart( ctx context.Context, txn *sql.Tx, localpart string, ) error { - stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) - _, err := stmt.ExecContext(ctx, localpart) - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) + _, err := stmt.ExecContext(ctx, localpart) + return err + }) } func (s *devicesStatements) updateDeviceName( ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, ) error { - stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) - _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID) - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) + _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID) + return err + }) } func (s *devicesStatements) selectDeviceByToken(