diff --git a/keyserver/api/api.go b/keyserver/api/api.go index 442af8715..d1e93e1f7 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -108,6 +108,8 @@ type OneTimeKeysCount struct { // PerformUploadKeysRequest is the request to PerformUploadKeys type PerformUploadKeysRequest struct { + UserID string // User performing the request + DeviceID string // Device performing the request DeviceKeys []DeviceKeys OneTimeKeys []OneTimeKeys // OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 53afe0a60..4884fa60f 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -513,6 +513,16 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per } func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { + if len(req.OneTimeKeys) == 0 { + counts, err := a.DB.OneTimeKeysCount(ctx, req.DeviceID, req.UserID) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("a.DB.OneTimeKeysCount: %s", err), + } + } + res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts) + return + } for _, key := range req.OneTimeKeys { // grab existing keys based on (user/device/algorithm/key ID) keyIDsWithAlgorithms := make([]string, len(key.KeyJSON)) @@ -521,9 +531,9 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform keyIDsWithAlgorithms[i] = keyIDWithAlgo i++ } - existingKeys, err := a.DB.ExistingOneTimeKeys(ctx, key.UserID, key.DeviceID, keyIDsWithAlgorithms) + existingKeys, err := a.DB.ExistingOneTimeKeys(ctx, req.UserID, req.DeviceID, keyIDsWithAlgorithms) if err != nil { - res.KeyError(key.UserID, key.DeviceID, &api.KeyError{ + res.KeyError(req.UserID, req.DeviceID, &api.KeyError{ Err: "failed to query existing one-time keys: " + err.Error(), }) continue @@ -531,8 +541,8 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform for keyIDWithAlgo := range existingKeys { // if keys exist and the JSON doesn't match, error out as the key already exists if !bytes.Equal(existingKeys[keyIDWithAlgo], key.KeyJSON[keyIDWithAlgo]) { - res.KeyError(key.UserID, key.DeviceID, &api.KeyError{ - Err: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time key already exists", key.UserID, key.DeviceID, keyIDWithAlgo), + res.KeyError(req.UserID, req.DeviceID, &api.KeyError{ + Err: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time key already exists", req.UserID, req.DeviceID, keyIDWithAlgo), }) continue } @@ -540,8 +550,8 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform // store one-time keys counts, err := a.DB.StoreOneTimeKeys(ctx, key) if err != nil { - res.KeyError(key.UserID, key.DeviceID, &api.KeyError{ - Err: fmt.Sprintf("%s device %s : failed to store one-time keys: %s", key.UserID, key.DeviceID, err.Error()), + res.KeyError(req.UserID, req.DeviceID, &api.KeyError{ + Err: fmt.Sprintf("%s device %s : failed to store one-time keys: %s", req.UserID, req.DeviceID, err.Error()), }) continue }