From 868f2d5a806b5d363ce04b2033f12668cdbd8d71 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 12 Oct 2020 11:20:11 +0100 Subject: [PATCH] Recursively fetch auth events if needed --- federationapi/routing/send.go | 92 ++++++++++++++++++++--------------- 1 file changed, 54 insertions(+), 38 deletions(-) diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index fe4295213..49a6a6728 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -112,6 +112,8 @@ type txnReq struct { haveEvents map[string]*gomatrixserverlib.HeaderedEvent // new events which the roomserver does not know about newEvents map[string]bool + // servers which we should fetch missing events from + servers []gomatrixserverlib.ServerName } // A subset of FederationClient functionality that txn requires. Useful for testing. @@ -350,7 +352,7 @@ func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event, is } var stateResp api.QueryMissingAuthPrevEventsResponse if err := t.rsAPI.QueryMissingAuthPrevEvents(ctx, &stateReq, &stateResp); err != nil { - return err + return fmt.Errorf("t.rsAPI.QueryMissingAuthPrevEvents: %w", err) } if !stateResp.RoomExists { @@ -366,45 +368,20 @@ func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event, is if len(stateResp.MissingAuthEventIDs) > 0 { logger.Infof("Event refers to %d unknown auth_events", len(stateResp.MissingAuthEventIDs)) - servers := []gomatrixserverlib.ServerName{t.Origin} - serverReq := &api.QueryServerJoinedToRoomRequest{ - RoomID: e.RoomID(), - } - serverRes := &api.QueryServerJoinedToRoomResponse{} - if err := t.rsAPI.QueryServerJoinedToRoom(ctx, serverReq, serverRes); err == nil { - servers = append(servers, serverRes.ServerNames...) - logger.Infof("Found %d server(s) to query for missing events", len(servers)) + if len(t.servers) == 0 { + t.servers = []gomatrixserverlib.ServerName{t.Origin} + serverReq := &api.QueryServerJoinedToRoomRequest{ + RoomID: e.RoomID(), + } + serverRes := &api.QueryServerJoinedToRoomResponse{} + if err := t.rsAPI.QueryServerJoinedToRoom(ctx, serverReq, serverRes); err == nil { + t.servers = append(t.servers, serverRes.ServerNames...) + logger.Infof("Found %d server(s) to query for missing events", len(t.servers)) + } } - getAuthEvent: - for _, missingAuthEventID := range stateResp.MissingAuthEventIDs { - for _, server := range servers { - logger.Infof("Retrieving missing auth event %q from %q", missingAuthEventID, server) - tx, err := t.federation.GetEvent(ctx, server, missingAuthEventID) - if err != nil { - continue // try the next server - } - ev, err := gomatrixserverlib.NewEventFromUntrustedJSON(tx.PDUs[0], stateResp.RoomVersion) - if err != nil { - logger.WithError(err).Errorf("Failed to unmarshal auth event %q", missingAuthEventID) - continue // try the next server - } - if err = api.SendInputRoomEvents( - context.Background(), - t.rsAPI, - []api.InputRoomEvent{ - { - Kind: api.KindOutlier, - Event: ev.Headered(stateResp.RoomVersion), - AuthEventIDs: ev.AuthEventIDs(), - SendAsServer: api.DoNotSendToOtherServers, - }, - }, - ); err != nil { - logger.WithError(err).Errorf("Failed to send auth event %q to roomserver", missingAuthEventID) - continue getAuthEvent // move onto the next event - } - } + if err := t.retrieveMissingAuthEvents(ctx, e, &stateResp); err != nil { + return fmt.Errorf("t.retrieveMissingAuthEvents: %w", err) } } @@ -427,6 +404,45 @@ func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event, is ) } +func (t *txnReq) retrieveMissingAuthEvents( + ctx context.Context, e gomatrixserverlib.Event, stateResp *api.QueryMissingAuthPrevEventsResponse, +) error { + logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) + + missingAuthEvents := make(map[string]struct{}) + for _, missingAuthEventID := range stateResp.MissingAuthEventIDs { + missingAuthEvents[missingAuthEventID] = struct{}{} + } + +withNextEvent: + for missingAuthEventID := range missingAuthEvents { + withNextServer: + for _, server := range t.servers { + logger.Infof("Retrieving missing auth event %q from %q", missingAuthEventID, server) + tx, err := t.federation.GetEvent(ctx, server, missingAuthEventID) + if err != nil { + logger.WithError(err).Warnf("Failed to retrieve auth event %q", missingAuthEventID) + continue withNextServer + } + ev, err := gomatrixserverlib.NewEventFromUntrustedJSON(tx.PDUs[0], stateResp.RoomVersion) + if err != nil { + logger.WithError(err).Warnf("Failed to unmarshal auth event %q", missingAuthEventID) + continue withNextServer + } + if err = t.processEvent(ctx, ev, false); err != nil { + return fmt.Errorf("recursive t.processEvent: %w", err) + } + delete(missingAuthEvents, missingAuthEventID) + continue withNextEvent + } + } + + if missing := len(missingAuthEvents); missing > 0 { + return fmt.Errorf("Event refers to %d auth_events which we failed to fetch", missing) + } + return nil +} + func checkAllowedByState(e gomatrixserverlib.Event, stateEvents []gomatrixserverlib.Event) error { authUsingState := gomatrixserverlib.NewAuthEvents(nil) for i := range stateEvents {