From 54fbf5fdc906fe28edb25e86968f3e7375845293 Mon Sep 17 00:00:00 2001
From: Till Faelligen <davidf@element.io>
Date: Thu, 2 Jun 2022 12:57:49 +0200
Subject: [PATCH] Cleanup

---
 syncapi/routing/messages.go | 47 ++++++++++++++++++++-----------------
 1 file changed, 26 insertions(+), 21 deletions(-)

diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go
index d568101ce..a733f877e 100644
--- a/syncapi/routing/messages.go
+++ b/syncapi/routing/messages.go
@@ -48,7 +48,7 @@ type messagesReq struct {
 	filter           *gomatrixserverlib.RoomEventFilter
 }
 
-type MessageResp struct {
+type messagesResp struct {
 	Start       string                          `json:"start"`
 	StartStream string                          `json:"start_stream,omitempty"` // NOTSPEC: used by Cerulean, so clients can hit /messages then immediately /sync with a latest sync token
 	End         string                          `json:"end"`
@@ -201,25 +201,11 @@ func OnIncomingMessagesRequest(
 		return jsonerror.InternalServerError()
 	}
 
-	// apply history_visibility filter
-	clientEventsNew := []gomatrixserverlib.ClientEvent{}
-	var stateForEvents internal.Visibility
-	stateForEvents, err = internal.GetStateForEvents(req.Context(), db, clientEvents, device.UserID)
-	if err != nil {
-		logrus.WithError(err).Error("internal.GetStateForEvents failed")
-		return jsonerror.InternalServerError()
-	}
-	for _, ev := range clientEvents {
-		if stateForEvents.Allowed(ev.EventID) {
-			clientEventsNew = append(clientEventsNew, ev)
-		}
-	}
-
 	// at least fetch the membership events for the users returned in chunk if LazyLoadMembers is set
 	state := []gomatrixserverlib.ClientEvent{}
 	if filter.LazyLoadMembers {
 		membershipToUser := make(map[string]*gomatrixserverlib.HeaderedEvent)
-		for _, evt := range clientEventsNew {
+		for _, evt := range clientEvents {
 			// Don't add membership events the client should already know about
 			if _, cached := lazyLoadCache.IsLazyLoadedUserCached(device, roomID, evt.Sender); cached {
 				continue
@@ -239,8 +225,6 @@ func OnIncomingMessagesRequest(
 		}
 	}
 
-	logrus.Debugf("Events after filtering: %d vs %d", len(clientEvents), len(clientEventsNew))
-
 	util.GetLogger(req.Context()).WithFields(logrus.Fields{
 		"from":         from.String(),
 		"to":           to.String(),
@@ -250,8 +234,8 @@ func OnIncomingMessagesRequest(
 		"return_end":   end.String(),
 	}).Info("Responding")
 
-	res := MessageResp{
-		Chunk: clientEventsNew,
+	res := messagesResp{
+		Chunk: clientEvents,
 		Start: start.String(),
 		End:   end.String(),
 		State: state,
@@ -267,6 +251,23 @@ func OnIncomingMessagesRequest(
 	}
 }
 
+func (r *messagesReq) applyHistoryVisibilityFilter(
+	clientEvents []gomatrixserverlib.ClientEvent,
+	userID string,
+) ([]gomatrixserverlib.ClientEvent, error) {
+	clientEventsNew := []gomatrixserverlib.ClientEvent{}
+	stateForEvents, err := internal.GetStateForEvents(r.ctx, r.db, clientEvents, userID)
+	if err != nil {
+		return clientEventsNew, err
+	}
+	for _, ev := range clientEvents {
+		if stateForEvents.Allowed(ev.EventID) {
+			clientEventsNew = append(clientEventsNew, ev)
+		}
+	}
+	return clientEventsNew, nil
+}
+
 func checkIsRoomForgotten(ctx context.Context, roomID, userID string, rsAPI api.SyncRoomserverAPI) (forgotten bool, exists bool, err error) {
 	req := api.QueryMembershipForUserRequest{
 		RoomID: roomID,
@@ -324,6 +325,9 @@ func (r *messagesReq) retrieveEvents() (
 	// reliable way to define it), it would be easier and less troublesome to
 	// only have to change it in one place, i.e. the database.
 	start, end, err = r.getStartEnd(events)
+	if err != nil {
+		return []gomatrixserverlib.ClientEvent{}, *r.from, *r.to, err
+	}
 
 	// Sort the events to ensure we send them in the right order.
 	if r.backwardOrdering {
@@ -342,7 +346,8 @@ func (r *messagesReq) retrieveEvents() (
 	}
 
 	// Convert all of the events into client events.
-	clientEvents = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatAll)
+	clientEvents, err = r.applyHistoryVisibilityFilter(gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatAll), r.device.UserID)
+
 	return clientEvents, start, end, err
 }