Implement claiming one-time keys locally (#1210)

* Add API shape for claiming keys

* Implement claiming one-time keys locally

Fairly boring, nothing too special going on.
This commit is contained in:
Kegsay 2020-07-21 14:47:53 +01:00 committed by GitHub
parent d76eb1b994
commit 1d72ce8b7a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 202 additions and 8 deletions

View file

@ -117,3 +117,40 @@ func QueryKeys(req *http.Request, keyAPI api.KeyInternalAPI) util.JSONResponse {
}, },
} }
} }
type claimKeysRequest struct {
TimeoutMS int `json:"timeout"`
// The keys to be claimed. A map from user ID, to a map from device ID to algorithm name.
OneTimeKeys map[string]map[string]string `json:"one_time_keys"`
}
func (r *claimKeysRequest) GetTimeout() time.Duration {
if r.TimeoutMS == 0 {
return 10 * time.Second
}
return time.Duration(r.TimeoutMS) * time.Millisecond
}
func ClaimKeys(req *http.Request, keyAPI api.KeyInternalAPI) util.JSONResponse {
var r claimKeysRequest
resErr := httputil.UnmarshalJSONRequest(req, &r)
if resErr != nil {
return *resErr
}
claimRes := api.PerformClaimKeysResponse{}
keyAPI.PerformClaimKeys(req.Context(), &api.PerformClaimKeysRequest{
OneTimeKeys: r.OneTimeKeys,
Timeout: r.GetTimeout(),
}, &claimRes)
if claimRes.Error != nil {
util.GetLogger(req.Context()).WithError(claimRes.Error).Error("failed to PerformClaimKeys")
return jsonerror.InternalServerError()
}
return util.JSONResponse{
Code: 200,
JSON: map[string]interface{}{
"one_time_keys": claimRes.OneTimeKeys,
"failures": claimRes.Failures,
},
}
}

View file

@ -714,4 +714,9 @@ func Setup(
return QueryKeys(req, keyAPI) return QueryKeys(req, keyAPI)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/keys/claim",
httputil.MakeAuthAPI("keys_claim", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return ClaimKeys(req, keyAPI)
}),
).Methods(http.MethodPost, http.MethodOptions)
} }

View file

@ -23,6 +23,7 @@ import (
type KeyInternalAPI interface { type KeyInternalAPI interface {
PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse)
// PerformClaimKeys claims one-time keys for use in pre-key messages
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse)
QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse)
} }
@ -102,9 +103,17 @@ func (r *PerformUploadKeysResponse) KeyError(userID, deviceID string, err *KeyEr
} }
type PerformClaimKeysRequest struct { type PerformClaimKeysRequest struct {
// Map of user_id to device_id to algorithm name
OneTimeKeys map[string]map[string]string
Timeout time.Duration
} }
type PerformClaimKeysResponse struct { type PerformClaimKeysResponse struct {
// Map of user_id to device_id to algorithm:key_id to key JSON
OneTimeKeys map[string]map[string]map[string]json.RawMessage
// Map of remote server domain to error JSON
Failures map[string]interface{}
// Set if there was a fatal error processing this action
Error *KeyError Error *KeyError
} }

View file

@ -37,9 +37,39 @@ func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perform
a.uploadDeviceKeys(ctx, req, res) a.uploadDeviceKeys(ctx, req, res)
a.uploadOneTimeKeys(ctx, req, res) a.uploadOneTimeKeys(ctx, req, res)
} }
func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) { func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) {
res.OneTimeKeys = make(map[string]map[string]map[string]json.RawMessage)
res.Failures = make(map[string]interface{})
// wrap request map in a top-level by-domain map
domainToDeviceKeys := make(map[string]map[string]map[string]string)
for userID, val := range req.OneTimeKeys {
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
continue // ignore invalid users
}
nested, ok := domainToDeviceKeys[string(serverName)]
if !ok {
nested = make(map[string]map[string]string)
}
nested[userID] = val
domainToDeviceKeys[string(serverName)] = nested
}
// claim local keys
if local, ok := domainToDeviceKeys[string(a.ThisServer)]; ok {
keys, err := a.DB.ClaimKeys(ctx, local)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to ClaimKeys locally: %s", err),
}
}
mergeInto(res.OneTimeKeys, keys)
delete(domainToDeviceKeys, string(a.ThisServer))
}
// TODO: claim remote keys
} }
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) { func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) {
res.DeviceKeys = make(map[string]map[string]json.RawMessage) res.DeviceKeys = make(map[string]map[string]json.RawMessage)
res.Failures = make(map[string]interface{}) res.Failures = make(map[string]interface{})
@ -166,3 +196,19 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform
func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceKeys) { func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceKeys) {
// TODO // TODO
} }
func mergeInto(dst map[string]map[string]map[string]json.RawMessage, src []api.OneTimeKeys) {
for _, key := range src {
_, ok := dst[key.UserID]
if !ok {
dst[key.UserID] = make(map[string]map[string]json.RawMessage)
}
_, ok = dst[key.UserID][key.DeviceID]
if !ok {
dst[key.UserID][key.DeviceID] = make(map[string]json.RawMessage)
}
for keyID, keyJSON := range key.KeyJSON {
dst[key.UserID][key.DeviceID][keyID] = keyJSON
}
}
}

View file

@ -39,4 +39,8 @@ type Database interface {
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected. // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice. // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error)
// ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key
// cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.
ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error)
} }

View file

@ -52,11 +52,19 @@ const selectKeysSQL = "" +
const selectKeysCountSQL = "" + const selectKeysCountSQL = "" +
"SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm" "SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm"
const deleteOneTimeKeySQL = "" +
"DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4"
const selectKeyByAlgorithmSQL = "" +
"SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1"
type oneTimeKeysStatements struct { type oneTimeKeysStatements struct {
db *sql.DB db *sql.DB
upsertKeysStmt *sql.Stmt upsertKeysStmt *sql.Stmt
selectKeysStmt *sql.Stmt selectKeysStmt *sql.Stmt
selectKeysCountStmt *sql.Stmt selectKeysCountStmt *sql.Stmt
selectKeyByAlgorithmStmt *sql.Stmt
deleteOneTimeKeyStmt *sql.Stmt
} }
func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
@ -76,6 +84,12 @@ func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil { if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil {
return nil, err return nil, err
} }
if s.selectKeyByAlgorithmStmt, err = db.Prepare(selectKeyByAlgorithmSQL); err != nil {
return nil, err
}
if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
@ -141,3 +155,21 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.
return rows.Err() return rows.Err()
}) })
} }
func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
) (map[string]json.RawMessage, error) {
var keyID string
var keyJSON string
err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
_, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
return map[string]json.RawMessage{
algorithm + ":" + keyID: json.RawMessage(keyJSON),
}, err
}

View file

@ -19,6 +19,7 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/keyserver/storage/tables"
) )
@ -48,3 +49,26 @@ func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceKeys) e
func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) { func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) {
return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs) return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs)
} }
func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) {
var result []api.OneTimeKeys
err := sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
for userID, deviceToAlgo := range userToDeviceToAlgorithm {
for deviceID, algo := range deviceToAlgo {
keyJSON, err := d.OneTimeKeysTable.SelectAndDeleteOneTimeKey(ctx, txn, userID, deviceID, algo)
if err != nil {
return err
}
if keyJSON != nil {
result = append(result, api.OneTimeKeys{
UserID: userID,
DeviceID: deviceID,
KeyJSON: keyJSON,
})
}
}
}
return nil
})
return result, err
}

View file

@ -52,11 +52,19 @@ const selectKeysSQL = "" +
const selectKeysCountSQL = "" + const selectKeysCountSQL = "" +
"SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm" "SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm"
const deleteOneTimeKeySQL = "" +
"DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4"
const selectKeyByAlgorithmSQL = "" +
"SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1"
type oneTimeKeysStatements struct { type oneTimeKeysStatements struct {
db *sql.DB db *sql.DB
upsertKeysStmt *sql.Stmt upsertKeysStmt *sql.Stmt
selectKeysStmt *sql.Stmt selectKeysStmt *sql.Stmt
selectKeysCountStmt *sql.Stmt selectKeysCountStmt *sql.Stmt
selectKeyByAlgorithmStmt *sql.Stmt
deleteOneTimeKeyStmt *sql.Stmt
} }
func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
@ -76,6 +84,12 @@ func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil { if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil {
return nil, err return nil, err
} }
if s.selectKeyByAlgorithmStmt, err = db.Prepare(selectKeyByAlgorithmSQL); err != nil {
return nil, err
}
if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
@ -141,3 +155,21 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.
return rows.Err() return rows.Err()
}) })
} }
func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
) (map[string]json.RawMessage, error) {
var keyID string
var keyJSON string
err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
_, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
return map[string]json.RawMessage{
algorithm + ":" + keyID: json.RawMessage(keyJSON),
}, err
}

View file

@ -16,6 +16,7 @@ package tables
import ( import (
"context" "context"
"database/sql"
"encoding/json" "encoding/json"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
@ -24,6 +25,9 @@ import (
type OneTimeKeys interface { type OneTimeKeys interface {
SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
// SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON.
// Returns an empty map if the key does not exist.
SelectAndDeleteOneTimeKey(ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string) (map[string]json.RawMessage, error)
} }
type DeviceKeys interface { type DeviceKeys interface {

View file

@ -124,6 +124,7 @@ Should reject keys claiming to belong to a different user
Can query device keys using POST Can query device keys using POST
Can query specific device keys using POST Can query specific device keys using POST
query for user with no keys returns empty key dict query for user with no keys returns empty key dict
Can claim one time key using POST
Can add account data Can add account data
Can add account data to room Can add account data to room
Can get account data without syncing Can get account data without syncing