Better mapping of stream positions to topological positions in /messages (#2263)

* Convert stream positions into topological positions for both `from` and `to` in `/messages`

* Hopefully it works now

* Remove unnecessary logging

* Return sane values if `StreamToTopologicalPosition` can't work out the right thing to do

* Revert logging change

* tweaks

* Fix `selectEventIDsInRangeASCSQL`

* Test `Getting messages going forward is limited for a departed room (SPEC-216)` was passing incorrectly so un-whitelist it
This commit is contained in:
Neil Alexander 2022-03-18 10:40:01 +00:00 committed by GitHub
parent 191486438c
commit 475d3c1af9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 128 additions and 84 deletions

View file

@ -41,7 +41,6 @@ type messagesReq struct {
roomID string roomID string
from *types.TopologyToken from *types.TopologyToken
to *types.TopologyToken to *types.TopologyToken
fromStream *types.StreamingToken
device *userapi.Device device *userapi.Device
wasToProvided bool wasToProvided bool
backwardOrdering bool backwardOrdering bool
@ -50,7 +49,7 @@ type messagesReq struct {
type messagesResp struct { type messagesResp struct {
Start string `json:"start"` Start string `json:"start"`
StartStream string `json:"start_stream,omitempty"` // NOTSPEC: so clients can hit /messages then immediately /sync with a latest sync token StartStream string `json:"start_stream,omitempty"` // NOTSPEC: used by Cerulean, so clients can hit /messages then immediately /sync with a latest sync token
End string `json:"end"` End string `json:"end"`
Chunk []gomatrixserverlib.ClientEvent `json:"chunk"` Chunk []gomatrixserverlib.ClientEvent `json:"chunk"`
State []gomatrixserverlib.ClientEvent `json:"state"` State []gomatrixserverlib.ClientEvent `json:"state"`
@ -93,6 +92,7 @@ func OnIncomingMessagesRequest(
// Pagination tokens. // Pagination tokens.
var fromStream *types.StreamingToken var fromStream *types.StreamingToken
fromQuery := req.URL.Query().Get("from") fromQuery := req.URL.Query().Get("from")
toQuery := req.URL.Query().Get("to")
emptyFromSupplied := fromQuery == "" emptyFromSupplied := fromQuery == ""
if emptyFromSupplied { if emptyFromSupplied {
// NOTSPEC: We will pretend they used the latest sync token if no ?from= was provided. // NOTSPEC: We will pretend they used the latest sync token if no ?from= was provided.
@ -101,18 +101,6 @@ func OnIncomingMessagesRequest(
fromQuery = currPos.String() fromQuery = currPos.String()
} }
from, err := types.NewTopologyTokenFromString(fromQuery)
if err != nil {
fs, err2 := types.NewStreamTokenFromString(fromQuery)
fromStream = &fs
if err2 != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue("Invalid from parameter: " + err2.Error()),
}
}
}
// Direction to return events from. // Direction to return events from.
dir := req.URL.Query().Get("dir") dir := req.URL.Query().Get("dir")
if dir != "b" && dir != "f" { if dir != "b" && dir != "f" {
@ -125,16 +113,43 @@ func OnIncomingMessagesRequest(
// to have one of the two accepted values (so dir == "f" <=> !backwardOrdering). // to have one of the two accepted values (so dir == "f" <=> !backwardOrdering).
backwardOrdering := (dir == "b") backwardOrdering := (dir == "b")
from, err := types.NewTopologyTokenFromString(fromQuery)
if err != nil {
var streamToken types.StreamingToken
if streamToken, err = types.NewStreamTokenFromString(fromQuery); err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue("Invalid from parameter: " + err.Error()),
}
} else {
fromStream = &streamToken
from, err = db.StreamToTopologicalPosition(req.Context(), roomID, streamToken.PDUPosition, backwardOrdering)
if err != nil {
logrus.WithError(err).Errorf("Failed to get topological position for streaming token %v", streamToken)
return jsonerror.InternalServerError()
}
}
}
// Pagination tokens. To is optional, and its default value depends on the // Pagination tokens. To is optional, and its default value depends on the
// direction ("b" or "f"). // direction ("b" or "f").
var to types.TopologyToken var to types.TopologyToken
wasToProvided := true wasToProvided := true
if s := req.URL.Query().Get("to"); len(s) > 0 { if len(toQuery) > 0 {
to, err = types.NewTopologyTokenFromString(s) to, err = types.NewTopologyTokenFromString(toQuery)
if err != nil { if err != nil {
return util.JSONResponse{ var streamToken types.StreamingToken
Code: http.StatusBadRequest, if streamToken, err = types.NewStreamTokenFromString(toQuery); err != nil {
JSON: jsonerror.InvalidArgumentValue("Invalid to parameter: " + err.Error()), return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue("Invalid to parameter: " + err.Error()),
}
} else {
to, err = db.StreamToTopologicalPosition(req.Context(), roomID, streamToken.PDUPosition, !backwardOrdering)
if err != nil {
logrus.WithError(err).Errorf("Failed to get topological position for streaming token %v", streamToken)
return jsonerror.InternalServerError()
}
} }
} }
} else { } else {
@ -168,7 +183,6 @@ func OnIncomingMessagesRequest(
roomID: roomID, roomID: roomID,
from: &from, from: &from,
to: &to, to: &to,
fromStream: fromStream,
wasToProvided: wasToProvided, wasToProvided: wasToProvided,
filter: filter, filter: filter,
backwardOrdering: backwardOrdering, backwardOrdering: backwardOrdering,
@ -215,7 +229,7 @@ func OnIncomingMessagesRequest(
End: end.String(), End: end.String(),
State: state, State: state,
} }
if emptyFromSupplied { if fromStream != nil {
res.StartStream = fromStream.String() res.StartStream = fromStream.String()
} }
@ -251,17 +265,9 @@ func (r *messagesReq) retrieveEvents() (
eventFilter := r.filter eventFilter := r.filter
// Retrieve the events from the local database. // Retrieve the events from the local database.
var streamEvents []types.StreamEvent streamEvents, err := r.db.GetEventsInTopologicalRange(
if r.fromStream != nil { r.ctx, r.from, r.to, r.roomID, eventFilter.Limit, r.backwardOrdering,
toStream := r.to.StreamToken() )
streamEvents, err = r.db.GetEventsInStreamingRange(
r.ctx, r.fromStream, &toStream, r.roomID, eventFilter, r.backwardOrdering,
)
} else {
streamEvents, err = r.db.GetEventsInTopologicalRange(
r.ctx, r.from, r.to, r.roomID, eventFilter.Limit, r.backwardOrdering,
)
}
if err != nil { if err != nil {
err = fmt.Errorf("GetEventsInRange: %w", err) err = fmt.Errorf("GetEventsInRange: %w", err)
return return

View file

@ -103,8 +103,6 @@ type Database interface {
// DeletePeek deletes all peeks for a given room by a given user // DeletePeek deletes all peeks for a given room by a given user
// Returns an error if there was a problem communicating with the database. // Returns an error if there was a problem communicating with the database.
DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error) DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error)
// GetEventsInStreamingRange retrieves all of the events on a given ordering using the given extremities and limit.
GetEventsInStreamingRange(ctx context.Context, from, to *types.StreamingToken, roomID string, eventFilter *gomatrixserverlib.RoomEventFilter, backwardOrdering bool) (events []types.StreamEvent, err error)
// GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. // GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit.
GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error) GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error)
// EventPositionInTopology returns the depth and stream position of the given event. // EventPositionInTopology returns the depth and stream position of the given event.
@ -149,4 +147,6 @@ type Database interface {
SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error)
SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error)
SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error)
StreamToTopologicalPosition(ctx context.Context, roomID string, streamPos types.StreamPosition, backwardOrdering bool) (types.TopologyToken, error)
} }

View file

@ -51,7 +51,7 @@ const selectEventIDsInRangeASCSQL = "" +
"SELECT event_id FROM syncapi_output_room_events_topology" + "SELECT event_id FROM syncapi_output_room_events_topology" +
" WHERE room_id = $1 AND (" + " WHERE room_id = $1 AND (" +
"(topological_position > $2 AND topological_position < $3) OR" + "(topological_position > $2 AND topological_position < $3) OR" +
"(topological_position = $4 AND stream_position <= $5)" + "(topological_position = $4 AND stream_position >= $5)" +
") ORDER BY topological_position ASC, stream_position ASC LIMIT $6" ") ORDER BY topological_position ASC, stream_position ASC LIMIT $6"
const selectEventIDsInRangeDESCSQL = "" + const selectEventIDsInRangeDESCSQL = "" +
@ -76,13 +76,21 @@ const selectMaxPositionInTopologySQL = "" +
const deleteTopologyForRoomSQL = "" + const deleteTopologyForRoomSQL = "" +
"DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1" "DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1"
const selectStreamToTopologicalPositionAscSQL = "" +
"SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position >= $2 ORDER BY topological_position ASC LIMIT 1;"
const selectStreamToTopologicalPositionDescSQL = "" +
"SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position <= $2 ORDER BY topological_position DESC LIMIT 1;"
type outputRoomEventsTopologyStatements struct { type outputRoomEventsTopologyStatements struct {
insertEventInTopologyStmt *sql.Stmt insertEventInTopologyStmt *sql.Stmt
selectEventIDsInRangeASCStmt *sql.Stmt selectEventIDsInRangeASCStmt *sql.Stmt
selectEventIDsInRangeDESCStmt *sql.Stmt selectEventIDsInRangeDESCStmt *sql.Stmt
selectPositionInTopologyStmt *sql.Stmt selectPositionInTopologyStmt *sql.Stmt
selectMaxPositionInTopologyStmt *sql.Stmt selectMaxPositionInTopologyStmt *sql.Stmt
deleteTopologyForRoomStmt *sql.Stmt deleteTopologyForRoomStmt *sql.Stmt
selectStreamToTopologicalPositionAscStmt *sql.Stmt
selectStreamToTopologicalPositionDescStmt *sql.Stmt
} }
func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) { func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) {
@ -109,6 +117,12 @@ func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) {
if s.deleteTopologyForRoomStmt, err = db.Prepare(deleteTopologyForRoomSQL); err != nil { if s.deleteTopologyForRoomStmt, err = db.Prepare(deleteTopologyForRoomSQL); err != nil {
return nil, err return nil, err
} }
if s.selectStreamToTopologicalPositionAscStmt, err = db.Prepare(selectStreamToTopologicalPositionAscSQL); err != nil {
return nil, err
}
if s.selectStreamToTopologicalPositionDescStmt, err = db.Prepare(selectStreamToTopologicalPositionDescSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
@ -170,6 +184,19 @@ func (s *outputRoomEventsTopologyStatements) SelectPositionInTopology(
return return
} }
// SelectStreamToTopologicalPosition returns the closest position of a given event
// in the topology of the room it belongs to from the given stream position.
func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition(
ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, backwardOrdering bool,
) (topoPos types.StreamPosition, err error) {
if backwardOrdering {
err = s.selectStreamToTopologicalPositionDescStmt.QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos)
} else {
err = s.selectStreamToTopologicalPositionAscStmt.QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos)
}
return
}
func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology( func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology(
ctx context.Context, txn *sql.Tx, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
) (pos types.StreamPosition, spos types.StreamPosition, err error) { ) (pos types.StreamPosition, spos types.StreamPosition, err error) {

View file

@ -155,37 +155,6 @@ func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixse
return d.StreamEventsToEvents(nil, streamEvents), nil return d.StreamEventsToEvents(nil, streamEvents), nil
} }
// GetEventsInStreamingRange retrieves all of the events on a given ordering using the
// given extremities and limit.
func (d *Database) GetEventsInStreamingRange(
ctx context.Context,
from, to *types.StreamingToken,
roomID string, eventFilter *gomatrixserverlib.RoomEventFilter,
backwardOrdering bool,
) (events []types.StreamEvent, err error) {
r := types.Range{
From: from.PDUPosition,
To: to.PDUPosition,
Backwards: backwardOrdering,
}
if backwardOrdering {
// When using backward ordering, we want the most recent events first.
if events, _, err = d.OutputEvents.SelectRecentEvents(
ctx, nil, roomID, r, eventFilter, false, false,
); err != nil {
return
}
} else {
// When using forward ordering, we want the least recent events first.
if events, err = d.OutputEvents.SelectEarlyEvents(
ctx, nil, roomID, r, eventFilter,
); err != nil {
return
}
}
return events, err
}
func (d *Database) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) { func (d *Database) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) {
return d.CurrentRoomState.SelectJoinedUsers(ctx) return d.CurrentRoomState.SelectJoinedUsers(ctx)
} }
@ -513,6 +482,26 @@ func (d *Database) EventPositionInTopology(
return types.TopologyToken{Depth: depth, PDUPosition: stream}, nil return types.TopologyToken{Depth: depth, PDUPosition: stream}, nil
} }
func (d *Database) StreamToTopologicalPosition(
ctx context.Context, roomID string, streamPos types.StreamPosition, backwardOrdering bool,
) (types.TopologyToken, error) {
topoPos, err := d.Topology.SelectStreamToTopologicalPosition(ctx, nil, roomID, streamPos, backwardOrdering)
switch {
case err == sql.ErrNoRows && backwardOrdering: // no events in range, going backward
return types.TopologyToken{PDUPosition: streamPos}, nil
case err == sql.ErrNoRows && !backwardOrdering: // no events in range, going forward
topoPos, streamPos, err = d.Topology.SelectMaxPositionInTopology(ctx, nil, roomID)
if err != nil {
return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectMaxPositionInTopology: %w", err)
}
return types.TopologyToken{Depth: topoPos, PDUPosition: streamPos}, nil
case err != nil: // some other error happened
return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectStreamToTopologicalPosition: %w", err)
default:
return types.TopologyToken{Depth: topoPos, PDUPosition: streamPos}, nil
}
}
func (d *Database) GetFilter( func (d *Database) GetFilter(
ctx context.Context, localpart string, filterID string, ctx context.Context, localpart string, filterID string,
) (*gomatrixserverlib.Filter, error) { ) (*gomatrixserverlib.Filter, error) {

View file

@ -47,7 +47,7 @@ const selectEventIDsInRangeASCSQL = "" +
"SELECT event_id FROM syncapi_output_room_events_topology" + "SELECT event_id FROM syncapi_output_room_events_topology" +
" WHERE room_id = $1 AND (" + " WHERE room_id = $1 AND (" +
"(topological_position > $2 AND topological_position < $3) OR" + "(topological_position > $2 AND topological_position < $3) OR" +
"(topological_position = $4 AND stream_position <= $5)" + "(topological_position = $4 AND stream_position >= $5)" +
") ORDER BY topological_position ASC, stream_position ASC LIMIT $6" ") ORDER BY topological_position ASC, stream_position ASC LIMIT $6"
const selectEventIDsInRangeDESCSQL = "" + const selectEventIDsInRangeDESCSQL = "" +
@ -65,17 +65,22 @@ const selectMaxPositionInTopologySQL = "" +
"SELECT MAX(topological_position), stream_position FROM syncapi_output_room_events_topology" + "SELECT MAX(topological_position), stream_position FROM syncapi_output_room_events_topology" +
" WHERE room_id = $1 ORDER BY stream_position DESC" " WHERE room_id = $1 ORDER BY stream_position DESC"
const deleteTopologyForRoomSQL = "" + const selectStreamToTopologicalPositionAscSQL = "" +
"DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1" "SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position >= $2 ORDER BY topological_position ASC LIMIT 1;"
const selectStreamToTopologicalPositionDescSQL = "" +
"SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position <= $2 ORDER BY topological_position DESC LIMIT 1;"
type outputRoomEventsTopologyStatements struct { type outputRoomEventsTopologyStatements struct {
db *sql.DB db *sql.DB
insertEventInTopologyStmt *sql.Stmt insertEventInTopologyStmt *sql.Stmt
selectEventIDsInRangeASCStmt *sql.Stmt selectEventIDsInRangeASCStmt *sql.Stmt
selectEventIDsInRangeDESCStmt *sql.Stmt selectEventIDsInRangeDESCStmt *sql.Stmt
selectPositionInTopologyStmt *sql.Stmt selectPositionInTopologyStmt *sql.Stmt
selectMaxPositionInTopologyStmt *sql.Stmt selectMaxPositionInTopologyStmt *sql.Stmt
deleteTopologyForRoomStmt *sql.Stmt deleteTopologyForRoomStmt *sql.Stmt
selectStreamToTopologicalPositionAscStmt *sql.Stmt
selectStreamToTopologicalPositionDescStmt *sql.Stmt
} }
func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) {
@ -101,7 +106,10 @@ func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) {
if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil { if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil {
return nil, err return nil, err
} }
if s.deleteTopologyForRoomStmt, err = db.Prepare(deleteTopologyForRoomSQL); err != nil { if s.selectStreamToTopologicalPositionAscStmt, err = db.Prepare(selectStreamToTopologicalPositionAscSQL); err != nil {
return nil, err
}
if s.selectStreamToTopologicalPositionDescStmt, err = db.Prepare(selectStreamToTopologicalPositionDescSQL); err != nil {
return nil, err return nil, err
} }
return s, nil return s, nil
@ -163,6 +171,19 @@ func (s *outputRoomEventsTopologyStatements) SelectPositionInTopology(
return return
} }
// SelectStreamToTopologicalPosition returns the closest position of a given event
// in the topology of the room it belongs to from the given stream position.
func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition(
ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, backwardOrdering bool,
) (topoPos types.StreamPosition, err error) {
if backwardOrdering {
err = s.selectStreamToTopologicalPositionDescStmt.QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos)
} else {
err = s.selectStreamToTopologicalPositionAscStmt.QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos)
}
return
}
func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology( func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology(
ctx context.Context, txn *sql.Tx, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
) (pos types.StreamPosition, spos types.StreamPosition, err error) { ) (pos types.StreamPosition, spos types.StreamPosition, err error) {

View file

@ -87,6 +87,8 @@ type Topology interface {
SelectMaxPositionInTopology(ctx context.Context, txn *sql.Tx, roomID string) (depth types.StreamPosition, spos types.StreamPosition, err error) SelectMaxPositionInTopology(ctx context.Context, txn *sql.Tx, roomID string) (depth types.StreamPosition, spos types.StreamPosition, err error)
// DeleteTopologyForRoom removes all topological information for a room. This should only be done when removing the room entirely. // DeleteTopologyForRoom removes all topological information for a room. This should only be done when removing the room entirely.
DeleteTopologyForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error) DeleteTopologyForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error)
// SelectStreamToTopologicalPosition converts a stream position to a topological position by finding the nearest topological position in the room.
SelectStreamToTopologicalPosition(ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, forward bool) (topoPos types.StreamPosition, err error)
} }
type CurrentRoomState interface { type CurrentRoomState interface {

View file

@ -239,7 +239,6 @@ Inbound federation can query room alias directory
Outbound federation can query v2 /send_join Outbound federation can query v2 /send_join
Inbound federation can receive v2 /send_join Inbound federation can receive v2 /send_join
Message history can be paginated Message history can be paginated
Getting messages going forward is limited for a departed room (SPEC-216)
Backfill works correctly with history visibility set to joined Backfill works correctly with history visibility set to joined
Guest user cannot call /events globally Guest user cannot call /events globally
Guest users can join guest_access rooms Guest users can join guest_access rooms