diff --git a/clientapi/producers/roomserver.go b/clientapi/producers/roomserver.go index 7eee83f5e..d5add903a 100644 --- a/clientapi/producers/roomserver.go +++ b/clientapi/producers/roomserver.go @@ -51,6 +51,31 @@ func (c *RoomserverProducer) SendEvents( return c.SendInputRoomEvents(ctx, ires) } +// SendEventWithKnownMissingState sends the missing state events followed by the new event to the roomserver with the given stateEventIDs. +func (c *RoomserverProducer) SendEventWithKnownMissingState( + ctx context.Context, stateEventIDs []string, missingStateEvents []gomatrixserverlib.HeaderedEvent, event gomatrixserverlib.HeaderedEvent, +) error { + var ires []api.InputRoomEvent + for _, outlier := range missingStateEvents { + ires = append(ires, api.InputRoomEvent{ + Kind: api.KindOutlier, + Event: outlier.Headered(event.RoomVersion), + AuthEventIDs: outlier.AuthEventIDs(), + }) + } + + ires = append(ires, api.InputRoomEvent{ + Kind: api.KindNew, + Event: event, + AuthEventIDs: event.AuthEventIDs(), + HasState: true, + StateEventIDs: stateEventIDs, + }) + + _, err := c.SendInputRoomEvents(ctx, ires) + return err +} + // SendEventWithState writes an event with KindNew to the roomserver input log // with the state at the event as KindOutlier before it. func (c *RoomserverProducer) SendEventWithState( diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 10210db64..8251e9ff4 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -309,9 +309,7 @@ func (t *txnReq) processEventWithMissingState(e gomatrixserverlib.Event, roomVer // TODO: Attempt to fill in the gap using /get_missing_events // Attempt to fetch the missing state using /state_ids and /events - var respState *gomatrixserverlib.RespState - var err error - respState, err = t.lookupMissingStateViaStateIDs(e, roomVersion) + respState, newEvents, err := t.lookupMissingStateViaStateIDs(e, roomVersion) if err != nil { // Fallback to /state util.GetLogger(t.context).WithError(err).Warn("processEventWithMissingState failed to /state_ids, falling back to /state") @@ -343,8 +341,16 @@ retryAllowedState: return err } - // pass the event along with the state to the roomserver - return t.producer.SendEventWithState(t.context, respState, e.Headered(roomVersion)) + if len(newEvents) > 0 { + stateEventIDs := make([]string, len(respState.StateEvents)) + for i := range respState.StateEvents { + stateEventIDs[i] = respState.StateEvents[i].EventID() + } + return t.producer.SendEventWithKnownMissingState(context.Background(), stateEventIDs, newEvents, e.Headered(roomVersion)) + } + // pass the event along with the state to the roomserver using a background context so we don't + // needlessly expire + return t.producer.SendEventWithState(context.Background(), respState, e.Headered(roomVersion)) } func (t *txnReq) lookupMissingStateViaState(e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) ( @@ -361,12 +367,12 @@ func (t *txnReq) lookupMissingStateViaState(e gomatrixserverlib.Event, roomVersi } func (t *txnReq) lookupMissingStateViaStateIDs(e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) ( - *gomatrixserverlib.RespState, error) { + *gomatrixserverlib.RespState, []gomatrixserverlib.HeaderedEvent, error) { // fetch the state event IDs at the time of the event stateIDs, err := t.federation.LookupStateIDs(t.context, t.Origin, e.RoomID(), e.EventID()) if err != nil { - return nil, err + return nil, nil, err } // fetch as many as we can from the roomserver, do them as 2 calls rather than @@ -377,8 +383,8 @@ func (t *txnReq) lookupMissingStateViaStateIDs(e gomatrixserverlib.Event, roomVe EventIDs: eventList, } var queryRes api.QueryEventsByIDResponse - if err := t.rsAPI.QueryEventsByID(t.context, &queryReq, &queryRes); err != nil { - return nil, err + if err = t.rsAPI.QueryEventsByID(t.context, &queryReq, &queryRes); err != nil { + return nil, nil, err } // allow indexing of current state by event ID for i := range queryRes.Events { @@ -402,28 +408,33 @@ func (t *txnReq) lookupMissingStateViaStateIDs(e gomatrixserverlib.Event, roomVe "total_state": len(stateIDs.StateEventIDs), "total_auth_events": len(stateIDs.AuthEventIDs), }).Info("Fetching missing state at event") + var newEvents []gomatrixserverlib.HeaderedEvent for missingEventID := range missing { - txn, err := t.federation.GetEvent(t.context, t.Origin, missingEventID) + var txn gomatrixserverlib.Transaction + txn, err = t.federation.GetEvent(t.context, t.Origin, missingEventID) if err != nil { util.GetLogger(t.context).WithError(err).WithField("event_id", missingEventID).Warn("failed to get missing /event for event ID") - return nil, err + return nil, nil, err } for _, pdu := range txn.PDUs { - event, err := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, roomVersion) + var event gomatrixserverlib.Event + event, err = gomatrixserverlib.NewEventFromUntrustedJSON(pdu, roomVersion) if err != nil { util.GetLogger(t.context).WithError(err).Warnf("Transaction: Failed to parse event JSON of event %q", event.EventID()) - return nil, unmarshalError{err} + return nil, nil, unmarshalError{err} } - if err := gomatrixserverlib.VerifyAllEventSignatures(t.context, []gomatrixserverlib.Event{event}, t.keys); err != nil { + if err = gomatrixserverlib.VerifyAllEventSignatures(t.context, []gomatrixserverlib.Event{event}, t.keys); err != nil { util.GetLogger(t.context).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID()) - return nil, verifySigError{event.EventID(), err} + return nil, nil, verifySigError{event.EventID(), err} } h := event.Headered(roomVersion) haveEventMap[event.EventID()] = &h + newEvents = append(newEvents, h) } } - return t.createRespStateFromStateIDs(stateIDs, haveEventMap) + resp, err := t.createRespStateFromStateIDs(stateIDs, haveEventMap) + return resp, newEvents, err } func (t *txnReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib.RespStateIDs, haveEventMap map[string]*gomatrixserverlib.HeaderedEvent) ( diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index 5e8e503a8..3649aee6d 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -79,6 +79,7 @@ func (p *testEDUProducer) InputTypingEvent( type testRoomserverAPI struct { inputRoomEvents []api.InputRoomEvent queryStateAfterEvents func(*api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse + queryEventsByID func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse } func (t *testRoomserverAPI) SetFederationSenderAPI(fsAPI fsAPI.FederationSenderInternalAPI) {} @@ -138,6 +139,8 @@ func (t *testRoomserverAPI) QueryEventsByID( request *api.QueryEventsByIDRequest, response *api.QueryEventsByIDResponse, ) error { + res := t.queryEventsByID(request) + response.Events = res.Events return nil } @@ -270,21 +273,43 @@ func (t *testRoomserverAPI) RemoveRoomAlias( return nil } -type txnFedClient struct{} +type txnFedClient struct { + state map[string]gomatrixserverlib.RespState // event_id to response + stateIDs map[string]gomatrixserverlib.RespStateIDs // event_id to response + getEvent map[string]gomatrixserverlib.Transaction // event_id to response +} func (c *txnFedClient) LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( res gomatrixserverlib.RespState, err error, ) { + r, ok := c.state[eventID] + if !ok { + err = fmt.Errorf("txnFedClient: no /state for event %s", eventID) + return + } + res = r return } func (c *txnFedClient) LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) { + r, ok := c.stateIDs[eventID] + if !ok { + err = fmt.Errorf("txnFedClient: no /state_ids for event %s", eventID) + return + } + res = r return } func (c *txnFedClient) GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) { + r, ok := c.getEvent[eventID] + if !ok { + err = fmt.Errorf("txnFedClient: no /event for event ID %s", eventID) + return + } + res = r return } -func mustCreateTransaction(rsAPI api.RoomserverInternalAPI, fedClient txnFederationClient, pdus []json.RawMessage, edus []gomatrixserverlib.EDU) *txnReq { +func mustCreateTransaction(rsAPI api.RoomserverInternalAPI, fedClient txnFederationClient, pdus []json.RawMessage) *txnReq { t := &txnReq{ context: context.Background(), rsAPI: rsAPI, @@ -294,7 +319,6 @@ func mustCreateTransaction(rsAPI api.RoomserverInternalAPI, fedClient txnFederat federation: fedClient, } t.PDUs = pdus - t.EDUs = edus t.Origin = testOrigin t.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) t.Destination = testDestination @@ -368,7 +392,7 @@ func TestBasicTransaction(t *testing.T) { pdus := []json.RawMessage{ testData[len(testData)-1], // a message event } - txn := mustCreateTransaction(rsAPI, &txnFedClient{}, pdus, nil) + txn := mustCreateTransaction(rsAPI, &txnFedClient{}, pdus) mustProcessTransaction(t, txn, nil) assertInputRoomEvents(t, rsAPI.inputRoomEvents, []gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) } @@ -390,10 +414,119 @@ func TestTransactionFailAuthChecks(t *testing.T) { pdus := []json.RawMessage{ testData[len(testData)-1], // a message event } - txn := mustCreateTransaction(rsAPI, &txnFedClient{}, pdus, nil) + txn := mustCreateTransaction(rsAPI, &txnFedClient{}, pdus) mustProcessTransaction(t, txn, []string{ // expect the event to have an error testEvents[len(testEvents)-1].EventID(), }) assertInputRoomEvents(t, rsAPI.inputRoomEvents, nil) // expect no messages to be sent to the roomserver } + +// The purpose of this test is to check that when there are missing prev_events that state is fetched via /state_ids +// and /event and not /state. It works by setting PrevEventsExist=false in the roomserver query response, resulting in +// a call to /state_ids which returns the whole room state. It should attempt to fetch as many of these events from the +// roomserver FIRST, resulting in a call to QueryEventsByID. However, this will be missing the m.room.power_levels event which +// should then be requested via /event. The net result is that the transaction should succeed and there should be 2 +// new events, first the m.room.power_levels event we were missing, then the transaction PDU. +func TestTransactionFetchMissingStateByStateIDs(t *testing.T) { + missingStateEvent := testStateEvents[gomatrixserverlib.StateKeyTuple{ + EventType: gomatrixserverlib.MRoomPowerLevels, + StateKey: "", + }] + rsAPI := &testRoomserverAPI{ + queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse { + return api.QueryStateAfterEventsResponse{ + // setting this to false should trigger a call to /state_ids + PrevEventsExist: false, + RoomExists: true, + StateEvents: nil, + } + }, + queryEventsByID: func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse { + var res api.QueryEventsByIDResponse + for _, wantEventID := range req.EventIDs { + for _, ev := range testStateEvents { + // roomserver is missing the power levels event + if wantEventID == missingStateEvent.EventID() { + continue + } + if ev.EventID() == wantEventID { + res.Events = append(res.Events, ev) + } + } + } + res.QueryEventsByIDRequest = *req + return res + }, + } + inputEvent := testEvents[len(testEvents)-1] + var stateEventIDs []string + for _, ev := range testStateEvents { + stateEventIDs = append(stateEventIDs, ev.EventID()) + } + cli := &txnFedClient{ + // /state_ids returns all the state events + stateIDs: map[string]gomatrixserverlib.RespStateIDs{ + inputEvent.EventID(): gomatrixserverlib.RespStateIDs{ + StateEventIDs: stateEventIDs, + AuthEventIDs: stateEventIDs, + }, + }, + // /event for the missing state event returns it + getEvent: map[string]gomatrixserverlib.Transaction{ + missingStateEvent.EventID(): gomatrixserverlib.Transaction{ + PDUs: []json.RawMessage{ + missingStateEvent.JSON(), + }, + }, + }, + } + + pdus := []json.RawMessage{ + testData[len(testData)-1], // a message event + } + txn := mustCreateTransaction(rsAPI, cli, pdus) + mustProcessTransaction(t, txn, nil) + assertInputRoomEvents(t, rsAPI.inputRoomEvents, []gomatrixserverlib.HeaderedEvent{missingStateEvent, inputEvent}) +} + +// The purpose of this test is to check that when there are missing prev_events and /state_ids fails, that we fallback to +// calling /state which returns the entire room state at that event. It works by setting PrevEventsExist=false in the +// roomserver query response, resulting in a call to /state_ids which fails (unset). It should then fetch via /state. +func TestTransactionFetchMissingStateByFallbackState(t *testing.T) { + rsAPI := &testRoomserverAPI{ + queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse { + return api.QueryStateAfterEventsResponse{ + // setting this to false should trigger a call to /state_ids + PrevEventsExist: false, + RoomExists: true, + StateEvents: nil, + } + }, + } + inputEvent := testEvents[len(testEvents)-1] + var stateEvents []gomatrixserverlib.HeaderedEvent + for _, ev := range testStateEvents { + stateEvents = append(stateEvents, ev) + } + cli := &txnFedClient{ + // /state_ids purposefully unset + stateIDs: nil, + // /state returns the state at that event (which is the current state) + state: map[string]gomatrixserverlib.RespState{ + inputEvent.EventID(): gomatrixserverlib.RespState{ + AuthEvents: gomatrixserverlib.UnwrapEventHeaders(stateEvents), + StateEvents: gomatrixserverlib.UnwrapEventHeaders(stateEvents), + }, + }, + } + + pdus := []json.RawMessage{ + testData[len(testData)-1], // a message event + } + txn := mustCreateTransaction(rsAPI, cli, pdus) + mustProcessTransaction(t, txn, nil) + // the roomserver should get all state events and the new input event + // TODO: it should really be only giving the missing ones + assertInputRoomEvents(t, rsAPI.inputRoomEvents, append(stateEvents, inputEvent)) +}