diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index ff2c8e5d4..c2c2d1874 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -146,28 +146,30 @@ func (s *OutputRoomEventConsumer) processInboundPeek(orp api.OutputNewInboundPee // processMessage updates the list of currently joined hosts in the room // and then sends the event to the hosts that were joined before the event. func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) error { - eventsRes := &api.QueryEventsByIDResponse{} - if len(ore.AddsStateEventIDs) > 0 { - eventsReq := &api.QueryEventsByIDRequest{ - EventIDs: ore.AddsStateEventIDs, - } - if err := s.rsAPI.QueryEventsByID(s.ctx, eventsReq, eventsRes); err != nil { - return fmt.Errorf("s.rsAPI.QueryEventsByID: %w", err) - } - - found := false - for _, event := range eventsRes.Events { - if event.EventID() == ore.Event.EventID() { - found = true - break - } - } - if !found { - eventsRes.Events = append(eventsRes.Events, ore.Event) + addsStateEvents := []*gomatrixserverlib.HeaderedEvent{} + missingEventIDs := make([]string, 0, len(ore.AddsStateEventIDs)) + for _, eventID := range ore.AddsStateEventIDs { + if eventID == ore.Event.EventID() { + addsStateEvents = append(addsStateEvents, ore.Event) + } else { + missingEventIDs = append(missingEventIDs, eventID) } } - addsJoinedHosts, err := joinedHostsFromEvents(gomatrixserverlib.UnwrapEventHeaders(eventsRes.Events)) + // Ask the roomserver and add in the rest of the results into the set. + // Finally, work out if there are any more events missing. + if len(missingEventIDs) > 0 { + eventsReq := &api.QueryEventsByIDRequest{ + EventIDs: missingEventIDs, + } + eventsRes := &api.QueryEventsByIDResponse{} + if err := s.rsAPI.QueryEventsByID(s.ctx, eventsReq, eventsRes); err != nil { + return fmt.Errorf("s.rsAPI.QueryEventsByID: %w", err) + } + addsStateEvents = append(addsStateEvents, eventsRes.Events...) + } + + addsJoinedHosts, err := joinedHostsFromEvents(gomatrixserverlib.UnwrapEventHeaders(addsStateEvents)) if err != nil { return err }