diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index d568101ce..10503facd 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -48,7 +48,7 @@ type messagesReq struct { filter *gomatrixserverlib.RoomEventFilter } -type MessageResp struct { +type messagesResp struct { Start string `json:"start"` StartStream string `json:"start_stream,omitempty"` // NOTSPEC: used by Cerulean, so clients can hit /messages then immediately /sync with a latest sync token End string `json:"end"` @@ -201,25 +201,11 @@ func OnIncomingMessagesRequest( return jsonerror.InternalServerError() } - // apply history_visibility filter - clientEventsNew := []gomatrixserverlib.ClientEvent{} - var stateForEvents internal.Visibility - stateForEvents, err = internal.GetStateForEvents(req.Context(), db, clientEvents, device.UserID) - if err != nil { - logrus.WithError(err).Error("internal.GetStateForEvents failed") - return jsonerror.InternalServerError() - } - for _, ev := range clientEvents { - if stateForEvents.Allowed(ev.EventID) { - clientEventsNew = append(clientEventsNew, ev) - } - } - // at least fetch the membership events for the users returned in chunk if LazyLoadMembers is set state := []gomatrixserverlib.ClientEvent{} if filter.LazyLoadMembers { membershipToUser := make(map[string]*gomatrixserverlib.HeaderedEvent) - for _, evt := range clientEventsNew { + for _, evt := range clientEvents { // Don't add membership events the client should already know about if _, cached := lazyLoadCache.IsLazyLoadedUserCached(device, roomID, evt.Sender); cached { continue @@ -239,8 +225,6 @@ func OnIncomingMessagesRequest( } } - logrus.Debugf("Events after filtering: %d vs %d", len(clientEvents), len(clientEventsNew)) - util.GetLogger(req.Context()).WithFields(logrus.Fields{ "from": from.String(), "to": to.String(), @@ -250,8 +234,8 @@ func OnIncomingMessagesRequest( "return_end": end.String(), }).Info("Responding") - res := MessageResp{ - Chunk: clientEventsNew, + res := messagesResp{ + Chunk: clientEvents, Start: start.String(), End: end.String(), State: state, @@ -267,6 +251,23 @@ func OnIncomingMessagesRequest( } } +func (r *messagesReq) applyHistoryVisibilityFilter( + clientEvents []gomatrixserverlib.ClientEvent, + userID string, +) ([]gomatrixserverlib.ClientEvent, error) { + clientEventsFiltered := []gomatrixserverlib.ClientEvent{} + stateForEvents, err := internal.GetStateForEvents(r.ctx, r.db, clientEvents, userID) + if err != nil { + return clientEventsFiltered, err + } + for _, ev := range clientEvents { + if stateForEvents.Allowed(ev.EventID) { + clientEventsFiltered = append(clientEventsFiltered, ev) + } + } + return clientEventsFiltered, nil +} + func checkIsRoomForgotten(ctx context.Context, roomID, userID string, rsAPI api.SyncRoomserverAPI) (forgotten bool, exists bool, err error) { req := api.QueryMembershipForUserRequest{ RoomID: roomID, @@ -324,6 +325,9 @@ func (r *messagesReq) retrieveEvents() ( // reliable way to define it), it would be easier and less troublesome to // only have to change it in one place, i.e. the database. start, end, err = r.getStartEnd(events) + if err != nil { + return []gomatrixserverlib.ClientEvent{}, *r.from, *r.to, err + } // Sort the events to ensure we send them in the right order. if r.backwardOrdering { @@ -341,8 +345,9 @@ func (r *messagesReq) retrieveEvents() ( return []gomatrixserverlib.ClientEvent{}, *r.from, *r.to, nil } - // Convert all of the events into client events. - clientEvents = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatAll) + // Convert all events into client events and filter them. + clientEvents, err = r.applyHistoryVisibilityFilter(gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatAll), r.device.UserID) + return clientEvents, start, end, err }