Always return OTK counts

This commit is contained in:
Neil Alexander 2021-02-26 10:43:46 +00:00
parent 3069079e37
commit db6a3cda5a
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
2 changed files with 18 additions and 6 deletions

View file

@ -108,6 +108,8 @@ type OneTimeKeysCount struct {
// PerformUploadKeysRequest is the request to PerformUploadKeys // PerformUploadKeysRequest is the request to PerformUploadKeys
type PerformUploadKeysRequest struct { type PerformUploadKeysRequest struct {
UserID string // User performing the request
DeviceID string // Device performing the request
DeviceKeys []DeviceKeys DeviceKeys []DeviceKeys
OneTimeKeys []OneTimeKeys OneTimeKeys []OneTimeKeys
// OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update // OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update

View file

@ -513,6 +513,16 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
} }
func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
if len(req.OneTimeKeys) == 0 {
counts, err := a.DB.OneTimeKeysCount(ctx, req.DeviceID, req.UserID)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.DB.OneTimeKeysCount: %s", err),
}
}
res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts)
return
}
for _, key := range req.OneTimeKeys { for _, key := range req.OneTimeKeys {
// grab existing keys based on (user/device/algorithm/key ID) // grab existing keys based on (user/device/algorithm/key ID)
keyIDsWithAlgorithms := make([]string, len(key.KeyJSON)) keyIDsWithAlgorithms := make([]string, len(key.KeyJSON))
@ -521,9 +531,9 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform
keyIDsWithAlgorithms[i] = keyIDWithAlgo keyIDsWithAlgorithms[i] = keyIDWithAlgo
i++ i++
} }
existingKeys, err := a.DB.ExistingOneTimeKeys(ctx, key.UserID, key.DeviceID, keyIDsWithAlgorithms) existingKeys, err := a.DB.ExistingOneTimeKeys(ctx, req.UserID, req.DeviceID, keyIDsWithAlgorithms)
if err != nil { if err != nil {
res.KeyError(key.UserID, key.DeviceID, &api.KeyError{ res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
Err: "failed to query existing one-time keys: " + err.Error(), Err: "failed to query existing one-time keys: " + err.Error(),
}) })
continue continue
@ -531,8 +541,8 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform
for keyIDWithAlgo := range existingKeys { for keyIDWithAlgo := range existingKeys {
// if keys exist and the JSON doesn't match, error out as the key already exists // if keys exist and the JSON doesn't match, error out as the key already exists
if !bytes.Equal(existingKeys[keyIDWithAlgo], key.KeyJSON[keyIDWithAlgo]) { if !bytes.Equal(existingKeys[keyIDWithAlgo], key.KeyJSON[keyIDWithAlgo]) {
res.KeyError(key.UserID, key.DeviceID, &api.KeyError{ res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
Err: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time key already exists", key.UserID, key.DeviceID, keyIDWithAlgo), Err: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time key already exists", req.UserID, req.DeviceID, keyIDWithAlgo),
}) })
continue continue
} }
@ -540,8 +550,8 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform
// store one-time keys // store one-time keys
counts, err := a.DB.StoreOneTimeKeys(ctx, key) counts, err := a.DB.StoreOneTimeKeys(ctx, key)
if err != nil { if err != nil {
res.KeyError(key.UserID, key.DeviceID, &api.KeyError{ res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
Err: fmt.Sprintf("%s device %s : failed to store one-time keys: %s", key.UserID, key.DeviceID, err.Error()), Err: fmt.Sprintf("%s device %s : failed to store one-time keys: %s", req.UserID, req.DeviceID, err.Error()),
}) })
continue continue
} }