diff --git a/roomserver/query/query.go b/roomserver/query/query.go index a62a1f706..19207b09d 100644 --- a/roomserver/query/query.go +++ b/roomserver/query/query.go @@ -617,42 +617,42 @@ func (r *RoomserverQueryAPI) QueryStateAndAuthChain( func getAuthChain( ctx context.Context, dB RoomserverQueryAPIEventDB, authEventIDs []string, ) ([]gomatrixserverlib.Event, error) { - var authEvents []gomatrixserverlib.Event - - // List of event ids to fetch. These will be added to the result and - // their auth events will be fetched (if they haven't been previously) eventsToFetch := authEventIDs + authEventsMap := make(map[string]gomatrixserverlib.Event) - // Set of events we've already fetched. - fetchedEventMap := make(map[string]bool) - - // Check if there's anything left to do for len(eventsToFetch) > 0 { - // Convert eventIDs to events. First need to fetch NIDs + // Try to retrieve the events from the database events, err := dB.EventsFromIDs(ctx, eventsToFetch) if err != nil { return nil, err } - // Work out a) which events we should add to the returned list of - // events and b) which of the auth events we haven't seen yet and - // add them to the list of events to fetch. + // Clear out the events to fetch, since we have already requested them, and + // we may now get new ones eventsToFetch = eventsToFetch[:0] - for _, event := range events { - fetchedEventMap[event.EventID()] = true - authEvents = append(authEvents, event.Event) - // Now we need to fetch any auth events that we haven't - // previously seen. - for _, authEventID := range event.AuthEventIDs() { - if !fetchedEventMap[authEventID] { - fetchedEventMap[authEventID] = true - eventsToFetch = append(eventsToFetch, authEventID) + for _, event := range events { + // Store the event in the event map - this prevents us from requesting it + // again, and we'll return this later + authEventsMap[event.EventID()] = event.Event + + // Extract all of the auth events from the newly obtained event. If we + // don't already have a record of the event, record it in the list of + // events we want to request + for _, authEvent := range event.AuthEvents() { + if _, ok := authEventsMap[authEvent.EventID]; !ok { + eventsToFetch = append(eventsToFetch, authEvent.EventID) } } } } + // Flatten the event map down into an array + var authEvents []gomatrixserverlib.Event + for _, event := range authEventsMap { + authEvents = append(authEvents, event) + } + return authEvents, nil }