diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index e5a8d0b81..63dcaa413 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -1118,18 +1118,4 @@ func Setup( return SetReceipt(req, eduAPI, device, vars["roomId"], vars["receiptType"], vars["eventId"]) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomId}/context/{eventId}", - httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) - if err != nil { - return util.ErrorResponse(err) - } - - return Context( - req, device, - rsAPI, userAPI, - vars["roomId"], vars["eventId"], - ) - }), - ).Methods(http.MethodGet, http.MethodOptions) } diff --git a/clientapi/routing/context.go b/syncapi/routing/context.go similarity index 69% rename from clientapi/routing/context.go rename to syncapi/routing/context.go index 7ba7713d9..b8afcd0ba 100644 --- a/clientapi/routing/context.go +++ b/syncapi/routing/context.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" roomserver "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/syncapi/storage" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -35,13 +36,13 @@ type ContextRespsonse struct { EventsAfter []gomatrixserverlib.ClientEvent `json:"events_after,omitempty"` EventsBefore []gomatrixserverlib.ClientEvent `json:"events_before,omitempty"` Start string `json:"start,omitempty"` - State []gomatrixserverlib.ClientEvent `json:"state,omitempty"` + State []gomatrixserverlib.ClientEvent `json:"state"` } func Context( req *http.Request, device *userapi.Device, rsAPI roomserver.RoomserverInternalAPI, - userAPI userapi.UserInternalAPI, + syncDB storage.Database, roomID, eventID string, ) util.JSONResponse { limit, filter, err := parseContextParams(req) @@ -60,6 +61,12 @@ func Context( } } ctx := req.Context() + membershipRes := roomserver.QueryMembershipForUserResponse{} + membershipReq := roomserver.QueryMembershipForUserRequest{UserID: device.UserID, RoomID: roomID} + if err := rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes); err != nil { + logrus.WithError(err).Error("unable to fo membership") + return jsonerror.InternalServerError() + } state, userAllowed, err := getCurrentState(ctx, rsAPI, roomID, device.UserID) if err != nil { @@ -71,41 +78,76 @@ func Context( JSON: jsonerror.Forbidden("User is not allowed to query contenxt"), } } - - requestedEvent := &roomserver.QueryEventsByIDResponse{} - if err := rsAPI.QueryEventsByID(ctx, &roomserver.QueryEventsByIDRequest{ - EventIDs: []string{eventID}, - }, requestedEvent); err != nil { + id, requestedEvent, err := syncDB.SelectContextEvent(ctx, roomID, eventID) + if err != nil { + logrus.WithError(err).WithField("eventID", eventID).Error("unable to find requested event") return jsonerror.InternalServerError() } - if requestedEvent.Events == nil || len(requestedEvent.Events) == 0 { - logrus.WithField("eventID", eventID).Error("unable to find requested event") - return jsonerror.InternalServerError() - } - // this should be safe now - event := requestedEvent.Events[0] - eventsBefore, err := queryEventsBefore(rsAPI, ctx, event.PrevEventIDs(), limit) + eventsBefore, err := syncDB.SelectContextBeforeEvent(ctx, id, roomID, limit/2) if err != nil && err != sql.ErrNoRows { logrus.WithError(err).Error("unable to fetch before events") return jsonerror.InternalServerError() } - eventsAfter, err := queryEventsAfter(rsAPI, ctx, event.EventID(), limit) - if err != nil { + _, eventsAfter, err := syncDB.SelectContextAfterEvent(ctx, id, roomID, limit/2) + if err != nil && err != sql.ErrNoRows { logrus.WithError(err).Error("unable to fetch after events") return jsonerror.InternalServerError() } - state = applyLazyLoadMembers(filter, eventsAfter, eventsBefore, state) + /*excludeEventIDs, err := syncDB.SelectEventIDsAfter(ctx, roomID, lastID) + if err != nil { + logrus.WithError(err).Error("unable to fetch excludeEventIDs") + return jsonerror.InternalServerError() + } + + stateFilter := gomatrixserverlib.StateFilter{Limit: 100, Rooms: []string{roomID}} + if filter != nil { + stateFilter = gomatrixserverlib.StateFilter{ + Limit: filter.Limit, + NotSenders: filter.NotSenders, + NotTypes: filter.NotTypes, + Senders: filter.Senders, + Types: filter.Types, + LazyLoadMembers: filter.LazyLoadMembers, + IncludeRedundantMembers: filter.IncludeRedundantMembers, + NotRooms: filter.NotRooms, + Rooms: filter.Rooms, + ContainsURL: filter.ContainsURL, + } + } + _ = stateFilter + _ = excludeEventIDs + + sstate, _ := syncDB.CurrentState(ctx, roomID, &stateFilter, nil) + for _, x := range sstate { + hisVis, err := x.HistoryVisibility() + if err != nil { + continue + } + allowed := hisVis != "world_readable" && membershipRes.Membership == "join" + logrus.Debugf("State: %+v %+v %+v", x.Type(), hisVis, allowed) + if !allowed { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("User is not allowed to query context"), + } + } + } + */ + + eventsBeforeClient := gomatrixserverlib.HeaderedToClientEvents(eventsBefore, gomatrixserverlib.FormatAll) + eventsAfterClient := gomatrixserverlib.HeaderedToClientEvents(eventsAfter, gomatrixserverlib.FormatAll) + newState := applyLazyLoadMembers(filter, eventsAfterClient, eventsBeforeClient, state) response := ContextRespsonse{ End: "end", - Event: gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatAll), - EventsAfter: eventsAfter, - EventsBefore: eventsBefore, + Event: gomatrixserverlib.HeaderedToClientEvent(&requestedEvent, gomatrixserverlib.FormatAll), + EventsAfter: eventsAfterClient, + EventsBefore: eventsBeforeClient, Start: "start", - State: state, + State: newState, } return util.JSONResponse{ @@ -116,6 +158,7 @@ func Context( func applyLazyLoadMembers(filter *gomatrixserverlib.RoomEventFilter, eventsAfter, eventsBefore []gomatrixserverlib.ClientEvent, state []gomatrixserverlib.ClientEvent) []gomatrixserverlib.ClientEvent { if filter == nil || !filter.LazyLoadMembers { + logrus.Debugf("filter is nil or lazyloadmembers is false") return state } allEvents := append(eventsAfter, eventsBefore...) @@ -157,6 +200,7 @@ func getCurrentState(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI joinRuleTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomJoinRules, StateKey: ""} // get the current state + currentState := &roomserver.QueryCurrentStateResponse{} if err := rsAPI.QueryCurrentState(ctx, &roomserver.QueryCurrentStateRequest{ RoomID: roomID, @@ -206,72 +250,6 @@ func getCurrentState(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI return state, true, nil } -// queryEventsAfter retrieves events that happened after a list of events. -// The function returns once the limit is reached or no new events can be found. -// TODO: inefficient -func queryEventsAfter( - rsAPI roomserver.RoomserverInternalAPI, - ctx context.Context, - eventID string, - limit int, -) ([]gomatrixserverlib.ClientEvent, error) { - result := []gomatrixserverlib.ClientEvent{} - for { - res := &roomserver.QueryEventsAfterEventIDesponse{} - if err := rsAPI.QueryEventsAfter(ctx, &roomserver.QueryEventsAfterEventIDRequest{EventIDs: eventID}, res); err != nil { - if err == sql.ErrNoRows { - return result, nil - } - return nil, err - } - if len(res.Events) > 0 { - for _, ev := range res.Events { - result = append(result, *ev) - eventID = ev.EventID - } - } - } -} - -// queryEventsBefore retrieves events that happened before a list of events. -// The function returns once the limit is reached or no new prevEvents can be found. -// TODO: inefficient -func queryEventsBefore( - rsAPI roomserver.RoomserverInternalAPI, - ctx context.Context, - prevEventIDs []string, - limit int, -) ([]gomatrixserverlib.ClientEvent, error) { - // query prev events - eventIDs := prevEventIDs - result := []*gomatrixserverlib.HeaderedEvent{} - for len(eventIDs) > 0 { - prevEvents := &roomserver.QueryEventsByIDResponse{} - if err := rsAPI.QueryEventsByID(ctx, &roomserver.QueryEventsByIDRequest{ - EventIDs: eventIDs, - }, prevEvents); err != nil { - return gomatrixserverlib.HeaderedToClientEvents(result, gomatrixserverlib.FormatAll), err - } - // we didn't receive any events, return - if len(prevEvents.Events) == 0 { - return gomatrixserverlib.HeaderedToClientEvents(result, gomatrixserverlib.FormatAll), nil - } - // clear eventIDs to search for - eventIDs = []string{} - // append found events to result - for _, ev := range prevEvents.Events { - result = append(result, ev) - if len(result) >= limit { - return gomatrixserverlib.HeaderedToClientEvents(result, gomatrixserverlib.FormatAll), nil - } - // add prev to new eventIDs - eventIDs = append(eventIDs, ev.PrevEventIDs()...) - } - } - - return gomatrixserverlib.HeaderedToClientEvents(result, gomatrixserverlib.FormatAll), nil -} - func parseContextParams(req *http.Request) (limit int, filter *gomatrixserverlib.RoomEventFilter, err error) { l := req.URL.Query().Get("limit") f := req.URL.Query().Get("filter") diff --git a/clientapi/routing/context_test.go b/syncapi/routing/context_test.go similarity index 100% rename from clientapi/routing/context_test.go rename to syncapi/routing/context_test.go diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index 005a33555..be366ba10 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -77,4 +77,19 @@ func Setup( v3mux.Handle("/keys/changes", httputil.MakeAuthAPI("keys_changes", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return srp.OnIncomingKeyChangeRequest(req, device) })).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/rooms/{roomId}/context/{eventId}", + httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + + return Context( + req, device, + rsAPI, syncDB, + vars["roomId"], vars["eventId"], + ) + }), + ).Methods(http.MethodGet, http.MethodOptions) }