Changes API progress

This commit is contained in:
SUMUKHA-PK 2019-10-24 18:24:28 +05:30
parent 94794485e0
commit d7664a1c96
6 changed files with 209 additions and 79 deletions

View file

@ -60,7 +60,8 @@ type KeyNotifier struct {
var keyProducer = &KeyNotifier{}
// UploadPKeys this function is for user upload his device key, and one-time-key to a limit at 50 set as default
// UploadPKeys enables the user to upload his device
// and one time keys with limit at 50 set as default
func UploadPKeys(
req *http.Request,
encryptionDB *storage.Database,
@ -78,7 +79,7 @@ func UploadPKeys(
&keySpecific,
userID, deviceID)
// numMap is algorithm-num map
numMap, ok := (QueryOneTimeKeys(
numMap, ok := (queryOneTimeKeys(
req.Context(),
TYPESUM,
userID,
@ -106,7 +107,7 @@ func UploadPKeys(
}
}
// QueryPKeys this function is for user query other's device key
// QueryPKeys enables the user to query for other devices's keys
func QueryPKeys(
req *http.Request,
encryptionDB *storage.Database,
@ -122,6 +123,7 @@ func QueryPKeys(
return *reqErr
}
// FED must return the keys from the other user
/*
federation consideration: when user id is in federation, a
query is needed to ask fed for keys.
@ -147,6 +149,11 @@ func QueryPKeys(
case <-make(chan interface{}):
// todo : here goes federation chan , still a mocked one
}
// probably some other better error to tell it timed out in FED
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: struct{}{},
}
}
// query one's device key from user corresponding to uid
@ -193,7 +200,7 @@ func QueryPKeys(
}
}
// ClaimOneTimeKeys claim for one time key that may be used in session exchange in olm encryption
// ClaimOneTimeKeys enables user to claim one time keys for sessions.
func ClaimOneTimeKeys(
req *http.Request,
encryptionDB *storage.Database,
@ -206,6 +213,7 @@ func ClaimOneTimeKeys(
return *reqErr
}
// not sure what FED should return here
/*
federation consideration: when user id is in federation, a query is needed to ask fed for keys
domain --------+ fed (keys)
@ -227,20 +235,25 @@ func ClaimOneTimeKeys(
case <-make(chan interface{}):
// todo : here goes federation chan , still a mocked one
}
// probably some other better error to tell it timed out in FED
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: struct{}{},
}
}
content := claimRq.ClaimDetail
for uid, detail := range content {
for deviceID, al := range detail {
var alTyp int
if strings.Contains(al, "signed") {
alTyp = ONETIMEKEYOBJECT
for deviceID, alg := range detail {
var algTyp int
if strings.Contains(alg, "signed") {
algTyp = ONETIMEKEYOBJECT
} else {
alTyp = ONETIMEKEYSTRING
algTyp = ONETIMEKEYSTRING
}
key, err := pickOne(req.Context(), *encryptionDB, uid, deviceID, al)
key, err := pickOne(req.Context(), *encryptionDB, uid, deviceID, alg)
if err != nil {
claimRp.Failures[uid] = fmt.Sprintf("%s:%s", "fail to get keys for device ", deviceID)
claimRp.Failures[uid] = fmt.Sprintf("%s: %s", "failed to get keys for device", deviceID)
}
claimRp.ClaimBody[uid] = make(map[string]map[string]interface{})
keyPreMap := claimRp.ClaimBody[uid]
@ -248,14 +261,14 @@ func ClaimOneTimeKeys(
if keymap == nil {
keymap = make(map[string]interface{})
}
switch alTyp {
switch algTyp {
case ONETIMEKEYSTRING:
keymap[fmt.Sprintf("%s:%s", al, key.KeyID)] = key.Key
keymap[fmt.Sprintf("%s:%s", alg, key.KeyID)] = key.Key
case ONETIMEKEYOBJECT:
sig := make(map[string]map[string]string)
sig[uid] = make(map[string]string)
sig[uid][fmt.Sprintf("%s:%s", "ed25519", deviceID)] = key.Signature
keymap[fmt.Sprintf("%s:%s", al, key.KeyID)] = types.KeyObject{Key: key.Key, Signature: sig}
keymap[fmt.Sprintf("%s:%s", alg, key.KeyID)] = types.KeyObject{Key: key.Key, Signature: sig}
}
claimRp.ClaimBody[uid][deviceID] = keymap
}
@ -267,13 +280,34 @@ func ClaimOneTimeKeys(
}
// ChangesInKeys returns the changes in the keys after last sync
func ChangesInKeys(req *http.Request,
// each user maintains a chain of the changes when provided by FED
func ChangesInKeys(
req *http.Request,
encryptionDB *storage.Database,
) util.JSONResponse {
// assuming federation has added keys to the DB,
// extracting from the DB here
// get from FED/Req
var readID int
var userID string
keyChanges, err := encryptionDB.GetKeyChanges(req.Context(), readID, userID)
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: struct{}{},
}
}
changesRes := types.ChangesResponse{}
changesRes.Changed = keyChanges.Changed
changesRes.Left = keyChanges.Left
// delete the extracted keys from the DB
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
JSON: changesRes,
}
}
@ -294,8 +328,8 @@ func checkUpload(req *types.UploadEncryptSpecific, typ int) bool {
return true
}
// QueryOneTimeKeys todo: complete this field through claim type
func QueryOneTimeKeys(
// queryOneTimeKeys todo: complete this field through claim type
func queryOneTimeKeys(
ctx context.Context,
typ int,
userID, deviceID string,
@ -320,7 +354,8 @@ func persistKeys(
userID,
deviceID string,
) (err error) {
// in order to persist keys , a check filtering duplicate should be processed
// in order to persist keys , a check,
// filtering the duplicates should be processed.
// true stands for counterparts are in request
// situation 1: only device keys
// situation 2: both device keys and one time keys
@ -333,11 +368,11 @@ func persistKeys(
return
}
if checkUpload(body, BODYONETIMEKEY) {
if err = bothKeyProcess(ctx, body, userID, deviceID, database, deviceKeys); err != nil {
if err = persistBothKeys(ctx, body, userID, deviceID, database, deviceKeys); err != nil {
return
}
} else {
if err = dkeyProcess(ctx, userID, deviceID, database, deviceKeys); err != nil {
if err = persistDeviceKeys(ctx, userID, deviceID, database, deviceKeys); err != nil {
return
}
}
@ -345,11 +380,11 @@ func persistKeys(
upnotify(userID)
} else {
if checkUpload(body, BODYONETIMEKEY) {
if err = otmKeyProcess(ctx, body, userID, deviceID, database); err != nil {
if err = persistOneTimeKeys(ctx, body, userID, deviceID, database); err != nil {
return
}
} else {
return errors.New("failed to touch keys")
return errors.New("failed to persist keys")
}
}
return err
@ -454,56 +489,27 @@ func presetDeviceKeysQueryMap(
return deviceKeysQueryMap
}
func bothKeyProcess(
func persistBothKeys(
ctx context.Context,
body *types.UploadEncryptSpecific,
userID, deviceID string,
database *storage.Database,
deviceKeys types.DeviceKeys,
) (err error) {
// insert one time keys firstly
onetimeKeys := body.OneTimeKey
for alKeyID, val := range onetimeKeys.KeyString {
al := (strings.Split(alKeyID, ":"))[0]
keyID := (strings.Split(alKeyID, ":"))[1]
keyInfo := val
keyStringTyp := ONETIMEKEYSTR
sig := ""
err = database.InsertKey(ctx, deviceID, userID, keyID, keyStringTyp, keyInfo, al, sig)
// insert one time keys
err = persistOneTimeKeys(ctx, body, userID, deviceID, database)
if err != nil {
return
}
}
for alKeyID, val := range onetimeKeys.KeyObject {
al := (strings.Split(alKeyID, ":"))[0]
keyID := (strings.Split(alKeyID, ":"))[1]
keyInfo := val.Key
keyObjectTyp := ONETIMEKEYSTR
sig := val.Signature[userID][fmt.Sprintf("%s:%s", "ed25519", deviceID)]
err = database.InsertKey(ctx, deviceID, userID, keyID, keyObjectTyp, keyInfo, al, sig)
if err != nil {
return
}
}
// insert device keys
keys := deviceKeys.Keys
sigs := deviceKeys.Signature
for alDevice, key := range keys {
al := (strings.Split(alDevice, ":"))[0]
keyTyp := DEVICEKEYSTR
keyInfo := key
keyID := ""
sig := sigs[userID][fmt.Sprintf("%s:%s", "ed25519", deviceID)]
err = database.InsertKey(
ctx, deviceID, userID, keyID, keyTyp, keyInfo, al, sig)
err = persistDeviceKeys(ctx, userID, deviceID, database, deviceKeys)
if err != nil {
return
}
}
return
}
func dkeyProcess(
func persistDeviceKeys(
ctx context.Context,
userID, deviceID string,
database *storage.Database,
@ -522,7 +528,7 @@ func dkeyProcess(
return
}
func otmKeyProcess(
func persistOneTimeKeys(
ctx context.Context,
body *types.UploadEncryptSpecific,
userID, deviceID string,

View file

@ -1,4 +1,4 @@
// Copyright 2018 Vector Creations Ltd
// Copyright FadeAce and Sumukha PK 2019
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -17,6 +17,7 @@ package storage
import (
"context"
"database/sql"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/encryptoapi/types"
@ -35,6 +36,15 @@ CREATE TABLE IF NOT EXISTS encrypt_keys (
signature TEXT NOT NULL
);
`
const changesTable = `
CREATE TABLE IF NOT EXISTS key_changes (
read_id INT NOT NULL,
user_id TEXT PRIMARY KEY NOT NULL,
neighbor_user_id TEXT NOT NULL,
status TEXT NOT NULL
);
`
const insertkeySQL = `
INSERT INTO encrypt_keys (device_id, user_id, key_id, key_type, key_info, algorithm, signature)
VALUES ($1, $2, $3, $4, $5, $6, $7)
@ -43,11 +53,11 @@ const selectkeySQL = `
SELECT user_id, device_id, key_id, key_type, key_info, algorithm, signature FROM encrypt_keys
WHERE user_id = $1 AND device_id = $2
`
const deleteSinglekeySQL = `
const selectSinglekeySQL = `
SELECT user_id, device_id, key_id, key_type, key_info, algorithm, signature FROM encrypt_keys
WHERE user_id = $1 AND device_id = $2 AND algorithm = $3
`
const selectSinglekeySQL = `
const deleteSinglekeySQL = `
DELETE FROM encrypt_keys
WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4
`
@ -63,6 +73,12 @@ const selectCountOneTimeKey = `
SELECT algorithm, COUNT(algorithm) FROM encrypt_keys WHERE user_id = $1 AND device_id = $2 AND key_type = 'one_time_key'
GROUP BY algorithm
`
const insertChangesSQL = `
INSERT INTO changesTable (read_id, user_id) VALUES ($1, $2)
`
const selectChangesSQL = `
SELECT read_id, user_id
`
type keyStatements struct {
insertKeyStmt *sql.Stmt
@ -72,6 +88,8 @@ type keyStatements struct {
selectSingleKeyStmt *sql.Stmt
deleteSingleKeyStmt *sql.Stmt
selectCountOneTimeKeyStmt *sql.Stmt
insertChangesStmt *sql.Stmt
selectChangesStmt *sql.Stmt
}
func (s *keyStatements) prepare(db *sql.DB) (err error) {
@ -79,6 +97,10 @@ func (s *keyStatements) prepare(db *sql.DB) (err error) {
if err != nil {
return
}
_, err = db.Exec(changesTable)
if err != nil {
return
}
if s.insertKeyStmt, err = db.Prepare(insertkeySQL); err != nil {
return
}
@ -91,19 +113,25 @@ func (s *keyStatements) prepare(db *sql.DB) (err error) {
if s.selectAllKeyStmt, err = db.Prepare(selectAllkeysSQL); err != nil {
return
}
if s.deleteSingleKeyStmt, err = db.Prepare(selectSinglekeySQL); err != nil {
if s.deleteSingleKeyStmt, err = db.Prepare(deleteSinglekeySQL); err != nil {
return
}
if s.selectSingleKeyStmt, err = db.Prepare(deleteSinglekeySQL); err != nil {
if s.selectSingleKeyStmt, err = db.Prepare(selectSinglekeySQL); err != nil {
return
}
if s.selectCountOneTimeKeyStmt, err = db.Prepare(selectCountOneTimeKey); err != nil {
return
}
if s.insertChangesStmt, err = db.Prepare(insertChangesSQL); err != nil {
return
}
if s.selectChangesStmt, err = db.Prepare(selectChangesSQL); err != nil {
return
}
return
}
// insert keys
// insertKeys inserts keys
func (s *keyStatements) insertKey(
ctx context.Context, txn *sql.Tx,
deviceID, userID, keyID, keyTyp, keyInfo, algorithm, signature string,
@ -113,18 +141,18 @@ func (s *keyStatements) insertKey(
return err
}
// select by user and device
// selectKey selects by user and device
func (s *keyStatements) selectKey(
ctx context.Context,
txn *sql.Tx,
deviceID, userID string,
) ([]types.KeyHolder, error) {
holders := []types.KeyHolder{}
stmt := common.TxStmt(txn, s.selectKeyStmt)
rows, err := stmt.QueryContext(ctx, userID, deviceID)
if err != nil {
return nil, err
}
holders := []types.KeyHolder{}
for rows.Next() {
single := &types.KeyHolder{}
if err = rows.Scan(
@ -141,10 +169,13 @@ func (s *keyStatements) selectKey(
holders = append(holders, *single)
}
err = rows.Close()
if err != nil {
return nil, err
}
return holders, err
}
// select single one for claim usage
// selectSingleKey selects single key for claim usage
func (s *keyStatements) selectSingleKey(
ctx context.Context,
userID, deviceID, algorithm string,
@ -170,7 +201,7 @@ func (s *keyStatements) selectSingleKey(
return holder, err
}
// select details by given an array of devices
// selectInKeys selects details based on an array of devices
func (s *keyStatements) selectInKeys(
ctx context.Context,
userID string,
@ -225,7 +256,7 @@ func injectKeyHolder(rows *sql.Rows, keyHolder []types.KeyHolder) (holders []typ
return
}
// select by user and device
// selectOneTimeKeyCount selects by user and device
func (s *keyStatements) selectOneTimeKeyCount(
ctx context.Context,
userID, deviceID string,
@ -249,3 +280,59 @@ func (s *keyStatements) selectOneTimeKeyCount(
err = rows.Close()
return holders, err
}
// insertChanges inserts into the changes table
func (s *keyStatements) insertChanges(
ctx context.Context, txn *sql.Tx,
readID int, userID string, status string,
) error {
stmt := common.TxStmt(txn, s.insertChangesStmt)
_, err := stmt.ExecContext(ctx, readID, userID, status)
return err
}
// selectChanges returns data from the DB
func (s *keyStatements) selectChanges(
ctx context.Context, txn *sql.Tx,
readID int, userID string,
) (types.KeyChanges, error) {
stmt := common.TxStmt(txn, s.selectChangesStmt)
rows, err := stmt.QueryContext(ctx, readID, userID)
if err != nil {
return types.KeyChanges{}, err
}
keyChanges := newKeyChanges()
for rows.Next() {
var rID, uID, nUID, status string
if err = rows.Scan(
&rID,
&uID,
&nUID,
&status,
); err != nil {
return types.KeyChanges{}, err
}
if keyChanges.UserID == "" {
keyChanges.UserID = uID
}
if status == "changed" {
keyChanges.Changed = append(keyChanges.Changed, nUID)
} else if status == "left" {
keyChanges.Left = append(keyChanges.Left, nUID)
}
}
err = rows.Close()
if err != nil {
return types.KeyChanges{}, err
}
return keyChanges, nil
}
func newKeyChanges() types.KeyChanges {
return types.KeyChanges{
UserID: "",
NeighborUserID: "",
Changed: []string{},
Left: []string{},
}
}

View file

@ -17,9 +17,10 @@ package storage
import (
"context"
"database/sql"
"strings"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/encryptoapi/types"
"strings"
)
// Database represents a presence database.
@ -129,3 +130,26 @@ func (d *Database) SyncOneTimeCount(
holder, err = d.keyStatements.selectOneTimeKeyCount(ctx, userID, deviceID)
return
}
// InsertChanges inserts the changes to the DB
func (d *Database) InsertChanges(
ctx context.Context,
readID int, userID, status string,
) error {
err := common.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.keyStatements.insertChanges(ctx, txn, readID, userID, status)
})
return err
}
// GetKeyChanges gets the changed keys from the DB
func (d *Database) GetKeyChanges(
ctx context.Context,
readID int, userID string,
) (keyChanges types.KeyChanges, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) (err error) {
keyChanges, err = d.keyStatements.selectChanges(ctx, txn, readID, userID)
return
})
return
}

View file

@ -31,3 +31,11 @@ type AlHolder struct {
DeviceID,
SupportedAlgorithm string
}
// KeyChanges holds the changed keys data
type KeyChanges struct {
UserID string
NeighborUserID string
Changed []string
Left []string
}

View file

@ -80,7 +80,8 @@ func Setup(apiMux *mux.Router, srp *sync.RequestPool, syncDB *storage.SyncServer
// vars := mux.Vars(req)
// eventType := vars["eventType"]
// txnID := vars["txnId"]
// return SendToDevice(req, device.UserID, syncDB, deviceDB, eventType, txnID, notifier)
// roomID := vars["roomId"]
// return SendToDevice(req, device.UserID, roomID, syncDB, deviceDB, eventType, txnID, notifier)
// }),
// ).Methods(http.MethodPut, http.MethodOptions)
}

View file

@ -2,6 +2,8 @@ package routing
// import (
// "encoding/json"
// "net/http"
// "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
// "github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
// "github.com/matrix-org/dendrite/clientapi/httputil"
@ -10,13 +12,13 @@ package routing
// "github.com/matrix-org/dendrite/syncapi/types"
// "github.com/matrix-org/gomatrixserverlib"
// "github.com/matrix-org/util"
// "net/http"
// )
// // SendToDevice this is a function for calling process of send-to-device messages those bypassed DAG
// func SendToDevice(
// req *http.Request,
// sender string,
// roomID string,
// syncDB *storage.SyncServerDatasource,
// deviceDB *devices.Database,
// eventType, txnID string,
@ -48,7 +50,7 @@ package routing
// Event: jsonBuffer,
// EventTyp: eventType,
// }
// var pos int64
// // var pos int64
// // wildcard all devices
// if device == "*" {
@ -58,7 +60,8 @@ package routing
// deviceCollection, err = deviceDB.GetDevicesByLocalpart(ctx, localpart)
// for _, val := range deviceCollection {
// pos, err = syncDB.InsertStdMessage(ctx, ev, txnID, uid, val.ID)
// notifier.OnNewEvent(nil, uid, types.StreamPosition(pos))
// // NEEDS MAJOR CHANGES
// // notifier.OnNewEvent(nil, roomID, uid, types.StreamPosition(pos))
// }
// if err != nil {
// return util.JSONResponse{
@ -78,7 +81,8 @@ package routing
// JSON: struct{}{},
// }
// }
// notifier.OnNewEvent(nil, uid, types.StreamPosition(pos))
// // NEEDS MAJOR CHANGES
// // notifier.OnNewEvent(nil, roomID, uid, types.StreamPosition(pos))
// }
// }