Add method to fetch the membership at a given event

This commit is contained in:
Till Faelligen 2022-07-22 07:15:34 +02:00
parent bde6f67932
commit e24dcaa205
No known key found for this signature in database
GPG key ID: 3DF82D8AB9211D4E
8 changed files with 143 additions and 4 deletions

View file

@ -97,6 +97,12 @@ type SyncRoomserverAPI interface {
req *PerformBackfillRequest,
res *PerformBackfillResponse,
) error
QueryMembershipAtEvent(
ctx context.Context,
request *QueryMembersipAtEventRequest,
response *QueryMembersipAtEventResponse,
) error
}
type AppserviceRoomserverAPI interface {

View file

@ -373,6 +373,16 @@ func (t *RoomserverInternalAPITrace) QueryRestrictedJoinAllowed(
return err
}
func (t *RoomserverInternalAPITrace) QueryMembershipAtEvent(
ctx context.Context,
request *QueryMembersipAtEventRequest,
response *QueryMembersipAtEventResponse,
) error {
err := t.Impl.QueryMembershipAtEvent(ctx, request, response)
util.GetLogger(ctx).WithError(err).Infof("QueryMembershipAtEvent req=%+v res=%+v", js(request), js(response))
return err
}
func js(thing interface{}) string {
b, err := json.Marshal(thing)
if err != nil {

View file

@ -427,3 +427,13 @@ func (r *QueryCurrentStateResponse) UnmarshalJSON(data []byte) error {
}
return nil
}
type QueryMembersipAtEventRequest struct {
RoomID string
EventIDs []string
UserID string
}
type QueryMembersipAtEventResponse struct {
Memberships map[string][]*gomatrixserverlib.HeaderedEvent `json:"memberships"`
}

View file

@ -197,6 +197,12 @@ func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.Room
return roomState.LoadCombinedStateAfterEvents(ctx, prevState)
}
func MembershipAtEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, stateKeyNID types.EventStateKeyNID) ([]types.StateEntry, error) {
roomState := state.NewStateResolution(db, info)
// Fetch the state as it was when this event was fired
return roomState.LoadMembershipAtEvent(ctx, eventID, stateKeyNID)
}
func LoadEvents(
ctx context.Context, db storage.Database, eventNIDs []types.EventNID,
) ([]*gomatrixserverlib.Event, error) {

View file

@ -203,6 +203,52 @@ func (r *Queryer) QueryMembershipForUser(
return err
}
func (r *Queryer) QueryMembershipAtEvent(
ctx context.Context,
request *api.QueryMembersipAtEventRequest,
response *api.QueryMembersipAtEventResponse,
) error {
response.Memberships = make(map[string][]*gomatrixserverlib.HeaderedEvent)
info, err := r.DB.RoomInfo(ctx, request.RoomID)
if err != nil {
return fmt.Errorf("unable to get roomInfo: %w", err)
}
if info == nil {
return fmt.Errorf("no roomInfo found")
}
// get the users stateKeyNID
stateKeyNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{request.UserID})
if err != nil {
return fmt.Errorf("unable to get stateKeyNIDs for %s: %w", request.UserID, err)
}
if _, ok := stateKeyNIDs[request.UserID]; !ok {
return fmt.Errorf("requested stateKeyNID for %s was not found", request.UserID)
}
for _, eventID := range request.EventIDs {
stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, info, eventID, stateKeyNIDs[request.UserID])
if err != nil {
return fmt.Errorf("unable to get state before event: %w", err)
}
memberships, err := helpers.GetMembershipsAtState(ctx, r.DB, stateEntries, false)
if err != nil {
return fmt.Errorf("unable to get memberships at state: %w", err)
}
res := make([]*gomatrixserverlib.HeaderedEvent, 0, len(memberships))
for i := range memberships {
ev := memberships[i]
if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(request.UserID) {
res = append(res, ev.Headered(info.RoomVersion))
}
}
response.Memberships[eventID] = res
}
return nil
}
// QueryMembershipsForRoom implements api.RoomserverInternalAPI
func (r *Queryer) QueryMembershipsForRoom(
ctx context.Context,

View file

@ -12,7 +12,6 @@ import (
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/roomserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/opentracing/opentracing-go"
)
@ -63,6 +62,7 @@ const (
RoomserverQueryServerBannedFromRoomPath = "/roomserver/queryServerBannedFromRoom"
RoomserverQueryAuthChainPath = "/roomserver/queryAuthChain"
RoomserverQueryRestrictedJoinAllowed = "/roomserver/queryRestrictedJoinAllowed"
RoomserverQueryMembershipAtEventPath = "/roomserver/queryMembershipAtEvent"
)
type httpRoomserverInternalAPI struct {
@ -594,3 +594,11 @@ func (h *httpRoomserverInternalAPI) PerformForget(ctx context.Context, req *api.
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpRoomserverInternalAPI) QueryMembershipAtEvent(ctx context.Context, req *api.QueryMembersipAtEventRequest, res *api.QueryMembersipAtEventResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMembershiptAtEvent")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverQueryMembershipAtEventPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}

View file

@ -496,4 +496,17 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(RoomserverQueryMembershipAtEventPath,
httputil.MakeInternalAPI("queryMembershipAtEventPath", func(req *http.Request) util.JSONResponse {
request := api.QueryMembersipAtEventRequest{}
response := api.QueryMembersipAtEventResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.QueryMembershipAtEvent(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
}

View file

@ -23,12 +23,11 @@ import (
"sync"
"time"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/opentracing/opentracing-go"
"github.com/prometheus/client_golang/prometheus"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
type StateResolutionStorage interface {
@ -124,6 +123,47 @@ func (v *StateResolution) LoadStateAtEvent(
return stateEntries, nil
}
func (v *StateResolution) LoadMembershipAtEvent(
ctx context.Context, eventID string, stateKeyNID types.EventStateKeyNID,
) ([]types.StateEntry, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadMembershipAtEvent")
defer span.Finish()
snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID)
if err != nil {
return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %w", eventID, err)
}
if snapshotNID == 0 {
return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID(%s) returned 0 NID, was this event stored?", eventID)
}
stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{snapshotNID})
if err != nil {
return nil, err
}
// We've asked for exactly one snapshot from the db so we should have exactly one entry in the result.
stateBlockNIDList := stateBlockNIDLists[0]
// Query the membership event for the user at the given stateblocks
stateEntryLists, err := v.db.StateEntriesForTuples(ctx, stateBlockNIDList.StateBlockNIDs, []types.StateKeyTuple{
{
EventTypeNID: types.MRoomMemberNID,
EventStateKeyNID: stateKeyNID,
},
})
if err != nil {
return nil, err
}
var result []types.StateEntry
for _, x := range stateEntryLists {
if len(x.StateEntries) > 0 {
result = append(result, x.StateEntries...)
}
}
return result, nil
}
// LoadCombinedStateAfterEvents loads a snapshot of the state after each of the events
// and combines those snapshots together into a single list. At this point it is
// possible to run into duplicate (type, state key) tuples.