diff --git a/roomserver/state/state.go b/roomserver/state/state.go index 7f2c0bd68..0da32d29b 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -822,10 +822,14 @@ func (v *StateResolution) resolveConflictsV2( key := conflictedEvent.EventID() // Store the newly found auth events in the auth set for this event. - authSets[key], err = v.loadAuthEvents(ctx, conflictedEvent) + var authEventMap map[string]types.StateEntry + authSets[key], authEventMap, err = v.loadAuthEvents(ctx, conflictedEvent) if err != nil { return nil, err } + for k, v := range authEventMap { + eventIDMap[k] = v + } // Only add auth events into the authEvents slice once, otherwise the // check for the auth difference can become expensive and produce @@ -971,13 +975,14 @@ func (v *StateResolution) loadStateEvents( return result, eventIDMap, nil } -// loadAuthEvents loads all of the auth events for a given event recursively. +// loadAuthEvents loads all of the auth events for a given event recursively, +// along with a map that contains state entries for all of the auth events. func (v *StateResolution) loadAuthEvents( ctx context.Context, event *gomatrixserverlib.Event, -) ([]*gomatrixserverlib.Event, error) { +) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) { eventMap := map[string]struct{}{} - var getEvents func(eventIDs []string) ([]*gomatrixserverlib.Event, error) - getEvents = func(eventIDs []string) ([]*gomatrixserverlib.Event, error) { + var getEvents func(eventIDs []string) ([]types.Event, error) + getEvents = func(eventIDs []string) ([]types.Event, error) { lookup := make([]string, 0, len(event.AuthEventIDs())) for _, eventID := range eventIDs { if _, ok := eventMap[eventID]; ok { @@ -992,19 +997,54 @@ func (v *StateResolution) loadAuthEvents( if err != nil { return nil, fmt.Errorf("v.db.EventsFromIDs: %w", err) } - result := make([]*gomatrixserverlib.Event, 0, len(events)) - for _, event := range events { - result = append(result, event.Event) - eventMap[event.EventID()] = struct{}{} - next, err := getEvents(event.AuthEventIDs()) - if err != nil { - return nil, err - } - result = append(result, next...) + eventMap[event.EventID()] = struct{}{} + next, err := getEvents(event.AuthEventIDs()) + if err != nil { + return nil, err } - return result, nil + return append(events, next...), nil } - return getEvents(event.AuthEventIDs()) + authEvents, err := getEvents(event.AuthEventIDs()) + if err != nil { + return nil, nil, fmt.Errorf("getEvents: %w", err) + } + authEventTypes := map[string]struct{}{} + authEventStateKeys := map[string]struct{}{} + for _, authEvent := range authEvents { + authEventTypes[authEvent.Type()] = struct{}{} + authEventStateKeys[*authEvent.StateKey()] = struct{}{} + } + lookupAuthEventTypes := make([]string, 0, len(authEventTypes)) + lookupAuthEventStateKeys := make([]string, 0, len(authEventStateKeys)) + for eventType := range authEventTypes { + lookupAuthEventTypes = append(lookupAuthEventTypes, eventType) + } + for eventStateKey := range authEventStateKeys { + lookupAuthEventStateKeys = append(lookupAuthEventStateKeys, eventStateKey) + } + eventTypes, err := v.db.EventTypeNIDs(ctx, lookupAuthEventTypes) + if err != nil { + return nil, nil, fmt.Errorf("v.db.EventTypeNIDs: %w", err) + } + eventStateKeys, err := v.db.EventStateKeyNIDs(ctx, lookupAuthEventStateKeys) + if err != nil { + return nil, nil, fmt.Errorf("v.db.EventStateKeyNIDs: %w", err) + } + stateEntryMap := map[string]types.StateEntry{} + for _, authEvent := range authEvents { + stateEntryMap[authEvent.EventID()] = types.StateEntry{ + EventNID: authEvent.EventNID, + StateKeyTuple: types.StateKeyTuple{ + EventTypeNID: eventTypes[authEvent.Type()], + EventStateKeyNID: eventStateKeys[*authEvent.StateKey()], + }, + } + } + nakedEvents := make([]*gomatrixserverlib.Event, 0, len(authEvents)) + for _, authEvent := range authEvents { + nakedEvents = append(nakedEvents, authEvent.Event) + } + return nakedEvents, stateEntryMap, nil } // findDuplicateStateKeys finds the state entries where the state key tuple appears more than once in a sorted list.