diff --git a/clientapi/routing/keys.go b/clientapi/routing/keys.go index 5f7bfb187..ba03a352f 100644 --- a/clientapi/routing/keys.go +++ b/clientapi/routing/keys.go @@ -117,3 +117,40 @@ func QueryKeys(req *http.Request, keyAPI api.KeyInternalAPI) util.JSONResponse { }, } } + +type claimKeysRequest struct { + TimeoutMS int `json:"timeout"` + // The keys to be claimed. A map from user ID, to a map from device ID to algorithm name. + OneTimeKeys map[string]map[string]string `json:"one_time_keys"` +} + +func (r *claimKeysRequest) GetTimeout() time.Duration { + if r.TimeoutMS == 0 { + return 10 * time.Second + } + return time.Duration(r.TimeoutMS) * time.Millisecond +} + +func ClaimKeys(req *http.Request, keyAPI api.KeyInternalAPI) util.JSONResponse { + var r claimKeysRequest + resErr := httputil.UnmarshalJSONRequest(req, &r) + if resErr != nil { + return *resErr + } + claimRes := api.PerformClaimKeysResponse{} + keyAPI.PerformClaimKeys(req.Context(), &api.PerformClaimKeysRequest{ + OneTimeKeys: r.OneTimeKeys, + Timeout: r.GetTimeout(), + }, &claimRes) + if claimRes.Error != nil { + util.GetLogger(req.Context()).WithError(claimRes.Error).Error("failed to PerformClaimKeys") + return jsonerror.InternalServerError() + } + return util.JSONResponse{ + Code: 200, + JSON: map[string]interface{}{ + "one_time_keys": claimRes.OneTimeKeys, + "failures": claimRes.Failures, + }, + } +} diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 492b7e253..c9ed5ea5c 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -714,4 +714,9 @@ func Setup( return QueryKeys(req, keyAPI) }), ).Methods(http.MethodPost, http.MethodOptions) + r0mux.Handle("/keys/claim", + httputil.MakeAuthAPI("keys_claim", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return ClaimKeys(req, keyAPI) + }), + ).Methods(http.MethodPost, http.MethodOptions) } diff --git a/keyserver/api/api.go b/keyserver/api/api.go index 0f6cb7979..d42fb60cf 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -23,6 +23,7 @@ import ( type KeyInternalAPI interface { PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) + // PerformClaimKeys claims one-time keys for use in pre-key messages PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) } @@ -102,9 +103,17 @@ func (r *PerformUploadKeysResponse) KeyError(userID, deviceID string, err *KeyEr } type PerformClaimKeysRequest struct { + // Map of user_id to device_id to algorithm name + OneTimeKeys map[string]map[string]string + Timeout time.Duration } type PerformClaimKeysResponse struct { + // Map of user_id to device_id to algorithm:key_id to key JSON + OneTimeKeys map[string]map[string]map[string]json.RawMessage + // Map of remote server domain to error JSON + Failures map[string]interface{} + // Set if there was a fatal error processing this action Error *KeyError } diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 5be87aa41..041732dc4 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -37,9 +37,39 @@ func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perform a.uploadDeviceKeys(ctx, req, res) a.uploadOneTimeKeys(ctx, req, res) } + func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) { + res.OneTimeKeys = make(map[string]map[string]map[string]json.RawMessage) + res.Failures = make(map[string]interface{}) + // wrap request map in a top-level by-domain map + domainToDeviceKeys := make(map[string]map[string]map[string]string) + for userID, val := range req.OneTimeKeys { + _, serverName, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + continue // ignore invalid users + } + nested, ok := domainToDeviceKeys[string(serverName)] + if !ok { + nested = make(map[string]map[string]string) + } + nested[userID] = val + domainToDeviceKeys[string(serverName)] = nested + } + // claim local keys + if local, ok := domainToDeviceKeys[string(a.ThisServer)]; ok { + keys, err := a.DB.ClaimKeys(ctx, local) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("failed to ClaimKeys locally: %s", err), + } + } + mergeInto(res.OneTimeKeys, keys) + delete(domainToDeviceKeys, string(a.ThisServer)) + } + // TODO: claim remote keys } + func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) { res.DeviceKeys = make(map[string]map[string]json.RawMessage) res.Failures = make(map[string]interface{}) @@ -166,3 +196,19 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceKeys) { // TODO } + +func mergeInto(dst map[string]map[string]map[string]json.RawMessage, src []api.OneTimeKeys) { + for _, key := range src { + _, ok := dst[key.UserID] + if !ok { + dst[key.UserID] = make(map[string]map[string]json.RawMessage) + } + _, ok = dst[key.UserID][key.DeviceID] + if !ok { + dst[key.UserID][key.DeviceID] = make(map[string]json.RawMessage) + } + for keyID, keyJSON := range key.KeyJSON { + dst[key.UserID][key.DeviceID][keyID] = keyJSON + } + } +} diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index a626c66a6..7a0328bd7 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -39,4 +39,8 @@ type Database interface { // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected. // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice. DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) + + // ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key + // cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice. + ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) } diff --git a/keyserver/storage/postgres/one_time_keys_table.go b/keyserver/storage/postgres/one_time_keys_table.go index b8aee72bd..a9d05548b 100644 --- a/keyserver/storage/postgres/one_time_keys_table.go +++ b/keyserver/storage/postgres/one_time_keys_table.go @@ -52,11 +52,19 @@ const selectKeysSQL = "" + const selectKeysCountSQL = "" + "SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm" +const deleteOneTimeKeySQL = "" + + "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4" + +const selectKeyByAlgorithmSQL = "" + + "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1" + type oneTimeKeysStatements struct { - db *sql.DB - upsertKeysStmt *sql.Stmt - selectKeysStmt *sql.Stmt - selectKeysCountStmt *sql.Stmt + db *sql.DB + upsertKeysStmt *sql.Stmt + selectKeysStmt *sql.Stmt + selectKeysCountStmt *sql.Stmt + selectKeyByAlgorithmStmt *sql.Stmt + deleteOneTimeKeyStmt *sql.Stmt } func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { @@ -76,6 +84,12 @@ func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil { return nil, err } + if s.selectKeyByAlgorithmStmt, err = db.Prepare(selectKeyByAlgorithmSQL); err != nil { + return nil, err + } + if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil { + return nil, err + } return s, nil } @@ -141,3 +155,21 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api. return rows.Err() }) } + +func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( + ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string, +) (map[string]json.RawMessage, error) { + var keyID string + var keyJSON string + err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + _, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) + return map[string]json.RawMessage{ + algorithm + ":" + keyID: json.RawMessage(keyJSON), + }, err +} diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index d5ac6458a..156b5b415 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -19,6 +19,7 @@ import ( "database/sql" "encoding/json" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" ) @@ -48,3 +49,26 @@ func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceKeys) e func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) { return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs) } + +func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) { + var result []api.OneTimeKeys + err := sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + for userID, deviceToAlgo := range userToDeviceToAlgorithm { + for deviceID, algo := range deviceToAlgo { + keyJSON, err := d.OneTimeKeysTable.SelectAndDeleteOneTimeKey(ctx, txn, userID, deviceID, algo) + if err != nil { + return err + } + if keyJSON != nil { + result = append(result, api.OneTimeKeys{ + UserID: userID, + DeviceID: deviceID, + KeyJSON: keyJSON, + }) + } + } + } + return nil + }) + return result, err +} diff --git a/keyserver/storage/sqlite3/one_time_keys_table.go b/keyserver/storage/sqlite3/one_time_keys_table.go index 86e91268e..fecf533e6 100644 --- a/keyserver/storage/sqlite3/one_time_keys_table.go +++ b/keyserver/storage/sqlite3/one_time_keys_table.go @@ -52,11 +52,19 @@ const selectKeysSQL = "" + const selectKeysCountSQL = "" + "SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm" +const deleteOneTimeKeySQL = "" + + "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4" + +const selectKeyByAlgorithmSQL = "" + + "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1" + type oneTimeKeysStatements struct { - db *sql.DB - upsertKeysStmt *sql.Stmt - selectKeysStmt *sql.Stmt - selectKeysCountStmt *sql.Stmt + db *sql.DB + upsertKeysStmt *sql.Stmt + selectKeysStmt *sql.Stmt + selectKeysCountStmt *sql.Stmt + selectKeyByAlgorithmStmt *sql.Stmt + deleteOneTimeKeyStmt *sql.Stmt } func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { @@ -76,6 +84,12 @@ func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil { return nil, err } + if s.selectKeyByAlgorithmStmt, err = db.Prepare(selectKeyByAlgorithmSQL); err != nil { + return nil, err + } + if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil { + return nil, err + } return s, nil } @@ -141,3 +155,21 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api. return rows.Err() }) } + +func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( + ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string, +) (map[string]json.RawMessage, error) { + var keyID string + var keyJSON string + err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + _, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) + return map[string]json.RawMessage{ + algorithm + ":" + keyID: json.RawMessage(keyJSON), + }, err +} diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index 1f7f686b9..216be773b 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -16,6 +16,7 @@ package tables import ( "context" + "database/sql" "encoding/json" "github.com/matrix-org/dendrite/keyserver/api" @@ -24,6 +25,9 @@ import ( type OneTimeKeys interface { SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) + // SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON. + // Returns an empty map if the key does not exist. + SelectAndDeleteOneTimeKey(ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string) (map[string]json.RawMessage, error) } type DeviceKeys interface { diff --git a/sytest-whitelist b/sytest-whitelist index a3df4e0cb..f21432fbd 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -124,6 +124,7 @@ Should reject keys claiming to belong to a different user Can query device keys using POST Can query specific device keys using POST query for user with no keys returns empty key dict +Can claim one time key using POST Can add account data Can add account data to room Can get account data without syncing