Refactor claimRemoteKeys
This commit is contained in:
parent
a785532463
commit
a2706e6498
|
@ -44,7 +44,7 @@ func (a *FederationInternalAPI) ClaimKeys(
|
|||
) (gomatrixserverlib.RespClaimKeys, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second*30)
|
||||
defer cancel()
|
||||
ires, err := a.doRequestIfNotBackingOffOrBlacklisted(s, func() (interface{}, error) {
|
||||
ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) {
|
||||
return a.federation.ClaimKeys(ctx, s, oneTimeKeys)
|
||||
})
|
||||
if err != nil {
|
||||
|
|
|
@ -128,58 +128,49 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC
|
|||
func (a *KeyInternalAPI) claimRemoteKeys(
|
||||
ctx context.Context, timeout time.Duration, res *api.PerformClaimKeysResponse, domainToDeviceKeys map[string]map[string]map[string]string,
|
||||
) {
|
||||
resultCh := make(chan *gomatrixserverlib.RespClaimKeys, len(domainToDeviceKeys))
|
||||
// allows us to wait until all federation servers have been poked
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(domainToDeviceKeys))
|
||||
// mutex for failures
|
||||
var failMu sync.Mutex
|
||||
util.GetLogger(ctx).WithField("num_servers", len(domainToDeviceKeys)).Info("Claiming remote keys from servers")
|
||||
var wg sync.WaitGroup // Wait for fan-out goroutines to finish
|
||||
var mu sync.Mutex // Protects the response struct
|
||||
var claimed int // Number of keys claimed in total
|
||||
var failures int // Number of servers we failed to ask
|
||||
|
||||
util.GetLogger(ctx).Infof("Claiming remote keys from %d server(s)", len(domainToDeviceKeys))
|
||||
wg.Add(len(domainToDeviceKeys))
|
||||
|
||||
// fan out
|
||||
for d, k := range domainToDeviceKeys {
|
||||
go func(domain string, keysToClaim map[string]map[string]string) {
|
||||
defer wg.Done()
|
||||
fedCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
defer wg.Done()
|
||||
|
||||
claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, gomatrixserverlib.ServerName(domain), keysToClaim)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).WithField("server", domain).Error("ClaimKeys failed")
|
||||
failMu.Lock()
|
||||
res.Failures[domain] = map[string]interface{}{
|
||||
"message": err.Error(),
|
||||
}
|
||||
failMu.Unlock()
|
||||
failures++
|
||||
return
|
||||
}
|
||||
resultCh <- &claimKeyRes
|
||||
|
||||
for userID, deviceIDToKeys := range claimKeyRes.OneTimeKeys {
|
||||
res.OneTimeKeys[userID] = make(map[string]map[string]json.RawMessage)
|
||||
for deviceID, keys := range deviceIDToKeys {
|
||||
res.OneTimeKeys[userID][deviceID] = keys
|
||||
claimed += len(keys)
|
||||
}
|
||||
}
|
||||
}(d, k)
|
||||
}
|
||||
|
||||
// Close the result channel when the goroutines have quit so the for .. range exits
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(resultCh)
|
||||
}()
|
||||
|
||||
keysClaimed := 0
|
||||
for result := range resultCh {
|
||||
for userID, nest := range result.OneTimeKeys {
|
||||
res.OneTimeKeys[userID] = make(map[string]map[string]json.RawMessage)
|
||||
for deviceID, nest2 := range nest {
|
||||
res.OneTimeKeys[userID][deviceID] = make(map[string]json.RawMessage)
|
||||
for keyIDWithAlgo, otk := range nest2 {
|
||||
keyJSON, err := json.Marshal(otk)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
res.OneTimeKeys[userID][deviceID][keyIDWithAlgo] = keyJSON
|
||||
keysClaimed++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
util.GetLogger(ctx).WithField("num_keys", keysClaimed).Info("Claimed remote keys")
|
||||
util.GetLogger(ctx).WithFields(logrus.Fields{
|
||||
"num_keys": claimed,
|
||||
"num_failures": failures,
|
||||
}).Infof("Claimed remote keys from %d server(s)", len(domainToDeviceKeys))
|
||||
}
|
||||
|
||||
func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error {
|
||||
|
|
Loading…
Reference in a new issue