diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 24e29a18d..fa2a7bbb6 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -508,13 +508,12 @@ func (t *txnReq) processEventWithMissingState(ctx context.Context, e gomatrixser // Therefore, we cannot just query /state_ids with this event to get the state before. Instead, we need to query // the state AFTER all the prev_events for this event, then apply state resolution to that to get the state before the event. var states []*gomatrixserverlib.RespState - needed := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{*backwardsExtremity}).Tuples() for _, prevEventID := range backwardsExtremity.PrevEventIDs() { // Look up what the state is after the backward extremity. This will either // come from the roomserver, if we know all the required events, or it will // come from a remote server via /state_ids if not. var prevState *gomatrixserverlib.RespState - prevState, err = t.lookupStateAfterEvent(gmectx, roomVersion, backwardsExtremity.RoomID(), prevEventID, needed) + prevState, err = t.lookupStateAfterEvent(gmectx, roomVersion, backwardsExtremity.RoomID(), prevEventID) if err != nil { util.GetLogger(ctx).WithError(err).Errorf("Failed to lookup state after prev_event: %s", prevEventID) return err @@ -573,9 +572,9 @@ func (t *txnReq) processEventWithMissingState(ctx context.Context, e gomatrixser // lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event) // added into the mix. -func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, needed []gomatrixserverlib.StateKeyTuple) (*gomatrixserverlib.RespState, error) { +func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (*gomatrixserverlib.RespState, error) { // try doing all this locally before we resort to querying federation - respState := t.lookupStateAfterEventLocally(ctx, roomID, eventID, needed) + respState := t.lookupStateAfterEventLocally(ctx, roomID, eventID) if respState != nil { return respState, nil } @@ -619,12 +618,11 @@ func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrix return respState, nil } -func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, eventID string, needed []gomatrixserverlib.StateKeyTuple) *gomatrixserverlib.RespState { +func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, eventID string) *gomatrixserverlib.RespState { var res api.QueryStateAfterEventsResponse err := t.rsAPI.QueryStateAfterEvents(ctx, &api.QueryStateAfterEventsRequest{ RoomID: roomID, PrevEventIDs: []string{eventID}, - StateToFetch: needed, }, &res) if err != nil || !res.PrevEventsExist { util.GetLogger(ctx).WithError(err).Warnf("failed to query state after %s locally", eventID) diff --git a/roomserver/api/query.go b/roomserver/api/query.go index aff6ee07a..3afca7e81 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -63,7 +63,8 @@ type QueryStateAfterEventsRequest struct { RoomID string `json:"room_id"` // The list of previous events to return the events after. PrevEventIDs []string `json:"prev_event_ids"` - // The state key tuples to fetch from the state + // The state key tuples to fetch from the state. If none are specified then + // the entire resolved room state will be returned. StateToFetch []gomatrixserverlib.StateKeyTuple `json:"state_to_fetch"` } diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go index 229665a0b..ca5d214d7 100644 --- a/roomserver/internal/input/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -133,8 +133,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { // If the event has already been written to the output log then we // don't need to do anything, as we've handled it already. - hasBeenSent, err := u.updater.HasEventBeenSent(u.stateAtEvent.EventNID) - if err != nil { + if hasBeenSent, err := u.updater.HasEventBeenSent(u.stateAtEvent.EventNID); err != nil { return fmt.Errorf("u.updater.HasEventBeenSent: %w", err) } else if hasBeenSent { return nil @@ -142,17 +141,19 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { // Work out what the latest events are. This will include the new // event if it is not already referenced. - u.calculateLatest( + if err := u.calculateLatest( oldLatest, types.StateAtEventAndReference{ EventReference: u.event.EventReference(), StateAtEvent: u.stateAtEvent, }, - ) + ); err != nil { + return fmt.Errorf("u.calculateLatest: %w", err) + } // Now that we know what the latest events are, it's time to get the // latest state. - if err = u.latestState(); err != nil { + if err := u.latestState(); err != nil { return fmt.Errorf("u.latestState: %w", err) } @@ -261,7 +262,7 @@ func (u *latestEventsUpdater) latestState() error { func (u *latestEventsUpdater) calculateLatest( oldLatest []types.StateAtEventAndReference, newEvent types.StateAtEventAndReference, -) { +) error { var newLatest []types.StateAtEventAndReference // First of all, let's see if any of the existing forward extremities @@ -271,6 +272,7 @@ func (u *latestEventsUpdater) calculateLatest( referenced, err := u.updater.IsReferenced(l.EventReference) if err != nil { logrus.WithError(err).Errorf("Failed to retrieve event reference for %q", l.EventID) + return fmt.Errorf("u.updater.IsReferenced (old): %w", err) } else if !referenced { newLatest = append(newLatest, l) } @@ -285,7 +287,7 @@ func (u *latestEventsUpdater) calculateLatest( // We've already referenced this new event so we can just return // the newly completed extremities at this point. u.latest = newLatest - return + return nil } } @@ -296,11 +298,13 @@ func (u *latestEventsUpdater) calculateLatest( referenced, err := u.updater.IsReferenced(newEvent.EventReference) if err != nil { logrus.WithError(err).Errorf("Failed to retrieve event reference for %q", newEvent.EventReference.EventID) + return fmt.Errorf("u.updater.IsReferenced (new): %w", err) } else if !referenced || len(newLatest) == 0 { newLatest = append(newLatest, newEvent) } u.latest = newLatest + return nil } func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) { diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 810511505..ecfb580f2 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -49,6 +49,7 @@ func (r *Queryer) QueryLatestEventsAndState( } // QueryStateAfterEvents implements api.RoomserverInternalAPI +// nolint:gocyclo func (r *Queryer) QueryStateAfterEvents( ctx context.Context, request *api.QueryStateAfterEventsRequest, @@ -78,10 +79,18 @@ func (r *Queryer) QueryStateAfterEvents( } response.PrevEventsExist = true - // Look up the currrent state for the requested tuples. - stateEntries, err := roomState.LoadStateAfterEventsForStringTuples( - ctx, prevStates, request.StateToFetch, - ) + var stateEntries []types.StateEntry + if len(request.StateToFetch) == 0 { + // Look up all of the current room state. + stateEntries, err = roomState.LoadCombinedStateAfterEvents( + ctx, prevStates, + ) + } else { + // Look up the current state for the requested tuples. + stateEntries, err = roomState.LoadStateAfterEventsForStringTuples( + ctx, prevStates, request.StateToFetch, + ) + } if err != nil { return err } @@ -91,6 +100,24 @@ func (r *Queryer) QueryStateAfterEvents( return err } + if len(request.PrevEventIDs) > 1 && len(request.StateToFetch) == 0 { + var authEventIDs []string + for _, e := range stateEvents { + authEventIDs = append(authEventIDs, e.AuthEventIDs()...) + } + authEventIDs = util.UniqueStrings(authEventIDs) + + authEvents, err := getAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs) + if err != nil { + return fmt.Errorf("getAuthChain: %w", err) + } + + stateEvents, err = state.ResolveConflictsAdhoc(info.RoomVersion, stateEvents, authEvents) + if err != nil { + return fmt.Errorf("state.ResolveConflictsAdhoc: %w", err) + } + } + for _, event := range stateEvents { response.StateEvents = append(response.StateEvents, event.Headered(info.RoomVersion)) }