Try to find missing auth events

This commit is contained in:
Neil Alexander 2020-09-29 09:35:56 +01:00
parent f8c42b89d2
commit 3b8da64592
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
2 changed files with 69 additions and 84 deletions

View file

@ -368,19 +368,35 @@ func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event, is
} }
// TODO: Make this less bad // TODO: Make this less bad
for _, missingAuthEventID := range stateResp.MissingAuthEventIDs { if len(stateResp.MissingAuthEventIDs) > 0 {
logrus.WithContext(ctx).Infof("Retrieving missing auth event %q", missingAuthEventID) servers := []gomatrixserverlib.ServerName{t.Origin}
if tx, err := t.federation.GetEvent(ctx, e.Origin(), missingAuthEventID); err == nil { serverReq := &api.QueryServerJoinedToRoomRequest{
ev, err := gomatrixserverlib.NewEventFromUntrustedJSON(tx.PDUs[0], stateResp.RoomVersion) RoomID: e.RoomID(),
if err != nil { }
logrus.WithContext(ctx).WithError(err).Warnf("Failed to unmarshal auth event %d", missingAuthEventID) serverRes := &api.QueryServerJoinedToRoomResponse{}
continue if err := t.rsAPI.QueryServerJoinedToRoom(ctx, serverReq, serverRes); err == nil {
servers = append(servers, serverRes.ServerNames...)
logrus.WithContext(ctx).Infof("Found %d server(s) to query for missing events", len(servers))
}
getAuthEvent:
for _, missingAuthEventID := range stateResp.MissingAuthEventIDs {
for _, server := range servers {
logrus.WithContext(ctx).Infof("Retrieving missing auth event %q from %q", missingAuthEventID, server)
tx, err := t.federation.GetEvent(ctx, server, missingAuthEventID)
if err != nil {
continue
}
ev, err := gomatrixserverlib.NewEventFromUntrustedJSON(tx.PDUs[0], stateResp.RoomVersion)
if err != nil {
logrus.WithContext(ctx).WithError(err).Warnf("Failed to unmarshal auth event %q", missingAuthEventID)
continue getAuthEvent
}
if err = t.processEvent(ctx, ev, false); err != nil {
logrus.WithContext(ctx).WithError(err).Warnf("Failed to process auth event %q", missingAuthEventID)
}
continue getAuthEvent
} }
if err = t.processEvent(ctx, ev, false); err != nil {
logrus.WithContext(ctx).WithError(err).Warnf("Failed to process auth event %d", missingAuthEventID)
}
} else {
logrus.WithContext(ctx).WithError(err).Warnf("Failed to retrieve unknown auth event %q", missingAuthEventID)
} }
} }

View file

@ -77,11 +77,10 @@ func (p *testEDUProducer) InputSendToDeviceEvent(
} }
type testRoomserverAPI struct { type testRoomserverAPI struct {
inputRoomEvents []api.InputRoomEvent inputRoomEvents []api.InputRoomEvent
queryMissingAuthPrevEvents func(*api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse queryStateAfterEvents func(*api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse
queryStateAfterEvents func(*api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse queryEventsByID func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse
queryEventsByID func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse queryLatestEventsAndState func(*api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse
queryLatestEventsAndState func(*api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse
} }
func (t *testRoomserverAPI) SetFederationSenderAPI(fsAPI fsAPI.FederationSenderInternalAPI) {} func (t *testRoomserverAPI) SetFederationSenderAPI(fsAPI fsAPI.FederationSenderInternalAPI) {}
@ -163,20 +162,6 @@ func (t *testRoomserverAPI) QueryStateAfterEvents(
return nil return nil
} }
// Query the state after a list of events in a room from the room server.
func (t *testRoomserverAPI) QueryMissingAuthPrevEvents(
ctx context.Context,
request *api.QueryMissingAuthPrevEventsRequest,
response *api.QueryMissingAuthPrevEventsResponse,
) error {
response.RoomVersion = testRoomVersion
res := t.queryMissingAuthPrevEvents(request)
response.RoomExists = res.RoomExists
response.MissingAuthEventIDs = res.MissingAuthEventIDs
response.MissingPrevEventIDs = res.MissingPrevEventIDs
return nil
}
// Query a list of events by event ID. // Query a list of events by event ID.
func (t *testRoomserverAPI) QueryEventsByID( func (t *testRoomserverAPI) QueryEventsByID(
ctx context.Context, ctx context.Context,
@ -468,11 +453,11 @@ func assertInputRoomEvents(t *testing.T, got []api.InputRoomEvent, want []gomatr
// to the roomserver. It's the most basic test possible. // to the roomserver. It's the most basic test possible.
func TestBasicTransaction(t *testing.T) { func TestBasicTransaction(t *testing.T) {
rsAPI := &testRoomserverAPI{ rsAPI := &testRoomserverAPI{
queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse {
return api.QueryMissingAuthPrevEventsResponse{ return api.QueryStateAfterEventsResponse{
RoomExists: true, PrevEventsExist: true,
MissingAuthEventIDs: []string{}, RoomExists: true,
MissingPrevEventIDs: []string{}, StateEvents: fromStateTuples(req.StateToFetch, nil),
} }
}, },
} }
@ -488,11 +473,14 @@ func TestBasicTransaction(t *testing.T) {
// as it does the auth check. // as it does the auth check.
func TestTransactionFailAuthChecks(t *testing.T) { func TestTransactionFailAuthChecks(t *testing.T) {
rsAPI := &testRoomserverAPI{ rsAPI := &testRoomserverAPI{
queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse {
return api.QueryMissingAuthPrevEventsResponse{ return api.QueryStateAfterEventsResponse{
RoomExists: true, PrevEventsExist: true,
MissingAuthEventIDs: []string{"create_event"}, RoomExists: true,
MissingPrevEventIDs: []string{}, // omit the create event so auth checks fail
StateEvents: fromStateTuples(req.StateToFetch, []gomatrixserverlib.StateKeyTuple{
{EventType: gomatrixserverlib.MRoomCreate, StateKey: ""},
}),
} }
}, },
} }
@ -516,6 +504,30 @@ func TestTransactionFetchMissingPrevEvents(t *testing.T) {
var rsAPI *testRoomserverAPI // ref here so we can refer to inputRoomEvents inside these functions var rsAPI *testRoomserverAPI // ref here so we can refer to inputRoomEvents inside these functions
rsAPI = &testRoomserverAPI{ rsAPI = &testRoomserverAPI{
queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse {
// we expect this to be called three times:
// - first with input event to realise there's a gap
// - second with the prevEvent to realise there is no gap
// - third with the input event to realise there is no longer a gap
prevEventsExist := false
if len(req.PrevEventIDs) == 1 {
switch req.PrevEventIDs[0] {
case haveEvent.EventID():
prevEventsExist = true
case prevEvent.EventID():
// we only have this event if we've been send prevEvent
if len(rsAPI.inputRoomEvents) == 1 && rsAPI.inputRoomEvents[0].Event.EventID() == prevEvent.EventID() {
prevEventsExist = true
}
}
}
return api.QueryStateAfterEventsResponse{
PrevEventsExist: prevEventsExist,
RoomExists: true,
StateEvents: fromStateTuples(req.StateToFetch, nil),
}
},
queryLatestEventsAndState: func(req *api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse { queryLatestEventsAndState: func(req *api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse {
return api.QueryLatestEventsAndStateResponse{ return api.QueryLatestEventsAndStateResponse{
RoomExists: true, RoomExists: true,
@ -526,30 +538,6 @@ func TestTransactionFetchMissingPrevEvents(t *testing.T) {
StateEvents: fromStateTuples(req.StateToFetch, nil), StateEvents: fromStateTuples(req.StateToFetch, nil),
} }
}, },
queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse {
// we expect this to be called three times:
// - first with input event to realise there's a gap
// - second with the prevEvent to realise there is no gap
// - third with the input event to realise there is no longer a gap
missingPrevEvent := []string{"missing_prev_event"}
if len(req.PrevEventIDs) == 1 {
switch req.PrevEventIDs[0] {
case haveEvent.EventID():
missingPrevEvent = []string{}
case prevEvent.EventID():
// we only have this event if we've been send prevEvent
if len(rsAPI.inputRoomEvents) == 1 && rsAPI.inputRoomEvents[0].Event.EventID() == prevEvent.EventID() {
missingPrevEvent = []string{}
}
}
}
return api.QueryMissingAuthPrevEventsResponse{
RoomExists: true,
MissingAuthEventIDs: []string{},
MissingPrevEventIDs: missingPrevEvent,
}
},
} }
cli := &txnFedClient{ cli := &txnFedClient{
@ -588,9 +576,6 @@ func TestTransactionFetchMissingPrevEvents(t *testing.T) {
// - /state_ids?event=B is requested, then /event/B to get the state AFTER B. B is a state event. // - /state_ids?event=B is requested, then /event/B to get the state AFTER B. B is a state event.
// - state resolution is done to check C is allowed. // - state resolution is done to check C is allowed.
// This results in B being sent as an outlier FIRST, then C,D. // This results in B being sent as an outlier FIRST, then C,D.
/*
TODO: Fix this test!
func TestTransactionFetchMissingStateByStateIDs(t *testing.T) { func TestTransactionFetchMissingStateByStateIDs(t *testing.T) {
eventA := testEvents[len(testEvents)-5] eventA := testEvents[len(testEvents)-5]
// this is also len(testEvents)-4 // this is also len(testEvents)-4
@ -654,21 +639,6 @@ func TestTransactionFetchMissingStateByStateIDs(t *testing.T) {
StateEvents: fromStateTuples(req.StateToFetch, omitTuples), StateEvents: fromStateTuples(req.StateToFetch, omitTuples),
} }
}, },
queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse {
askingForEvent := req.PrevEventIDs[0]
missingPrevEvents := []string{"missing_prev_event"}
if askingForEvent == eventC.EventID() {
missingPrevEvents = []string{}
} else if askingForEvent == eventB.EventID() {
missingPrevEvents = []string{}
}
return api.QueryMissingAuthPrevEventsResponse{
RoomExists: true,
MissingAuthEventIDs: []string{},
MissingPrevEventIDs: missingPrevEvents,
}
},
queryEventsByID: func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse { queryEventsByID: func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse {
var res api.QueryEventsByIDResponse var res api.QueryEventsByIDResponse
fmt.Println("queryEventsByID ", req.EventIDs) fmt.Println("queryEventsByID ", req.EventIDs)
@ -746,4 +716,3 @@ func TestTransactionFetchMissingStateByStateIDs(t *testing.T) {
mustProcessTransaction(t, txn, nil) mustProcessTransaction(t, txn, nil)
assertInputRoomEvents(t, rsAPI.inputRoomEvents, []gomatrixserverlib.HeaderedEvent{eventB, eventC, eventD}) assertInputRoomEvents(t, rsAPI.inputRoomEvents, []gomatrixserverlib.HeaderedEvent{eventB, eventC, eventD})
} }
*/