diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index c5711e73c..28638c294 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -300,12 +300,38 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques // attempt to satisfy key queries from the local database first as we should get device updates pushed to us domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, domainToDeviceKeys) - if len(domainToDeviceKeys) == 0 && len(domainToCrossSigningKeys) == 0 { - return // nothing to query + if len(domainToDeviceKeys) > 0 || len(domainToCrossSigningKeys) > 0 { + // perform key queries for remote devices + a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys) } - // perform key queries for remote devices - a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys) + // Finally, append signatures that we know about + // TODO: This is horrible because we need to round-trip the signature from + // JSON, add the signatures and marshal it again, for some reason? + for userID, forUserID := range res.DeviceKeys { + for keyID, key := range forUserID { + sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, userID, gomatrixserverlib.KeyID(keyID)) + if err != nil { + logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed") + continue + } + if len(sigMap) == 0 { + continue + } + var deviceKey gomatrixserverlib.DeviceKeys + if err = json.Unmarshal(key, &deviceKey); err != nil { + continue + } + for sourceUserID, forSourceUser := range sigMap { + for sourceKeyID, sourceSig := range forSourceUser { + deviceKey.Signatures[sourceUserID][sourceKeyID] = sourceSig + } + } + if js, err := json.Marshal(deviceKey); err == nil { + res.DeviceKeys[userID][keyID] = js + } + } + } } func (a *KeyInternalAPI) remoteKeysFromDatabase(