diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index 769823e92..0b53dfa9e 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -94,9 +94,6 @@ const selectEarlyEventsSQL = "" + " WHERE room_id = $1 AND id > $2 AND id <= $3" + " ORDER BY id ASC LIMIT $4" -const selectStreamPositionForEventIDSQL = "" + - "SELECT id FROM syncapi_output_room_events WHERE event_id = $1" - const selectMaxEventIDSQL = "" + "SELECT MAX(id) FROM syncapi_output_room_events" @@ -114,14 +111,13 @@ const selectStateInRangeSQL = "" + " LIMIT $8" type outputRoomEventsStatements struct { - insertEventStmt *sql.Stmt - selectEventsStmt *sql.Stmt - selectMaxEventIDStmt *sql.Stmt - selectRecentEventsStmt *sql.Stmt - selectRecentEventsForSyncStmt *sql.Stmt - selectEarlyEventsStmt *sql.Stmt - selectStateInRangeStmt *sql.Stmt - selectStreamPositionForEventIDStmt *sql.Stmt + insertEventStmt *sql.Stmt + selectEventsStmt *sql.Stmt + selectMaxEventIDStmt *sql.Stmt + selectRecentEventsStmt *sql.Stmt + selectRecentEventsForSyncStmt *sql.Stmt + selectEarlyEventsStmt *sql.Stmt + selectStateInRangeStmt *sql.Stmt } func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) { @@ -150,18 +146,9 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) { if s.selectStateInRangeStmt, err = db.Prepare(selectStateInRangeSQL); err != nil { return } - if s.selectStreamPositionForEventIDStmt, err = db.Prepare(selectStreamPositionForEventIDSQL); err != nil { - return - } return } -func (s *outputRoomEventsStatements) selectStreamPositionForEventID(ctx context.Context, eventID string) (types.StreamPosition, error) { - var id int64 - err := s.selectStreamPositionForEventIDStmt.QueryRowContext(ctx, eventID).Scan(&id) - return types.StreamPosition(id), err -} - // selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos. // Results are bucketed based on the room ID. If the same state is overwritten multiple times between the // two positions, only the most recent state is returned. diff --git a/syncapi/storage/postgres/output_room_events_topology_table.go b/syncapi/storage/postgres/output_room_events_topology_table.go index f77365c8d..51cbd50d0 100644 --- a/syncapi/storage/postgres/output_room_events_topology_table.go +++ b/syncapi/storage/postgres/output_room_events_topology_table.go @@ -60,7 +60,7 @@ const selectEventIDsInRangeDESCSQL = "" + " ORDER BY topological_position DESC, stream_position DESC LIMIT $6" const selectPositionInTopologySQL = "" + - "SELECT topological_position FROM syncapi_output_room_events_topology" + + "SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" + " WHERE event_id = $1" // Select the max topological position for the room, then sort by stream position and take the highest, @@ -163,8 +163,8 @@ func (s *outputRoomEventsTopologyStatements) selectEventIDsInRange( // topology of the room it belongs to. func (s *outputRoomEventsTopologyStatements) selectPositionInTopology( ctx context.Context, eventID string, -) (pos types.StreamPosition, err error) { - err = s.selectPositionInTopologyStmt.QueryRowContext(ctx, eventID).Scan(&pos) +) (pos, spos types.StreamPosition, err error) { + err = s.selectPositionInTopologyStmt.QueryRowContext(ctx, eventID).Scan(&pos, &spos) return } diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go index 744ae7b8d..a6de15178 100644 --- a/syncapi/storage/postgres/syncserver.go +++ b/syncapi/storage/postgres/syncserver.go @@ -320,12 +320,7 @@ func (d *SyncServerDatasource) EventsAtTopologicalPosition( func (d *SyncServerDatasource) EventPositionInTopology( ctx context.Context, eventID string, ) (depth types.StreamPosition, stream types.StreamPosition, err error) { - depth, err = d.topology.selectPositionInTopology(ctx, eventID) - if err != nil { - return - } - stream, err = d.events.selectStreamPositionForEventID(ctx, eventID) - return + return d.topology.selectPositionInTopology(ctx, eventID) } func (d *SyncServerDatasource) SyncStreamPosition(ctx context.Context) (types.StreamPosition, error) { @@ -591,8 +586,8 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( // Retrieve the backward topology position, i.e. the position of the // oldest event in the room's topology. - var backwardTopologyPos types.StreamPosition - backwardTopologyPos, err = d.topology.selectPositionInTopology(ctx, recentStreamEvents[0].EventID()) + var backwardTopologyPos, backwardStreamPos types.StreamPosition + backwardTopologyPos, backwardStreamPos, err = d.topology.selectPositionInTopology(ctx, recentStreamEvents[0].EventID()) if backwardTopologyPos-1 <= 0 { backwardTopologyPos = types.StreamPosition(1) } else { @@ -605,7 +600,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( stateEvents = removeDuplicates(stateEvents, recentEvents) jr := types.NewJoinResponse() jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( - types.PaginationTokenTypeTopology, backwardTopologyPos, 0, + types.PaginationTokenTypeTopology, backwardTopologyPos, backwardStreamPos, ).String() jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Limited = true @@ -720,9 +715,9 @@ func (d *SyncServerDatasource) addInvitesToResponse( func (d *SyncServerDatasource) getBackwardTopologyPos( ctx context.Context, events []types.StreamEvent, -) (pos types.StreamPosition) { +) (pos, spos types.StreamPosition) { if len(events) > 0 { - pos, _ = d.topology.selectPositionInTopology(ctx, events[0].EventID()) + pos, spos, _ = d.topology.selectPositionInTopology(ctx, events[0].EventID()) } if pos-1 <= 0 { pos = types.StreamPosition(1) @@ -761,14 +756,14 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse( } recentEvents := d.StreamEventsToEvents(device, recentStreamEvents) delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back - backwardTopologyPos := d.getBackwardTopologyPos(ctx, recentStreamEvents) + backwardTopologyPos, backwardStreamPos := d.getBackwardTopologyPos(ctx, recentStreamEvents) switch delta.membership { case gomatrixserverlib.Join: jr := types.NewJoinResponse() jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( - types.PaginationTokenTypeTopology, backwardTopologyPos, 0, + types.PaginationTokenTypeTopology, backwardTopologyPos, backwardStreamPos, ).String() jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true @@ -781,7 +776,7 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse( // no longer in the room. lr := types.NewLeaveResponse() lr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( - types.PaginationTokenTypeTopology, backwardTopologyPos, 0, + types.PaginationTokenTypeTopology, backwardTopologyPos, backwardStreamPos, ).String() lr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index 83d7940ad..08299f64b 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -74,9 +74,6 @@ const selectEarlyEventsSQL = "" + const selectMaxEventIDSQL = "" + "SELECT MAX(id) FROM syncapi_output_room_events" -const selectStreamPositionForEventIDSQL = "" + - "SELECT id FROM syncapi_output_room_events WHERE event_id = $1" - // In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id). /* $1 = oldPos, @@ -102,15 +99,14 @@ const selectStateInRangeSQL = "" + " LIMIT $8" // limit type outputRoomEventsStatements struct { - streamIDStatements *streamIDStatements - insertEventStmt *sql.Stmt - selectEventsStmt *sql.Stmt - selectMaxEventIDStmt *sql.Stmt - selectRecentEventsStmt *sql.Stmt - selectRecentEventsForSyncStmt *sql.Stmt - selectEarlyEventsStmt *sql.Stmt - selectStateInRangeStmt *sql.Stmt - selectStreamPositionForEventIDStmt *sql.Stmt + streamIDStatements *streamIDStatements + insertEventStmt *sql.Stmt + selectEventsStmt *sql.Stmt + selectMaxEventIDStmt *sql.Stmt + selectRecentEventsStmt *sql.Stmt + selectRecentEventsForSyncStmt *sql.Stmt + selectEarlyEventsStmt *sql.Stmt + selectStateInRangeStmt *sql.Stmt } func (s *outputRoomEventsStatements) prepare(db *sql.DB, streamID *streamIDStatements) (err error) { @@ -140,18 +136,9 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB, streamID *streamIDState if s.selectStateInRangeStmt, err = db.Prepare(selectStateInRangeSQL); err != nil { return } - if s.selectStreamPositionForEventIDStmt, err = db.Prepare(selectStreamPositionForEventIDSQL); err != nil { - return - } return } -func (s *outputRoomEventsStatements) selectStreamPositionForEventID(ctx context.Context, eventID string) (types.StreamPosition, error) { - var id int64 - err := s.selectStreamPositionForEventIDStmt.QueryRowContext(ctx, eventID).Scan(&id) - return types.StreamPosition(id), err -} - // selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos. // Results are bucketed based on the room ID. If the same state is overwritten multiple times between the // two positions, only the most recent state is returned. diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go index d64894663..0d313d7c6 100644 --- a/syncapi/storage/sqlite3/output_room_events_topology_table.go +++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go @@ -57,7 +57,7 @@ const selectEventIDsInRangeDESCSQL = "" + " ORDER BY topological_position DESC, stream_position DESC LIMIT $6" const selectPositionInTopologySQL = "" + - "SELECT topological_position FROM syncapi_output_room_events_topology" + + "SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" + " WHERE event_id = $1" const selectMaxPositionInTopologySQL = "" + @@ -157,9 +157,9 @@ func (s *outputRoomEventsTopologyStatements) selectEventIDsInRange( // topology of the room it belongs to. func (s *outputRoomEventsTopologyStatements) selectPositionInTopology( ctx context.Context, txn *sql.Tx, eventID string, -) (pos types.StreamPosition, err error) { +) (pos types.StreamPosition, spos types.StreamPosition, err error) { stmt := common.TxStmt(txn, s.selectPositionInTopologyStmt) - err = stmt.QueryRowContext(ctx, eventID).Scan(&pos) + err = stmt.QueryRowContext(ctx, eventID).Scan(&pos, &spos) return } diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index bdf943e08..7e8e4ff5d 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -374,12 +374,7 @@ func (d *SyncServerDatasource) EventsAtTopologicalPosition( func (d *SyncServerDatasource) EventPositionInTopology( ctx context.Context, eventID string, ) (depth types.StreamPosition, stream types.StreamPosition, err error) { - depth, err = d.topology.selectPositionInTopology(ctx, nil, eventID) - if err != nil { - return - } - stream, err = d.events.selectStreamPositionForEventID(ctx, eventID) - return + return d.topology.selectPositionInTopology(ctx, nil, eventID) } // SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet. @@ -657,8 +652,8 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( // Retrieve the backward topology position, i.e. the position of the // oldest event in the room's topology. - var backwardTopologyPos types.StreamPosition - backwardTopologyPos, err = d.topology.selectPositionInTopology(ctx, txn, recentStreamEvents[0].EventID()) + var backwardTopologyPos, backwardTopologyStreamPos types.StreamPosition + backwardTopologyPos, backwardTopologyStreamPos, err = d.topology.selectPositionInTopology(ctx, txn, recentStreamEvents[0].EventID()) if backwardTopologyPos-1 <= 0 { backwardTopologyPos = types.StreamPosition(1) } else { @@ -671,7 +666,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( stateEvents = removeDuplicates(stateEvents, recentEvents) jr := types.NewJoinResponse() jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( - types.PaginationTokenTypeTopology, backwardTopologyPos, 0, + types.PaginationTokenTypeTopology, backwardTopologyPos, backwardTopologyStreamPos, ).String() jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Limited = true @@ -818,10 +813,11 @@ func (d *SyncServerDatasource) addInvitesToResponse( func (d *SyncServerDatasource) getBackwardTopologyPos( ctx context.Context, txn *sql.Tx, events []types.StreamEvent, -) (pos types.StreamPosition) { +) (pos, spos types.StreamPosition) { if len(events) > 0 { - pos, _ = d.topology.selectPositionInTopology(ctx, txn, events[0].EventID()) + pos, spos, _ = d.topology.selectPositionInTopology(ctx, txn, events[0].EventID()) } + // TODO: I have no idea what this is doing. if pos-1 <= 0 { pos = types.StreamPosition(1) } else { @@ -859,14 +855,14 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse( } recentEvents := d.StreamEventsToEvents(device, recentStreamEvents) delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) - backwardTopologyPos := d.getBackwardTopologyPos(ctx, txn, recentStreamEvents) + backwardTopologyPos, backwardStreamPos := d.getBackwardTopologyPos(ctx, txn, recentStreamEvents) switch delta.membership { case gomatrixserverlib.Join: jr := types.NewJoinResponse() jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( - types.PaginationTokenTypeTopology, backwardTopologyPos, 0, + types.PaginationTokenTypeTopology, backwardTopologyPos, backwardStreamPos, ).String() jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true @@ -879,7 +875,7 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse( // no longer in the room. lr := types.NewLeaveResponse() lr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( - types.PaginationTokenTypeTopology, backwardTopologyPos, 0, + types.PaginationTokenTypeTopology, backwardTopologyPos, backwardStreamPos, ).String() lr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index a57d59176..378c1fe35 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -220,6 +220,49 @@ func TestSyncResponse(t *testing.T) { } } +func TestGetEventsInRangeWithPrevBatch(t *testing.T) { + t.Parallel() + db := MustCreateDatabase(t) + events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB) + positions := MustWriteEvents(t, db, events) + latest, err := db.SyncPosition(ctx) + if err != nil { + t.Fatalf("failed to get SyncPosition: %s", err) + } + from := types.NewPaginationTokenFromTypeAndPosition( + types.PaginationTokenTypeStream, positions[len(positions)-2], types.StreamPosition(0), + ) + + res, err := db.IncrementalSync(ctx, testUserDeviceA, *from, latest, 5, false) + if err != nil { + t.Fatalf("failed to IncrementalSync with latest token") + } + roomRes, ok := res.Rooms.Join[testRoomID] + if !ok { + t.Fatalf("IncrementalSync response missing room %s - response: %+v", testRoomID, res) + } + // returns the last event "Message 10" + assertEventsEqual(t, "IncrementalSync Timeline", false, roomRes.Timeline.Events, reversed(events[len(events)-1:])) + + prev := roomRes.Timeline.PrevBatch + if prev == "" { + t.Fatalf("IncrementalSync expected prev_batch token") + } + prevBatchToken, err := types.NewPaginationTokenFromString(prev) + if err != nil { + t.Fatalf("failed to NewPaginationTokenFromString : %s", err) + } + // backpaginate 5 messages starting at the latest position. + // head towards the beginning of time + to := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 0, 0) + paginatedEvents, err := db.GetEventsInRange(ctx, prevBatchToken, to, testRoomID, 5, true) + if err != nil { + t.Fatalf("GetEventsInRange returned an error: %s", err) + } + gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll) + assertEventsEqual(t, "", true, gots, reversed(events[len(events)-6:len(events)-1])) +} + // The purpose of this test is to ensure that backfill does indeed go backwards, using a stream token. func TestGetEventsInRangeWithStreamToken(t *testing.T) { t.Parallel()