mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-15 18:13:09 -06:00
Add method to fetch the membership at a given event
This commit is contained in:
parent
bde6f67932
commit
e24dcaa205
|
|
@ -97,6 +97,12 @@ type SyncRoomserverAPI interface {
|
||||||
req *PerformBackfillRequest,
|
req *PerformBackfillRequest,
|
||||||
res *PerformBackfillResponse,
|
res *PerformBackfillResponse,
|
||||||
) error
|
) error
|
||||||
|
|
||||||
|
QueryMembershipAtEvent(
|
||||||
|
ctx context.Context,
|
||||||
|
request *QueryMembersipAtEventRequest,
|
||||||
|
response *QueryMembersipAtEventResponse,
|
||||||
|
) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type AppserviceRoomserverAPI interface {
|
type AppserviceRoomserverAPI interface {
|
||||||
|
|
|
||||||
|
|
@ -373,6 +373,16 @@ func (t *RoomserverInternalAPITrace) QueryRestrictedJoinAllowed(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *RoomserverInternalAPITrace) QueryMembershipAtEvent(
|
||||||
|
ctx context.Context,
|
||||||
|
request *QueryMembersipAtEventRequest,
|
||||||
|
response *QueryMembersipAtEventResponse,
|
||||||
|
) error {
|
||||||
|
err := t.Impl.QueryMembershipAtEvent(ctx, request, response)
|
||||||
|
util.GetLogger(ctx).WithError(err).Infof("QueryMembershipAtEvent req=%+v res=%+v", js(request), js(response))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func js(thing interface{}) string {
|
func js(thing interface{}) string {
|
||||||
b, err := json.Marshal(thing)
|
b, err := json.Marshal(thing)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -427,3 +427,13 @@ func (r *QueryCurrentStateResponse) UnmarshalJSON(data []byte) error {
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type QueryMembersipAtEventRequest struct {
|
||||||
|
RoomID string
|
||||||
|
EventIDs []string
|
||||||
|
UserID string
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueryMembersipAtEventResponse struct {
|
||||||
|
Memberships map[string][]*gomatrixserverlib.HeaderedEvent `json:"memberships"`
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -197,6 +197,12 @@ func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.Room
|
||||||
return roomState.LoadCombinedStateAfterEvents(ctx, prevState)
|
return roomState.LoadCombinedStateAfterEvents(ctx, prevState)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func MembershipAtEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, stateKeyNID types.EventStateKeyNID) ([]types.StateEntry, error) {
|
||||||
|
roomState := state.NewStateResolution(db, info)
|
||||||
|
// Fetch the state as it was when this event was fired
|
||||||
|
return roomState.LoadMembershipAtEvent(ctx, eventID, stateKeyNID)
|
||||||
|
}
|
||||||
|
|
||||||
func LoadEvents(
|
func LoadEvents(
|
||||||
ctx context.Context, db storage.Database, eventNIDs []types.EventNID,
|
ctx context.Context, db storage.Database, eventNIDs []types.EventNID,
|
||||||
) ([]*gomatrixserverlib.Event, error) {
|
) ([]*gomatrixserverlib.Event, error) {
|
||||||
|
|
|
||||||
|
|
@ -203,6 +203,52 @@ func (r *Queryer) QueryMembershipForUser(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *Queryer) QueryMembershipAtEvent(
|
||||||
|
ctx context.Context,
|
||||||
|
request *api.QueryMembersipAtEventRequest,
|
||||||
|
response *api.QueryMembersipAtEventResponse,
|
||||||
|
) error {
|
||||||
|
response.Memberships = make(map[string][]*gomatrixserverlib.HeaderedEvent)
|
||||||
|
info, err := r.DB.RoomInfo(ctx, request.RoomID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to get roomInfo: %w", err)
|
||||||
|
}
|
||||||
|
if info == nil {
|
||||||
|
return fmt.Errorf("no roomInfo found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// get the users stateKeyNID
|
||||||
|
stateKeyNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{request.UserID})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to get stateKeyNIDs for %s: %w", request.UserID, err)
|
||||||
|
}
|
||||||
|
if _, ok := stateKeyNIDs[request.UserID]; !ok {
|
||||||
|
return fmt.Errorf("requested stateKeyNID for %s was not found", request.UserID)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, eventID := range request.EventIDs {
|
||||||
|
stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, info, eventID, stateKeyNIDs[request.UserID])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to get state before event: %w", err)
|
||||||
|
}
|
||||||
|
memberships, err := helpers.GetMembershipsAtState(ctx, r.DB, stateEntries, false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to get memberships at state: %w", err)
|
||||||
|
}
|
||||||
|
res := make([]*gomatrixserverlib.HeaderedEvent, 0, len(memberships))
|
||||||
|
|
||||||
|
for i := range memberships {
|
||||||
|
ev := memberships[i]
|
||||||
|
if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(request.UserID) {
|
||||||
|
res = append(res, ev.Headered(info.RoomVersion))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
response.Memberships[eventID] = res
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// QueryMembershipsForRoom implements api.RoomserverInternalAPI
|
// QueryMembershipsForRoom implements api.RoomserverInternalAPI
|
||||||
func (r *Queryer) QueryMembershipsForRoom(
|
func (r *Queryer) QueryMembershipsForRoom(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,6 @@ import (
|
||||||
"github.com/matrix-org/dendrite/internal/httputil"
|
"github.com/matrix-org/dendrite/internal/httputil"
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/opentracing/opentracing-go"
|
"github.com/opentracing/opentracing-go"
|
||||||
)
|
)
|
||||||
|
|
@ -63,6 +62,7 @@ const (
|
||||||
RoomserverQueryServerBannedFromRoomPath = "/roomserver/queryServerBannedFromRoom"
|
RoomserverQueryServerBannedFromRoomPath = "/roomserver/queryServerBannedFromRoom"
|
||||||
RoomserverQueryAuthChainPath = "/roomserver/queryAuthChain"
|
RoomserverQueryAuthChainPath = "/roomserver/queryAuthChain"
|
||||||
RoomserverQueryRestrictedJoinAllowed = "/roomserver/queryRestrictedJoinAllowed"
|
RoomserverQueryRestrictedJoinAllowed = "/roomserver/queryRestrictedJoinAllowed"
|
||||||
|
RoomserverQueryMembershipAtEventPath = "/roomserver/queryMembershipAtEvent"
|
||||||
)
|
)
|
||||||
|
|
||||||
type httpRoomserverInternalAPI struct {
|
type httpRoomserverInternalAPI struct {
|
||||||
|
|
@ -594,3 +594,11 @@ func (h *httpRoomserverInternalAPI) PerformForget(ctx context.Context, req *api.
|
||||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *httpRoomserverInternalAPI) QueryMembershipAtEvent(ctx context.Context, req *api.QueryMembersipAtEventRequest, res *api.QueryMembersipAtEventResponse) error {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMembershiptAtEvent")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.roomserverURL + RoomserverQueryMembershipAtEventPath
|
||||||
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -496,4 +496,17 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) {
|
||||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
internalAPIMux.Handle(RoomserverQueryMembershipAtEventPath,
|
||||||
|
httputil.MakeInternalAPI("queryMembershipAtEventPath", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.QueryMembersipAtEventRequest{}
|
||||||
|
response := api.QueryMembersipAtEventResponse{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
if err := r.QueryMembershipAtEvent(req.Context(), &request, &response); err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -23,12 +23,11 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
"github.com/opentracing/opentracing-go"
|
"github.com/opentracing/opentracing-go"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type StateResolutionStorage interface {
|
type StateResolutionStorage interface {
|
||||||
|
|
@ -124,6 +123,47 @@ func (v *StateResolution) LoadStateAtEvent(
|
||||||
return stateEntries, nil
|
return stateEntries, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (v *StateResolution) LoadMembershipAtEvent(
|
||||||
|
ctx context.Context, eventID string, stateKeyNID types.EventStateKeyNID,
|
||||||
|
) ([]types.StateEntry, error) {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadMembershipAtEvent")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %w", eventID, err)
|
||||||
|
}
|
||||||
|
if snapshotNID == 0 {
|
||||||
|
return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID(%s) returned 0 NID, was this event stored?", eventID)
|
||||||
|
}
|
||||||
|
|
||||||
|
stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{snapshotNID})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// We've asked for exactly one snapshot from the db so we should have exactly one entry in the result.
|
||||||
|
stateBlockNIDList := stateBlockNIDLists[0]
|
||||||
|
|
||||||
|
// Query the membership event for the user at the given stateblocks
|
||||||
|
stateEntryLists, err := v.db.StateEntriesForTuples(ctx, stateBlockNIDList.StateBlockNIDs, []types.StateKeyTuple{
|
||||||
|
{
|
||||||
|
EventTypeNID: types.MRoomMemberNID,
|
||||||
|
EventStateKeyNID: stateKeyNID,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var result []types.StateEntry
|
||||||
|
for _, x := range stateEntryLists {
|
||||||
|
if len(x.StateEntries) > 0 {
|
||||||
|
result = append(result, x.StateEntries...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
// LoadCombinedStateAfterEvents loads a snapshot of the state after each of the events
|
// LoadCombinedStateAfterEvents loads a snapshot of the state after each of the events
|
||||||
// and combines those snapshots together into a single list. At this point it is
|
// and combines those snapshots together into a single list. At this point it is
|
||||||
// possible to run into duplicate (type, state key) tuples.
|
// possible to run into duplicate (type, state key) tuples.
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue