diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 2ef25e032..3250b0610 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -16,6 +16,7 @@ package routing import ( "context" + "encoding/json" "fmt" "net/http" "sort" @@ -45,7 +46,7 @@ type messagesReq struct { fromStream *types.StreamingToken device *userapi.Device wasToProvided bool - limit int + filter gomatrixserverlib.RoomEventFilter backwardOrdering bool } @@ -143,9 +144,21 @@ func OnIncomingMessagesRequest( wasToProvided = false } - // Maximum number of events to return; defaults to 10. - limit := defaultMessagesLimit + // RoomEventFilter + filter := gomatrixserverlib.DefaultRoomEventFilter() + filter.Limit = defaultMessagesLimit + if s := req.URL.Query().Get("filter"); len(s) > 0 { + if err = json.Unmarshal([]byte(s), &filter); err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("The filter object could not be decoded into valid JSON: " + err.Error()), + } + } + } + + // Maximum number of events to return. if len(req.URL.Query().Get("limit")) > 0 { + var limit int limit, err = strconv.Atoi(req.URL.Query().Get("limit")) if err != nil { @@ -154,8 +167,10 @@ func OnIncomingMessagesRequest( JSON: jsonerror.InvalidArgumentValue("limit could not be parsed into an integer: " + err.Error()), } } + + // Overwrite the limit set by the filter + filter.Limit = limit } - // TODO: Implement filtering (#587) // Check the room ID's format. if _, _, err = gomatrixserverlib.SplitID('!', roomID); err != nil { @@ -176,7 +191,7 @@ func OnIncomingMessagesRequest( to: &to, fromStream: fromStream, wasToProvided: wasToProvided, - limit: limit, + filter: filter, backwardOrdering: backwardOrdering, device: device, } @@ -190,7 +205,7 @@ func OnIncomingMessagesRequest( util.GetLogger(req.Context()).WithFields(logrus.Fields{ "from": from.String(), "to": to.String(), - "limit": limit, + "limit": filter.Limit, "backwards": backwardOrdering, "return_start": start.String(), "return_end": end.String(), @@ -234,19 +249,17 @@ func (r *messagesReq) retrieveEvents() ( clientEvents []gomatrixserverlib.ClientEvent, start, end types.TopologyToken, err error, ) { - eventFilter := gomatrixserverlib.DefaultRoomEventFilter() - eventFilter.Limit = r.limit // Retrieve the events from the local database. var streamEvents []types.StreamEvent if r.fromStream != nil { toStream := r.to.StreamToken() streamEvents, err = r.db.GetEventsInStreamingRange( - r.ctx, r.fromStream, &toStream, r.roomID, &eventFilter, r.backwardOrdering, + r.ctx, r.fromStream, &toStream, r.roomID, &r.filter, r.backwardOrdering, ) } else { streamEvents, err = r.db.GetEventsInTopologicalRange( - r.ctx, r.from, r.to, r.roomID, r.limit, r.backwardOrdering, + r.ctx, r.from, r.to, r.roomID, r.filter.Limit, r.backwardOrdering, ) } if err != nil { @@ -434,7 +447,7 @@ func (r *messagesReq) handleEmptyEventsSlice() ( // Check if we have backward extremities for this room. if len(backwardExtremities) > 0 { // If so, retrieve as much events as needed through backfilling. - events, err = r.backfill(r.roomID, backwardExtremities, r.limit) + events, err = r.backfill(r.roomID, backwardExtremities, r.filter.Limit) if err != nil { return } @@ -456,7 +469,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent events []*gomatrixserverlib.HeaderedEvent, err error, ) { // Check if we have enough events. - isSetLargeEnough := len(streamEvents) >= r.limit + isSetLargeEnough := len(streamEvents) >= r.filter.Limit if !isSetLargeEnough { // it might be fine we don't have up to 'limit' events, let's find out if r.backwardOrdering { @@ -483,7 +496,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent if len(backwardExtremities) > 0 && !isSetLargeEnough && r.backwardOrdering { var pdus []*gomatrixserverlib.HeaderedEvent // Only ask the remote server for enough events to reach the limit. - pdus, err = r.backfill(r.roomID, backwardExtremities, r.limit-len(streamEvents)) + pdus, err = r.backfill(r.roomID, backwardExtremities, r.filter.Limit-len(streamEvents)) if err != nil { return }