From e24dcaa205b632295c2462bb13ce0eeb1afe41c7 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Fri, 22 Jul 2022 07:15:34 +0200 Subject: [PATCH] Add method to fetch the membership at a given event --- roomserver/api/api.go | 6 ++++ roomserver/api/api_trace.go | 10 ++++++ roomserver/api/query.go | 10 ++++++ roomserver/internal/helpers/helpers.go | 6 ++++ roomserver/internal/query/query.go | 46 ++++++++++++++++++++++++++ roomserver/inthttp/client.go | 10 +++++- roomserver/inthttp/server.go | 13 ++++++++ roomserver/state/state.go | 46 ++++++++++++++++++++++++-- 8 files changed, 143 insertions(+), 4 deletions(-) diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 38baa617f..3ca2c565f 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -97,6 +97,12 @@ type SyncRoomserverAPI interface { req *PerformBackfillRequest, res *PerformBackfillResponse, ) error + + QueryMembershipAtEvent( + ctx context.Context, + request *QueryMembersipAtEventRequest, + response *QueryMembersipAtEventResponse, + ) error } type AppserviceRoomserverAPI interface { diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go index 211f320ff..81532e91a 100644 --- a/roomserver/api/api_trace.go +++ b/roomserver/api/api_trace.go @@ -373,6 +373,16 @@ func (t *RoomserverInternalAPITrace) QueryRestrictedJoinAllowed( return err } +func (t *RoomserverInternalAPITrace) QueryMembershipAtEvent( + ctx context.Context, + request *QueryMembersipAtEventRequest, + response *QueryMembersipAtEventResponse, +) error { + err := t.Impl.QueryMembershipAtEvent(ctx, request, response) + util.GetLogger(ctx).WithError(err).Infof("QueryMembershipAtEvent req=%+v res=%+v", js(request), js(response)) + return err +} + func js(thing interface{}) string { b, err := json.Marshal(thing) if err != nil { diff --git a/roomserver/api/query.go b/roomserver/api/query.go index f157a9025..8c91592e7 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -427,3 +427,13 @@ func (r *QueryCurrentStateResponse) UnmarshalJSON(data []byte) error { } return nil } + +type QueryMembersipAtEventRequest struct { + RoomID string + EventIDs []string + UserID string +} + +type QueryMembersipAtEventResponse struct { + Memberships map[string][]*gomatrixserverlib.HeaderedEvent `json:"memberships"` +} diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index e67bbfcaa..039f33f27 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -197,6 +197,12 @@ func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.Room return roomState.LoadCombinedStateAfterEvents(ctx, prevState) } +func MembershipAtEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, stateKeyNID types.EventStateKeyNID) ([]types.StateEntry, error) { + roomState := state.NewStateResolution(db, info) + // Fetch the state as it was when this event was fired + return roomState.LoadMembershipAtEvent(ctx, eventID, stateKeyNID) +} + func LoadEvents( ctx context.Context, db storage.Database, eventNIDs []types.EventNID, ) ([]*gomatrixserverlib.Event, error) { diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index da1b32530..ce3dfe2a2 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -203,6 +203,52 @@ func (r *Queryer) QueryMembershipForUser( return err } +func (r *Queryer) QueryMembershipAtEvent( + ctx context.Context, + request *api.QueryMembersipAtEventRequest, + response *api.QueryMembersipAtEventResponse, +) error { + response.Memberships = make(map[string][]*gomatrixserverlib.HeaderedEvent) + info, err := r.DB.RoomInfo(ctx, request.RoomID) + if err != nil { + return fmt.Errorf("unable to get roomInfo: %w", err) + } + if info == nil { + return fmt.Errorf("no roomInfo found") + } + + // get the users stateKeyNID + stateKeyNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{request.UserID}) + if err != nil { + return fmt.Errorf("unable to get stateKeyNIDs for %s: %w", request.UserID, err) + } + if _, ok := stateKeyNIDs[request.UserID]; !ok { + return fmt.Errorf("requested stateKeyNID for %s was not found", request.UserID) + } + + for _, eventID := range request.EventIDs { + stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, info, eventID, stateKeyNIDs[request.UserID]) + if err != nil { + return fmt.Errorf("unable to get state before event: %w", err) + } + memberships, err := helpers.GetMembershipsAtState(ctx, r.DB, stateEntries, false) + if err != nil { + return fmt.Errorf("unable to get memberships at state: %w", err) + } + res := make([]*gomatrixserverlib.HeaderedEvent, 0, len(memberships)) + + for i := range memberships { + ev := memberships[i] + if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(request.UserID) { + res = append(res, ev.Headered(info.RoomVersion)) + } + } + response.Memberships[eventID] = res + } + + return nil +} + // QueryMembershipsForRoom implements api.RoomserverInternalAPI func (r *Queryer) QueryMembershipsForRoom( ctx context.Context, diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index 2fa8afc49..465d00f47 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -12,7 +12,6 @@ import ( "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/roomserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" "github.com/opentracing/opentracing-go" ) @@ -63,6 +62,7 @@ const ( RoomserverQueryServerBannedFromRoomPath = "/roomserver/queryServerBannedFromRoom" RoomserverQueryAuthChainPath = "/roomserver/queryAuthChain" RoomserverQueryRestrictedJoinAllowed = "/roomserver/queryRestrictedJoinAllowed" + RoomserverQueryMembershipAtEventPath = "/roomserver/queryMembershipAtEvent" ) type httpRoomserverInternalAPI struct { @@ -594,3 +594,11 @@ func (h *httpRoomserverInternalAPI) PerformForget(ctx context.Context, req *api. return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) } + +func (h *httpRoomserverInternalAPI) QueryMembershipAtEvent(ctx context.Context, req *api.QueryMembersipAtEventRequest, res *api.QueryMembersipAtEventResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMembershiptAtEvent") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryMembershipAtEventPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go index 993381585..f9b9a0bd8 100644 --- a/roomserver/inthttp/server.go +++ b/roomserver/inthttp/server.go @@ -496,4 +496,17 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(RoomserverQueryMembershipAtEventPath, + httputil.MakeInternalAPI("queryMembershipAtEventPath", func(req *http.Request) util.JSONResponse { + request := api.QueryMembersipAtEventRequest{} + response := api.QueryMembersipAtEventResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := r.QueryMembershipAtEvent(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) } diff --git a/roomserver/state/state.go b/roomserver/state/state.go index d1d24b099..ca79cd28d 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -23,12 +23,11 @@ import ( "sync" "time" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/opentracing/opentracing-go" "github.com/prometheus/client_golang/prometheus" - - "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/gomatrixserverlib" ) type StateResolutionStorage interface { @@ -124,6 +123,47 @@ func (v *StateResolution) LoadStateAtEvent( return stateEntries, nil } +func (v *StateResolution) LoadMembershipAtEvent( + ctx context.Context, eventID string, stateKeyNID types.EventStateKeyNID, +) ([]types.StateEntry, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadMembershipAtEvent") + defer span.Finish() + + snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID) + if err != nil { + return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %w", eventID, err) + } + if snapshotNID == 0 { + return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID(%s) returned 0 NID, was this event stored?", eventID) + } + + stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{snapshotNID}) + if err != nil { + return nil, err + } + // We've asked for exactly one snapshot from the db so we should have exactly one entry in the result. + stateBlockNIDList := stateBlockNIDLists[0] + + // Query the membership event for the user at the given stateblocks + stateEntryLists, err := v.db.StateEntriesForTuples(ctx, stateBlockNIDList.StateBlockNIDs, []types.StateKeyTuple{ + { + EventTypeNID: types.MRoomMemberNID, + EventStateKeyNID: stateKeyNID, + }, + }) + if err != nil { + return nil, err + } + + var result []types.StateEntry + for _, x := range stateEntryLists { + if len(x.StateEntries) > 0 { + result = append(result, x.StateEntries...) + } + } + return result, nil +} + // LoadCombinedStateAfterEvents loads a snapshot of the state after each of the events // and combines those snapshots together into a single list. At this point it is // possible to run into duplicate (type, state key) tuples.