diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 2ee303d68..2e8a74dc5 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -48,6 +48,8 @@ func Send( eduProducer: eduProducer, keys: keys, federation: federation, + haveEvents: make(map[string]*gomatrixserverlib.HeaderedEvent), + newEvents: make(map[string]bool), } var txnEvents struct { @@ -105,6 +107,11 @@ type txnReq struct { eduProducer *producers.EDUServerProducer keys gomatrixserverlib.JSONVerifier federation txnFederationClient + // local cache of events for auth checks, etc - this may include events + // which the roomserver is unaware of. + haveEvents map[string]*gomatrixserverlib.HeaderedEvent + // new events which the roomserver does not know about + newEvents map[string]bool } // A subset of FederationClient functionality that txn requires. Useful for testing. @@ -214,6 +221,17 @@ func (e missingPrevEventsError) Error() string { return fmt.Sprintf("unable to get prev_events for event %q: %s", e.eventID, e.err) } +func (t *txnReq) haveEventIDs() map[string]bool { + result := make(map[string]bool, len(t.haveEvents)) + for eventID := range t.haveEvents { + if t.newEvents[eventID] { + continue + } + result[eventID] = true + } + return result +} + func (t *txnReq) processEDUs(edus []gomatrixserverlib.EDU) { for _, e := range edus { switch e.Type { @@ -239,7 +257,6 @@ func (t *txnReq) processEDUs(edus []gomatrixserverlib.EDU) { func (t *txnReq) processEvent(e gomatrixserverlib.Event, isInboundTxn bool) error { prevEventIDs := e.PrevEventIDs() - util.GetLogger(t.context).Infof("processEvent %s with prev_events %v", e.EventID(), prevEventIDs) // Fetch the state needed to authenticate the event. needed := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{e}) @@ -252,7 +269,6 @@ func (t *txnReq) processEvent(e gomatrixserverlib.Event, isInboundTxn bool) erro if err := t.rsAPI.QueryStateAfterEvents(t.context, &stateReq, &stateResp); err != nil { return err } - util.GetLogger(t.context).Infof("processEvent %s stateResp.PrevEventsExist: %v", e.EventID(), stateResp.PrevEventsExist) if !stateResp.RoomExists { // TODO: When synapse receives a message for a room it is not in it @@ -334,142 +350,148 @@ func (t *txnReq) processEventWithMissingState(e gomatrixserverlib.Event, roomVer return nil } - // fetch the state BEFORE the event then check that the event is allowed - respState, haveEventIDs, err := t.lookupStateAfterEvent(roomVersion, *backwardsExtremity) + // at this point we know we're going to have a gap: we need to work out the room state at the new backwards extremity. + // security: we have to do state resolution on the new backwards extremity (TODO: WHY) + // Therefore, we cannot just query /state_ids with this event to get the state before. Instead, we need to query + // the state AFTER all the prev_events for this event, then mix in our current room state and apply state resolution + // to that to get the state before the event. + var states []*gomatrixserverlib.RespState + needed := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{*backwardsExtremity}).Tuples() + for _, prevEventID := range backwardsExtremity.PrevEventIDs() { + prevState, err := t.lookupStateAfterEvent(roomVersion, backwardsExtremity.RoomID(), prevEventID, needed) + if err != nil { + util.GetLogger(t.context).WithError(err).Errorf("Failed to lookup state after prev_event: %s", prevEventID) + return err + } + states = append(states, prevState) + } + // mix in the current room state + currState, err := t.lookupCurrentState(backwardsExtremity) if err != nil { + util.GetLogger(t.context).WithError(err).Errorf("Failed to lookup current room state") + return err + } + states = append(states, currState) + resolvedState, err := t.resolveStatesAndCheck(roomVersion, states, backwardsExtremity) + if err != nil { + util.GetLogger(t.context).WithError(err).Errorf("Failed to resolve state conflicts for event %s", backwardsExtremity.EventID()) return err } - fmt.Println("Calcuated lookupStateAfterEvent") // pass the event along with the state to the roomserver using a background context so we don't // needlessly expire - return t.producer.SendEventWithState(context.Background(), respState, e.Headered(roomVersion), haveEventIDs) + return t.producer.SendEventWithState(context.Background(), resolvedState, e.Headered(roomVersion), t.haveEventIDs()) } -// lookupStateAfterEvent returns the room state after the event e, which is all the states before e resolved via state resolution -// then having e applied to the resulting state. -func (t *txnReq) lookupStateAfterEvent(roomVersion gomatrixserverlib.RoomVersion, e gomatrixserverlib.Event) (*gomatrixserverlib.RespState, map[string]bool, error) { - // de-dupe all the events - authEvents := make(map[string]*gomatrixserverlib.Event) - stateEvents := make(map[string]*gomatrixserverlib.Event) - haveEventIDs := make(map[string]bool) - for _, prevEventID := range e.PrevEventIDs() { - // don't do auth checks on this RespState as we're just interested in grabbing state/auth events and putting it into the pot - respState, haveIDs, err := t.lookupStateBeforeEvent(roomVersion, false, e.RoomID(), prevEventID) - if err != nil { - return nil, nil, err +// lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event) +// added into the mix. +func (t *txnReq) lookupStateAfterEvent(roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, needed []gomatrixserverlib.StateKeyTuple) (*gomatrixserverlib.RespState, error) { + var res api.QueryStateAfterEventsResponse + err := t.rsAPI.QueryStateAfterEvents(t.context, &api.QueryStateAfterEventsRequest{ + RoomID: roomID, + PrevEventIDs: []string{eventID}, + StateToFetch: needed, + }, &res) + if err != nil || !res.PrevEventsExist { + util.GetLogger(t.context).WithError(err).Warnf("failed to query state after %s locally; trying remotely", eventID) + } else { + for i, ev := range res.StateEvents { + t.haveEvents[ev.EventID()] = &res.StateEvents[i] } - for i := range respState.StateEvents { - stateEvents[respState.StateEvents[i].EventID()] = &respState.StateEvents[i] + var authEvents []gomatrixserverlib.Event + missingAuthEvents := make(map[string]bool) + for _, ev := range res.StateEvents { + for _, ae := range ev.AuthEventIDs() { + aev, ok := t.haveEvents[ae] + if ok { + authEvents = append(authEvents, aev.Unwrap()) + } else { + missingAuthEvents[ae] = true + } + } } - for i := range respState.AuthEvents { - authEvents[respState.AuthEvents[i].EventID()] = &respState.AuthEvents[i] + // QueryStateAfterEvents does not return the auth events, so fetch them now. We know the roomserver has them else it wouldn't + // have stored the event. + var missingEventList []string + for evID := range missingAuthEvents { + missingEventList = append(missingEventList, evID) } - for id := range haveIDs { - haveEventIDs[id] = true + queryReq := api.QueryEventsByIDRequest{ + EventIDs: missingEventList, } - // fetch the event we're missing and add it to the pile - h, err := t.lookupEvent(roomVersion, prevEventID) - if err != nil { - return nil, nil, err + util.GetLogger(t.context).Infof("Fetching missing auth events: %v", missingEventList) + var queryRes api.QueryEventsByIDResponse + if err = t.rsAPI.QueryEventsByID(t.context, &queryReq, &queryRes); err != nil { + return nil, err } - if h.StateKey() != nil { - he := h.Unwrap() - stateEvents[h.EventID()] = &he + for i := range queryRes.Events { + evID := queryRes.Events[i].EventID() + t.haveEvents[evID] = &queryRes.Events[i] + authEvents = append(authEvents, queryRes.Events[i].Unwrap()) } + + evs := gomatrixserverlib.UnwrapEventHeaders(res.StateEvents) + return &gomatrixserverlib.RespState{ + StateEvents: evs, + AuthEvents: authEvents, + }, nil } - authEventList := make([]gomatrixserverlib.Event, len(authEvents)) - i := 0 - for _, ev := range authEvents { - authEventList[i] = *ev - i++ - } - stateEventList := make([]gomatrixserverlib.Event, len(stateEvents)) - i = 0 - for _, ev := range stateEvents { - stateEventList[i] = *ev - i++ - } - resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts(roomVersion, stateEventList, authEventList) + + // don't do auth checks on this RespState as we're just interested in grabbing state/auth events and putting it into the pot + respState, err := t.lookupStateBeforeEvent(roomVersion, false, roomID, eventID) if err != nil { - return nil, nil, err + return nil, err } - // apply the current event - if err = checkAllowedByState(e, resolvedStateEvents); err != nil { - return nil, nil, err + + // fetch the event we're missing and add it to the pile + h, err := t.lookupEvent(roomVersion, eventID) + if err != nil { + return nil, err } - // roll forward state if this event is a state event - if e.StateKey() != nil { - for i := range resolvedStateEvents { - if resolvedStateEvents[i].Type() == e.Type() && resolvedStateEvents[i].StateKeyEquals(*e.StateKey()) { - resolvedStateEvents[i] = e + t.haveEvents[h.EventID()] = h + if h.StateKey() != nil { + addedToState := false + for i := range respState.StateEvents { + se := respState.StateEvents[i] + if se.Type() == h.Type() && se.StateKeyEquals(*h.StateKey()) { + respState.StateEvents[i] = h.Unwrap() + addedToState = true break } } - } - for _, s := range resolvedStateEvents { - util.GetLogger(t.context).Infof("resolved: %s -> %s", s.Type(), string(s.Content())) - } - for _, s := range authEventList { - util.GetLogger(t.context).Infof("authEventList: %s -> %s", s.Type(), string(s.Content())) + if !addedToState { + respState.StateEvents = append(respState.StateEvents, h.Unwrap()) + } } - resp := &gomatrixserverlib.RespState{ - AuthEvents: authEventList, - StateEvents: resolvedStateEvents, - } - if err = resp.Check(t.context, t.keys); err != nil { - return nil, nil, fmt.Errorf("lookupStateAfterEvent: resolved state is not valid: %w", err) - } + return respState, nil +} - return resp, haveEventIDs, nil +func (t *txnReq) lookupCurrentState(newEvent *gomatrixserverlib.Event) (*gomatrixserverlib.RespState, error) { + // Ask the roomserver for information about this room + queryReq := api.QueryLatestEventsAndStateRequest{ + RoomID: newEvent.RoomID(), + StateToFetch: gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{*newEvent}).Tuples(), + } + var queryRes api.QueryLatestEventsAndStateResponse + if err := t.rsAPI.QueryLatestEventsAndState(t.context, &queryReq, &queryRes); err != nil { + return nil, fmt.Errorf("lookupCurrentState rsAPI.QueryLatestEventsAndState: %w", err) + } + return &gomatrixserverlib.RespState{ + StateEvents: gomatrixserverlib.UnwrapEventHeaders(queryRes.StateEvents), + // TODO: Auth events? + }, nil } // lookuptStateBeforeEvent returns the room state before the event e, which is just /state_ids and/or /state depending on what // the server supports. func (t *txnReq) lookupStateBeforeEvent(roomVersion gomatrixserverlib.RoomVersion, doAuthCheck bool, roomID, eventID string) ( - respState *gomatrixserverlib.RespState, haveEventIDs map[string]bool, err error) { + respState *gomatrixserverlib.RespState, err error) { util.GetLogger(t.context).Infof("lookupStateBeforeEvent %s", eventID) - // It's entirely possible that we know this state, as QueryStateAfterEventsRequest only returns success if ALL prev_events - // exist, so query the roomserver for the state with just this prev event - stateReq := api.QueryStateAfterEventsRequest{ - RoomID: roomID, - StateToFetch: nil, // TODO: do we need everything? - PrevEventIDs: []string{eventID}, - } - var stateResp api.QueryStateAfterEventsResponse - if err = t.rsAPI.QueryStateAfterEvents(t.context, &stateReq, &stateResp); err != nil || stateResp.StateEvents == nil { - util.GetLogger(t.context).WithError(err).Warnf("Failed to lookup state before event %s via roomserver - asking remote", eventID) - // fallthrough to remote lookup - } else { - util.GetLogger(t.context).Infof("lookupStateBeforeEvent %s returned locally", eventID) - // we have all the events - haveEvents := make(map[string]*gomatrixserverlib.HeaderedEvent) - haveEventIDs = make(map[string]bool) - for i, ev := range stateResp.StateEvents { - haveEventIDs[ev.EventID()] = true - haveEvents[ev.EventID()] = &stateResp.StateEvents[i] - } - var authEvents []gomatrixserverlib.Event - for _, ev := range stateResp.StateEvents { - for _, ae := range ev.AuthEventIDs() { - aev, ok := haveEvents[ae] - if ok { - authEvents = append(authEvents, aev.Unwrap()) - } - } - } - - respState = &gomatrixserverlib.RespState{ - AuthEvents: authEvents, - StateEvents: gomatrixserverlib.UnwrapEventHeaders(stateResp.StateEvents), - } - return - } // Attempt to fetch the missing state using /state_ids and /events - respState, haveEventIDs, err = t.lookupMissingStateViaStateIDs(roomID, eventID, doAuthCheck, roomVersion) + respState, err = t.lookupMissingStateViaStateIDs(roomID, eventID, doAuthCheck, roomVersion) if err != nil { // Fallback to /state util.GetLogger(t.context).WithError(err).Warn("lookupStateBeforeEvent failed to /state_ids, falling back to /state") @@ -478,6 +500,31 @@ func (t *txnReq) lookupStateBeforeEvent(roomVersion gomatrixserverlib.RoomVersio return } +func (t *txnReq) resolveStatesAndCheck(roomVersion gomatrixserverlib.RoomVersion, states []*gomatrixserverlib.RespState, backwardsExtremity *gomatrixserverlib.Event) (*gomatrixserverlib.RespState, error) { + var authEventList []gomatrixserverlib.Event + var stateEventList []gomatrixserverlib.Event + for _, state := range states { + for _, ae := range state.AuthEvents { + authEventList = append(authEventList, ae) + } + for _, se := range state.StateEvents { + stateEventList = append(stateEventList, se) + } + } + resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts(roomVersion, stateEventList, authEventList) + if err != nil { + return nil, err + } + // apply the current event + if err = checkAllowedByState(*backwardsExtremity, resolvedStateEvents); err != nil { + return nil, err + } + return &gomatrixserverlib.RespState{ + AuthEvents: authEventList, + StateEvents: resolvedStateEvents, + }, nil +} + // getMissingEvents returns a nil backwardsExtremity if missing events were fetched and handled, else returns the new backwards extremity which we should // begin from. Returns an error only if we should terminate the transaction which initiated /get_missing_events // This function recursively calls txnReq.processEvent with the missing events, which will be processed before this function returns. @@ -485,7 +532,6 @@ func (t *txnReq) lookupStateBeforeEvent(roomVersion gomatrixserverlib.RoomVersio func (t *txnReq) getMissingEvents(e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, isInboundTxn bool) (backwardsExtremity *gomatrixserverlib.Event, err error) { if !isInboundTxn { // we've recursed here, so just take a state snapshot please! - fmt.Println("backwards extremity is now ", e.EventID()) return &e, nil } logger := util.GetLogger(t.context).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) @@ -588,46 +634,46 @@ func (t *txnReq) lookupMissingStateViaState(roomID, eventID string, roomVersion } func (t *txnReq) lookupMissingStateViaStateIDs(roomID, eventID string, doAuthCheck bool, roomVersion gomatrixserverlib.RoomVersion) ( - *gomatrixserverlib.RespState, map[string]bool, error) { + *gomatrixserverlib.RespState, error) { util.GetLogger(t.context).Infof("lookupMissingStateViaStateIDs %s", eventID) // fetch the state event IDs at the time of the event stateIDs, err := t.federation.LookupStateIDs(t.context, t.Origin, roomID, eventID) if err != nil { - return nil, nil, err + return nil, err } - - // fetch as many as we can from the roomserver, do them as 2 calls rather than - // 1 to try to reduce the number of parameters in the bulk query this will use - haveEventMap := make(map[string]*gomatrixserverlib.HeaderedEvent, len(stateIDs.StateEventIDs)) - haveEventIDs := make(map[string]bool) - for _, eventList := range [][]string{stateIDs.StateEventIDs, stateIDs.AuthEventIDs} { - queryReq := api.QueryEventsByIDRequest{ - EventIDs: eventList, - } - var queryRes api.QueryEventsByIDResponse - if err = t.rsAPI.QueryEventsByID(t.context, &queryReq, &queryRes); err != nil { - return nil, nil, err - } - // allow indexing of current state by event ID - for i := range queryRes.Events { - haveEventMap[queryRes.Events[i].EventID()] = &queryRes.Events[i] - haveEventIDs[queryRes.Events[i].EventID()] = true - } - } - // work out which auth/state IDs are missing wantIDs := append(stateIDs.StateEventIDs, stateIDs.AuthEventIDs...) missing := make(map[string]bool) + var missingEventList []string for _, sid := range wantIDs { - if _, ok := haveEventMap[sid]; !ok { - missing[sid] = true + if _, ok := t.haveEvents[sid]; !ok { + if !missing[sid] { + missing[sid] = true + missingEventList = append(missingEventList, sid) + } } } + + // fetch as many as we can from the roomserver + queryReq := api.QueryEventsByIDRequest{ + EventIDs: missingEventList, + } + var queryRes api.QueryEventsByIDResponse + if err = t.rsAPI.QueryEventsByID(t.context, &queryReq, &queryRes); err != nil { + return nil, err + } + for i := range queryRes.Events { + evID := queryRes.Events[i].EventID() + t.haveEvents[evID] = &queryRes.Events[i] + if missing[evID] { + delete(missing, evID) + } + } + util.GetLogger(t.context).WithFields(logrus.Fields{ "missing": len(missing), "event_id": eventID, "room_id": roomID, - "already_have": len(haveEventMap), "total_state": len(stateIDs.StateEventIDs), "total_auth_events": len(stateIDs.AuthEventIDs), }).Info("Fetching missing state at event") @@ -636,15 +682,15 @@ func (t *txnReq) lookupMissingStateViaStateIDs(roomID, eventID string, doAuthChe var h *gomatrixserverlib.HeaderedEvent h, err = t.lookupEvent(roomVersion, missingEventID) if err != nil { - return nil, nil, err + return nil, err } - haveEventMap[h.EventID()] = h + t.haveEvents[h.EventID()] = h } - resp, err := t.createRespStateFromStateIDs(stateIDs, doAuthCheck, haveEventMap) - return resp, haveEventIDs, err + resp, err := t.createRespStateFromStateIDs(stateIDs, doAuthCheck) + return resp, err } -func (t *txnReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib.RespStateIDs, doAuthCheck bool, haveEventMap map[string]*gomatrixserverlib.HeaderedEvent) ( +func (t *txnReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib.RespStateIDs, doAuthCheck bool) ( *gomatrixserverlib.RespState, error) { // create a RespState response using the response to /state_ids as a guide respState := gomatrixserverlib.RespState{ @@ -654,7 +700,7 @@ func (t *txnReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib.RespStat var roomVer gomatrixserverlib.RoomVersion for i := range stateIDs.StateEventIDs { - ev, ok := haveEventMap[stateIDs.StateEventIDs[i]] + ev, ok := t.haveEvents[stateIDs.StateEventIDs[i]] if !ok { return nil, fmt.Errorf("missing state event %s", stateIDs.StateEventIDs[i]) } @@ -662,12 +708,16 @@ func (t *txnReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib.RespStat roomVer = ev.RoomVersion } for i := range stateIDs.AuthEventIDs { - ev, ok := haveEventMap[stateIDs.AuthEventIDs[i]] + ev, ok := t.haveEvents[stateIDs.AuthEventIDs[i]] if !ok { return nil, fmt.Errorf("missing auth event %s", stateIDs.AuthEventIDs[i]) } respState.AuthEvents[i] = ev.Unwrap() } + + if !doAuthCheck { + return &respState, nil + } // Check that the returned state is valid. retryCheck: if err := respState.Check(t.context, t.keys); err != nil { @@ -683,11 +733,7 @@ retryCheck: respState.AuthEvents = append(respState.AuthEvents, newEv.Unwrap()) goto retryCheck } - if doAuthCheck { - return nil, err - } else { - return &respState, nil - } + return nil, err } return &respState, nil } @@ -710,5 +756,6 @@ func (t *txnReq) lookupEvent(roomVersion gomatrixserverlib.RoomVersion, missingE return nil, verifySigError{event.EventID(), err} } h := event.Headered(roomVersion) + t.newEvents[h.EventID()] = true return &h, nil } diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index 14fe0a820..cb8aec6f5 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -92,6 +92,9 @@ func (t *testRoomserverAPI) InputRoomEvents( response *api.InputRoomEventsResponse, ) error { t.inputRoomEvents = append(t.inputRoomEvents, request.InputRoomEvents...) + for _, ire := range request.InputRoomEvents { + fmt.Println("InputRoomEvents: ", ire.Event.EventID()) + } return nil } @@ -292,6 +295,7 @@ type txnFedClient struct { func (c *txnFedClient) LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( res gomatrixserverlib.RespState, err error, ) { + fmt.Println("testFederationClient.LookupState", eventID) r, ok := c.state[eventID] if !ok { err = fmt.Errorf("txnFedClient: no /state for event %s", eventID) @@ -301,6 +305,7 @@ func (c *txnFedClient) LookupState(ctx context.Context, s gomatrixserverlib.Serv return } func (c *txnFedClient) LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) { + fmt.Println("testFederationClient.LookupStateIDs", eventID) r, ok := c.stateIDs[eventID] if !ok { err = fmt.Errorf("txnFedClient: no /state_ids for event %s", eventID) @@ -310,6 +315,7 @@ func (c *txnFedClient) LookupStateIDs(ctx context.Context, s gomatrixserverlib.S return } func (c *txnFedClient) GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) { + fmt.Println("testFederationClient.GetEvent", eventID) r, ok := c.getEvent[eventID] if !ok { err = fmt.Errorf("txnFedClient: no /event for event ID %s", eventID) @@ -331,6 +337,8 @@ func mustCreateTransaction(rsAPI api.RoomserverInternalAPI, fedClient txnFederat eduProducer: producers.NewEDUServerProducer(&testEDUProducer{}), keys: &testNopJSONVerifier{}, federation: fedClient, + haveEvents: make(map[string]*gomatrixserverlib.HeaderedEvent), + newEvents: make(map[string]bool), } t.PDUs = pdus t.Origin = testOrigin @@ -538,30 +546,35 @@ func TestTransactionFetchMissingStateByStateIDs(t *testing.T) { var rsAPI *testRoomserverAPI rsAPI = &testRoomserverAPI{ queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse { - // if we have event C from GME, then PrevEventsExist: True, else it is false - prevEventExists := false omitTuples := []gomatrixserverlib.StateKeyTuple{ - gomatrixserverlib.StateKeyTuple{ + { EventType: gomatrixserverlib.MRoomPowerLevels, StateKey: "", }, } + askingForEvent := req.PrevEventIDs[0] + haveEventB := false + haveEventC := false for _, ev := range rsAPI.inputRoomEvents { - if ev.Event.EventID() == eventC.EventID() && len(req.PrevEventIDs) == 1 && req.PrevEventIDs[0] == eventC.EventID() { - prevEventExists = true - } - if ev.Event.EventID() == eventB.EventID() { - omitTuples = nil + switch ev.Event.EventID() { + case eventB.EventID(): + haveEventB = true + omitTuples = nil // include event B now + case eventC.EventID(): + haveEventC = true } } + prevEventExists := false + if askingForEvent == eventC.EventID() { + prevEventExists = haveEventC + } else if askingForEvent == eventB.EventID() { + prevEventExists = haveEventB + } var stateEvents []gomatrixserverlib.HeaderedEvent if prevEventExists { stateEvents = fromStateTuples(req.StateToFetch, omitTuples) } return api.QueryStateAfterEventsResponse{ - // setting this to false should trigger a call to /get_missing_events or /state_ids depending - // on far back we've gone. The first time should trigger /get_missing_events but we should - // give up on subsequent calls and just use the /state_ids PrevEventsExist: prevEventExists, RoomExists: true, StateEvents: stateEvents, @@ -621,14 +634,14 @@ func TestTransactionFetchMissingStateByStateIDs(t *testing.T) { } cli := &txnFedClient{ stateIDs: map[string]gomatrixserverlib.RespStateIDs{ - eventB.EventID(): gomatrixserverlib.RespStateIDs{ + eventB.EventID(): { StateEventIDs: stateEventIDs, AuthEventIDs: authEventIDs, }, }, // /event for event B returns it getEvent: map[string]gomatrixserverlib.Transaction{ - eventB.EventID(): gomatrixserverlib.Transaction{ + eventB.EventID(): { PDUs: []json.RawMessage{ eventB.JSON(), },