diff --git a/syncapi/internal/history_visibility.go b/syncapi/internal/history_visibility.go index 1cfd7ec21..9f5316aed 100644 --- a/syncapi/internal/history_visibility.go +++ b/syncapi/internal/history_visibility.go @@ -84,8 +84,29 @@ func (v Visibility) Allowed(eventID string) (allowed bool) { } } -// GetStateForEvents returns a Visibility map containing the state before and at the given events. -func GetStateForEvents(ctx context.Context, db storage.Database, events []gomatrixserverlib.ClientEvent, userID string) (Visibility, error) { +// ApplyHistoryVisibilityFilter applies the room history visibility filter on gomatrixserverlib.ClientEvents. +// Returns the filtered events and an error, if any. +func ApplyHistoryVisibilityFilter( + ctx context.Context, + syncDB storage.Database, + clientEvents []gomatrixserverlib.ClientEvent, + userID string, +) ([]gomatrixserverlib.ClientEvent, error) { + clientEventsFiltered := []gomatrixserverlib.ClientEvent{} + stateForEvents, err := getStateForEvents(ctx, syncDB, clientEvents, userID) + if err != nil { + return clientEventsFiltered, err + } + for _, ev := range clientEvents { + if stateForEvents.Allowed(ev.EventID) { + clientEventsFiltered = append(clientEventsFiltered, ev) + } + } + return clientEventsFiltered, nil +} + +// getStateForEvents returns a Visibility map containing the state before and at the given events. +func getStateForEvents(ctx context.Context, db storage.Database, events []gomatrixserverlib.ClientEvent, userID string) (Visibility, error) { result := make(map[string]EventVisibility, len(events)) var ( membershipCurrent string diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 10503facd..110232581 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -251,23 +251,6 @@ func OnIncomingMessagesRequest( } } -func (r *messagesReq) applyHistoryVisibilityFilter( - clientEvents []gomatrixserverlib.ClientEvent, - userID string, -) ([]gomatrixserverlib.ClientEvent, error) { - clientEventsFiltered := []gomatrixserverlib.ClientEvent{} - stateForEvents, err := internal.GetStateForEvents(r.ctx, r.db, clientEvents, userID) - if err != nil { - return clientEventsFiltered, err - } - for _, ev := range clientEvents { - if stateForEvents.Allowed(ev.EventID) { - clientEventsFiltered = append(clientEventsFiltered, ev) - } - } - return clientEventsFiltered, nil -} - func checkIsRoomForgotten(ctx context.Context, roomID, userID string, rsAPI api.SyncRoomserverAPI) (forgotten bool, exists bool, err error) { req := api.QueryMembershipForUserRequest{ RoomID: roomID, @@ -346,7 +329,7 @@ func (r *messagesReq) retrieveEvents() ( } // Convert all events into client events and filter them. - clientEvents, err = r.applyHistoryVisibilityFilter(gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatAll), r.device.UserID) + clientEvents, err = internal.ApplyHistoryVisibilityFilter(r.ctx, r.db, gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatAll), r.device.UserID) return clientEvents, start, end, err }