diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 320792914..e1312671b 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -281,16 +281,16 @@ func (d *Database) WriteEvent( ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, ) if err != nil { - return err + return fmt.Errorf("d.OutputEvents.InsertEvent: %w", err) } pduPosition = pos if err = d.Topology.InsertEventInTopology(ctx, txn, ev, pos); err != nil { - return err + return fmt.Errorf("d.Topology.InsertEventInTopology: %w", err) } if err = d.handleBackwardExtremities(ctx, txn, ev); err != nil { - return err + return fmt.Errorf("d.handleBackwardExtremities: %w", err) } if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 { @@ -313,7 +313,7 @@ func (d *Database) updateRoomState( // remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add. for _, eventID := range removedEventIDs { if err := d.CurrentRoomState.DeleteRoomStateByEventID(ctx, txn, eventID); err != nil { - return err + return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateByEventID: %w", err) } } @@ -326,13 +326,13 @@ func (d *Database) updateRoomState( if event.Type() == "m.room.member" { value, err := event.Membership() if err != nil { - return err + return fmt.Errorf("event.Membership: %w", err) } membership = &value } if err := d.CurrentRoomState.UpsertRoomState(ctx, txn, event, membership, pduPosition); err != nil { - return err + return fmt.Errorf("d.CurrentRoomState.UpsertRoomState: %w", err) } } diff --git a/syncapi/storage/sqlite3/backwards_extremities_table.go b/syncapi/storage/sqlite3/backwards_extremities_table.go index 25a3047f1..1aeb041f4 100644 --- a/syncapi/storage/sqlite3/backwards_extremities_table.go +++ b/syncapi/storage/sqlite3/backwards_extremities_table.go @@ -79,7 +79,7 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string, ) (err error) { - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { _, err := txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID) return err }) @@ -110,7 +110,7 @@ func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom( func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( ctx context.Context, txn *sql.Tx, roomID, knownEventID string, ) (err error) { - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { _, err := txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) return err }) diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index 7e579df26..08b42f5b4 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -200,7 +200,7 @@ func (s *currentRoomStateStatements) SelectCurrentState( func (s *currentRoomStateStatements) DeleteRoomStateByEventID( ctx context.Context, txn *sql.Tx, eventID string, ) error { - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt) _, err := stmt.ExecContext(ctx, eventID) return err @@ -225,7 +225,7 @@ func (s *currentRoomStateStatements) UpsertRoomState( } // upsert state event - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt) _, err := stmt.ExecContext( ctx, diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go index 0ad038cf8..19e7a7c68 100644 --- a/syncapi/storage/sqlite3/invites_table.go +++ b/syncapi/storage/sqlite3/invites_table.go @@ -95,7 +95,7 @@ func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Inv func (s *inviteEventsStatements) InsertInviteEvent( ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent, ) (streamPos types.StreamPosition, err error) { - err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error { var err error streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn) if err != nil { diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index 16d155834..12b4dbabe 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -295,11 +295,6 @@ func (s *outputRoomEventsStatements) InsertEvent( return 0, err } - streamPos, err := s.streamIDStatements.nextStreamID(ctx, txn) - if err != nil { - return 0, err - } - addStateJSON, err := json.Marshal(addState) if err != nil { return 0, err @@ -309,7 +304,13 @@ func (s *outputRoomEventsStatements) InsertEvent( return 0, err } - err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + var streamPos types.StreamPosition + err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn) + if err != nil { + return err + } + insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) _, ierr := insertStmt.ExecContext( ctx, diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go index eb3dcaca6..2e71e8f33 100644 --- a/syncapi/storage/sqlite3/output_room_events_topology_table.go +++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go @@ -107,7 +107,7 @@ func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { func (s *outputRoomEventsTopologyStatements) InsertEventInTopology( ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition, ) (err error) { - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { stmt := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt) _, err := stmt.ExecContext( ctx, event.EventID(), event.Depth(), event.RoomID(), pos, diff --git a/syncapi/storage/sqlite3/send_to_device_table.go b/syncapi/storage/sqlite3/send_to_device_table.go index 317511db6..88b319fb3 100644 --- a/syncapi/storage/sqlite3/send_to_device_table.go +++ b/syncapi/storage/sqlite3/send_to_device_table.go @@ -103,7 +103,7 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { func (s *sendToDeviceStatements) InsertSendToDeviceMessage( ctx context.Context, txn *sql.Tx, userID, deviceID, content string, ) (err error) { - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { _, err := sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content) return err }) @@ -163,7 +163,7 @@ func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages( for k, v := range nids { params[k+1] = v } - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { _, err := txn.ExecContext(ctx, query, params...) return err }) @@ -177,7 +177,7 @@ func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( for k, v := range nids { params[k] = v } - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { _, err := txn.ExecContext(ctx, query, params...) return err }) diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index feacbc18c..474d3222b 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -5,6 +5,7 @@ import ( "crypto/ed25519" "encoding/json" "fmt" + "os" "testing" "time" @@ -52,7 +53,13 @@ func MustCreateEvent(t *testing.T, roomID string, prevs []gomatrixserverlib.Head } func MustCreateDatabase(t *testing.T) storage.Database { - db, err := sqlite3.NewDatabase("file::memory:") + dbname := fmt.Sprintf("test_%s.db", t.Name()) + if _, err := os.Stat(dbname); err == nil { + if err = os.Remove(dbname); err != nil { + t.Fatalf("tried to delete stale test database but failed: %s", err) + } + } + db, err := sqlite3.NewDatabase(fmt.Sprintf("file:%s", dbname)) if err != nil { t.Fatalf("NewSyncServerDatasource returned %s", err) }