diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index a1151e304..ddd7c4744 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -370,19 +370,22 @@ func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event) er } // We will need to know this when fetching missing auth or prev events. - t.servers = []gomatrixserverlib.ServerName{t.Origin} - serverReq := &api.QueryServerJoinedToRoomRequest{ - RoomID: e.RoomID(), - } - serverRes := &api.QueryServerJoinedToRoomResponse{} - if err := t.rsAPI.QueryServerJoinedToRoom(ctx, serverReq, serverRes); err == nil { - t.servers = append(t.servers, serverRes.ServerNames...) - logger.Infof("Found %d server(s) to query for missing events", len(t.servers)) + getServers := func() { + t.servers = []gomatrixserverlib.ServerName{} + serverReq := &api.QueryServerJoinedToRoomRequest{ + RoomID: e.RoomID(), + } + serverRes := &api.QueryServerJoinedToRoomResponse{} + if err := t.rsAPI.QueryServerJoinedToRoom(ctx, serverReq, serverRes); err == nil { + t.servers = append(t.servers, serverRes.ServerNames...) + logger.Infof("Found %d server(s) to query for missing events", len(t.servers)) + } } if len(stateResp.MissingAuthEventIDs) > 0 { logger.Infof("Event refers to %d unknown auth_events", len(stateResp.MissingAuthEventIDs)) + getServers() if err := t.retrieveMissingAuthEvents(ctx, e, &stateResp); err != nil { return fmt.Errorf("t.retrieveMissingAuthEvents: %w", err) } @@ -390,6 +393,7 @@ func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event) er if len(stateResp.MissingPrevEventIDs) > 0 { logger.Infof("Event refers to %d unknown prev_events", len(stateResp.MissingPrevEventIDs)) + getServers() return t.processEventWithMissingState(ctx, e, stateResp.RoomVersion) } @@ -420,7 +424,7 @@ func (t *txnReq) retrieveMissingAuthEvents( withNextEvent: for missingAuthEventID := range missingAuthEvents { withNextServer: - for _, server := range t.servers { + for _, server := range append([]gomatrixserverlib.ServerName{t.Origin}, t.servers...) { logger.Infof("Retrieving missing auth event %q from %q", missingAuthEventID, server) tx, err := t.federation.GetEvent(ctx, server, missingAuthEventID) if err != nil { @@ -994,7 +998,7 @@ func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib. } var event gomatrixserverlib.Event found := false - for _, serverName := range t.servers { + for _, serverName := range append([]gomatrixserverlib.ServerName{t.Origin}, t.servers...) { txn, err := t.federation.GetEvent(ctx, serverName, missingEventID) if err != nil || len(txn.PDUs) == 0 { util.GetLogger(ctx).WithError(err).WithField("event_id", missingEventID).Warn("Failed to get missing /event for event ID")