Refactor claimRemoteKeys
This commit is contained in:
parent
a169a9121a
commit
d0ca183f49
|
@ -128,58 +128,49 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC
|
||||||
func (a *KeyInternalAPI) claimRemoteKeys(
|
func (a *KeyInternalAPI) claimRemoteKeys(
|
||||||
ctx context.Context, timeout time.Duration, res *api.PerformClaimKeysResponse, domainToDeviceKeys map[string]map[string]map[string]string,
|
ctx context.Context, timeout time.Duration, res *api.PerformClaimKeysResponse, domainToDeviceKeys map[string]map[string]map[string]string,
|
||||||
) {
|
) {
|
||||||
resultCh := make(chan *gomatrixserverlib.RespClaimKeys, len(domainToDeviceKeys))
|
var wg sync.WaitGroup // Wait for fan-out goroutines to finish
|
||||||
// allows us to wait until all federation servers have been poked
|
var mu sync.Mutex // Protects the response struct
|
||||||
var wg sync.WaitGroup
|
var claimed int // Number of keys claimed in total
|
||||||
wg.Add(len(domainToDeviceKeys))
|
var failures int // Number of servers we failed to ask
|
||||||
// mutex for failures
|
|
||||||
var failMu sync.Mutex
|
util.GetLogger(ctx).WithField("num_servers", len(domainToDeviceKeys)).Info("Claiming remote keys from servers")
|
||||||
util.GetLogger(ctx).WithField("num_servers", len(domainToDeviceKeys)).Info("Claiming remote keys from servers")
|
wg.Add(len(domainToDeviceKeys))
|
||||||
|
|
||||||
// fan out
|
|
||||||
for d, k := range domainToDeviceKeys {
|
for d, k := range domainToDeviceKeys {
|
||||||
go func(domain string, keysToClaim map[string]map[string]string) {
|
go func(domain string, keysToClaim map[string]map[string]string) {
|
||||||
defer wg.Done()
|
|
||||||
fedCtx, cancel := context.WithTimeout(ctx, timeout)
|
fedCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, gomatrixserverlib.ServerName(domain), keysToClaim)
|
claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, gomatrixserverlib.ServerName(domain), keysToClaim)
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(ctx).WithError(err).WithField("server", domain).Error("ClaimKeys failed")
|
util.GetLogger(ctx).WithError(err).WithField("server", domain).Error("ClaimKeys failed")
|
||||||
failMu.Lock()
|
|
||||||
res.Failures[domain] = map[string]interface{}{
|
res.Failures[domain] = map[string]interface{}{
|
||||||
"message": err.Error(),
|
"message": err.Error(),
|
||||||
}
|
}
|
||||||
failMu.Unlock()
|
failures++
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resultCh <- &claimKeyRes
|
|
||||||
|
for userID, deviceIDToKeys := range claimKeyRes.OneTimeKeys {
|
||||||
|
res.OneTimeKeys[userID] = make(map[string]map[string]json.RawMessage)
|
||||||
|
for deviceID, keys := range deviceIDToKeys {
|
||||||
|
res.OneTimeKeys[userID][deviceID] = keys
|
||||||
|
claimed += len(keys)
|
||||||
|
}
|
||||||
|
}
|
||||||
}(d, k)
|
}(d, k)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close the result channel when the goroutines have quit so the for .. range exits
|
wg.Wait()
|
||||||
go func() {
|
util.GetLogger(ctx).WithFields(logrus.Fields{
|
||||||
wg.Wait()
|
"num_keys": claimed,
|
||||||
close(resultCh)
|
"num_failures": failures,
|
||||||
}()
|
}).Info("Claimed remote keys")
|
||||||
|
|
||||||
keysClaimed := 0
|
|
||||||
for result := range resultCh {
|
|
||||||
for userID, nest := range result.OneTimeKeys {
|
|
||||||
res.OneTimeKeys[userID] = make(map[string]map[string]json.RawMessage)
|
|
||||||
for deviceID, nest2 := range nest {
|
|
||||||
res.OneTimeKeys[userID][deviceID] = make(map[string]json.RawMessage)
|
|
||||||
for keyIDWithAlgo, otk := range nest2 {
|
|
||||||
keyJSON, err := json.Marshal(otk)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
res.OneTimeKeys[userID][deviceID][keyIDWithAlgo] = keyJSON
|
|
||||||
keysClaimed++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
util.GetLogger(ctx).WithField("num_keys", keysClaimed).Info("Claimed remote keys")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error {
|
func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error {
|
||||||
|
|
Loading…
Reference in a new issue