From 2c8058636092f56e6006202ee61a7203ac292d6f Mon Sep 17 00:00:00 2001 From: Till Faelligen Date: Mon, 21 Feb 2022 09:39:41 +0100 Subject: [PATCH] Simplify getting the required events --- syncapi/routing/context.go | 196 +++++++++----------------------- syncapi/routing/context_test.go | 27 ++--- 2 files changed, 66 insertions(+), 157 deletions(-) diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index b8afcd0ba..0b97a774c 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -15,7 +15,6 @@ package routing import ( - "context" "database/sql" "encoding/json" "net/http" @@ -45,7 +44,7 @@ func Context( syncDB storage.Database, roomID, eventID string, ) util.JSONResponse { - limit, filter, err := parseContextParams(req) + filter, err := parseContextParams(req) if err != nil { errMsg := "" switch err.(type) { @@ -60,6 +59,8 @@ func Context( Headers: nil, } } + filter.Rooms = append(filter.Rooms, roomID) + ctx := req.Context() membershipRes := roomserver.QueryMembershipForUserResponse{} membershipReq := roomserver.QueryMembershipForUserRequest{UserID: device.UserID, RoomID: roomID} @@ -68,66 +69,27 @@ func Context( return jsonerror.InternalServerError() } - state, userAllowed, err := getCurrentState(ctx, rsAPI, roomID, device.UserID) - if err != nil { - return jsonerror.InternalServerError() - } - if !userAllowed { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("User is not allowed to query contenxt"), - } - } - id, requestedEvent, err := syncDB.SelectContextEvent(ctx, roomID, eventID) - if err != nil { - logrus.WithError(err).WithField("eventID", eventID).Error("unable to find requested event") - return jsonerror.InternalServerError() + stateFilter := gomatrixserverlib.StateFilter{ + Limit: filter.Limit, + NotSenders: filter.NotSenders, + NotTypes: filter.NotTypes, + Senders: filter.Senders, + Types: filter.Types, + LazyLoadMembers: filter.LazyLoadMembers, + IncludeRedundantMembers: filter.IncludeRedundantMembers, + NotRooms: filter.NotRooms, + Rooms: filter.Rooms, + ContainsURL: filter.ContainsURL, } - eventsBefore, err := syncDB.SelectContextBeforeEvent(ctx, id, roomID, limit/2) - if err != nil && err != sql.ErrNoRows { - logrus.WithError(err).Error("unable to fetch before events") - return jsonerror.InternalServerError() - } - - _, eventsAfter, err := syncDB.SelectContextAfterEvent(ctx, id, roomID, limit/2) - if err != nil && err != sql.ErrNoRows { - logrus.WithError(err).Error("unable to fetch after events") - return jsonerror.InternalServerError() - } - - /*excludeEventIDs, err := syncDB.SelectEventIDsAfter(ctx, roomID, lastID) - if err != nil { - logrus.WithError(err).Error("unable to fetch excludeEventIDs") - return jsonerror.InternalServerError() - } - - stateFilter := gomatrixserverlib.StateFilter{Limit: 100, Rooms: []string{roomID}} - if filter != nil { - stateFilter = gomatrixserverlib.StateFilter{ - Limit: filter.Limit, - NotSenders: filter.NotSenders, - NotTypes: filter.NotTypes, - Senders: filter.Senders, - Types: filter.Types, - LazyLoadMembers: filter.LazyLoadMembers, - IncludeRedundantMembers: filter.IncludeRedundantMembers, - NotRooms: filter.NotRooms, - Rooms: filter.Rooms, - ContainsURL: filter.ContainsURL, - } - } - _ = stateFilter - _ = excludeEventIDs - - sstate, _ := syncDB.CurrentState(ctx, roomID, &stateFilter, nil) - for _, x := range sstate { + state, _ := syncDB.CurrentState(ctx, roomID, &stateFilter, nil) + // verify the user is allowed to see the context for this room/event + for _, x := range state { hisVis, err := x.HistoryVisibility() if err != nil { continue } allowed := hisVis != "world_readable" && membershipRes.Membership == "join" - logrus.Debugf("State: %+v %+v %+v", x.Type(), hisVis, allowed) if !allowed { return util.JSONResponse{ Code: http.StatusForbidden, @@ -135,7 +97,24 @@ func Context( } } } - */ + + id, requestedEvent, err := syncDB.SelectContextEvent(ctx, roomID, eventID) + if err != nil { + logrus.WithError(err).WithField("eventID", eventID).Error("unable to find requested event") + return jsonerror.InternalServerError() + } + + eventsBefore, err := syncDB.SelectContextBeforeEvent(ctx, id, roomID, filter.Limit/2) + if err != nil && err != sql.ErrNoRows { + logrus.WithError(err).Error("unable to fetch before events") + return jsonerror.InternalServerError() + } + + _, eventsAfter, err := syncDB.SelectContextAfterEvent(ctx, id, roomID, filter.Limit/2) + if err != nil && err != sql.ErrNoRows { + logrus.WithError(err).Error("unable to fetch after events") + return jsonerror.InternalServerError() + } eventsBeforeClient := gomatrixserverlib.HeaderedToClientEvents(eventsBefore, gomatrixserverlib.FormatAll) eventsAfterClient := gomatrixserverlib.HeaderedToClientEvents(eventsAfter, gomatrixserverlib.FormatAll) @@ -147,7 +126,7 @@ func Context( EventsAfter: eventsAfterClient, EventsBefore: eventsBeforeClient, Start: "start", - State: newState, + State: gomatrixserverlib.HeaderedToClientEvents(newState, gomatrixserverlib.FormatAll), } return util.JSONResponse{ @@ -156,118 +135,53 @@ func Context( } } -func applyLazyLoadMembers(filter *gomatrixserverlib.RoomEventFilter, eventsAfter, eventsBefore []gomatrixserverlib.ClientEvent, state []gomatrixserverlib.ClientEvent) []gomatrixserverlib.ClientEvent { +func applyLazyLoadMembers(filter *gomatrixserverlib.RoomEventFilter, eventsAfter, eventsBefore []gomatrixserverlib.ClientEvent, state []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent { if filter == nil || !filter.LazyLoadMembers { - logrus.Debugf("filter is nil or lazyloadmembers is false") return state } allEvents := append(eventsAfter, eventsBefore...) x := make(map[string]bool) // get members who actually send an event for _, e := range allEvents { - if filter.LazyLoadMembers { - x[e.Sender] = true - } + x[e.Sender] = true } - // apply lazy_load_members - if filter.LazyLoadMembers { - newState := []gomatrixserverlib.ClientEvent{} - for _, event := range state { - if event.Type != gomatrixserverlib.MRoomMember { + newState := []*gomatrixserverlib.HeaderedEvent{} + for _, event := range state { + if event.Type() != gomatrixserverlib.MRoomMember { + newState = append(newState, event) + } else { + // did the user send an event? + if x[event.Sender()] { newState = append(newState, event) - } else { - // did the user send an event? - if x[event.Sender] { - newState = append(newState, event) - } } } - return newState } - return state + return newState } -// getCurrentState returns the current state of the requested room -func getCurrentState(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI, roomID, userID string) (events []gomatrixserverlib.ClientEvent, userAllowed bool, err error) { +func parseContextParams(req *http.Request) (*gomatrixserverlib.RoomEventFilter, error) { + // Default room filter + filter := &gomatrixserverlib.RoomEventFilter{Limit: 10} - avatarTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.avatar", StateKey: ""} - nameTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.name", StateKey: ""} - canonicalTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomCanonicalAlias, StateKey: ""} - topicTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.topic", StateKey: ""} - guestTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.guest_access", StateKey: ""} - visibilityTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomHistoryVisibility, StateKey: ""} - joinRuleTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomJoinRules, StateKey: ""} - - // get the current state - - currentState := &roomserver.QueryCurrentStateResponse{} - if err := rsAPI.QueryCurrentState(ctx, &roomserver.QueryCurrentStateRequest{ - RoomID: roomID, - StateTuples: []gomatrixserverlib.StateKeyTuple{ - avatarTuple, nameTuple, canonicalTuple, topicTuple, guestTuple, visibilityTuple, joinRuleTuple, - }, - }, currentState); err != nil { - logrus.WithField("roomID", roomID).WithError(err).Error("unable to fetch current state") - return nil, true, err - } - - // get all room members - roomMembers := roomserver.QueryMembershipsForRoomResponse{} - if err := rsAPI.QueryMembershipsForRoom(ctx, &roomserver.QueryMembershipsForRoomRequest{ - RoomID: roomID, - Sender: userID, - }, &roomMembers); err != nil { - logrus.WithField("roomID", roomID).WithError(err).Error("unable to fetch room members") - return nil, true, err - } - - state := []gomatrixserverlib.ClientEvent{} - for _, ev := range roomMembers.JoinEvents { - state = append(state, ev) - } - - membershipRes := roomserver.QueryMembershipForUserResponse{} - membershipReq := roomserver.QueryMembershipForUserRequest{UserID: userID, RoomID: roomID} - if err := rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes); err != nil { - return nil, true, err - } - - for tuple, event := range currentState.StateEvents { - // check that the user is allowed to view the context - if tuple == visibilityTuple { - hisVis, err := event.HistoryVisibility() - if err != nil { - return nil, true, err - } - allowed := hisVis != "world_readable" && membershipRes.Membership == "join" - if !allowed { - return nil, false, nil - } - } - state = append(state, gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatAll)) - } - return state, true, nil -} - -func parseContextParams(req *http.Request) (limit int, filter *gomatrixserverlib.RoomEventFilter, err error) { l := req.URL.Query().Get("limit") f := req.URL.Query().Get("filter") - limit = 10 if l != "" { - limit, err = strconv.Atoi(l) + limit, err := strconv.Atoi(l) if err != nil { - return 0, filter, err + return nil, err } // not in the spec, but feels like a good idea to have an upper bound limit if limit > 100 { limit = 100 } + filter.Limit = limit } if f != "" { if err := json.Unmarshal([]byte(f), &filter); err != nil { - return 0, filter, err + return nil, err } } - return limit, filter, nil + + return filter, nil } diff --git a/syncapi/routing/context_test.go b/syncapi/routing/context_test.go index 72c67ccbc..1b430d83a 100644 --- a/syncapi/routing/context_test.go +++ b/syncapi/routing/context_test.go @@ -19,30 +19,28 @@ func Test_parseContextParams(t *testing.T) { tests := []struct { name string req *http.Request - wantLimit int wantFilter *gomatrixserverlib.RoomEventFilter wantErr bool }{ { - name: "no params set", - req: noParamsReq, - wantLimit: 10, + name: "no params set", + req: noParamsReq, + wantFilter: &gomatrixserverlib.RoomEventFilter{Limit: 10}, }, { - name: "limit 2 param set", - req: limit2Req, - wantLimit: 2, + name: "limit 2 param set", + req: limit2Req, + wantFilter: &gomatrixserverlib.RoomEventFilter{Limit: 2}, }, { - name: "limit 10000 param set", - req: limit10000Req, - wantLimit: 100, + name: "limit 10000 param set", + req: limit10000Req, + wantFilter: &gomatrixserverlib.RoomEventFilter{Limit: 100}, }, { name: "filter lazy_load_members param set", req: lazyLoadReq, - wantLimit: 2, - wantFilter: &gomatrixserverlib.RoomEventFilter{LazyLoadMembers: true}, + wantFilter: &gomatrixserverlib.RoomEventFilter{Limit: 2, LazyLoadMembers: true}, }, { name: "invalid limit req", @@ -57,14 +55,11 @@ func Test_parseContextParams(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotLimit, gotFilter, err := parseContextParams(tt.req) + gotFilter, err := parseContextParams(tt.req) if (err != nil) != tt.wantErr { t.Errorf("parseContextParams() error = %v, wantErr %v", err, tt.wantErr) return } - if gotLimit != tt.wantLimit { - t.Errorf("parseContextParams() gotLimit = %v, want %v", gotLimit, tt.wantLimit) - } if !reflect.DeepEqual(gotFilter, tt.wantFilter) { t.Errorf("parseContextParams() gotFilter = %v, want %v", gotFilter, tt.wantFilter) }