diff --git a/keyserver/internal/cross_signing.go b/keyserver/internal/cross_signing.go index bfb2037f8..527990cf9 100644 --- a/keyserver/internal/cross_signing.go +++ b/keyserver/internal/cross_signing.go @@ -308,12 +308,8 @@ func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req // Finally, generate a notification that we updated the signatures. for userID := range req.Signatures { - masterKey := queryRes.MasterKeys[userID] - selfSigningKey := queryRes.SelfSigningKeys[userID] update := eduserverAPI.CrossSigningKeyUpdate{ - UserID: userID, - MasterKey: &masterKey, - SelfSigningKey: &selfSigningKey, + UserID: userID, } if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil { res.Error = &api.KeyError{ diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 371dda6d0..259249217 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -243,45 +243,49 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques } domain := string(serverName) // query local devices - deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs) - if err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("failed to query local device keys: %s", err), + if serverName == a.ThisServer { + deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("failed to query local device keys: %s", err), + } + return } - return - } - // pull out display names after we have the keys so we handle wildcards correctly - var dids []string - for _, dk := range deviceKeys { - dids = append(dids, dk.DeviceID) - } - var queryRes userapi.QueryDeviceInfosResponse - err = a.UserAPI.QueryDeviceInfos(ctx, &userapi.QueryDeviceInfosRequest{ - DeviceIDs: dids, - }, &queryRes) - if err != nil { - util.GetLogger(ctx).Warnf("Failed to QueryDeviceInfos for device IDs, display names will be missing") - } - - if res.DeviceKeys[userID] == nil { - res.DeviceKeys[userID] = make(map[string]json.RawMessage) - } - for _, dk := range deviceKeys { - if len(dk.KeyJSON) == 0 { - continue // don't include blank keys + // pull out display names after we have the keys so we handle wildcards correctly + var dids []string + for _, dk := range deviceKeys { + dids = append(dids, dk.DeviceID) } - // inject display name if known (either locally or remotely) - displayName := dk.DisplayName - if queryRes.DeviceInfo[dk.DeviceID].DisplayName != "" { - displayName = queryRes.DeviceInfo[dk.DeviceID].DisplayName + var queryRes userapi.QueryDeviceInfosResponse + err = a.UserAPI.QueryDeviceInfos(ctx, &userapi.QueryDeviceInfosRequest{ + DeviceIDs: dids, + }, &queryRes) + if err != nil { + util.GetLogger(ctx).Warnf("Failed to QueryDeviceInfos for device IDs, display names will be missing") } - dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct { - DisplayName string `json:"device_display_name,omitempty"` - }{displayName}) - res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON - } + if res.DeviceKeys[userID] == nil { + res.DeviceKeys[userID] = make(map[string]json.RawMessage) + } + for _, dk := range deviceKeys { + if len(dk.KeyJSON) == 0 { + continue // don't include blank keys + } + // inject display name if known (either locally or remotely) + displayName := dk.DisplayName + if queryRes.DeviceInfo[dk.DeviceID].DisplayName != "" { + displayName = queryRes.DeviceInfo[dk.DeviceID].DisplayName + } + dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct { + DisplayName string `json:"device_display_name,omitempty"` + }{displayName}) + res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON + } + } else { + domainToDeviceKeys[domain] = make(map[string][]string) + domainToDeviceKeys[domain][userID] = append(domainToDeviceKeys[domain][userID], deviceIDs...) + } // work out if our cross-signing request for this user was // satisfied, if not add them to the list of things to fetch if _, ok := res.MasterKeys[userID]; !ok { @@ -322,14 +326,8 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques if err = json.Unmarshal(key, &deviceKey); err != nil { continue } - if deviceKey.Signatures == nil { - deviceKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} - } for sourceUserID, forSourceUser := range sigMap { for sourceKeyID, sourceSig := range forSourceUser { - if _, ok := deviceKey.Signatures[sourceUserID]; !ok { - deviceKey.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} - } deviceKey.Signatures[sourceUserID][sourceKeyID] = sourceSig } }