diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index ff0968b27..a62d36967 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -128,58 +128,49 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC func (a *KeyInternalAPI) claimRemoteKeys( ctx context.Context, timeout time.Duration, res *api.PerformClaimKeysResponse, domainToDeviceKeys map[string]map[string]map[string]string, ) { - resultCh := make(chan *gomatrixserverlib.RespClaimKeys, len(domainToDeviceKeys)) - // allows us to wait until all federation servers have been poked - var wg sync.WaitGroup - wg.Add(len(domainToDeviceKeys)) - // mutex for failures - var failMu sync.Mutex - util.GetLogger(ctx).WithField("num_servers", len(domainToDeviceKeys)).Info("Claiming remote keys from servers") + var wg sync.WaitGroup // Wait for fan-out goroutines to finish + var mu sync.Mutex // Protects the response struct + var claimed int // Number of keys claimed in total + var failures int // Number of servers we failed to ask + + util.GetLogger(ctx).WithField("num_servers", len(domainToDeviceKeys)).Info("Claiming remote keys from servers") + wg.Add(len(domainToDeviceKeys)) - // fan out for d, k := range domainToDeviceKeys { go func(domain string, keysToClaim map[string]map[string]string) { - defer wg.Done() fedCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() + defer wg.Done() + claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, gomatrixserverlib.ServerName(domain), keysToClaim) + + mu.Lock() + defer mu.Unlock() + if err != nil { util.GetLogger(ctx).WithError(err).WithField("server", domain).Error("ClaimKeys failed") - failMu.Lock() res.Failures[domain] = map[string]interface{}{ "message": err.Error(), } - failMu.Unlock() + failures++ return } - resultCh <- &claimKeyRes + + for userID, deviceIDToKeys := range claimKeyRes.OneTimeKeys { + res.OneTimeKeys[userID] = make(map[string]map[string]json.RawMessage) + for deviceID, keys := range deviceIDToKeys { + res.OneTimeKeys[userID][deviceID] = keys + claimed += len(keys) + } + } }(d, k) } - // Close the result channel when the goroutines have quit so the for .. range exits - go func() { - wg.Wait() - close(resultCh) - }() - - keysClaimed := 0 - for result := range resultCh { - for userID, nest := range result.OneTimeKeys { - res.OneTimeKeys[userID] = make(map[string]map[string]json.RawMessage) - for deviceID, nest2 := range nest { - res.OneTimeKeys[userID][deviceID] = make(map[string]json.RawMessage) - for keyIDWithAlgo, otk := range nest2 { - keyJSON, err := json.Marshal(otk) - if err != nil { - continue - } - res.OneTimeKeys[userID][deviceID][keyIDWithAlgo] = keyJSON - keysClaimed++ - } - } - } - } - util.GetLogger(ctx).WithField("num_keys", keysClaimed).Info("Claimed remote keys") + wg.Wait() + util.GetLogger(ctx).WithFields(logrus.Fields{ + "num_keys": claimed, + "num_failures": failures, + }).Info("Claimed remote keys") } func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error {