diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 5d5ea310d..5f20b2d8e 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -19,6 +19,7 @@ import ( "encoding/json" "fmt" "net/http" + "sync" "time" "github.com/matrix-org/dendrite/clientapi/jsonerror" @@ -705,6 +706,20 @@ Event: return nil, nil } +func (t *txnReq) lookupMissingStateViaState(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( + respState *gomatrixserverlib.RespState, err error) { + state, err := t.federation.LookupState(ctx, t.Origin, roomID, eventID, roomVersion) + if err != nil { + return nil, err + } + // Check that the returned state is valid. + if err := state.Check(ctx, t.keys, nil); err != nil { + return nil, err + } + return &state, nil +} + +// nolint:gocyclo func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( *gomatrixserverlib.RespState, error) { util.GetLogger(ctx).Infof("lookupMissingStateViaStateIDs %s", eventID) @@ -742,27 +757,90 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even } } + concurrentRequests := 8 + missingCount := len(missing) + + // If over 50% of the auth/state events from /state_ids are missing + // then we'll just call /state instead, otherwise we'll just end up + // hammering the remote side with /event requests unnecessarily. + if missingCount > concurrentRequests && missingCount > len(wantIDs)/2 { + util.GetLogger(ctx).WithFields(logrus.Fields{ + "missing": missingCount, + "event_id": eventID, + "room_id": roomID, + "total_state": len(stateIDs.StateEventIDs), + "total_auth_events": len(stateIDs.AuthEventIDs), + }).Info("Fetching all state at event") + return t.lookupMissingStateViaState(ctx, roomID, eventID, roomVersion) + } + util.GetLogger(ctx).WithFields(logrus.Fields{ - "missing": len(missing), - "event_id": eventID, - "room_id": roomID, - "total_state": len(stateIDs.StateEventIDs), - "total_auth_events": len(stateIDs.AuthEventIDs), + "missing": missingCount, + "event_id": eventID, + "room_id": roomID, + "total_state": len(stateIDs.StateEventIDs), + "total_auth_events": len(stateIDs.AuthEventIDs), + "concurrent_requests": concurrentRequests, }).Info("Fetching missing state at event") + // Create a queue containing all of the missing event IDs that we want + // to retrieve. + pending := make(chan string, missingCount) for missingEventID := range missing { + pending <- missingEventID + } + close(pending) + + // Define how many workers we should start to do this. + if missingCount < concurrentRequests { + concurrentRequests = missingCount + } + + // Create the wait group. + var fetchgroup sync.WaitGroup + fetchgroup.Add(concurrentRequests) + + // This is the only place where we'll write to t.haveEvents from + // multiple goroutines, and everywhere else is blocked on this + // synchronous function anyway. + var haveEventsMutex sync.Mutex + + // Define what we'll do in order to fetch the missing event ID. + fetch := func(missingEventID string) { var h *gomatrixserverlib.HeaderedEvent h, err = t.lookupEvent(ctx, roomVersion, missingEventID, false) switch err.(type) { case verifySigError: - continue + break case nil: - // do nothing + break default: - return nil, err + util.GetLogger(ctx).WithFields(logrus.Fields{ + "event_id": missingEventID, + "room_id": roomID, + }).Info("Failed to fetch missing event") + return } + haveEventsMutex.Lock() t.haveEvents[h.EventID()] = h + haveEventsMutex.Unlock() } + + // Create the worker. + worker := func(ch <-chan string) { + defer fetchgroup.Done() + for missingEventID := range ch { + fetch(missingEventID) + } + } + + // Start the workers. + for i := 0; i < concurrentRequests; i++ { + go worker(pending) + } + + // Wait for the workers to finish. + fetchgroup.Wait() resp, err := t.createRespStateFromStateIDs(stateIDs) return resp, err }