diff --git a/syncapi/storage/postgres/output_room_events_topology_table.go b/syncapi/storage/postgres/output_room_events_topology_table.go index b84c734ef..48ebcf251 100644 --- a/syncapi/storage/postgres/output_room_events_topology_table.go +++ b/syncapi/storage/postgres/output_room_events_topology_table.go @@ -115,7 +115,7 @@ func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) { // InsertEventInTopology inserts the given event in the room's topology, based // on the event's depth. func (s *outputRoomEventsTopologyStatements) InsertEventInTopology( - ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition, + ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition, ) (err error) { _, err = s.insertEventInTopologyStmt.ExecContext( ctx, event.EventID(), event.Depth(), event.RoomID(), pos, @@ -127,7 +127,7 @@ func (s *outputRoomEventsTopologyStatements) InsertEventInTopology( // given range in a given room's topological order. // Returns an empty slice if no events match the given range. func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange( - ctx context.Context, roomID string, fromPos, toPos, toMicroPos types.StreamPosition, + ctx context.Context, txn *sql.Tx, roomID string, fromPos, toPos, toMicroPos types.StreamPosition, limit int, chronologicalOrder bool, ) (eventIDs []string, err error) { // Decide on the selection's order according to whether chronological order @@ -164,14 +164,14 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange( // SelectPositionInTopology returns the position of a given event in the // topology of the room it belongs to. func (s *outputRoomEventsTopologyStatements) SelectPositionInTopology( - ctx context.Context, eventID string, + ctx context.Context, txn *sql.Tx, eventID string, ) (pos, spos types.StreamPosition, err error) { err = s.selectPositionInTopologyStmt.QueryRowContext(ctx, eventID).Scan(&pos, &spos) return } func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology( - ctx context.Context, roomID string, + ctx context.Context, txn *sql.Tx, roomID string, ) (pos types.StreamPosition, spos types.StreamPosition, err error) { err = s.selectMaxPositionInTopologyStmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos) return @@ -180,7 +180,7 @@ func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology( // SelectEventIDsFromPosition returns the IDs of all events that have a given // position in the topology of a given room. func (s *outputRoomEventsTopologyStatements) SelectEventIDsFromPosition( - ctx context.Context, roomID string, pos types.StreamPosition, + ctx context.Context, txn *sql.Tx, roomID string, pos types.StreamPosition, ) (eventIDs []string, err error) { // Query the event IDs. rows, err := s.selectEventIDsFromPositionStmt.QueryContext(ctx, roomID, pos) diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index f74381fe3..42d6b7e1e 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -260,7 +260,7 @@ func (d *Database) WriteEvent( } pduPosition = pos - if err = d.Topology.InsertEventInTopology(ctx, ev, pos); err != nil { + if err = d.Topology.InsertEventInTopology(ctx, nil, ev, pos); err != nil { return err } @@ -337,7 +337,7 @@ func (d *Database) GetEventsInTopologicalRange( // Select the event IDs from the defined range. var eIDs []string eIDs, err = d.Topology.SelectEventIDsInRange( - ctx, roomID, backwardLimit, forwardLimit, forwardMicroLimit, limit, !backwardOrdering, + ctx, nil, roomID, backwardLimit, forwardLimit, forwardMicroLimit, limit, !backwardOrdering, ) if err != nil { return @@ -361,13 +361,13 @@ func (d *Database) BackwardExtremitiesForRoom( func (d *Database) MaxTopologicalPosition( ctx context.Context, roomID string, ) (depth types.StreamPosition, stream types.StreamPosition, err error) { - return d.Topology.SelectMaxPositionInTopology(ctx, roomID) + return d.Topology.SelectMaxPositionInTopology(ctx, nil, roomID) } func (d *Database) EventsAtTopologicalPosition( ctx context.Context, roomID string, pos types.StreamPosition, ) ([]types.StreamEvent, error) { - eIDs, err := d.Topology.SelectEventIDsFromPosition(ctx, roomID, pos) + eIDs, err := d.Topology.SelectEventIDsFromPosition(ctx, nil, roomID, pos) if err != nil { return nil, err } @@ -378,7 +378,7 @@ func (d *Database) EventsAtTopologicalPosition( func (d *Database) EventPositionInTopology( ctx context.Context, eventID string, ) (depth types.StreamPosition, stream types.StreamPosition, err error) { - return d.Topology.SelectPositionInTopology(ctx, eventID) + return d.Topology.SelectPositionInTopology(ctx, nil, eventID) } func (d *Database) syncPositionTx( @@ -618,7 +618,7 @@ func (d *Database) getResponseWithPDUsForCompleteSync( var prevBatchStr string if len(recentStreamEvents) > 0 { var backwardTopologyPos, backwardStreamPos types.StreamPosition - backwardTopologyPos, backwardStreamPos, err = d.Topology.SelectPositionInTopology(ctx, recentStreamEvents[0].EventID()) + backwardTopologyPos, backwardStreamPos, err = d.Topology.SelectPositionInTopology(ctx, nil, recentStreamEvents[0].EventID()) if err != nil { return } @@ -704,7 +704,7 @@ func (d *Database) getBackwardTopologyPos( events []types.StreamEvent, ) (pos, spos types.StreamPosition) { if len(events) > 0 { - pos, spos, _ = d.Topology.SelectPositionInTopology(ctx, events[0].EventID()) + pos, spos, _ = d.Topology.SelectPositionInTopology(ctx, nil, events[0].EventID()) } if pos-1 <= 0 { pos = types.StreamPosition(1) diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go index 0d313d7c6..4469f5b76 100644 --- a/syncapi/storage/sqlite3/output_room_events_topology_table.go +++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -77,35 +78,36 @@ type outputRoomEventsTopologyStatements struct { selectEventIDsFromPositionStmt *sql.Stmt } -func (s *outputRoomEventsTopologyStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(outputRoomEventsTopologySchema) +func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { + s := &outputRoomEventsTopologyStatements{} + _, err := db.Exec(outputRoomEventsTopologySchema) if err != nil { - return + return nil, err } if s.insertEventInTopologyStmt, err = db.Prepare(insertEventInTopologySQL); err != nil { - return + return nil, err } if s.selectEventIDsInRangeASCStmt, err = db.Prepare(selectEventIDsInRangeASCSQL); err != nil { - return + return nil, err } if s.selectEventIDsInRangeDESCStmt, err = db.Prepare(selectEventIDsInRangeDESCSQL); err != nil { - return + return nil, err } if s.selectPositionInTopologyStmt, err = db.Prepare(selectPositionInTopologySQL); err != nil { - return + return nil, err } if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil { - return + return nil, err } if s.selectEventIDsFromPositionStmt, err = db.Prepare(selectEventIDsFromPositionSQL); err != nil { - return + return nil, err } - return + return s, nil } // insertEventInTopology inserts the given event in the room's topology, based // on the event's depth. -func (s *outputRoomEventsTopologyStatements) insertEventInTopology( +func (s *outputRoomEventsTopologyStatements) InsertEventInTopology( ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition, ) (err error) { stmt := common.TxStmt(txn, s.insertEventInTopologyStmt) @@ -118,7 +120,7 @@ func (s *outputRoomEventsTopologyStatements) insertEventInTopology( // selectEventIDsInRange selects the IDs of events which positions are within a // given range in a given room's topological order. // Returns an empty slice if no events match the given range. -func (s *outputRoomEventsTopologyStatements) selectEventIDsInRange( +func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange( ctx context.Context, txn *sql.Tx, roomID string, fromPos, toPos, toMicroPos types.StreamPosition, limit int, chronologicalOrder bool, @@ -155,7 +157,7 @@ func (s *outputRoomEventsTopologyStatements) selectEventIDsInRange( // selectPositionInTopology returns the position of a given event in the // topology of the room it belongs to. -func (s *outputRoomEventsTopologyStatements) selectPositionInTopology( +func (s *outputRoomEventsTopologyStatements) SelectPositionInTopology( ctx context.Context, txn *sql.Tx, eventID string, ) (pos types.StreamPosition, spos types.StreamPosition, err error) { stmt := common.TxStmt(txn, s.selectPositionInTopologyStmt) @@ -163,7 +165,7 @@ func (s *outputRoomEventsTopologyStatements) selectPositionInTopology( return } -func (s *outputRoomEventsTopologyStatements) selectMaxPositionInTopology( +func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology( ctx context.Context, txn *sql.Tx, roomID string, ) (pos types.StreamPosition, spos types.StreamPosition, err error) { stmt := common.TxStmt(txn, s.selectMaxPositionInTopologyStmt) @@ -173,7 +175,7 @@ func (s *outputRoomEventsTopologyStatements) selectMaxPositionInTopology( // selectEventIDsFromPosition returns the IDs of all events that have a given // position in the topology of a given room. -func (s *outputRoomEventsTopologyStatements) selectEventIDsFromPosition( +func (s *outputRoomEventsTopologyStatements) SelectEventIDsFromPosition( ctx context.Context, txn *sql.Tx, roomID string, pos types.StreamPosition, ) (eventIDs []string, err error) { // Query the event IDs. diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index 182cbb2d7..01aaf31da 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -55,7 +55,6 @@ type SyncServerDatasource struct { db *sql.DB common.PartitionOffsetStatements streamID streamIDStatements - topology outputRoomEventsTopologyStatements } // NewSyncServerDatasource creates a new sync server database @@ -106,7 +105,8 @@ func (d *SyncServerDatasource) prepare() (err error) { if err != nil { return err } - if err = d.topology.prepare(d.db); err != nil { + topology, err := NewSqliteTopologyTable(d.db) + if err != nil { return err } bwExtrem, err := NewSqliteBackwardsExtremitiesTable(d.db) @@ -120,6 +120,7 @@ func (d *SyncServerDatasource) prepare() (err error) { OutputEvents: events, BackwardExtremities: bwExtrem, CurrentRoomState: roomState, + Topology: topology, EDUCache: cache.New(), } return nil @@ -179,7 +180,7 @@ func (d *SyncServerDatasource) WriteEvent( } pduPosition = pos - if err = d.topology.insertEventInTopology(ctx, txn, ev, pos); err != nil { + if err = d.Database.Topology.InsertEventInTopology(ctx, txn, ev, pos); err != nil { return err } @@ -271,7 +272,7 @@ func (d *SyncServerDatasource) GetEventsInTopologicalRange( // Select the event IDs from the defined range. var eIDs []string - eIDs, err = d.topology.selectEventIDsInRange( + eIDs, err = d.Database.Topology.SelectEventIDsInRange( ctx, nil, roomID, backwardLimit, forwardLimit, forwardMicroLimit, limit, !backwardOrdering, ) if err != nil { @@ -288,7 +289,7 @@ func (d *SyncServerDatasource) GetEventsInTopologicalRange( func (d *SyncServerDatasource) MaxTopologicalPosition( ctx context.Context, roomID string, ) (types.StreamPosition, types.StreamPosition, error) { - return d.topology.selectMaxPositionInTopology(ctx, nil, roomID) + return d.Database.Topology.SelectMaxPositionInTopology(ctx, nil, roomID) } // EventsAtTopologicalPosition returns all of the events matching a given @@ -296,7 +297,7 @@ func (d *SyncServerDatasource) MaxTopologicalPosition( func (d *SyncServerDatasource) EventsAtTopologicalPosition( ctx context.Context, roomID string, pos types.StreamPosition, ) ([]types.StreamEvent, error) { - eIDs, err := d.topology.selectEventIDsFromPosition(ctx, nil, roomID, pos) + eIDs, err := d.Database.Topology.SelectEventIDsFromPosition(ctx, nil, roomID, pos) if err != nil { return nil, err } @@ -307,7 +308,7 @@ func (d *SyncServerDatasource) EventsAtTopologicalPosition( func (d *SyncServerDatasource) EventPositionInTopology( ctx context.Context, eventID string, ) (depth types.StreamPosition, stream types.StreamPosition, err error) { - return d.topology.selectPositionInTopology(ctx, nil, eventID) + return d.Database.Topology.SelectPositionInTopology(ctx, nil, eventID) } // SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet. @@ -591,7 +592,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( var prevBatchStr string if len(recentStreamEvents) > 0 { var backwardTopologyPos, backwardStreamPos types.StreamPosition - backwardTopologyPos, backwardStreamPos, err = d.topology.selectPositionInTopology(ctx, txn, recentStreamEvents[0].EventID()) + backwardTopologyPos, backwardStreamPos, err = d.Database.Topology.SelectPositionInTopology(ctx, txn, recentStreamEvents[0].EventID()) if err != nil { return } @@ -678,7 +679,7 @@ func (d *SyncServerDatasource) getBackwardTopologyPos( events []types.StreamEvent, ) (pos, spos types.StreamPosition) { if len(events) > 0 { - pos, spos, _ = d.topology.selectPositionInTopology(ctx, txn, events[0].EventID()) + pos, spos, _ = d.Database.Topology.SelectPositionInTopology(ctx, txn, events[0].EventID()) } // go to the previous position so we don't pull out the same event twice // FIXME: This could be done more nicely by being explicit with inclusive/exclusive rules diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 5c3ef9cc6..8f0b8b895 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -34,18 +34,18 @@ type Events interface { type Topology interface { // InsertEventInTopology inserts the given event in the room's topology, based // on the event's depth. - InsertEventInTopology(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition) (err error) + InsertEventInTopology(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition) (err error) // SelectEventIDsInRange selects the IDs of events which positions are within a // given range in a given room's topological order. // Returns an empty slice if no events match the given range. - SelectEventIDsInRange(ctx context.Context, roomID string, fromPos, toPos, toMicroPos types.StreamPosition, limit int, chronologicalOrder bool) (eventIDs []string, err error) + SelectEventIDsInRange(ctx context.Context, txn *sql.Tx, roomID string, fromPos, toPos, toMicroPos types.StreamPosition, limit int, chronologicalOrder bool) (eventIDs []string, err error) // SelectPositionInTopology returns the position of a given event in the // topology of the room it belongs to. - SelectPositionInTopology(ctx context.Context, eventID string) (pos, spos types.StreamPosition, err error) - SelectMaxPositionInTopology(ctx context.Context, roomID string) (pos types.StreamPosition, spos types.StreamPosition, err error) + SelectPositionInTopology(ctx context.Context, txn *sql.Tx, eventID string) (pos, spos types.StreamPosition, err error) + SelectMaxPositionInTopology(ctx context.Context, txn *sql.Tx, roomID string) (pos types.StreamPosition, spos types.StreamPosition, err error) // SelectEventIDsFromPosition returns the IDs of all events that have a given // position in the topology of a given room. - SelectEventIDsFromPosition(ctx context.Context, roomID string, pos types.StreamPosition) (eventIDs []string, err error) + SelectEventIDsFromPosition(ctx context.Context, txn *sql.Tx, roomID string, pos types.StreamPosition) (eventIDs []string, err error) } type CurrentRoomState interface {