diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 2536c1f76..547447270 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -549,6 +549,16 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( } func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { + // get a list of devices from the user API that actually exist, as + // we won't store keys for devices that don't exist + uapidevices := &userapi.QueryDevicesResponse{} + if err := a.UserAPI.QueryDevices(ctx, &userapi.QueryDevicesRequest{UserID: req.UserID}, uapidevices); err != nil { + res.Error = &api.KeyError{ + Err: err.Error(), + } + return + } + var keysToStore []api.DeviceMessage // assert that the user ID / device ID are not lying for each key for _, key := range req.DeviceKeys { @@ -563,6 +573,17 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per keysToStore = append(keysToStore, key.WithStreamID(0)) continue // deleted keys don't need sanity checking } + // check that the device in question actually exists in the user + // API before we try and store a key for it + foundDevice := false + for _, d := range uapidevices.Devices { + if d.ID == key.DeviceID { + foundDevice = true + } + } + if !foundDevice { + continue + } gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str if gotUserID == key.UserID && gotDeviceID == key.DeviceID { @@ -578,29 +599,45 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per }) } - // get existing device keys so we can check for changes - existingKeys := make([]api.DeviceMessage, len(keysToStore)) - for i := range keysToStore { - existingKeys[i] = api.DeviceMessage{ - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - UserID: keysToStore[i].UserID, - DeviceID: keysToStore[i].DeviceID, - }, - } - } - if err := a.DB.DeviceKeysJSON(ctx, existingKeys); err != nil { + // Get all of the user existing device keys so we can check for changes. + existingKeys, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil) + if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()), } return } + + // Work out whether we have device keys in the keyserver for devices that + // no longer exist in the user API. This is mostly an exercise to ensure + // that we keep some integrity between the two. + var toClean []gomatrixserverlib.KeyID + for _, k := range existingKeys { + found := false + for _, d := range uapidevices.Devices { + if k.UserID == d.UserID && k.DeviceID == d.ID { + found = true + } + } + if !found { + toClean = append(toClean, gomatrixserverlib.KeyID(k.DeviceID)) + } + } + if len(toClean) > 0 { + if err = a.DB.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("failed to clean device keys: %s", err.Error()), + } + return + } + } + if req.OnlyDisplayNameUpdates { // add the display name field from keysToStore into existingKeys keysToStore = appendDisplayNames(existingKeys, keysToStore) } // store the device keys and emit changes - err := a.DB.StoreLocalDeviceKeys(ctx, keysToStore) + err = a.DB.StoreLocalDeviceKeys(ctx, keysToStore) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to store device keys: %s", err.Error()),