diff --git a/keyserver/api/api.go b/keyserver/api/api.go index 5f3cbe301..e84dc28df 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -80,6 +80,14 @@ type PerformUploadKeysResponse struct { OneTimeKeyCounts []OneTimeKeysCount } +// KeyError sets a key error field on KeyErrors +func (r *PerformUploadKeysResponse) KeyError(userID, deviceID string, err *KeyError) { + if r.KeyErrors[userID] == nil { + r.KeyErrors[userID] = make(map[string]*KeyError) + } + r.KeyErrors[userID][deviceID] = err +} + type PerformClaimKeysRequest struct { } diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 594d102cd..89bc102b3 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -51,17 +51,35 @@ func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformU continue } - if res.KeyErrors[key.UserID] == nil { - res.KeyErrors[key.UserID] = make(map[string]*api.KeyError) - } - res.KeyErrors[key.UserID][key.DeviceID] = &api.KeyError{ + res.KeyError(key.UserID, key.DeviceID, &api.KeyError{ Error: fmt.Sprintf( "user_id or device_id mismatch: users: %s - %s, devices: %s - %s", gotUserID, key.UserID, gotDeviceID, key.DeviceID, ), + }) + } + // get existing device keys so we can check for changes + existingKeys := make([]api.DeviceKeys, len(keysToStore)) + for i := range keysToStore { + existingKeys[i] = api.DeviceKeys{ + UserID: keysToStore[i].UserID, + DeviceID: keysToStore[i].DeviceID, } } + if err := a.db.DeviceKeysJSON(ctx, existingKeys); err != nil { + res.Error = &api.KeyError{ + Error: fmt.Sprintf("failed to query existing device keys: %s", err.Error()), + } + return + } // store the device keys and emit changes + if err := a.db.StoreDeviceKeys(ctx, keysToStore); err != nil { + res.Error = &api.KeyError{ + Error: fmt.Sprintf("failed to store device keys: %s", err.Error()), + } + return + } + a.emitDeviceKeyChanges(existingKeys, keysToStore) } func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { @@ -75,27 +93,30 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform } existingKeys, err := a.db.ExistingOneTimeKeys(ctx, key.UserID, key.DeviceID, keyIDsWithAlgorithms) if err != nil { - if res.KeyErrors[key.UserID] == nil { - res.KeyErrors[key.UserID] = make(map[string]*api.KeyError) - } - res.KeyErrors[key.UserID][key.DeviceID] = &api.KeyError{ - Error: "failed to query existing keys: " + err.Error(), - } + res.KeyError(key.UserID, key.DeviceID, &api.KeyError{ + Error: "failed to query existing one-time keys: " + err.Error(), + }) continue } for keyIDWithAlgo := range existingKeys { // if keys exist and the JSON doesn't match, error out as the key already exists if bytes.Compare(existingKeys[keyIDWithAlgo], key.KeyJSON[keyIDWithAlgo]) != 0 { - if res.KeyErrors[key.UserID] == nil { - res.KeyErrors[key.UserID] = make(map[string]*api.KeyError) - } - res.KeyErrors[key.UserID][key.DeviceID] = &api.KeyError{ - Error: fmt.Sprintf("%s device %s: algorithm / key ID %s already exists", key.UserID, key.DeviceID, keyIDWithAlgo), - } + res.KeyError(key.UserID, key.DeviceID, &api.KeyError{ + Error: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time key already exists", key.UserID, key.DeviceID, keyIDWithAlgo), + }) continue } } // store one-time keys + if err := a.db.StoreOneTimeKeys(ctx, key); err != nil { + res.KeyError(key.UserID, key.DeviceID, &api.KeyError{ + Error: fmt.Sprintf("%s device %s : failed to store one-time keys: %s", key.UserID, key.DeviceID, err.Error()), + }) + } } } + +func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceKeys) { + // TODO +} diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index 85426e213..89b666d18 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -17,10 +17,22 @@ package storage import ( "context" "encoding/json" + + "github.com/matrix-org/dendrite/keyserver/api" ) type Database interface { // ExistingOneTimeKeys returns a map of keyIDWithAlgorithm to key JSON for the given parameters. If no keys exist with this combination // of user/device/key/algorithm 4-uple then it is omitted from the map. Returns an error when failing to communicate with the database. ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) + + // StoreOneTimeKeys persists the given one-time keys. + StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) error + + // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` already then it will be replaced. + DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error + + // StoreDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. + // Returns an error if there was a problem storing the keys. + StoreDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error }