Refactor claimRemoteKeys

This commit is contained in:
Neil Alexander 2022-10-27 13:31:53 +01:00
parent a169a9121a
commit d0ca183f49
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944

View file

@ -128,58 +128,49 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC
func (a *KeyInternalAPI) claimRemoteKeys( func (a *KeyInternalAPI) claimRemoteKeys(
ctx context.Context, timeout time.Duration, res *api.PerformClaimKeysResponse, domainToDeviceKeys map[string]map[string]map[string]string, ctx context.Context, timeout time.Duration, res *api.PerformClaimKeysResponse, domainToDeviceKeys map[string]map[string]map[string]string,
) { ) {
resultCh := make(chan *gomatrixserverlib.RespClaimKeys, len(domainToDeviceKeys)) var wg sync.WaitGroup // Wait for fan-out goroutines to finish
// allows us to wait until all federation servers have been poked var mu sync.Mutex // Protects the response struct
var wg sync.WaitGroup var claimed int // Number of keys claimed in total
wg.Add(len(domainToDeviceKeys)) var failures int // Number of servers we failed to ask
// mutex for failures
var failMu sync.Mutex util.GetLogger(ctx).WithField("num_servers", len(domainToDeviceKeys)).Info("Claiming remote keys from servers")
util.GetLogger(ctx).WithField("num_servers", len(domainToDeviceKeys)).Info("Claiming remote keys from servers") wg.Add(len(domainToDeviceKeys))
// fan out
for d, k := range domainToDeviceKeys { for d, k := range domainToDeviceKeys {
go func(domain string, keysToClaim map[string]map[string]string) { go func(domain string, keysToClaim map[string]map[string]string) {
defer wg.Done()
fedCtx, cancel := context.WithTimeout(ctx, timeout) fedCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel() defer cancel()
defer wg.Done()
claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, gomatrixserverlib.ServerName(domain), keysToClaim) claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, gomatrixserverlib.ServerName(domain), keysToClaim)
mu.Lock()
defer mu.Unlock()
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).WithField("server", domain).Error("ClaimKeys failed") util.GetLogger(ctx).WithError(err).WithField("server", domain).Error("ClaimKeys failed")
failMu.Lock()
res.Failures[domain] = map[string]interface{}{ res.Failures[domain] = map[string]interface{}{
"message": err.Error(), "message": err.Error(),
} }
failMu.Unlock() failures++
return return
} }
resultCh <- &claimKeyRes
for userID, deviceIDToKeys := range claimKeyRes.OneTimeKeys {
res.OneTimeKeys[userID] = make(map[string]map[string]json.RawMessage)
for deviceID, keys := range deviceIDToKeys {
res.OneTimeKeys[userID][deviceID] = keys
claimed += len(keys)
}
}
}(d, k) }(d, k)
} }
// Close the result channel when the goroutines have quit so the for .. range exits wg.Wait()
go func() { util.GetLogger(ctx).WithFields(logrus.Fields{
wg.Wait() "num_keys": claimed,
close(resultCh) "num_failures": failures,
}() }).Info("Claimed remote keys")
keysClaimed := 0
for result := range resultCh {
for userID, nest := range result.OneTimeKeys {
res.OneTimeKeys[userID] = make(map[string]map[string]json.RawMessage)
for deviceID, nest2 := range nest {
res.OneTimeKeys[userID][deviceID] = make(map[string]json.RawMessage)
for keyIDWithAlgo, otk := range nest2 {
keyJSON, err := json.Marshal(otk)
if err != nil {
continue
}
res.OneTimeKeys[userID][deviceID][keyIDWithAlgo] = keyJSON
keysClaimed++
}
}
}
}
util.GetLogger(ctx).WithField("num_keys", keysClaimed).Info("Claimed remote keys")
} }
func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error { func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error {