Fix sytest, hopefully

This commit is contained in:
Neil Alexander 2021-08-04 11:25:23 +01:00
parent 900d05c21a
commit 4eee0ea4f5
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944

View file

@ -304,20 +304,8 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
return // nothing to query return // nothing to query
} }
// add in any cross-signing requests that need to be made to the list
for domain, forDomain := range domainToCrossSigningKeys {
for userID := range forDomain {
if _, ok := domainToDeviceKeys[domain]; !ok {
domainToDeviceKeys[domain] = make(map[string][]string)
}
if _, ok := domainToDeviceKeys[domain][userID]; !ok {
domainToDeviceKeys[domain][userID] = []string{}
}
}
}
// perform key queries for remote devices // perform key queries for remote devices
a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys) a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys)
} }
func (a *KeyInternalAPI) remoteKeysFromDatabase( func (a *KeyInternalAPI) remoteKeysFromDatabase(
@ -347,18 +335,30 @@ func (a *KeyInternalAPI) remoteKeysFromDatabase(
} }
func (a *KeyInternalAPI) queryRemoteKeys( func (a *KeyInternalAPI) queryRemoteKeys(
ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string, ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse,
domainToDeviceKeys map[string]map[string][]string, domainToCrossSigningKeys map[string]map[string]struct{},
) { ) {
resultCh := make(chan *gomatrixserverlib.RespQueryKeys, len(domainToDeviceKeys)) resultCh := make(chan *gomatrixserverlib.RespQueryKeys, len(domainToDeviceKeys))
// allows us to wait until all federation servers have been poked // allows us to wait until all federation servers have been poked
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(len(domainToDeviceKeys))
// mutex for writing directly to res (e.g failures) // mutex for writing directly to res (e.g failures)
var respMu sync.Mutex var respMu sync.Mutex
domains := map[string]struct{}{}
for domain := range domainToDeviceKeys {
domains[domain] = struct{}{}
}
for domain := range domainToCrossSigningKeys {
domains[domain] = struct{}{}
}
wg.Add(len(domains))
// fan out // fan out
for domain, deviceKeys := range domainToDeviceKeys { for domain := range domains {
go a.queryRemoteKeysOnServer(ctx, domain, deviceKeys, &wg, &respMu, timeout, resultCh, res) go a.queryRemoteKeysOnServer(
ctx, domain, domainToDeviceKeys[domain], domainToCrossSigningKeys[domain],
&wg, &respMu, timeout, resultCh, res,
)
} }
// Close the result channel when the goroutines have quit so the for .. range exits // Close the result channel when the goroutines have quit so the for .. range exits
@ -399,8 +399,8 @@ func (a *KeyInternalAPI) queryRemoteKeys(
} }
func (a *KeyInternalAPI) queryRemoteKeysOnServer( func (a *KeyInternalAPI) queryRemoteKeysOnServer(
ctx context.Context, serverName string, devKeys map[string][]string, wg *sync.WaitGroup, ctx context.Context, serverName string, devKeys map[string][]string, crossSigningKeys map[string]struct{},
respMu *sync.Mutex, timeout time.Duration, resultCh chan<- *gomatrixserverlib.RespQueryKeys, wg *sync.WaitGroup, respMu *sync.Mutex, timeout time.Duration, resultCh chan<- *gomatrixserverlib.RespQueryKeys,
res *api.QueryKeysResponse, res *api.QueryKeysResponse,
) { ) {
defer wg.Done() defer wg.Done()
@ -409,14 +409,24 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
// for users who we do not have any knowledge about, try to start doing device list updates for them // for users who we do not have any knowledge about, try to start doing device list updates for them
// by hitting /users/devices - otherwise fallback to /keys/query which has nicer bulk properties but // by hitting /users/devices - otherwise fallback to /keys/query which has nicer bulk properties but
// lack a stream ID. // lack a stream ID.
var userIDsForAllDevices []string userIDsForAllDevices := map[string]struct{}{}
for userID, deviceIDs := range devKeys { for userID, deviceIDs := range devKeys {
if len(deviceIDs) == 0 { if len(deviceIDs) == 0 {
userIDsForAllDevices = append(userIDsForAllDevices, userID) userIDsForAllDevices[userID] = struct{}{}
//delete(devKeys, userID) delete(devKeys, userID)
} }
} }
for _, userID := range userIDsForAllDevices { // for cross-signing keys, it's probably easier just to hit /keys/query if we aren't already doing
// a device list update, so we'll populate those back into the /keys/query list if not
for userID := range crossSigningKeys {
if devKeys == nil {
devKeys = map[string][]string{}
}
if _, ok := userIDsForAllDevices[userID]; !ok {
devKeys[userID] = []string{}
}
}
for userID := range userIDsForAllDevices {
err := a.Updater.ManualUpdate(context.Background(), gomatrixserverlib.ServerName(serverName), userID) err := a.Updater.ManualUpdate(context.Background(), gomatrixserverlib.ServerName(serverName), userID)
if err != nil { if err != nil {
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{