diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index f3b192c1c..b73c95d2f 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -368,19 +368,35 @@ func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event, is } // TODO: Make this less bad - for _, missingAuthEventID := range stateResp.MissingAuthEventIDs { - logrus.WithContext(ctx).Infof("Retrieving missing auth event %q", missingAuthEventID) - if tx, err := t.federation.GetEvent(ctx, e.Origin(), missingAuthEventID); err == nil { - ev, err := gomatrixserverlib.NewEventFromUntrustedJSON(tx.PDUs[0], stateResp.RoomVersion) - if err != nil { - logrus.WithContext(ctx).WithError(err).Warnf("Failed to unmarshal auth event %d", missingAuthEventID) - continue + if len(stateResp.MissingAuthEventIDs) > 0 { + servers := []gomatrixserverlib.ServerName{t.Origin} + serverReq := &api.QueryServerJoinedToRoomRequest{ + RoomID: e.RoomID(), + } + serverRes := &api.QueryServerJoinedToRoomResponse{} + if err := t.rsAPI.QueryServerJoinedToRoom(ctx, serverReq, serverRes); err == nil { + servers = append(servers, serverRes.ServerNames...) + logrus.WithContext(ctx).Infof("Found %d server(s) to query for missing events", len(servers)) + } + + getAuthEvent: + for _, missingAuthEventID := range stateResp.MissingAuthEventIDs { + for _, server := range servers { + logrus.WithContext(ctx).Infof("Retrieving missing auth event %q from %q", missingAuthEventID, server) + tx, err := t.federation.GetEvent(ctx, server, missingAuthEventID) + if err != nil { + continue + } + ev, err := gomatrixserverlib.NewEventFromUntrustedJSON(tx.PDUs[0], stateResp.RoomVersion) + if err != nil { + logrus.WithContext(ctx).WithError(err).Warnf("Failed to unmarshal auth event %q", missingAuthEventID) + continue getAuthEvent + } + if err = t.processEvent(ctx, ev, false); err != nil { + logrus.WithContext(ctx).WithError(err).Warnf("Failed to process auth event %q", missingAuthEventID) + } + continue getAuthEvent } - if err = t.processEvent(ctx, ev, false); err != nil { - logrus.WithContext(ctx).WithError(err).Warnf("Failed to process auth event %d", missingAuthEventID) - } - } else { - logrus.WithContext(ctx).WithError(err).Warnf("Failed to retrieve unknown auth event %q", missingAuthEventID) } } diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index 17723fc5b..e1211ffe9 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -77,11 +77,10 @@ func (p *testEDUProducer) InputSendToDeviceEvent( } type testRoomserverAPI struct { - inputRoomEvents []api.InputRoomEvent - queryMissingAuthPrevEvents func(*api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse - queryStateAfterEvents func(*api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse - queryEventsByID func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse - queryLatestEventsAndState func(*api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse + inputRoomEvents []api.InputRoomEvent + queryStateAfterEvents func(*api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse + queryEventsByID func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse + queryLatestEventsAndState func(*api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse } func (t *testRoomserverAPI) SetFederationSenderAPI(fsAPI fsAPI.FederationSenderInternalAPI) {} @@ -163,20 +162,6 @@ func (t *testRoomserverAPI) QueryStateAfterEvents( return nil } -// Query the state after a list of events in a room from the room server. -func (t *testRoomserverAPI) QueryMissingAuthPrevEvents( - ctx context.Context, - request *api.QueryMissingAuthPrevEventsRequest, - response *api.QueryMissingAuthPrevEventsResponse, -) error { - response.RoomVersion = testRoomVersion - res := t.queryMissingAuthPrevEvents(request) - response.RoomExists = res.RoomExists - response.MissingAuthEventIDs = res.MissingAuthEventIDs - response.MissingPrevEventIDs = res.MissingPrevEventIDs - return nil -} - // Query a list of events by event ID. func (t *testRoomserverAPI) QueryEventsByID( ctx context.Context, @@ -468,11 +453,11 @@ func assertInputRoomEvents(t *testing.T, got []api.InputRoomEvent, want []gomatr // to the roomserver. It's the most basic test possible. func TestBasicTransaction(t *testing.T) { rsAPI := &testRoomserverAPI{ - queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { - return api.QueryMissingAuthPrevEventsResponse{ - RoomExists: true, - MissingAuthEventIDs: []string{}, - MissingPrevEventIDs: []string{}, + queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse { + return api.QueryStateAfterEventsResponse{ + PrevEventsExist: true, + RoomExists: true, + StateEvents: fromStateTuples(req.StateToFetch, nil), } }, } @@ -488,11 +473,14 @@ func TestBasicTransaction(t *testing.T) { // as it does the auth check. func TestTransactionFailAuthChecks(t *testing.T) { rsAPI := &testRoomserverAPI{ - queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { - return api.QueryMissingAuthPrevEventsResponse{ - RoomExists: true, - MissingAuthEventIDs: []string{"create_event"}, - MissingPrevEventIDs: []string{}, + queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse { + return api.QueryStateAfterEventsResponse{ + PrevEventsExist: true, + RoomExists: true, + // omit the create event so auth checks fail + StateEvents: fromStateTuples(req.StateToFetch, []gomatrixserverlib.StateKeyTuple{ + {EventType: gomatrixserverlib.MRoomCreate, StateKey: ""}, + }), } }, } @@ -516,6 +504,30 @@ func TestTransactionFetchMissingPrevEvents(t *testing.T) { var rsAPI *testRoomserverAPI // ref here so we can refer to inputRoomEvents inside these functions rsAPI = &testRoomserverAPI{ + queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse { + // we expect this to be called three times: + // - first with input event to realise there's a gap + // - second with the prevEvent to realise there is no gap + // - third with the input event to realise there is no longer a gap + prevEventsExist := false + if len(req.PrevEventIDs) == 1 { + switch req.PrevEventIDs[0] { + case haveEvent.EventID(): + prevEventsExist = true + case prevEvent.EventID(): + // we only have this event if we've been send prevEvent + if len(rsAPI.inputRoomEvents) == 1 && rsAPI.inputRoomEvents[0].Event.EventID() == prevEvent.EventID() { + prevEventsExist = true + } + } + } + + return api.QueryStateAfterEventsResponse{ + PrevEventsExist: prevEventsExist, + RoomExists: true, + StateEvents: fromStateTuples(req.StateToFetch, nil), + } + }, queryLatestEventsAndState: func(req *api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse { return api.QueryLatestEventsAndStateResponse{ RoomExists: true, @@ -526,30 +538,6 @@ func TestTransactionFetchMissingPrevEvents(t *testing.T) { StateEvents: fromStateTuples(req.StateToFetch, nil), } }, - queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { - // we expect this to be called three times: - // - first with input event to realise there's a gap - // - second with the prevEvent to realise there is no gap - // - third with the input event to realise there is no longer a gap - missingPrevEvent := []string{"missing_prev_event"} - if len(req.PrevEventIDs) == 1 { - switch req.PrevEventIDs[0] { - case haveEvent.EventID(): - missingPrevEvent = []string{} - case prevEvent.EventID(): - // we only have this event if we've been send prevEvent - if len(rsAPI.inputRoomEvents) == 1 && rsAPI.inputRoomEvents[0].Event.EventID() == prevEvent.EventID() { - missingPrevEvent = []string{} - } - } - } - - return api.QueryMissingAuthPrevEventsResponse{ - RoomExists: true, - MissingAuthEventIDs: []string{}, - MissingPrevEventIDs: missingPrevEvent, - } - }, } cli := &txnFedClient{ @@ -588,9 +576,6 @@ func TestTransactionFetchMissingPrevEvents(t *testing.T) { // - /state_ids?event=B is requested, then /event/B to get the state AFTER B. B is a state event. // - state resolution is done to check C is allowed. // This results in B being sent as an outlier FIRST, then C,D. -/* -TODO: Fix this test! - func TestTransactionFetchMissingStateByStateIDs(t *testing.T) { eventA := testEvents[len(testEvents)-5] // this is also len(testEvents)-4 @@ -654,21 +639,6 @@ func TestTransactionFetchMissingStateByStateIDs(t *testing.T) { StateEvents: fromStateTuples(req.StateToFetch, omitTuples), } }, - queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { - askingForEvent := req.PrevEventIDs[0] - missingPrevEvents := []string{"missing_prev_event"} - if askingForEvent == eventC.EventID() { - missingPrevEvents = []string{} - } else if askingForEvent == eventB.EventID() { - missingPrevEvents = []string{} - } - - return api.QueryMissingAuthPrevEventsResponse{ - RoomExists: true, - MissingAuthEventIDs: []string{}, - MissingPrevEventIDs: missingPrevEvents, - } - }, queryEventsByID: func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse { var res api.QueryEventsByIDResponse fmt.Println("queryEventsByID ", req.EventIDs) @@ -746,4 +716,3 @@ func TestTransactionFetchMissingStateByStateIDs(t *testing.T) { mustProcessTransaction(t, txn, nil) assertInputRoomEvents(t, rsAPI.inputRoomEvents, []gomatrixserverlib.HeaderedEvent{eventB, eventC, eventD}) } -*/