From d7664a1c96578f21b75074187b8d673b16378138 Mon Sep 17 00:00:00 2001 From: SUMUKHA-PK Date: Thu, 24 Oct 2019 18:24:28 +0530 Subject: [PATCH] Changes API progress --- encryptoapi/routing/keys.go | 130 +++++++++++----------- encryptoapi/storage/encrypt_keys_table.go | 109 ++++++++++++++++-- encryptoapi/storage/storage.go | 26 ++++- encryptoapi/types/storage.go | 8 ++ syncapi/routing/routing.go | 3 +- syncapi/routing/std.go | 12 +- 6 files changed, 209 insertions(+), 79 deletions(-) diff --git a/encryptoapi/routing/keys.go b/encryptoapi/routing/keys.go index 9995ffe5d..d912172d8 100644 --- a/encryptoapi/routing/keys.go +++ b/encryptoapi/routing/keys.go @@ -60,7 +60,8 @@ type KeyNotifier struct { var keyProducer = &KeyNotifier{} -// UploadPKeys this function is for user upload his device key, and one-time-key to a limit at 50 set as default +// UploadPKeys enables the user to upload his device +// and one time keys with limit at 50 set as default func UploadPKeys( req *http.Request, encryptionDB *storage.Database, @@ -78,7 +79,7 @@ func UploadPKeys( &keySpecific, userID, deviceID) // numMap is algorithm-num map - numMap, ok := (QueryOneTimeKeys( + numMap, ok := (queryOneTimeKeys( req.Context(), TYPESUM, userID, @@ -106,7 +107,7 @@ func UploadPKeys( } } -// QueryPKeys this function is for user query other's device key +// QueryPKeys enables the user to query for other devices's keys func QueryPKeys( req *http.Request, encryptionDB *storage.Database, @@ -122,6 +123,7 @@ func QueryPKeys( return *reqErr } + // FED must return the keys from the other user /* federation consideration: when user id is in federation, a query is needed to ask fed for keys. @@ -147,6 +149,11 @@ func QueryPKeys( case <-make(chan interface{}): // todo : here goes federation chan , still a mocked one } + // probably some other better error to tell it timed out in FED + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: struct{}{}, + } } // query one's device key from user corresponding to uid @@ -193,7 +200,7 @@ func QueryPKeys( } } -// ClaimOneTimeKeys claim for one time key that may be used in session exchange in olm encryption +// ClaimOneTimeKeys enables user to claim one time keys for sessions. func ClaimOneTimeKeys( req *http.Request, encryptionDB *storage.Database, @@ -206,6 +213,7 @@ func ClaimOneTimeKeys( return *reqErr } + // not sure what FED should return here /* federation consideration: when user id is in federation, a query is needed to ask fed for keys domain --------+ fed (keys) @@ -227,20 +235,25 @@ func ClaimOneTimeKeys( case <-make(chan interface{}): // todo : here goes federation chan , still a mocked one } + // probably some other better error to tell it timed out in FED + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: struct{}{}, + } } content := claimRq.ClaimDetail for uid, detail := range content { - for deviceID, al := range detail { - var alTyp int - if strings.Contains(al, "signed") { - alTyp = ONETIMEKEYOBJECT + for deviceID, alg := range detail { + var algTyp int + if strings.Contains(alg, "signed") { + algTyp = ONETIMEKEYOBJECT } else { - alTyp = ONETIMEKEYSTRING + algTyp = ONETIMEKEYSTRING } - key, err := pickOne(req.Context(), *encryptionDB, uid, deviceID, al) + key, err := pickOne(req.Context(), *encryptionDB, uid, deviceID, alg) if err != nil { - claimRp.Failures[uid] = fmt.Sprintf("%s:%s", "fail to get keys for device ", deviceID) + claimRp.Failures[uid] = fmt.Sprintf("%s: %s", "failed to get keys for device", deviceID) } claimRp.ClaimBody[uid] = make(map[string]map[string]interface{}) keyPreMap := claimRp.ClaimBody[uid] @@ -248,14 +261,14 @@ func ClaimOneTimeKeys( if keymap == nil { keymap = make(map[string]interface{}) } - switch alTyp { + switch algTyp { case ONETIMEKEYSTRING: - keymap[fmt.Sprintf("%s:%s", al, key.KeyID)] = key.Key + keymap[fmt.Sprintf("%s:%s", alg, key.KeyID)] = key.Key case ONETIMEKEYOBJECT: sig := make(map[string]map[string]string) sig[uid] = make(map[string]string) sig[uid][fmt.Sprintf("%s:%s", "ed25519", deviceID)] = key.Signature - keymap[fmt.Sprintf("%s:%s", al, key.KeyID)] = types.KeyObject{Key: key.Key, Signature: sig} + keymap[fmt.Sprintf("%s:%s", alg, key.KeyID)] = types.KeyObject{Key: key.Key, Signature: sig} } claimRp.ClaimBody[uid][deviceID] = keymap } @@ -267,13 +280,34 @@ func ClaimOneTimeKeys( } // ChangesInKeys returns the changes in the keys after last sync -func ChangesInKeys(req *http.Request, +// each user maintains a chain of the changes when provided by FED +func ChangesInKeys( + req *http.Request, encryptionDB *storage.Database, ) util.JSONResponse { + // assuming federation has added keys to the DB, + // extracting from the DB here + + // get from FED/Req + var readID int + var userID string + keyChanges, err := encryptionDB.GetKeyChanges(req.Context(), readID, userID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: struct{}{}, + } + } + + changesRes := types.ChangesResponse{} + changesRes.Changed = keyChanges.Changed + changesRes.Left = keyChanges.Left + + // delete the extracted keys from the DB return util.JSONResponse{ Code: http.StatusOK, - JSON: struct{}{}, + JSON: changesRes, } } @@ -294,8 +328,8 @@ func checkUpload(req *types.UploadEncryptSpecific, typ int) bool { return true } -// QueryOneTimeKeys todo: complete this field through claim type -func QueryOneTimeKeys( +// queryOneTimeKeys todo: complete this field through claim type +func queryOneTimeKeys( ctx context.Context, typ int, userID, deviceID string, @@ -320,7 +354,8 @@ func persistKeys( userID, deviceID string, ) (err error) { - // in order to persist keys , a check filtering duplicate should be processed + // in order to persist keys , a check, + // filtering the duplicates should be processed. // true stands for counterparts are in request // situation 1: only device keys // situation 2: both device keys and one time keys @@ -333,11 +368,11 @@ func persistKeys( return } if checkUpload(body, BODYONETIMEKEY) { - if err = bothKeyProcess(ctx, body, userID, deviceID, database, deviceKeys); err != nil { + if err = persistBothKeys(ctx, body, userID, deviceID, database, deviceKeys); err != nil { return } } else { - if err = dkeyProcess(ctx, userID, deviceID, database, deviceKeys); err != nil { + if err = persistDeviceKeys(ctx, userID, deviceID, database, deviceKeys); err != nil { return } } @@ -345,11 +380,11 @@ func persistKeys( upnotify(userID) } else { if checkUpload(body, BODYONETIMEKEY) { - if err = otmKeyProcess(ctx, body, userID, deviceID, database); err != nil { + if err = persistOneTimeKeys(ctx, body, userID, deviceID, database); err != nil { return } } else { - return errors.New("failed to touch keys") + return errors.New("failed to persist keys") } } return err @@ -454,56 +489,27 @@ func presetDeviceKeysQueryMap( return deviceKeysQueryMap } -func bothKeyProcess( +func persistBothKeys( ctx context.Context, body *types.UploadEncryptSpecific, userID, deviceID string, database *storage.Database, deviceKeys types.DeviceKeys, ) (err error) { - // insert one time keys firstly - onetimeKeys := body.OneTimeKey - for alKeyID, val := range onetimeKeys.KeyString { - al := (strings.Split(alKeyID, ":"))[0] - keyID := (strings.Split(alKeyID, ":"))[1] - keyInfo := val - keyStringTyp := ONETIMEKEYSTR - sig := "" - err = database.InsertKey(ctx, deviceID, userID, keyID, keyStringTyp, keyInfo, al, sig) - if err != nil { - return - } - } - for alKeyID, val := range onetimeKeys.KeyObject { - al := (strings.Split(alKeyID, ":"))[0] - keyID := (strings.Split(alKeyID, ":"))[1] - keyInfo := val.Key - keyObjectTyp := ONETIMEKEYSTR - sig := val.Signature[userID][fmt.Sprintf("%s:%s", "ed25519", deviceID)] - err = database.InsertKey(ctx, deviceID, userID, keyID, keyObjectTyp, keyInfo, al, sig) - if err != nil { - return - } + // insert one time keys + err = persistOneTimeKeys(ctx, body, userID, deviceID, database) + if err != nil { + return } // insert device keys - keys := deviceKeys.Keys - sigs := deviceKeys.Signature - for alDevice, key := range keys { - al := (strings.Split(alDevice, ":"))[0] - keyTyp := DEVICEKEYSTR - keyInfo := key - keyID := "" - sig := sigs[userID][fmt.Sprintf("%s:%s", "ed25519", deviceID)] - err = database.InsertKey( - ctx, deviceID, userID, keyID, keyTyp, keyInfo, al, sig) - if err != nil { - return - } + err = persistDeviceKeys(ctx, userID, deviceID, database, deviceKeys) + if err != nil { + return } return } -func dkeyProcess( +func persistDeviceKeys( ctx context.Context, userID, deviceID string, database *storage.Database, @@ -522,7 +528,7 @@ func dkeyProcess( return } -func otmKeyProcess( +func persistOneTimeKeys( ctx context.Context, body *types.UploadEncryptSpecific, userID, deviceID string, diff --git a/encryptoapi/storage/encrypt_keys_table.go b/encryptoapi/storage/encrypt_keys_table.go index 97cb6196a..66bf24066 100644 --- a/encryptoapi/storage/encrypt_keys_table.go +++ b/encryptoapi/storage/encrypt_keys_table.go @@ -1,4 +1,4 @@ -// Copyright 2018 Vector Creations Ltd +// Copyright FadeAce and Sumukha PK 2019 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package storage import ( "context" "database/sql" + "github.com/lib/pq" "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/encryptoapi/types" @@ -35,6 +36,15 @@ CREATE TABLE IF NOT EXISTS encrypt_keys ( signature TEXT NOT NULL ); ` +const changesTable = ` +CREATE TABLE IF NOT EXISTS key_changes ( + read_id INT NOT NULL, + user_id TEXT PRIMARY KEY NOT NULL, + neighbor_user_id TEXT NOT NULL, + status TEXT NOT NULL +); +` + const insertkeySQL = ` INSERT INTO encrypt_keys (device_id, user_id, key_id, key_type, key_info, algorithm, signature) VALUES ($1, $2, $3, $4, $5, $6, $7) @@ -43,11 +53,11 @@ const selectkeySQL = ` SELECT user_id, device_id, key_id, key_type, key_info, algorithm, signature FROM encrypt_keys WHERE user_id = $1 AND device_id = $2 ` -const deleteSinglekeySQL = ` +const selectSinglekeySQL = ` SELECT user_id, device_id, key_id, key_type, key_info, algorithm, signature FROM encrypt_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 ` -const selectSinglekeySQL = ` +const deleteSinglekeySQL = ` DELETE FROM encrypt_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4 ` @@ -63,6 +73,12 @@ const selectCountOneTimeKey = ` SELECT algorithm, COUNT(algorithm) FROM encrypt_keys WHERE user_id = $1 AND device_id = $2 AND key_type = 'one_time_key' GROUP BY algorithm ` +const insertChangesSQL = ` +INSERT INTO changesTable (read_id, user_id) VALUES ($1, $2) +` +const selectChangesSQL = ` +SELECT read_id, user_id +` type keyStatements struct { insertKeyStmt *sql.Stmt @@ -72,6 +88,8 @@ type keyStatements struct { selectSingleKeyStmt *sql.Stmt deleteSingleKeyStmt *sql.Stmt selectCountOneTimeKeyStmt *sql.Stmt + insertChangesStmt *sql.Stmt + selectChangesStmt *sql.Stmt } func (s *keyStatements) prepare(db *sql.DB) (err error) { @@ -79,6 +97,10 @@ func (s *keyStatements) prepare(db *sql.DB) (err error) { if err != nil { return } + _, err = db.Exec(changesTable) + if err != nil { + return + } if s.insertKeyStmt, err = db.Prepare(insertkeySQL); err != nil { return } @@ -91,19 +113,25 @@ func (s *keyStatements) prepare(db *sql.DB) (err error) { if s.selectAllKeyStmt, err = db.Prepare(selectAllkeysSQL); err != nil { return } - if s.deleteSingleKeyStmt, err = db.Prepare(selectSinglekeySQL); err != nil { + if s.deleteSingleKeyStmt, err = db.Prepare(deleteSinglekeySQL); err != nil { return } - if s.selectSingleKeyStmt, err = db.Prepare(deleteSinglekeySQL); err != nil { + if s.selectSingleKeyStmt, err = db.Prepare(selectSinglekeySQL); err != nil { return } if s.selectCountOneTimeKeyStmt, err = db.Prepare(selectCountOneTimeKey); err != nil { return } + if s.insertChangesStmt, err = db.Prepare(insertChangesSQL); err != nil { + return + } + if s.selectChangesStmt, err = db.Prepare(selectChangesSQL); err != nil { + return + } return } -// insert keys +// insertKeys inserts keys func (s *keyStatements) insertKey( ctx context.Context, txn *sql.Tx, deviceID, userID, keyID, keyTyp, keyInfo, algorithm, signature string, @@ -113,18 +141,18 @@ func (s *keyStatements) insertKey( return err } -// select by user and device +// selectKey selects by user and device func (s *keyStatements) selectKey( ctx context.Context, txn *sql.Tx, deviceID, userID string, ) ([]types.KeyHolder, error) { - holders := []types.KeyHolder{} stmt := common.TxStmt(txn, s.selectKeyStmt) rows, err := stmt.QueryContext(ctx, userID, deviceID) if err != nil { return nil, err } + holders := []types.KeyHolder{} for rows.Next() { single := &types.KeyHolder{} if err = rows.Scan( @@ -141,10 +169,13 @@ func (s *keyStatements) selectKey( holders = append(holders, *single) } err = rows.Close() + if err != nil { + return nil, err + } return holders, err } -// select single one for claim usage +// selectSingleKey selects single key for claim usage func (s *keyStatements) selectSingleKey( ctx context.Context, userID, deviceID, algorithm string, @@ -170,7 +201,7 @@ func (s *keyStatements) selectSingleKey( return holder, err } -// select details by given an array of devices +// selectInKeys selects details based on an array of devices func (s *keyStatements) selectInKeys( ctx context.Context, userID string, @@ -225,7 +256,7 @@ func injectKeyHolder(rows *sql.Rows, keyHolder []types.KeyHolder) (holders []typ return } -// select by user and device +// selectOneTimeKeyCount selects by user and device func (s *keyStatements) selectOneTimeKeyCount( ctx context.Context, userID, deviceID string, @@ -249,3 +280,59 @@ func (s *keyStatements) selectOneTimeKeyCount( err = rows.Close() return holders, err } + +// insertChanges inserts into the changes table +func (s *keyStatements) insertChanges( + ctx context.Context, txn *sql.Tx, + readID int, userID string, status string, +) error { + stmt := common.TxStmt(txn, s.insertChangesStmt) + _, err := stmt.ExecContext(ctx, readID, userID, status) + return err +} + +// selectChanges returns data from the DB +func (s *keyStatements) selectChanges( + ctx context.Context, txn *sql.Tx, + readID int, userID string, +) (types.KeyChanges, error) { + stmt := common.TxStmt(txn, s.selectChangesStmt) + rows, err := stmt.QueryContext(ctx, readID, userID) + if err != nil { + return types.KeyChanges{}, err + } + keyChanges := newKeyChanges() + for rows.Next() { + var rID, uID, nUID, status string + if err = rows.Scan( + &rID, + &uID, + &nUID, + &status, + ); err != nil { + return types.KeyChanges{}, err + } + if keyChanges.UserID == "" { + keyChanges.UserID = uID + } + if status == "changed" { + keyChanges.Changed = append(keyChanges.Changed, nUID) + } else if status == "left" { + keyChanges.Left = append(keyChanges.Left, nUID) + } + } + err = rows.Close() + if err != nil { + return types.KeyChanges{}, err + } + return keyChanges, nil +} + +func newKeyChanges() types.KeyChanges { + return types.KeyChanges{ + UserID: "", + NeighborUserID: "", + Changed: []string{}, + Left: []string{}, + } +} diff --git a/encryptoapi/storage/storage.go b/encryptoapi/storage/storage.go index 372221630..eeb90965e 100644 --- a/encryptoapi/storage/storage.go +++ b/encryptoapi/storage/storage.go @@ -17,9 +17,10 @@ package storage import ( "context" "database/sql" + "strings" + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/encryptoapi/types" - "strings" ) // Database represents a presence database. @@ -129,3 +130,26 @@ func (d *Database) SyncOneTimeCount( holder, err = d.keyStatements.selectOneTimeKeyCount(ctx, userID, deviceID) return } + +// InsertChanges inserts the changes to the DB +func (d *Database) InsertChanges( + ctx context.Context, + readID int, userID, status string, +) error { + err := common.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.keyStatements.insertChanges(ctx, txn, readID, userID, status) + }) + return err +} + +// GetKeyChanges gets the changed keys from the DB +func (d *Database) GetKeyChanges( + ctx context.Context, + readID int, userID string, +) (keyChanges types.KeyChanges, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) (err error) { + keyChanges, err = d.keyStatements.selectChanges(ctx, txn, readID, userID) + return + }) + return +} diff --git a/encryptoapi/types/storage.go b/encryptoapi/types/storage.go index c3d098398..472639e8d 100644 --- a/encryptoapi/types/storage.go +++ b/encryptoapi/types/storage.go @@ -31,3 +31,11 @@ type AlHolder struct { DeviceID, SupportedAlgorithm string } + +// KeyChanges holds the changed keys data +type KeyChanges struct { + UserID string + NeighborUserID string + Changed []string + Left []string +} diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index 0c7376da2..98cba7efb 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -80,7 +80,8 @@ func Setup(apiMux *mux.Router, srp *sync.RequestPool, syncDB *storage.SyncServer // vars := mux.Vars(req) // eventType := vars["eventType"] // txnID := vars["txnId"] - // return SendToDevice(req, device.UserID, syncDB, deviceDB, eventType, txnID, notifier) + // roomID := vars["roomId"] + // return SendToDevice(req, device.UserID, roomID, syncDB, deviceDB, eventType, txnID, notifier) // }), // ).Methods(http.MethodPut, http.MethodOptions) } diff --git a/syncapi/routing/std.go b/syncapi/routing/std.go index 3e78e6bad..ea5308265 100644 --- a/syncapi/routing/std.go +++ b/syncapi/routing/std.go @@ -2,6 +2,8 @@ package routing // import ( // "encoding/json" +// "net/http" + // "github.com/matrix-org/dendrite/clientapi/auth/authtypes" // "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" // "github.com/matrix-org/dendrite/clientapi/httputil" @@ -10,13 +12,13 @@ package routing // "github.com/matrix-org/dendrite/syncapi/types" // "github.com/matrix-org/gomatrixserverlib" // "github.com/matrix-org/util" -// "net/http" // ) // // SendToDevice this is a function for calling process of send-to-device messages those bypassed DAG // func SendToDevice( // req *http.Request, // sender string, +// roomID string, // syncDB *storage.SyncServerDatasource, // deviceDB *devices.Database, // eventType, txnID string, @@ -48,7 +50,7 @@ package routing // Event: jsonBuffer, // EventTyp: eventType, // } -// var pos int64 +// // var pos int64 // // wildcard all devices // if device == "*" { @@ -58,7 +60,8 @@ package routing // deviceCollection, err = deviceDB.GetDevicesByLocalpart(ctx, localpart) // for _, val := range deviceCollection { // pos, err = syncDB.InsertStdMessage(ctx, ev, txnID, uid, val.ID) -// notifier.OnNewEvent(nil, uid, types.StreamPosition(pos)) +// // NEEDS MAJOR CHANGES +// // notifier.OnNewEvent(nil, roomID, uid, types.StreamPosition(pos)) // } // if err != nil { // return util.JSONResponse{ @@ -78,7 +81,8 @@ package routing // JSON: struct{}{}, // } // } -// notifier.OnNewEvent(nil, uid, types.StreamPosition(pos)) +// // NEEDS MAJOR CHANGES +// // notifier.OnNewEvent(nil, roomID, uid, types.StreamPosition(pos)) // } // }