diff --git a/appservice/consumers/roomserver.go b/appservice/consumers/roomserver.go index 7b59e3704..9d723bed1 100644 --- a/appservice/consumers/roomserver.go +++ b/appservice/consumers/roomserver.go @@ -88,7 +88,16 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) } events := []*gomatrixserverlib.HeaderedEvent{output.NewRoomEvent.Event} - events = append(events, output.NewRoomEvent.AddStateEvents...) + if len(output.NewRoomEvent.AddsStateEventIDs) > 0 { + eventsReq := &api.QueryEventsByIDRequest{ + EventIDs: output.NewRoomEvent.AddsStateEventIDs, + } + eventsRes := &api.QueryEventsByIDResponse{} + if err := s.rsAPI.QueryEventsByID(s.ctx, eventsReq, eventsRes); err != nil { + return false + } + events = append(events, eventsRes.Events...) + } // Send event to any relevant application services if err := s.filterRoomserverEvents(context.TODO(), events); err != nil { diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index 173dcff01..989f7cf49 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -146,7 +146,28 @@ func (s *OutputRoomEventConsumer) processInboundPeek(orp api.OutputNewInboundPee // processMessage updates the list of currently joined hosts in the room // and then sends the event to the hosts that were joined before the event. func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) error { - addsJoinedHosts, err := joinedHostsFromEvents(gomatrixserverlib.UnwrapEventHeaders(ore.AddsState())) + eventsRes := &api.QueryEventsByIDResponse{} + if len(ore.AddsStateEventIDs) > 0 { + eventsReq := &api.QueryEventsByIDRequest{ + EventIDs: ore.AddsStateEventIDs, + } + if err := s.rsAPI.QueryEventsByID(s.ctx, eventsReq, eventsRes); err != nil { + return fmt.Errorf("s.rsAPI.QueryEventsByID: %w", err) + } + + found := false + for _, event := range eventsRes.Events { + if event.EventID() == ore.Event.EventID() { + found = true + break + } + } + if !found { + eventsRes.Events = append(eventsRes.Events, ore.Event) + } + } + + addsJoinedHosts, err := joinedHostsFromEvents(gomatrixserverlib.UnwrapEventHeaders(eventsRes.Events)) if err != nil { return err } diff --git a/roomserver/api/output.go b/roomserver/api/output.go index d60d1cc86..f3ddb83b0 100644 --- a/roomserver/api/output.go +++ b/roomserver/api/output.go @@ -114,13 +114,6 @@ type OutputNewRoomEvent struct { // Together with RemovesStateEventIDs this allows the receiver to keep an up to date // view of the current state of the room. AddsStateEventIDs []string `json:"adds_state_event_ids"` - // All extra newly added state events. This is only set if there are *extra* events - // other than `Event`. This can happen when forks get merged because state resolution - // may decide a bunch of state events on one branch are now valid, so they will be - // present in this list. This is useful when trying to maintain the current state of a room - // as to do so you need to include both these events and `Event`. - AddStateEvents []*gomatrixserverlib.HeaderedEvent `json:"adds_state_events"` - // The state event IDs that were removed from the state of the room by this event. RemovesStateEventIDs []string `json:"removes_state_event_ids"` // The ID of the event that was output before this event. @@ -170,26 +163,6 @@ type OutputNewRoomEvent struct { TransactionID *TransactionID `json:"transaction_id"` } -// AddsState returns all added state events from this event. -// -// This function is needed because `AddStateEvents` will not include a copy of -// the original event to save space, so you cannot use that slice alone. -// Instead, use this function which will add the original event if it is present -// in `AddsStateEventIDs`. -func (ore *OutputNewRoomEvent) AddsState() []*gomatrixserverlib.HeaderedEvent { - includeOutputEvent := false - for _, id := range ore.AddsStateEventIDs { - if id == ore.Event.EventID() { - includeOutputEvent = true - break - } - } - if !includeOutputEvent { - return ore.AddStateEvents - } - return append(ore.AddStateEvents, ore.Event) -} - // An OutputOldRoomEvent is written when the roomserver receives an old event. // This will typically happen as a result of getting either missing events // or backfilling. Downstream components may wish to send these events to diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go index f4a52031a..7e58ef9d0 100644 --- a/roomserver/internal/input/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -365,6 +365,7 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) LastSentEventID: u.lastEventIDSent, LatestEventIDs: latestEventIDs, TransactionID: u.transactionID, + SendAsServer: u.sendAsServer, } eventIDMap, err := u.stateEventMap() @@ -384,51 +385,17 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) ore.StateBeforeAddsEventIDs = append(ore.StateBeforeAddsEventIDs, eventIDMap[entry.EventNID]) } - ore.SendAsServer = u.sendAsServer - - // include extra state events if they were added as nearly every downstream component will care about it - // and we'd rather not have them all hit QueryEventsByID at the same time! - if len(ore.AddsStateEventIDs) > 0 { - var err error - if ore.AddStateEvents, err = u.extraEventsForIDs(u.roomInfo.RoomVersion, ore.AddsStateEventIDs); err != nil { - return nil, fmt.Errorf("failed to load add_state_events from db: %w", err) - } - } - return &api.OutputEvent{ Type: api.OutputTypeNewRoomEvent, NewRoomEvent: &ore, }, nil } -// extraEventsForIDs returns the full events for the event IDs given, but does not include the current event being -// updated. -func (u *latestEventsUpdater) extraEventsForIDs(roomVersion gomatrixserverlib.RoomVersion, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { - var extraEventIDs []string - for _, e := range eventIDs { - if e == u.event.EventID() { - continue - } - extraEventIDs = append(extraEventIDs, e) - } - if len(extraEventIDs) == 0 { - return nil, nil - } - extraEvents, err := u.updater.UnsentEventsFromIDs(u.ctx, extraEventIDs) - if err != nil { - return nil, err - } - var h []*gomatrixserverlib.HeaderedEvent - for _, e := range extraEvents { - h = append(h, e.Headered(roomVersion)) - } - return h, nil -} - // retrieve an event nid -> event ID map for all events that need updating func (u *latestEventsUpdater) stateEventMap() (map[types.EventNID]string, error) { - var stateEventNIDs []types.EventNID - var allStateEntries []types.StateEntry + cap := len(u.added) + len(u.removed) + len(u.stateBeforeEventRemoves) + len(u.stateBeforeEventAdds) + stateEventNIDs := make(types.EventNIDs, 0, cap) + allStateEntries := make([]types.StateEntry, 0, cap) allStateEntries = append(allStateEntries, u.added...) allStateEntries = append(allStateEntries, u.removed...) allStateEntries = append(allStateEntries, u.stateBeforeEventRemoves...) @@ -436,12 +403,6 @@ func (u *latestEventsUpdater) stateEventMap() (map[types.EventNID]string, error) for _, entry := range allStateEntries { stateEventNIDs = append(stateEventNIDs, entry.EventNID) } - stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))] + stateEventNIDs = stateEventNIDs[:util.SortAndUnique(stateEventNIDs)] return u.updater.EventIDs(u.ctx, stateEventNIDs) } - -type eventNIDSorter []types.EventNID - -func (s eventNIDSorter) Len() int { return len(s) } -func (s eventNIDSorter) Less(i, j int) bool { return s[i] < s[j] } -func (s eventNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index 159657f9f..640c505c2 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -154,7 +154,42 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( ctx context.Context, msg api.OutputNewRoomEvent, ) error { ev := msg.Event - addsStateEvents := msg.AddsState() + + addsStateEvents := []*gomatrixserverlib.HeaderedEvent{} + foundEventIDs := map[string]bool{} + if len(msg.AddsStateEventIDs) > 0 { + for _, eventID := range msg.AddsStateEventIDs { + foundEventIDs[eventID] = false + } + foundEvents, err := s.db.Events(ctx, msg.AddsStateEventIDs) + if err != nil { + return fmt.Errorf("s.db.Events: %w", err) + } + for _, event := range foundEvents { + foundEventIDs[event.EventID()] = true + } + eventsReq := &api.QueryEventsByIDRequest{} + eventsRes := &api.QueryEventsByIDResponse{} + for eventID, found := range foundEventIDs { + if !found { + eventsReq.EventIDs = append(eventsReq.EventIDs, eventID) + } + } + if err = s.rsAPI.QueryEventsByID(ctx, eventsReq, eventsRes); err != nil { + return fmt.Errorf("s.rsAPI.QueryEventsByID: %w", err) + } + for _, event := range eventsRes.Events { + eventID := event.EventID() + foundEvents = append(foundEvents, event) + foundEventIDs[eventID] = true + } + for eventID, found := range foundEventIDs { + if !found { + return fmt.Errorf("event %s is missing", eventID) + } + } + addsStateEvents = foundEvents + } ev, err := s.updateStateEvent(ev) if err != nil {