Ensure we have state entries for the auth events
This commit is contained in:
parent
8034de544b
commit
9b13b7ed37
|
@ -822,10 +822,14 @@ func (v *StateResolution) resolveConflictsV2(
|
||||||
key := conflictedEvent.EventID()
|
key := conflictedEvent.EventID()
|
||||||
|
|
||||||
// Store the newly found auth events in the auth set for this event.
|
// Store the newly found auth events in the auth set for this event.
|
||||||
authSets[key], err = v.loadAuthEvents(ctx, conflictedEvent)
|
var authEventMap map[string]types.StateEntry
|
||||||
|
authSets[key], authEventMap, err = v.loadAuthEvents(ctx, conflictedEvent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
for k, v := range authEventMap {
|
||||||
|
eventIDMap[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
// Only add auth events into the authEvents slice once, otherwise the
|
// Only add auth events into the authEvents slice once, otherwise the
|
||||||
// check for the auth difference can become expensive and produce
|
// check for the auth difference can become expensive and produce
|
||||||
|
@ -971,13 +975,14 @@ func (v *StateResolution) loadStateEvents(
|
||||||
return result, eventIDMap, nil
|
return result, eventIDMap, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// loadAuthEvents loads all of the auth events for a given event recursively.
|
// loadAuthEvents loads all of the auth events for a given event recursively,
|
||||||
|
// along with a map that contains state entries for all of the auth events.
|
||||||
func (v *StateResolution) loadAuthEvents(
|
func (v *StateResolution) loadAuthEvents(
|
||||||
ctx context.Context, event *gomatrixserverlib.Event,
|
ctx context.Context, event *gomatrixserverlib.Event,
|
||||||
) ([]*gomatrixserverlib.Event, error) {
|
) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) {
|
||||||
eventMap := map[string]struct{}{}
|
eventMap := map[string]struct{}{}
|
||||||
var getEvents func(eventIDs []string) ([]*gomatrixserverlib.Event, error)
|
var getEvents func(eventIDs []string) ([]types.Event, error)
|
||||||
getEvents = func(eventIDs []string) ([]*gomatrixserverlib.Event, error) {
|
getEvents = func(eventIDs []string) ([]types.Event, error) {
|
||||||
lookup := make([]string, 0, len(event.AuthEventIDs()))
|
lookup := make([]string, 0, len(event.AuthEventIDs()))
|
||||||
for _, eventID := range eventIDs {
|
for _, eventID := range eventIDs {
|
||||||
if _, ok := eventMap[eventID]; ok {
|
if _, ok := eventMap[eventID]; ok {
|
||||||
|
@ -992,19 +997,54 @@ func (v *StateResolution) loadAuthEvents(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("v.db.EventsFromIDs: %w", err)
|
return nil, fmt.Errorf("v.db.EventsFromIDs: %w", err)
|
||||||
}
|
}
|
||||||
result := make([]*gomatrixserverlib.Event, 0, len(events))
|
eventMap[event.EventID()] = struct{}{}
|
||||||
for _, event := range events {
|
next, err := getEvents(event.AuthEventIDs())
|
||||||
result = append(result, event.Event)
|
if err != nil {
|
||||||
eventMap[event.EventID()] = struct{}{}
|
return nil, err
|
||||||
next, err := getEvents(event.AuthEventIDs())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
result = append(result, next...)
|
|
||||||
}
|
}
|
||||||
return result, nil
|
return append(events, next...), nil
|
||||||
}
|
}
|
||||||
return getEvents(event.AuthEventIDs())
|
authEvents, err := getEvents(event.AuthEventIDs())
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("getEvents: %w", err)
|
||||||
|
}
|
||||||
|
authEventTypes := map[string]struct{}{}
|
||||||
|
authEventStateKeys := map[string]struct{}{}
|
||||||
|
for _, authEvent := range authEvents {
|
||||||
|
authEventTypes[authEvent.Type()] = struct{}{}
|
||||||
|
authEventStateKeys[*authEvent.StateKey()] = struct{}{}
|
||||||
|
}
|
||||||
|
lookupAuthEventTypes := make([]string, 0, len(authEventTypes))
|
||||||
|
lookupAuthEventStateKeys := make([]string, 0, len(authEventStateKeys))
|
||||||
|
for eventType := range authEventTypes {
|
||||||
|
lookupAuthEventTypes = append(lookupAuthEventTypes, eventType)
|
||||||
|
}
|
||||||
|
for eventStateKey := range authEventStateKeys {
|
||||||
|
lookupAuthEventStateKeys = append(lookupAuthEventStateKeys, eventStateKey)
|
||||||
|
}
|
||||||
|
eventTypes, err := v.db.EventTypeNIDs(ctx, lookupAuthEventTypes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("v.db.EventTypeNIDs: %w", err)
|
||||||
|
}
|
||||||
|
eventStateKeys, err := v.db.EventStateKeyNIDs(ctx, lookupAuthEventStateKeys)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("v.db.EventStateKeyNIDs: %w", err)
|
||||||
|
}
|
||||||
|
stateEntryMap := map[string]types.StateEntry{}
|
||||||
|
for _, authEvent := range authEvents {
|
||||||
|
stateEntryMap[authEvent.EventID()] = types.StateEntry{
|
||||||
|
EventNID: authEvent.EventNID,
|
||||||
|
StateKeyTuple: types.StateKeyTuple{
|
||||||
|
EventTypeNID: eventTypes[authEvent.Type()],
|
||||||
|
EventStateKeyNID: eventStateKeys[*authEvent.StateKey()],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nakedEvents := make([]*gomatrixserverlib.Event, 0, len(authEvents))
|
||||||
|
for _, authEvent := range authEvents {
|
||||||
|
nakedEvents = append(nakedEvents, authEvent.Event)
|
||||||
|
}
|
||||||
|
return nakedEvents, stateEntryMap, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// findDuplicateStateKeys finds the state entries where the state key tuple appears more than once in a sorted list.
|
// findDuplicateStateKeys finds the state entries where the state key tuple appears more than once in a sorted list.
|
||||||
|
|
Loading…
Reference in a new issue