From a2706e6498287a5b052ef47413175bf7551b36b1 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 27 Oct 2022 15:34:26 +0100 Subject: [PATCH] Refactor `claimRemoteKeys` --- federationapi/internal/federationclient.go | 2 +- keyserver/internal/internal.go | 63 ++++++++++------------ 2 files changed, 28 insertions(+), 37 deletions(-) diff --git a/federationapi/internal/federationclient.go b/federationapi/internal/federationclient.go index b8bd5beda..2636b7fa0 100644 --- a/federationapi/internal/federationclient.go +++ b/federationapi/internal/federationclient.go @@ -44,7 +44,7 @@ func (a *FederationInternalAPI) ClaimKeys( ) (gomatrixserverlib.RespClaimKeys, error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() - ires, err := a.doRequestIfNotBackingOffOrBlacklisted(s, func() (interface{}, error) { + ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { return a.federation.ClaimKeys(ctx, s, oneTimeKeys) }) if err != nil { diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index ff0968b27..92ee80d81 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -128,58 +128,49 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC func (a *KeyInternalAPI) claimRemoteKeys( ctx context.Context, timeout time.Duration, res *api.PerformClaimKeysResponse, domainToDeviceKeys map[string]map[string]map[string]string, ) { - resultCh := make(chan *gomatrixserverlib.RespClaimKeys, len(domainToDeviceKeys)) - // allows us to wait until all federation servers have been poked - var wg sync.WaitGroup - wg.Add(len(domainToDeviceKeys)) - // mutex for failures - var failMu sync.Mutex - util.GetLogger(ctx).WithField("num_servers", len(domainToDeviceKeys)).Info("Claiming remote keys from servers") + var wg sync.WaitGroup // Wait for fan-out goroutines to finish + var mu sync.Mutex // Protects the response struct + var claimed int // Number of keys claimed in total + var failures int // Number of servers we failed to ask + + util.GetLogger(ctx).Infof("Claiming remote keys from %d server(s)", len(domainToDeviceKeys)) + wg.Add(len(domainToDeviceKeys)) - // fan out for d, k := range domainToDeviceKeys { go func(domain string, keysToClaim map[string]map[string]string) { - defer wg.Done() fedCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() + defer wg.Done() + claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, gomatrixserverlib.ServerName(domain), keysToClaim) + + mu.Lock() + defer mu.Unlock() + if err != nil { util.GetLogger(ctx).WithError(err).WithField("server", domain).Error("ClaimKeys failed") - failMu.Lock() res.Failures[domain] = map[string]interface{}{ "message": err.Error(), } - failMu.Unlock() + failures++ return } - resultCh <- &claimKeyRes + + for userID, deviceIDToKeys := range claimKeyRes.OneTimeKeys { + res.OneTimeKeys[userID] = make(map[string]map[string]json.RawMessage) + for deviceID, keys := range deviceIDToKeys { + res.OneTimeKeys[userID][deviceID] = keys + claimed += len(keys) + } + } }(d, k) } - // Close the result channel when the goroutines have quit so the for .. range exits - go func() { - wg.Wait() - close(resultCh) - }() - - keysClaimed := 0 - for result := range resultCh { - for userID, nest := range result.OneTimeKeys { - res.OneTimeKeys[userID] = make(map[string]map[string]json.RawMessage) - for deviceID, nest2 := range nest { - res.OneTimeKeys[userID][deviceID] = make(map[string]json.RawMessage) - for keyIDWithAlgo, otk := range nest2 { - keyJSON, err := json.Marshal(otk) - if err != nil { - continue - } - res.OneTimeKeys[userID][deviceID][keyIDWithAlgo] = keyJSON - keysClaimed++ - } - } - } - } - util.GetLogger(ctx).WithField("num_keys", keysClaimed).Info("Claimed remote keys") + wg.Wait() + util.GetLogger(ctx).WithFields(logrus.Fields{ + "num_keys": claimed, + "num_failures": failures, + }).Infof("Claimed remote keys from %d server(s)", len(domainToDeviceKeys)) } func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error {