Recursively fetch auth events if needed

This commit is contained in:
Neil Alexander 2020-10-12 11:20:11 +01:00
parent 0804594a61
commit 868f2d5a80
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944

View file

@ -112,6 +112,8 @@ type txnReq struct {
haveEvents map[string]*gomatrixserverlib.HeaderedEvent haveEvents map[string]*gomatrixserverlib.HeaderedEvent
// new events which the roomserver does not know about // new events which the roomserver does not know about
newEvents map[string]bool newEvents map[string]bool
// servers which we should fetch missing events from
servers []gomatrixserverlib.ServerName
} }
// A subset of FederationClient functionality that txn requires. Useful for testing. // A subset of FederationClient functionality that txn requires. Useful for testing.
@ -350,7 +352,7 @@ func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event, is
} }
var stateResp api.QueryMissingAuthPrevEventsResponse var stateResp api.QueryMissingAuthPrevEventsResponse
if err := t.rsAPI.QueryMissingAuthPrevEvents(ctx, &stateReq, &stateResp); err != nil { if err := t.rsAPI.QueryMissingAuthPrevEvents(ctx, &stateReq, &stateResp); err != nil {
return err return fmt.Errorf("t.rsAPI.QueryMissingAuthPrevEvents: %w", err)
} }
if !stateResp.RoomExists { if !stateResp.RoomExists {
@ -366,45 +368,20 @@ func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event, is
if len(stateResp.MissingAuthEventIDs) > 0 { if len(stateResp.MissingAuthEventIDs) > 0 {
logger.Infof("Event refers to %d unknown auth_events", len(stateResp.MissingAuthEventIDs)) logger.Infof("Event refers to %d unknown auth_events", len(stateResp.MissingAuthEventIDs))
servers := []gomatrixserverlib.ServerName{t.Origin} if len(t.servers) == 0 {
serverReq := &api.QueryServerJoinedToRoomRequest{ t.servers = []gomatrixserverlib.ServerName{t.Origin}
RoomID: e.RoomID(), serverReq := &api.QueryServerJoinedToRoomRequest{
} RoomID: e.RoomID(),
serverRes := &api.QueryServerJoinedToRoomResponse{} }
if err := t.rsAPI.QueryServerJoinedToRoom(ctx, serverReq, serverRes); err == nil { serverRes := &api.QueryServerJoinedToRoomResponse{}
servers = append(servers, serverRes.ServerNames...) if err := t.rsAPI.QueryServerJoinedToRoom(ctx, serverReq, serverRes); err == nil {
logger.Infof("Found %d server(s) to query for missing events", len(servers)) t.servers = append(t.servers, serverRes.ServerNames...)
logger.Infof("Found %d server(s) to query for missing events", len(t.servers))
}
} }
getAuthEvent: if err := t.retrieveMissingAuthEvents(ctx, e, &stateResp); err != nil {
for _, missingAuthEventID := range stateResp.MissingAuthEventIDs { return fmt.Errorf("t.retrieveMissingAuthEvents: %w", err)
for _, server := range servers {
logger.Infof("Retrieving missing auth event %q from %q", missingAuthEventID, server)
tx, err := t.federation.GetEvent(ctx, server, missingAuthEventID)
if err != nil {
continue // try the next server
}
ev, err := gomatrixserverlib.NewEventFromUntrustedJSON(tx.PDUs[0], stateResp.RoomVersion)
if err != nil {
logger.WithError(err).Errorf("Failed to unmarshal auth event %q", missingAuthEventID)
continue // try the next server
}
if err = api.SendInputRoomEvents(
context.Background(),
t.rsAPI,
[]api.InputRoomEvent{
{
Kind: api.KindOutlier,
Event: ev.Headered(stateResp.RoomVersion),
AuthEventIDs: ev.AuthEventIDs(),
SendAsServer: api.DoNotSendToOtherServers,
},
},
); err != nil {
logger.WithError(err).Errorf("Failed to send auth event %q to roomserver", missingAuthEventID)
continue getAuthEvent // move onto the next event
}
}
} }
} }
@ -427,6 +404,45 @@ func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event, is
) )
} }
func (t *txnReq) retrieveMissingAuthEvents(
ctx context.Context, e gomatrixserverlib.Event, stateResp *api.QueryMissingAuthPrevEventsResponse,
) error {
logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID())
missingAuthEvents := make(map[string]struct{})
for _, missingAuthEventID := range stateResp.MissingAuthEventIDs {
missingAuthEvents[missingAuthEventID] = struct{}{}
}
withNextEvent:
for missingAuthEventID := range missingAuthEvents {
withNextServer:
for _, server := range t.servers {
logger.Infof("Retrieving missing auth event %q from %q", missingAuthEventID, server)
tx, err := t.federation.GetEvent(ctx, server, missingAuthEventID)
if err != nil {
logger.WithError(err).Warnf("Failed to retrieve auth event %q", missingAuthEventID)
continue withNextServer
}
ev, err := gomatrixserverlib.NewEventFromUntrustedJSON(tx.PDUs[0], stateResp.RoomVersion)
if err != nil {
logger.WithError(err).Warnf("Failed to unmarshal auth event %q", missingAuthEventID)
continue withNextServer
}
if err = t.processEvent(ctx, ev, false); err != nil {
return fmt.Errorf("recursive t.processEvent: %w", err)
}
delete(missingAuthEvents, missingAuthEventID)
continue withNextEvent
}
}
if missing := len(missingAuthEvents); missing > 0 {
return fmt.Errorf("Event refers to %d auth_events which we failed to fetch", missing)
}
return nil
}
func checkAllowedByState(e gomatrixserverlib.Event, stateEvents []gomatrixserverlib.Event) error { func checkAllowedByState(e gomatrixserverlib.Event, stateEvents []gomatrixserverlib.Event) error {
authUsingState := gomatrixserverlib.NewAuthEvents(nil) authUsingState := gomatrixserverlib.NewAuthEvents(nil)
for i := range stateEvents { for i := range stateEvents {