From 4eee0ea4f5516e50eac0927555d012c6dbb01b22 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 4 Aug 2021 11:25:23 +0100 Subject: [PATCH] Fix sytest, hopefully --- keyserver/internal/internal.go | 56 ++++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 23 deletions(-) diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 0126fa066..c5711e73c 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -304,20 +304,8 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques return // nothing to query } - // add in any cross-signing requests that need to be made to the list - for domain, forDomain := range domainToCrossSigningKeys { - for userID := range forDomain { - if _, ok := domainToDeviceKeys[domain]; !ok { - domainToDeviceKeys[domain] = make(map[string][]string) - } - if _, ok := domainToDeviceKeys[domain][userID]; !ok { - domainToDeviceKeys[domain][userID] = []string{} - } - } - } - // perform key queries for remote devices - a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys) + a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys) } func (a *KeyInternalAPI) remoteKeysFromDatabase( @@ -347,18 +335,30 @@ func (a *KeyInternalAPI) remoteKeysFromDatabase( } func (a *KeyInternalAPI) queryRemoteKeys( - ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string, + ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse, + domainToDeviceKeys map[string]map[string][]string, domainToCrossSigningKeys map[string]map[string]struct{}, ) { resultCh := make(chan *gomatrixserverlib.RespQueryKeys, len(domainToDeviceKeys)) // allows us to wait until all federation servers have been poked var wg sync.WaitGroup - wg.Add(len(domainToDeviceKeys)) // mutex for writing directly to res (e.g failures) var respMu sync.Mutex + domains := map[string]struct{}{} + for domain := range domainToDeviceKeys { + domains[domain] = struct{}{} + } + for domain := range domainToCrossSigningKeys { + domains[domain] = struct{}{} + } + wg.Add(len(domains)) + // fan out - for domain, deviceKeys := range domainToDeviceKeys { - go a.queryRemoteKeysOnServer(ctx, domain, deviceKeys, &wg, &respMu, timeout, resultCh, res) + for domain := range domains { + go a.queryRemoteKeysOnServer( + ctx, domain, domainToDeviceKeys[domain], domainToCrossSigningKeys[domain], + &wg, &respMu, timeout, resultCh, res, + ) } // Close the result channel when the goroutines have quit so the for .. range exits @@ -399,8 +399,8 @@ func (a *KeyInternalAPI) queryRemoteKeys( } func (a *KeyInternalAPI) queryRemoteKeysOnServer( - ctx context.Context, serverName string, devKeys map[string][]string, wg *sync.WaitGroup, - respMu *sync.Mutex, timeout time.Duration, resultCh chan<- *gomatrixserverlib.RespQueryKeys, + ctx context.Context, serverName string, devKeys map[string][]string, crossSigningKeys map[string]struct{}, + wg *sync.WaitGroup, respMu *sync.Mutex, timeout time.Duration, resultCh chan<- *gomatrixserverlib.RespQueryKeys, res *api.QueryKeysResponse, ) { defer wg.Done() @@ -409,14 +409,24 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( // for users who we do not have any knowledge about, try to start doing device list updates for them // by hitting /users/devices - otherwise fallback to /keys/query which has nicer bulk properties but // lack a stream ID. - var userIDsForAllDevices []string + userIDsForAllDevices := map[string]struct{}{} for userID, deviceIDs := range devKeys { if len(deviceIDs) == 0 { - userIDsForAllDevices = append(userIDsForAllDevices, userID) - //delete(devKeys, userID) + userIDsForAllDevices[userID] = struct{}{} + delete(devKeys, userID) } } - for _, userID := range userIDsForAllDevices { + // for cross-signing keys, it's probably easier just to hit /keys/query if we aren't already doing + // a device list update, so we'll populate those back into the /keys/query list if not + for userID := range crossSigningKeys { + if devKeys == nil { + devKeys = map[string][]string{} + } + if _, ok := userIDsForAllDevices[userID]; !ok { + devKeys[userID] = []string{} + } + } + for userID := range userIDsForAllDevices { err := a.Updater.ManualUpdate(context.Background(), gomatrixserverlib.ServerName(serverName), userID) if err != nil { logrus.WithFields(logrus.Fields{