Ensure we have state entries for the auth events

This commit is contained in:
Neil Alexander 2022-05-31 17:21:26 +01:00
parent 8034de544b
commit 9b13b7ed37
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944

View file

@ -822,10 +822,14 @@ func (v *StateResolution) resolveConflictsV2(
key := conflictedEvent.EventID() key := conflictedEvent.EventID()
// Store the newly found auth events in the auth set for this event. // Store the newly found auth events in the auth set for this event.
authSets[key], err = v.loadAuthEvents(ctx, conflictedEvent) var authEventMap map[string]types.StateEntry
authSets[key], authEventMap, err = v.loadAuthEvents(ctx, conflictedEvent)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for k, v := range authEventMap {
eventIDMap[k] = v
}
// Only add auth events into the authEvents slice once, otherwise the // Only add auth events into the authEvents slice once, otherwise the
// check for the auth difference can become expensive and produce // check for the auth difference can become expensive and produce
@ -971,13 +975,14 @@ func (v *StateResolution) loadStateEvents(
return result, eventIDMap, nil return result, eventIDMap, nil
} }
// loadAuthEvents loads all of the auth events for a given event recursively. // loadAuthEvents loads all of the auth events for a given event recursively,
// along with a map that contains state entries for all of the auth events.
func (v *StateResolution) loadAuthEvents( func (v *StateResolution) loadAuthEvents(
ctx context.Context, event *gomatrixserverlib.Event, ctx context.Context, event *gomatrixserverlib.Event,
) ([]*gomatrixserverlib.Event, error) { ) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) {
eventMap := map[string]struct{}{} eventMap := map[string]struct{}{}
var getEvents func(eventIDs []string) ([]*gomatrixserverlib.Event, error) var getEvents func(eventIDs []string) ([]types.Event, error)
getEvents = func(eventIDs []string) ([]*gomatrixserverlib.Event, error) { getEvents = func(eventIDs []string) ([]types.Event, error) {
lookup := make([]string, 0, len(event.AuthEventIDs())) lookup := make([]string, 0, len(event.AuthEventIDs()))
for _, eventID := range eventIDs { for _, eventID := range eventIDs {
if _, ok := eventMap[eventID]; ok { if _, ok := eventMap[eventID]; ok {
@ -992,19 +997,54 @@ func (v *StateResolution) loadAuthEvents(
if err != nil { if err != nil {
return nil, fmt.Errorf("v.db.EventsFromIDs: %w", err) return nil, fmt.Errorf("v.db.EventsFromIDs: %w", err)
} }
result := make([]*gomatrixserverlib.Event, 0, len(events))
for _, event := range events {
result = append(result, event.Event)
eventMap[event.EventID()] = struct{}{} eventMap[event.EventID()] = struct{}{}
next, err := getEvents(event.AuthEventIDs()) next, err := getEvents(event.AuthEventIDs())
if err != nil { if err != nil {
return nil, err return nil, err
} }
result = append(result, next...) return append(events, next...), nil
} }
return result, nil authEvents, err := getEvents(event.AuthEventIDs())
if err != nil {
return nil, nil, fmt.Errorf("getEvents: %w", err)
} }
return getEvents(event.AuthEventIDs()) authEventTypes := map[string]struct{}{}
authEventStateKeys := map[string]struct{}{}
for _, authEvent := range authEvents {
authEventTypes[authEvent.Type()] = struct{}{}
authEventStateKeys[*authEvent.StateKey()] = struct{}{}
}
lookupAuthEventTypes := make([]string, 0, len(authEventTypes))
lookupAuthEventStateKeys := make([]string, 0, len(authEventStateKeys))
for eventType := range authEventTypes {
lookupAuthEventTypes = append(lookupAuthEventTypes, eventType)
}
for eventStateKey := range authEventStateKeys {
lookupAuthEventStateKeys = append(lookupAuthEventStateKeys, eventStateKey)
}
eventTypes, err := v.db.EventTypeNIDs(ctx, lookupAuthEventTypes)
if err != nil {
return nil, nil, fmt.Errorf("v.db.EventTypeNIDs: %w", err)
}
eventStateKeys, err := v.db.EventStateKeyNIDs(ctx, lookupAuthEventStateKeys)
if err != nil {
return nil, nil, fmt.Errorf("v.db.EventStateKeyNIDs: %w", err)
}
stateEntryMap := map[string]types.StateEntry{}
for _, authEvent := range authEvents {
stateEntryMap[authEvent.EventID()] = types.StateEntry{
EventNID: authEvent.EventNID,
StateKeyTuple: types.StateKeyTuple{
EventTypeNID: eventTypes[authEvent.Type()],
EventStateKeyNID: eventStateKeys[*authEvent.StateKey()],
},
}
}
nakedEvents := make([]*gomatrixserverlib.Event, 0, len(authEvents))
for _, authEvent := range authEvents {
nakedEvents = append(nakedEvents, authEvent.Event)
}
return nakedEvents, stateEntryMap, nil
} }
// findDuplicateStateKeys finds the state entries where the state key tuple appears more than once in a sorted list. // findDuplicateStateKeys finds the state entries where the state key tuple appears more than once in a sorted list.