diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index e1211ffe9..ba653c1e8 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -77,10 +77,11 @@ func (p *testEDUProducer) InputSendToDeviceEvent( } type testRoomserverAPI struct { - inputRoomEvents []api.InputRoomEvent - queryStateAfterEvents func(*api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse - queryEventsByID func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse - queryLatestEventsAndState func(*api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse + inputRoomEvents []api.InputRoomEvent + queryMissingAuthPrevEvents func(*api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse + queryStateAfterEvents func(*api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse + queryEventsByID func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse + queryLatestEventsAndState func(*api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse } func (t *testRoomserverAPI) SetFederationSenderAPI(fsAPI fsAPI.FederationSenderInternalAPI) {} @@ -162,6 +163,20 @@ func (t *testRoomserverAPI) QueryStateAfterEvents( return nil } +// Query the state after a list of events in a room from the room server. +func (t *testRoomserverAPI) QueryMissingAuthPrevEvents( + ctx context.Context, + request *api.QueryMissingAuthPrevEventsRequest, + response *api.QueryMissingAuthPrevEventsResponse, +) error { + response.RoomVersion = testRoomVersion + res := t.queryMissingAuthPrevEvents(request) + response.RoomExists = res.RoomExists + response.MissingAuthEventIDs = res.MissingAuthEventIDs + response.MissingPrevEventIDs = res.MissingPrevEventIDs + return nil +} + // Query a list of events by event ID. func (t *testRoomserverAPI) QueryEventsByID( ctx context.Context, @@ -453,11 +468,11 @@ func assertInputRoomEvents(t *testing.T, got []api.InputRoomEvent, want []gomatr // to the roomserver. It's the most basic test possible. func TestBasicTransaction(t *testing.T) { rsAPI := &testRoomserverAPI{ - queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse { - return api.QueryStateAfterEventsResponse{ - PrevEventsExist: true, - RoomExists: true, - StateEvents: fromStateTuples(req.StateToFetch, nil), + queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { + return api.QueryMissingAuthPrevEventsResponse{ + RoomExists: true, + MissingAuthEventIDs: []string{}, + MissingPrevEventIDs: []string{}, } }, } @@ -473,14 +488,11 @@ func TestBasicTransaction(t *testing.T) { // as it does the auth check. func TestTransactionFailAuthChecks(t *testing.T) { rsAPI := &testRoomserverAPI{ - queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse { - return api.QueryStateAfterEventsResponse{ - PrevEventsExist: true, - RoomExists: true, - // omit the create event so auth checks fail - StateEvents: fromStateTuples(req.StateToFetch, []gomatrixserverlib.StateKeyTuple{ - {EventType: gomatrixserverlib.MRoomCreate, StateKey: ""}, - }), + queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { + return api.QueryMissingAuthPrevEventsResponse{ + RoomExists: true, + MissingAuthEventIDs: []string{"create_event"}, + MissingPrevEventIDs: []string{}, } }, } @@ -504,28 +516,24 @@ func TestTransactionFetchMissingPrevEvents(t *testing.T) { var rsAPI *testRoomserverAPI // ref here so we can refer to inputRoomEvents inside these functions rsAPI = &testRoomserverAPI{ - queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse { - // we expect this to be called three times: - // - first with input event to realise there's a gap - // - second with the prevEvent to realise there is no gap - // - third with the input event to realise there is no longer a gap - prevEventsExist := false + queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { + missingPrevEvent := []string{"missing_prev_event"} if len(req.PrevEventIDs) == 1 { switch req.PrevEventIDs[0] { case haveEvent.EventID(): - prevEventsExist = true + missingPrevEvent = []string{} case prevEvent.EventID(): // we only have this event if we've been send prevEvent if len(rsAPI.inputRoomEvents) == 1 && rsAPI.inputRoomEvents[0].Event.EventID() == prevEvent.EventID() { - prevEventsExist = true + missingPrevEvent = []string{} } } } - return api.QueryStateAfterEventsResponse{ - PrevEventsExist: prevEventsExist, - RoomExists: true, - StateEvents: fromStateTuples(req.StateToFetch, nil), + return api.QueryMissingAuthPrevEventsResponse{ + RoomExists: true, + MissingAuthEventIDs: []string{}, + MissingPrevEventIDs: missingPrevEvent, } }, queryLatestEventsAndState: func(req *api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse { @@ -626,6 +634,38 @@ func TestTransactionFetchMissingStateByStateIDs(t *testing.T) { StateEvents: stateEvents, } }, + + queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { + askingForEvent := req.PrevEventIDs[0] + haveEventB := false + haveEventC := false + for _, ev := range rsAPI.inputRoomEvents { + switch ev.Event.EventID() { + case eventB.EventID(): + haveEventB = true + case eventC.EventID(): + haveEventC = true + } + } + prevEventExists := false + if askingForEvent == eventC.EventID() { + prevEventExists = haveEventC + } else if askingForEvent == eventB.EventID() { + prevEventExists = haveEventB + } + + var missingPrevEvent []string + if !prevEventExists { + missingPrevEvent = []string{"test"} + } + + return api.QueryMissingAuthPrevEventsResponse{ + RoomExists: true, + MissingAuthEventIDs: []string{}, + MissingPrevEventIDs: missingPrevEvent, + } + }, + queryLatestEventsAndState: func(req *api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse { omitTuples := []gomatrixserverlib.StateKeyTuple{ {EventType: gomatrixserverlib.MRoomPowerLevels, StateKey: ""},