Correctly generate backpagination tokens for events which have the same depth (#996)

* Correctly generate backpagination tokens for events which have the same depth

With tests. Unfortunately the code around here is hard to understand.
There will be a subsequent PR which fixes this up now that we have a test
harness in place.

* Add postgres impl

* More linting

* Fix psql statement so it actually works
This commit is contained in:
Kegsay 2020-05-01 11:01:34 +01:00 committed by GitHub
parent e15f6676ac
commit b28674435e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 238 additions and 80 deletions

View file

@ -229,14 +229,14 @@ func (r *messagesReq) retrieveEvents() (
// change the way topological positions are defined (as depth isn't the most // change the way topological positions are defined (as depth isn't the most
// reliable way to define it), it would be easier and less troublesome to // reliable way to define it), it would be easier and less troublesome to
// only have to change it in one place, i.e. the database. // only have to change it in one place, i.e. the database.
startPos, err := r.db.EventPositionInTopology( startPos, startStreamPos, err := r.db.EventPositionInTopology(
r.ctx, events[0].EventID(), r.ctx, events[0].EventID(),
) )
if err != nil { if err != nil {
err = fmt.Errorf("EventPositionInTopology: for start event %s: %w", events[0].EventID(), err) err = fmt.Errorf("EventPositionInTopology: for start event %s: %w", events[0].EventID(), err)
return return
} }
endPos, err := r.db.EventPositionInTopology( endPos, endStreamPos, err := r.db.EventPositionInTopology(
r.ctx, events[len(events)-1].EventID(), r.ctx, events[len(events)-1].EventID(),
) )
if err != nil { if err != nil {
@ -246,10 +246,10 @@ func (r *messagesReq) retrieveEvents() (
// Generate pagination tokens to send to the client using the positions // Generate pagination tokens to send to the client using the positions
// retrieved previously. // retrieved previously.
start = types.NewPaginationTokenFromTypeAndPosition( start = types.NewPaginationTokenFromTypeAndPosition(
types.PaginationTokenTypeTopology, startPos, 0, types.PaginationTokenTypeTopology, startPos, startStreamPos,
) )
end = types.NewPaginationTokenFromTypeAndPosition( end = types.NewPaginationTokenFromTypeAndPosition(
types.PaginationTokenTypeTopology, endPos, 0, types.PaginationTokenTypeTopology, endPos, endStreamPos,
) )
if r.backwardOrdering { if r.backwardOrdering {
@ -407,13 +407,13 @@ func setToDefault(
// go 1 earlier than the first event so we correctly fetch the earliest event // go 1 earlier than the first event so we correctly fetch the earliest event
to = types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 0, 0) to = types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 0, 0)
} else { } else {
var pos types.StreamPosition var pos, stream types.StreamPosition
pos, err = db.MaxTopologicalPosition(ctx, roomID) pos, stream, err = db.MaxTopologicalPosition(ctx, roomID)
if err != nil { if err != nil {
return return
} }
to = types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, pos, 0) to = types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, pos, stream)
} }
return return

View file

@ -91,8 +91,8 @@ type Database interface {
// GetEventsInRange retrieves all of the events on a given ordering using the // GetEventsInRange retrieves all of the events on a given ordering using the
// given extremities and limit. // given extremities and limit.
GetEventsInRange(ctx context.Context, from, to *types.PaginationToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error) GetEventsInRange(ctx context.Context, from, to *types.PaginationToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error)
// EventPositionInTopology returns the depth of the given event. // EventPositionInTopology returns the depth and stream position of the given event.
EventPositionInTopology(ctx context.Context, eventID string) (types.StreamPosition, error) EventPositionInTopology(ctx context.Context, eventID string) (depth types.StreamPosition, stream types.StreamPosition, err error)
// EventsAtTopologicalPosition returns all of the events matching a given // EventsAtTopologicalPosition returns all of the events matching a given
// position in the topology of a given room. // position in the topology of a given room.
EventsAtTopologicalPosition(ctx context.Context, roomID string, pos types.StreamPosition) ([]types.StreamEvent, error) EventsAtTopologicalPosition(ctx context.Context, roomID string, pos types.StreamPosition) ([]types.StreamEvent, error)
@ -100,7 +100,7 @@ type Database interface {
// extremities we know of for a given room. // extremities we know of for a given room.
BackwardExtremitiesForRoom(ctx context.Context, roomID string) (backwardExtremities []string, err error) BackwardExtremitiesForRoom(ctx context.Context, roomID string) (backwardExtremities []string, err error)
// MaxTopologicalPosition returns the highest topological position for a given room. // MaxTopologicalPosition returns the highest topological position for a given room.
MaxTopologicalPosition(ctx context.Context, roomID string) (types.StreamPosition, error) MaxTopologicalPosition(ctx context.Context, roomID string) (depth types.StreamPosition, stream types.StreamPosition, err error)
// StreamEventsToEvents converts streamEvent to Event. If device is non-nil and // StreamEventsToEvents converts streamEvent to Event. If device is non-nil and
// matches the streamevent.transactionID device then the transaction ID gets // matches the streamevent.transactionID device then the transaction ID gets
// added to the unsigned section of the output event. // added to the unsigned section of the output event.

View file

@ -94,6 +94,9 @@ const selectEarlyEventsSQL = "" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" + " WHERE room_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id ASC LIMIT $4" " ORDER BY id ASC LIMIT $4"
const selectStreamPositionForEventIDSQL = "" +
"SELECT id FROM syncapi_output_room_events WHERE event_id = $1"
const selectMaxEventIDSQL = "" + const selectMaxEventIDSQL = "" +
"SELECT MAX(id) FROM syncapi_output_room_events" "SELECT MAX(id) FROM syncapi_output_room_events"
@ -111,13 +114,14 @@ const selectStateInRangeSQL = "" +
" LIMIT $8" " LIMIT $8"
type outputRoomEventsStatements struct { type outputRoomEventsStatements struct {
insertEventStmt *sql.Stmt insertEventStmt *sql.Stmt
selectEventsStmt *sql.Stmt selectEventsStmt *sql.Stmt
selectMaxEventIDStmt *sql.Stmt selectMaxEventIDStmt *sql.Stmt
selectRecentEventsStmt *sql.Stmt selectRecentEventsStmt *sql.Stmt
selectRecentEventsForSyncStmt *sql.Stmt selectRecentEventsForSyncStmt *sql.Stmt
selectEarlyEventsStmt *sql.Stmt selectEarlyEventsStmt *sql.Stmt
selectStateInRangeStmt *sql.Stmt selectStateInRangeStmt *sql.Stmt
selectStreamPositionForEventIDStmt *sql.Stmt
} }
func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) { func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) {
@ -146,9 +150,18 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) {
if s.selectStateInRangeStmt, err = db.Prepare(selectStateInRangeSQL); err != nil { if s.selectStateInRangeStmt, err = db.Prepare(selectStateInRangeSQL); err != nil {
return return
} }
if s.selectStreamPositionForEventIDStmt, err = db.Prepare(selectStreamPositionForEventIDSQL); err != nil {
return
}
return return
} }
func (s *outputRoomEventsStatements) selectStreamPositionForEventID(ctx context.Context, eventID string) (types.StreamPosition, error) {
var id int64
err := s.selectStreamPositionForEventIDStmt.QueryRowContext(ctx, eventID).Scan(&id)
return types.StreamPosition(id), err
}
// selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos. // selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos.
// Results are bucketed based on the room ID. If the same state is overwritten multiple times between the // Results are bucketed based on the room ID. If the same state is overwritten multiple times between the
// two positions, only the most recent state is returned. // two positions, only the most recent state is returned.

View file

@ -32,35 +32,44 @@ CREATE TABLE IF NOT EXISTS syncapi_output_room_events_topology (
-- The place of the event in the room's topology. This can usually be determined -- The place of the event in the room's topology. This can usually be determined
-- from the event's depth. -- from the event's depth.
topological_position BIGINT NOT NULL, topological_position BIGINT NOT NULL,
stream_position BIGINT NOT NULL,
-- The 'room_id' key for the event. -- The 'room_id' key for the event.
room_id TEXT NOT NULL room_id TEXT NOT NULL
); );
-- The topological order will be used in events selection and ordering -- The topological order will be used in events selection and ordering
CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_topological_position_idx ON syncapi_output_room_events_topology(topological_position, room_id); CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_topological_position_idx ON syncapi_output_room_events_topology(topological_position, stream_position, room_id);
` `
const insertEventInTopologySQL = "" + const insertEventInTopologySQL = "" +
"INSERT INTO syncapi_output_room_events_topology (event_id, topological_position, room_id)" + "INSERT INTO syncapi_output_room_events_topology (event_id, topological_position, room_id, stream_position)" +
" VALUES ($1, $2, $3)" + " VALUES ($1, $2, $3, $4)" +
" ON CONFLICT (topological_position, room_id) DO UPDATE SET event_id = $1" " ON CONFLICT (topological_position, stream_position, room_id) DO UPDATE SET event_id = $1"
const selectEventIDsInRangeASCSQL = "" + const selectEventIDsInRangeASCSQL = "" +
"SELECT event_id FROM syncapi_output_room_events_topology" + "SELECT event_id FROM syncapi_output_room_events_topology" +
" WHERE room_id = $1 AND topological_position > $2 AND topological_position <= $3" + " WHERE room_id = $1 AND" +
" ORDER BY topological_position ASC LIMIT $4" "(topological_position > $2 AND topological_position < $3) OR" +
"(topological_position = $4 AND stream_position <= $5)" +
" ORDER BY topological_position ASC, stream_position ASC LIMIT $6"
const selectEventIDsInRangeDESCSQL = "" + const selectEventIDsInRangeDESCSQL = "" +
"SELECT event_id FROM syncapi_output_room_events_topology" + "SELECT event_id FROM syncapi_output_room_events_topology" +
" WHERE room_id = $1 AND topological_position > $2 AND topological_position <= $3" + " WHERE room_id = $1 AND" +
" ORDER BY topological_position DESC LIMIT $4" "(topological_position > $2 AND topological_position < $3) OR" +
"(topological_position = $4 AND stream_position <= $5)" +
" ORDER BY topological_position DESC, stream_position DESC LIMIT $6"
const selectPositionInTopologySQL = "" + const selectPositionInTopologySQL = "" +
"SELECT topological_position FROM syncapi_output_room_events_topology" + "SELECT topological_position FROM syncapi_output_room_events_topology" +
" WHERE event_id = $1" " WHERE event_id = $1"
// Select the max topological position for the room, then sort by stream position and take the highest,
// returning both topological and stream positions.
const selectMaxPositionInTopologySQL = "" + const selectMaxPositionInTopologySQL = "" +
"SELECT MAX(topological_position) FROM syncapi_output_room_events_topology" + "SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" +
" WHERE room_id = $1" " WHERE topological_position=(" +
"SELECT MAX(topological_position) FROM syncapi_output_room_events_topology WHERE room_id=$1" +
") ORDER BY stream_position DESC LIMIT 1"
const selectEventIDsFromPositionSQL = "" + const selectEventIDsFromPositionSQL = "" +
"SELECT event_id FROM syncapi_output_room_events_topology" + "SELECT event_id FROM syncapi_output_room_events_topology" +
@ -104,10 +113,10 @@ func (s *outputRoomEventsTopologyStatements) prepare(db *sql.DB) (err error) {
// insertEventInTopology inserts the given event in the room's topology, based // insertEventInTopology inserts the given event in the room's topology, based
// on the event's depth. // on the event's depth.
func (s *outputRoomEventsTopologyStatements) insertEventInTopology( func (s *outputRoomEventsTopologyStatements) insertEventInTopology(
ctx context.Context, event *gomatrixserverlib.HeaderedEvent, ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition,
) (err error) { ) (err error) {
_, err = s.insertEventInTopologyStmt.ExecContext( _, err = s.insertEventInTopologyStmt.ExecContext(
ctx, event.EventID(), event.Depth(), event.RoomID(), ctx, event.EventID(), event.Depth(), event.RoomID(), pos,
) )
return return
} }
@ -116,7 +125,7 @@ func (s *outputRoomEventsTopologyStatements) insertEventInTopology(
// given range in a given room's topological order. // given range in a given room's topological order.
// Returns an empty slice if no events match the given range. // Returns an empty slice if no events match the given range.
func (s *outputRoomEventsTopologyStatements) selectEventIDsInRange( func (s *outputRoomEventsTopologyStatements) selectEventIDsInRange(
ctx context.Context, roomID string, fromPos, toPos types.StreamPosition, ctx context.Context, roomID string, fromPos, toPos, toMicroPos types.StreamPosition,
limit int, chronologicalOrder bool, limit int, chronologicalOrder bool,
) (eventIDs []string, err error) { ) (eventIDs []string, err error) {
// Decide on the selection's order according to whether chronological order // Decide on the selection's order according to whether chronological order
@ -129,7 +138,7 @@ func (s *outputRoomEventsTopologyStatements) selectEventIDsInRange(
} }
// Query the event IDs. // Query the event IDs.
rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit) rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, toPos, toMicroPos, limit)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
// If no event matched the request, return an empty slice. // If no event matched the request, return an empty slice.
return []string{}, nil return []string{}, nil
@ -161,8 +170,8 @@ func (s *outputRoomEventsTopologyStatements) selectPositionInTopology(
func (s *outputRoomEventsTopologyStatements) selectMaxPositionInTopology( func (s *outputRoomEventsTopologyStatements) selectMaxPositionInTopology(
ctx context.Context, roomID string, ctx context.Context, roomID string,
) (pos types.StreamPosition, err error) { ) (pos types.StreamPosition, spos types.StreamPosition, err error) {
err = s.selectMaxPositionInTopologyStmt.QueryRowContext(ctx, roomID).Scan(&pos) err = s.selectMaxPositionInTopologyStmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos)
return return
} }

View file

@ -159,7 +159,7 @@ func (d *SyncServerDatasource) WriteEvent(
} }
pduPosition = pos pduPosition = pos
if err = d.topology.insertEventInTopology(ctx, ev); err != nil { if err = d.topology.insertEventInTopology(ctx, ev, pos); err != nil {
return err return err
} }
@ -240,12 +240,13 @@ func (d *SyncServerDatasource) GetEventsInRange(
if from.Type == types.PaginationTokenTypeTopology { if from.Type == types.PaginationTokenTypeTopology {
// Determine the backward and forward limit, i.e. the upper and lower // Determine the backward and forward limit, i.e. the upper and lower
// limits to the selection in the room's topology, from the direction. // limits to the selection in the room's topology, from the direction.
var backwardLimit, forwardLimit types.StreamPosition var backwardLimit, forwardLimit, forwardMicroLimit types.StreamPosition
if backwardOrdering { if backwardOrdering {
// Backward ordering is antichronological (latest event to oldest // Backward ordering is antichronological (latest event to oldest
// one). // one).
backwardLimit = to.PDUPosition backwardLimit = to.PDUPosition
forwardLimit = from.PDUPosition forwardLimit = from.PDUPosition
forwardMicroLimit = from.EDUTypingPosition
} else { } else {
// Forward ordering is chronological (oldest event to latest one). // Forward ordering is chronological (oldest event to latest one).
backwardLimit = from.PDUPosition backwardLimit = from.PDUPosition
@ -255,7 +256,7 @@ func (d *SyncServerDatasource) GetEventsInRange(
// Select the event IDs from the defined range. // Select the event IDs from the defined range.
var eIDs []string var eIDs []string
eIDs, err = d.topology.selectEventIDsInRange( eIDs, err = d.topology.selectEventIDsInRange(
ctx, roomID, backwardLimit, forwardLimit, limit, !backwardOrdering, ctx, roomID, backwardLimit, forwardLimit, forwardMicroLimit, limit, !backwardOrdering,
) )
if err != nil { if err != nil {
return return
@ -301,7 +302,7 @@ func (d *SyncServerDatasource) BackwardExtremitiesForRoom(
func (d *SyncServerDatasource) MaxTopologicalPosition( func (d *SyncServerDatasource) MaxTopologicalPosition(
ctx context.Context, roomID string, ctx context.Context, roomID string,
) (types.StreamPosition, error) { ) (depth types.StreamPosition, stream types.StreamPosition, err error) {
return d.topology.selectMaxPositionInTopology(ctx, roomID) return d.topology.selectMaxPositionInTopology(ctx, roomID)
} }
@ -318,8 +319,13 @@ func (d *SyncServerDatasource) EventsAtTopologicalPosition(
func (d *SyncServerDatasource) EventPositionInTopology( func (d *SyncServerDatasource) EventPositionInTopology(
ctx context.Context, eventID string, ctx context.Context, eventID string,
) (types.StreamPosition, error) { ) (depth types.StreamPosition, stream types.StreamPosition, err error) {
return d.topology.selectPositionInTopology(ctx, eventID) depth, err = d.topology.selectPositionInTopology(ctx, eventID)
if err != nil {
return
}
stream, err = d.events.selectStreamPositionForEventID(ctx, eventID)
return
} }
func (d *SyncServerDatasource) SyncStreamPosition(ctx context.Context) (types.StreamPosition, error) { func (d *SyncServerDatasource) SyncStreamPosition(ctx context.Context) (types.StreamPosition, error) {

View file

@ -74,6 +74,9 @@ const selectEarlyEventsSQL = "" +
const selectMaxEventIDSQL = "" + const selectMaxEventIDSQL = "" +
"SELECT MAX(id) FROM syncapi_output_room_events" "SELECT MAX(id) FROM syncapi_output_room_events"
const selectStreamPositionForEventIDSQL = "" +
"SELECT id FROM syncapi_output_room_events WHERE event_id = $1"
// In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id). // In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id).
/* /*
$1 = oldPos, $1 = oldPos,
@ -99,14 +102,15 @@ const selectStateInRangeSQL = "" +
" LIMIT $8" // limit " LIMIT $8" // limit
type outputRoomEventsStatements struct { type outputRoomEventsStatements struct {
streamIDStatements *streamIDStatements streamIDStatements *streamIDStatements
insertEventStmt *sql.Stmt insertEventStmt *sql.Stmt
selectEventsStmt *sql.Stmt selectEventsStmt *sql.Stmt
selectMaxEventIDStmt *sql.Stmt selectMaxEventIDStmt *sql.Stmt
selectRecentEventsStmt *sql.Stmt selectRecentEventsStmt *sql.Stmt
selectRecentEventsForSyncStmt *sql.Stmt selectRecentEventsForSyncStmt *sql.Stmt
selectEarlyEventsStmt *sql.Stmt selectEarlyEventsStmt *sql.Stmt
selectStateInRangeStmt *sql.Stmt selectStateInRangeStmt *sql.Stmt
selectStreamPositionForEventIDStmt *sql.Stmt
} }
func (s *outputRoomEventsStatements) prepare(db *sql.DB, streamID *streamIDStatements) (err error) { func (s *outputRoomEventsStatements) prepare(db *sql.DB, streamID *streamIDStatements) (err error) {
@ -136,9 +140,18 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB, streamID *streamIDState
if s.selectStateInRangeStmt, err = db.Prepare(selectStateInRangeSQL); err != nil { if s.selectStateInRangeStmt, err = db.Prepare(selectStateInRangeSQL); err != nil {
return return
} }
if s.selectStreamPositionForEventIDStmt, err = db.Prepare(selectStreamPositionForEventIDSQL); err != nil {
return
}
return return
} }
func (s *outputRoomEventsStatements) selectStreamPositionForEventID(ctx context.Context, eventID string) (types.StreamPosition, error) {
var id int64
err := s.selectStreamPositionForEventIDStmt.QueryRowContext(ctx, eventID).Scan(&id)
return types.StreamPosition(id), err
}
// selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos. // selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos.
// Results are bucketed based on the room ID. If the same state is overwritten multiple times between the // Results are bucketed based on the room ID. If the same state is overwritten multiple times between the
// two positions, only the most recent state is returned. // two positions, only the most recent state is returned.

View file

@ -27,37 +27,42 @@ const outputRoomEventsTopologySchema = `
-- Stores output room events received from the roomserver. -- Stores output room events received from the roomserver.
CREATE TABLE IF NOT EXISTS syncapi_output_room_events_topology ( CREATE TABLE IF NOT EXISTS syncapi_output_room_events_topology (
event_id TEXT PRIMARY KEY, event_id TEXT PRIMARY KEY,
topological_position BIGINT NOT NULL, topological_position BIGINT NOT NULL,
stream_position BIGINT NOT NULL,
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
UNIQUE(topological_position, room_id) UNIQUE(topological_position, room_id, stream_position)
); );
-- The topological order will be used in events selection and ordering -- The topological order will be used in events selection and ordering
-- CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_topological_position_idx ON syncapi_output_room_events_topology(topological_position, room_id); -- CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_topological_position_idx ON syncapi_output_room_events_topology(topological_position, stream_position, room_id);
` `
const insertEventInTopologySQL = "" + const insertEventInTopologySQL = "" +
"INSERT INTO syncapi_output_room_events_topology (event_id, topological_position, room_id)" + "INSERT INTO syncapi_output_room_events_topology (event_id, topological_position, room_id, stream_position)" +
" VALUES ($1, $2, $3)" + " VALUES ($1, $2, $3, $4)" +
" ON CONFLICT (topological_position, room_id) DO UPDATE SET event_id = $1" " ON CONFLICT DO NOTHING"
const selectEventIDsInRangeASCSQL = "" + const selectEventIDsInRangeASCSQL = "" +
"SELECT event_id FROM syncapi_output_room_events_topology" + "SELECT event_id FROM syncapi_output_room_events_topology" +
" WHERE room_id = $1 AND topological_position > $2 AND topological_position <= $3" + " WHERE room_id = $1 AND" +
" ORDER BY topological_position ASC LIMIT $4" "(topological_position > $2 AND topological_position < $3) OR" +
"(topological_position = $4 AND stream_position <= $5)" +
" ORDER BY topological_position ASC, stream_position ASC LIMIT $6"
const selectEventIDsInRangeDESCSQL = "" + const selectEventIDsInRangeDESCSQL = "" +
"SELECT event_id FROM syncapi_output_room_events_topology" + "SELECT event_id FROM syncapi_output_room_events_topology" +
" WHERE room_id = $1 AND topological_position > $2 AND topological_position <= $3" + " WHERE room_id = $1 AND" +
" ORDER BY topological_position DESC LIMIT $4" "(topological_position > $2 AND topological_position < $3) OR" +
"(topological_position = $4 AND stream_position <= $5)" +
" ORDER BY topological_position DESC, stream_position DESC LIMIT $6"
const selectPositionInTopologySQL = "" + const selectPositionInTopologySQL = "" +
"SELECT topological_position FROM syncapi_output_room_events_topology" + "SELECT topological_position FROM syncapi_output_room_events_topology" +
" WHERE event_id = $1" " WHERE event_id = $1"
const selectMaxPositionInTopologySQL = "" + const selectMaxPositionInTopologySQL = "" +
"SELECT MAX(topological_position) FROM syncapi_output_room_events_topology" + "SELECT MAX(topological_position), stream_position FROM syncapi_output_room_events_topology" +
" WHERE room_id = $1" " WHERE room_id = $1 ORDER BY stream_position DESC"
const selectEventIDsFromPositionSQL = "" + const selectEventIDsFromPositionSQL = "" +
"SELECT event_id FROM syncapi_output_room_events_topology" + "SELECT event_id FROM syncapi_output_room_events_topology" +
@ -101,11 +106,11 @@ func (s *outputRoomEventsTopologyStatements) prepare(db *sql.DB) (err error) {
// insertEventInTopology inserts the given event in the room's topology, based // insertEventInTopology inserts the given event in the room's topology, based
// on the event's depth. // on the event's depth.
func (s *outputRoomEventsTopologyStatements) insertEventInTopology( func (s *outputRoomEventsTopologyStatements) insertEventInTopology(
ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition,
) (err error) { ) (err error) {
stmt := common.TxStmt(txn, s.insertEventInTopologyStmt) stmt := common.TxStmt(txn, s.insertEventInTopologyStmt)
_, err = stmt.ExecContext( _, err = stmt.ExecContext(
ctx, event.EventID(), event.Depth(), event.RoomID(), ctx, event.EventID(), event.Depth(), event.RoomID(), pos,
) )
return return
} }
@ -115,7 +120,7 @@ func (s *outputRoomEventsTopologyStatements) insertEventInTopology(
// Returns an empty slice if no events match the given range. // Returns an empty slice if no events match the given range.
func (s *outputRoomEventsTopologyStatements) selectEventIDsInRange( func (s *outputRoomEventsTopologyStatements) selectEventIDsInRange(
ctx context.Context, txn *sql.Tx, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
fromPos, toPos types.StreamPosition, fromPos, toPos, toMicroPos types.StreamPosition,
limit int, chronologicalOrder bool, limit int, chronologicalOrder bool,
) (eventIDs []string, err error) { ) (eventIDs []string, err error) {
// Decide on the selection's order according to whether chronological order // Decide on the selection's order according to whether chronological order
@ -128,7 +133,7 @@ func (s *outputRoomEventsTopologyStatements) selectEventIDsInRange(
} }
// Query the event IDs. // Query the event IDs.
rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit) rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, toPos, toMicroPos, limit)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
// If no event matched the request, return an empty slice. // If no event matched the request, return an empty slice.
return []string{}, nil return []string{}, nil
@ -160,9 +165,9 @@ func (s *outputRoomEventsTopologyStatements) selectPositionInTopology(
func (s *outputRoomEventsTopologyStatements) selectMaxPositionInTopology( func (s *outputRoomEventsTopologyStatements) selectMaxPositionInTopology(
ctx context.Context, txn *sql.Tx, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
) (pos types.StreamPosition, err error) { ) (pos types.StreamPosition, spos types.StreamPosition, err error) {
stmt := common.TxStmt(txn, s.selectMaxPositionInTopologyStmt) stmt := common.TxStmt(txn, s.selectMaxPositionInTopologyStmt)
err = stmt.QueryRowContext(ctx, roomID).Scan(&pos) err = stmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos)
return return
} }

View file

@ -194,7 +194,7 @@ func (d *SyncServerDatasource) WriteEvent(
} }
pduPosition = pos pduPosition = pos
if err = d.topology.insertEventInTopology(ctx, txn, ev); err != nil { if err = d.topology.insertEventInTopology(ctx, txn, ev, pos); err != nil {
return err return err
} }
@ -281,14 +281,16 @@ func (d *SyncServerDatasource) GetEventsInRange(
// events must be retrieved from the rooms' topology table rather than the // events must be retrieved from the rooms' topology table rather than the
// table contaning the syncapi server's whole stream of events. // table contaning the syncapi server's whole stream of events.
if from.Type == types.PaginationTokenTypeTopology { if from.Type == types.PaginationTokenTypeTopology {
// TODO: ARGH CONFUSING
// Determine the backward and forward limit, i.e. the upper and lower // Determine the backward and forward limit, i.e. the upper and lower
// limits to the selection in the room's topology, from the direction. // limits to the selection in the room's topology, from the direction.
var backwardLimit, forwardLimit types.StreamPosition var backwardLimit, forwardLimit, forwardMicroLimit types.StreamPosition
if backwardOrdering { if backwardOrdering {
// Backward ordering is antichronological (latest event to oldest // Backward ordering is antichronological (latest event to oldest
// one). // one).
backwardLimit = to.PDUPosition backwardLimit = to.PDUPosition
forwardLimit = from.PDUPosition forwardLimit = from.PDUPosition
forwardMicroLimit = from.EDUTypingPosition
} else { } else {
// Forward ordering is chronological (oldest event to latest one). // Forward ordering is chronological (oldest event to latest one).
backwardLimit = from.PDUPosition backwardLimit = from.PDUPosition
@ -298,7 +300,7 @@ func (d *SyncServerDatasource) GetEventsInRange(
// Select the event IDs from the defined range. // Select the event IDs from the defined range.
var eIDs []string var eIDs []string
eIDs, err = d.topology.selectEventIDsInRange( eIDs, err = d.topology.selectEventIDsInRange(
ctx, nil, roomID, backwardLimit, forwardLimit, limit, !backwardOrdering, ctx, nil, roomID, backwardLimit, forwardLimit, forwardMicroLimit, limit, !backwardOrdering,
) )
if err != nil { if err != nil {
return return
@ -328,8 +330,7 @@ func (d *SyncServerDatasource) GetEventsInRange(
return return
} }
} }
return events, err
return
} }
// SyncPosition returns the latest positions for syncing. // SyncPosition returns the latest positions for syncing.
@ -353,7 +354,7 @@ func (d *SyncServerDatasource) BackwardExtremitiesForRoom(
// room. // room.
func (d *SyncServerDatasource) MaxTopologicalPosition( func (d *SyncServerDatasource) MaxTopologicalPosition(
ctx context.Context, roomID string, ctx context.Context, roomID string,
) (types.StreamPosition, error) { ) (types.StreamPosition, types.StreamPosition, error) {
return d.topology.selectMaxPositionInTopology(ctx, nil, roomID) return d.topology.selectMaxPositionInTopology(ctx, nil, roomID)
} }
@ -372,8 +373,13 @@ func (d *SyncServerDatasource) EventsAtTopologicalPosition(
func (d *SyncServerDatasource) EventPositionInTopology( func (d *SyncServerDatasource) EventPositionInTopology(
ctx context.Context, eventID string, ctx context.Context, eventID string,
) (types.StreamPosition, error) { ) (depth types.StreamPosition, stream types.StreamPosition, err error) {
return d.topology.selectPositionInTopology(ctx, nil, eventID) depth, err = d.topology.selectPositionInTopology(ctx, nil, eventID)
if err != nil {
return
}
stream, err = d.events.selectStreamPositionForEventID(ctx, eventID)
return
} }
// SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet. // SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet.

View file

@ -182,7 +182,7 @@ func TestSyncResponse(t *testing.T) {
// limit set to 5 // limit set to 5
return db.CompleteSync(ctx, testUserIDA, 5) return db.CompleteSync(ctx, testUserIDA, 5)
}, },
// want the last 5 events, NOT the last 10. // want the last 5 events
WantTimeline: events[len(events)-5:], WantTimeline: events[len(events)-5:],
// want all state for the room // want all state for the room
WantState: state, WantState: state,
@ -248,11 +248,11 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
db := MustCreateDatabase(t) db := MustCreateDatabase(t)
events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB) events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
MustWriteEvents(t, db, events) MustWriteEvents(t, db, events)
latest, err := db.MaxTopologicalPosition(ctx, testRoomID) latest, latestStream, err := db.MaxTopologicalPosition(ctx, testRoomID)
if err != nil { if err != nil {
t.Fatalf("failed to get MaxTopologicalPosition: %s", err) t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
} }
from := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, latest, 0) from := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, latest, latestStream)
// head towards the beginning of time // head towards the beginning of time
to := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 0, 0) to := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 0, 0)
@ -265,6 +265,105 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
assertEventsEqual(t, "", true, gots, reversed(events[len(events)-5:])) assertEventsEqual(t, "", true, gots, reversed(events[len(events)-5:]))
} }
// The purpose of this test is to make sure that backpagination returns all events, even if some events have the same depth.
// For cases where events have the same depth, the streaming token should be used to tie break so events written via WriteEvent
// will appear FIRST when going backwards. This test creates a DAG like:
// .-----> Message ---.
// Create -> Membership --------> Message -------> Message
// `-----> Message ---`
// depth 1 2 3 4
//
// With a total depth of 4. It tests that:
// - Backpagination over the whole fork should include all messages and not leave any out.
// - Backpagination from the middle of the fork should not return duplicates (things later than the token).
func TestGetEventsInRangeWithEventsSameDepth(t *testing.T) {
t.Parallel()
db := MustCreateDatabase(t)
var events []gomatrixserverlib.HeaderedEvent
events = append(events, MustCreateEvent(t, testRoomID, nil, &gomatrixserverlib.EventBuilder{
Content: []byte(fmt.Sprintf(`{"room_version":"4","creator":"%s"}`, testUserIDA)),
Type: "m.room.create",
StateKey: &emptyStateKey,
Sender: testUserIDA,
Depth: int64(len(events) + 1),
}))
events = append(events, MustCreateEvent(t, testRoomID, []gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{
Content: []byte(fmt.Sprintf(`{"membership":"join"}`)),
Type: "m.room.member",
StateKey: &testUserIDA,
Sender: testUserIDA,
Depth: int64(len(events) + 1),
}))
// fork the dag into three, same prev_events and depth
parent := []gomatrixserverlib.HeaderedEvent{events[len(events)-1]}
depth := int64(len(events) + 1)
for i := 0; i < 3; i++ {
events = append(events, MustCreateEvent(t, testRoomID, parent, &gomatrixserverlib.EventBuilder{
Content: []byte(fmt.Sprintf(`{"body":"Message A %d"}`, i+1)),
Type: "m.room.message",
Sender: testUserIDA,
Depth: depth,
}))
}
// merge the fork, prev_events are all 3 messages, depth is increased by 1.
events = append(events, MustCreateEvent(t, testRoomID, events[len(events)-3:], &gomatrixserverlib.EventBuilder{
Content: []byte(fmt.Sprintf(`{"body":"Message merge"}`)),
Type: "m.room.message",
Sender: testUserIDA,
Depth: depth + 1,
}))
MustWriteEvents(t, db, events)
latestPos, latestStreamPos, err := db.EventPositionInTopology(ctx, events[len(events)-1].EventID())
if err != nil {
t.Fatalf("failed to get EventPositionInTopology: %s", err)
}
topoPos, streamPos, err := db.EventPositionInTopology(ctx, events[len(events)-3].EventID()) // Message 2
if err != nil {
t.Fatalf("failed to get EventPositionInTopology for event: %s", err)
}
fromLatest := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, latestPos, latestStreamPos)
fromFork := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, topoPos, streamPos)
// head towards the beginning of time
to := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 0, 0)
testCases := []struct {
Name string
From *types.PaginationToken
Limit int
Wants []gomatrixserverlib.HeaderedEvent
}{
{
Name: "Pagination over the whole fork",
From: fromLatest,
Limit: 5,
Wants: reversed(events[len(events)-5:]),
},
{
Name: "Paginating to the middle of the fork",
From: fromLatest,
Limit: 2,
Wants: reversed(events[len(events)-2:]),
},
{
Name: "Pagination FROM the middle of the fork",
From: fromFork,
Limit: 3,
Wants: reversed(events[len(events)-5 : len(events)-2]),
},
}
for _, tc := range testCases {
// backpaginate messages starting at the latest position.
paginatedEvents, err := db.GetEventsInRange(ctx, tc.From, to, testRoomID, tc.Limit, true)
if err != nil {
t.Fatalf("%s GetEventsInRange returned an error: %s", tc.Name, err)
}
gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll)
assertEventsEqual(t, tc.Name, true, gots, tc.Wants)
}
}
func assertEventsEqual(t *testing.T, msg string, checkRoomID bool, gots []gomatrixserverlib.ClientEvent, wants []gomatrixserverlib.HeaderedEvent) { func assertEventsEqual(t *testing.T, msg string, checkRoomID bool, gots []gomatrixserverlib.ClientEvent, wants []gomatrixserverlib.HeaderedEvent) {
if len(gots) != len(wants) { if len(gots) != len(wants) {
t.Fatalf("%s response returned %d events, want %d", msg, len(gots), len(wants)) t.Fatalf("%s response returned %d events, want %d", msg, len(gots), len(wants))
@ -294,7 +393,8 @@ func assertEventsEqual(t *testing.T, msg string, checkRoomID bool, gots []gomatr
t.Errorf("%s event[%d] unsigned mismatch: got %s want %s", msg, i, string(g.Unsigned), string(w.Unsigned())) t.Errorf("%s event[%d] unsigned mismatch: got %s want %s", msg, i, string(g.Unsigned), string(w.Unsigned()))
} }
if (g.StateKey == nil && w.StateKey() != nil) || (g.StateKey != nil && w.StateKey() == nil) { if (g.StateKey == nil && w.StateKey() != nil) || (g.StateKey != nil && w.StateKey() == nil) {
t.Fatalf("%s event[%d] state_key [not] missing: got %v want %v", msg, i, g.StateKey, w.StateKey()) t.Errorf("%s event[%d] state_key [not] missing: got %v want %v", msg, i, g.StateKey, w.StateKey())
continue
} }
if g.StateKey != nil { if g.StateKey != nil {
if !w.StateKeyEquals(*g.StateKey) { if !w.StateKeyEquals(*g.StateKey) {

View file

@ -64,8 +64,14 @@ const (
// /sync or /messages, for example. // /sync or /messages, for example.
type PaginationToken struct { type PaginationToken struct {
//Position StreamPosition //Position StreamPosition
Type PaginationTokenType Type PaginationTokenType
PDUPosition StreamPosition // For /sync, this is the PDU position. For /messages, this is the topological position (depth).
// TODO: Given how different the positions are depending on the token type, they should probably be renamed
// or use different structs altogether.
PDUPosition StreamPosition
// For /sync, this is the EDU position. For /messages, this is the stream (PDU) position.
// TODO: Given how different the positions are depending on the token type, they should probably be renamed
// or use different structs altogether.
EDUTypingPosition StreamPosition EDUTypingPosition StreamPosition
} }