diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 70a84a9c5..b5f8d65ff 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -20,7 +20,6 @@ import ( "encoding/json" "errors" "fmt" - "time" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -240,17 +239,42 @@ func (r *Queryer) QueryMembershipAtEvent( return fmt.Errorf("unable to get state before event: %w", err) } - start := time.Now() + // If we only have one or less state entries, we can short circuit the below + // loop and avoid hitting the database + allStateEventNIDs := make(map[types.EventNID]types.StateEntry) for _, eventID := range request.EventIDs { + stateEntry := stateEntries[eventID] + for _, s := range stateEntry { + allStateEventNIDs[s.EventNID] = s + } + } + + var canShortCircuit bool + if len(allStateEventNIDs) <= 1 { + canShortCircuit = true + } + + var memberships []types.Event + for i, eventID := range request.EventIDs { stateEntry, ok := stateEntries[eventID] if !ok { response.Memberships[eventID] = []*gomatrixserverlib.HeaderedEvent{} continue } - memberships, err := helpers.GetMembershipsAtState(ctx, r.DB, stateEntry, false) + + // If we can short circuit, e.g. we only have 0 or 1 membership events, we only get the memberships + // once. If we have more than one membership event, we need to get the state for each state entry. + if canShortCircuit { + if i == 0 { + memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, stateEntry, false) + } + } else { + memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, stateEntry, false) + } if err != nil { return fmt.Errorf("unable to get memberships at state: %w", err) } + res := make([]*gomatrixserverlib.HeaderedEvent, 0, len(memberships)) for i := range memberships { @@ -261,7 +285,6 @@ func (r *Queryer) QueryMembershipAtEvent( } response.Memberships[eventID] = res } - logrus.Debugf("XXX: GetMembershipsAtState duration: %s", time.Since(start)) return nil } diff --git a/roomserver/state/state.go b/roomserver/state/state.go index f6ee2e437..83e048325 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -27,7 +27,6 @@ import ( "github.com/matrix-org/util" "github.com/opentracing/opentracing-go" "github.com/prometheus/client_golang/prometheus" - "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -132,13 +131,11 @@ func (v *StateResolution) LoadMembershipAtEvent( span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadMembershipAtEvent") defer span.Finish() - // De-dupe snapshotNIDs - start := time.Now() + // Get a mapping from snapshotNID -> eventIDs snapshotNIDMap, err := v.db.BulkSelectSnapshotsFromEventIDs(ctx, eventIDs) if err != nil { return nil, err } - logrus.Debugf("XXX: duration to lookup snapshot nids: %s", time.Since(start)) snapshotNIDs := make([]types.StateSnapshotNID, 0, len(snapshotNIDMap)) for nid := range snapshotNIDMap { @@ -155,7 +152,6 @@ func (v *StateResolution) LoadMembershipAtEvent( wantStateBlocks = append(wantStateBlocks, x.StateBlockNIDs...) } - start = time.Now() stateEntryLists, err := v.db.StateEntriesForTuples(ctx, uniqueStateBlockNIDs(wantStateBlocks), []types.StateKeyTuple{ { EventTypeNID: types.MRoomMemberNID, @@ -165,13 +161,11 @@ func (v *StateResolution) LoadMembershipAtEvent( if err != nil { return nil, err } - logrus.Debugf("XXX: duration to lookup StateEntriesForTuples: %s", time.Since(start)) stateBlockNIDsMap := stateBlockNIDListMap(stateBlockNIDLists) stateEntriesMap := stateEntryListMap(stateEntryLists) result := make(map[string][]types.StateEntry) - start = time.Now() for _, stateBlockNIDList := range stateBlockNIDLists { stateBlockNIDs, ok := stateBlockNIDsMap.lookup(stateBlockNIDList.StateSnapshotNID) if !ok { @@ -195,7 +189,6 @@ func (v *StateResolution) LoadMembershipAtEvent( } } } - logrus.Debugf("XXX: duration to generate list: %s", time.Since(start)) return result, nil }