mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-16 11:23:11 -06:00
Changes API progress
This commit is contained in:
parent
94794485e0
commit
d7664a1c96
|
|
@ -60,7 +60,8 @@ type KeyNotifier struct {
|
||||||
|
|
||||||
var keyProducer = &KeyNotifier{}
|
var keyProducer = &KeyNotifier{}
|
||||||
|
|
||||||
// UploadPKeys this function is for user upload his device key, and one-time-key to a limit at 50 set as default
|
// UploadPKeys enables the user to upload his device
|
||||||
|
// and one time keys with limit at 50 set as default
|
||||||
func UploadPKeys(
|
func UploadPKeys(
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
encryptionDB *storage.Database,
|
encryptionDB *storage.Database,
|
||||||
|
|
@ -78,7 +79,7 @@ func UploadPKeys(
|
||||||
&keySpecific,
|
&keySpecific,
|
||||||
userID, deviceID)
|
userID, deviceID)
|
||||||
// numMap is algorithm-num map
|
// numMap is algorithm-num map
|
||||||
numMap, ok := (QueryOneTimeKeys(
|
numMap, ok := (queryOneTimeKeys(
|
||||||
req.Context(),
|
req.Context(),
|
||||||
TYPESUM,
|
TYPESUM,
|
||||||
userID,
|
userID,
|
||||||
|
|
@ -106,7 +107,7 @@ func UploadPKeys(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryPKeys this function is for user query other's device key
|
// QueryPKeys enables the user to query for other devices's keys
|
||||||
func QueryPKeys(
|
func QueryPKeys(
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
encryptionDB *storage.Database,
|
encryptionDB *storage.Database,
|
||||||
|
|
@ -122,6 +123,7 @@ func QueryPKeys(
|
||||||
return *reqErr
|
return *reqErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FED must return the keys from the other user
|
||||||
/*
|
/*
|
||||||
federation consideration: when user id is in federation, a
|
federation consideration: when user id is in federation, a
|
||||||
query is needed to ask fed for keys.
|
query is needed to ask fed for keys.
|
||||||
|
|
@ -147,6 +149,11 @@ func QueryPKeys(
|
||||||
case <-make(chan interface{}):
|
case <-make(chan interface{}):
|
||||||
// todo : here goes federation chan , still a mocked one
|
// todo : here goes federation chan , still a mocked one
|
||||||
}
|
}
|
||||||
|
// probably some other better error to tell it timed out in FED
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
JSON: struct{}{},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// query one's device key from user corresponding to uid
|
// query one's device key from user corresponding to uid
|
||||||
|
|
@ -193,7 +200,7 @@ func QueryPKeys(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClaimOneTimeKeys claim for one time key that may be used in session exchange in olm encryption
|
// ClaimOneTimeKeys enables user to claim one time keys for sessions.
|
||||||
func ClaimOneTimeKeys(
|
func ClaimOneTimeKeys(
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
encryptionDB *storage.Database,
|
encryptionDB *storage.Database,
|
||||||
|
|
@ -206,6 +213,7 @@ func ClaimOneTimeKeys(
|
||||||
return *reqErr
|
return *reqErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// not sure what FED should return here
|
||||||
/*
|
/*
|
||||||
federation consideration: when user id is in federation, a query is needed to ask fed for keys
|
federation consideration: when user id is in federation, a query is needed to ask fed for keys
|
||||||
domain --------+ fed (keys)
|
domain --------+ fed (keys)
|
||||||
|
|
@ -227,20 +235,25 @@ func ClaimOneTimeKeys(
|
||||||
case <-make(chan interface{}):
|
case <-make(chan interface{}):
|
||||||
// todo : here goes federation chan , still a mocked one
|
// todo : here goes federation chan , still a mocked one
|
||||||
}
|
}
|
||||||
|
// probably some other better error to tell it timed out in FED
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
JSON: struct{}{},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
content := claimRq.ClaimDetail
|
content := claimRq.ClaimDetail
|
||||||
for uid, detail := range content {
|
for uid, detail := range content {
|
||||||
for deviceID, al := range detail {
|
for deviceID, alg := range detail {
|
||||||
var alTyp int
|
var algTyp int
|
||||||
if strings.Contains(al, "signed") {
|
if strings.Contains(alg, "signed") {
|
||||||
alTyp = ONETIMEKEYOBJECT
|
algTyp = ONETIMEKEYOBJECT
|
||||||
} else {
|
} else {
|
||||||
alTyp = ONETIMEKEYSTRING
|
algTyp = ONETIMEKEYSTRING
|
||||||
}
|
}
|
||||||
key, err := pickOne(req.Context(), *encryptionDB, uid, deviceID, al)
|
key, err := pickOne(req.Context(), *encryptionDB, uid, deviceID, alg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
claimRp.Failures[uid] = fmt.Sprintf("%s:%s", "fail to get keys for device ", deviceID)
|
claimRp.Failures[uid] = fmt.Sprintf("%s: %s", "failed to get keys for device", deviceID)
|
||||||
}
|
}
|
||||||
claimRp.ClaimBody[uid] = make(map[string]map[string]interface{})
|
claimRp.ClaimBody[uid] = make(map[string]map[string]interface{})
|
||||||
keyPreMap := claimRp.ClaimBody[uid]
|
keyPreMap := claimRp.ClaimBody[uid]
|
||||||
|
|
@ -248,14 +261,14 @@ func ClaimOneTimeKeys(
|
||||||
if keymap == nil {
|
if keymap == nil {
|
||||||
keymap = make(map[string]interface{})
|
keymap = make(map[string]interface{})
|
||||||
}
|
}
|
||||||
switch alTyp {
|
switch algTyp {
|
||||||
case ONETIMEKEYSTRING:
|
case ONETIMEKEYSTRING:
|
||||||
keymap[fmt.Sprintf("%s:%s", al, key.KeyID)] = key.Key
|
keymap[fmt.Sprintf("%s:%s", alg, key.KeyID)] = key.Key
|
||||||
case ONETIMEKEYOBJECT:
|
case ONETIMEKEYOBJECT:
|
||||||
sig := make(map[string]map[string]string)
|
sig := make(map[string]map[string]string)
|
||||||
sig[uid] = make(map[string]string)
|
sig[uid] = make(map[string]string)
|
||||||
sig[uid][fmt.Sprintf("%s:%s", "ed25519", deviceID)] = key.Signature
|
sig[uid][fmt.Sprintf("%s:%s", "ed25519", deviceID)] = key.Signature
|
||||||
keymap[fmt.Sprintf("%s:%s", al, key.KeyID)] = types.KeyObject{Key: key.Key, Signature: sig}
|
keymap[fmt.Sprintf("%s:%s", alg, key.KeyID)] = types.KeyObject{Key: key.Key, Signature: sig}
|
||||||
}
|
}
|
||||||
claimRp.ClaimBody[uid][deviceID] = keymap
|
claimRp.ClaimBody[uid][deviceID] = keymap
|
||||||
}
|
}
|
||||||
|
|
@ -267,13 +280,34 @@ func ClaimOneTimeKeys(
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChangesInKeys returns the changes in the keys after last sync
|
// ChangesInKeys returns the changes in the keys after last sync
|
||||||
func ChangesInKeys(req *http.Request,
|
// each user maintains a chain of the changes when provided by FED
|
||||||
|
func ChangesInKeys(
|
||||||
|
req *http.Request,
|
||||||
encryptionDB *storage.Database,
|
encryptionDB *storage.Database,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
|
// assuming federation has added keys to the DB,
|
||||||
|
// extracting from the DB here
|
||||||
|
|
||||||
|
// get from FED/Req
|
||||||
|
var readID int
|
||||||
|
var userID string
|
||||||
|
keyChanges, err := encryptionDB.GetKeyChanges(req.Context(), readID, userID)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
JSON: struct{}{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
changesRes := types.ChangesResponse{}
|
||||||
|
changesRes.Changed = keyChanges.Changed
|
||||||
|
changesRes.Left = keyChanges.Left
|
||||||
|
|
||||||
|
// delete the extracted keys from the DB
|
||||||
|
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
JSON: struct{}{},
|
JSON: changesRes,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -294,8 +328,8 @@ func checkUpload(req *types.UploadEncryptSpecific, typ int) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryOneTimeKeys todo: complete this field through claim type
|
// queryOneTimeKeys todo: complete this field through claim type
|
||||||
func QueryOneTimeKeys(
|
func queryOneTimeKeys(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
typ int,
|
typ int,
|
||||||
userID, deviceID string,
|
userID, deviceID string,
|
||||||
|
|
@ -320,7 +354,8 @@ func persistKeys(
|
||||||
userID,
|
userID,
|
||||||
deviceID string,
|
deviceID string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
// in order to persist keys , a check filtering duplicate should be processed
|
// in order to persist keys , a check,
|
||||||
|
// filtering the duplicates should be processed.
|
||||||
// true stands for counterparts are in request
|
// true stands for counterparts are in request
|
||||||
// situation 1: only device keys
|
// situation 1: only device keys
|
||||||
// situation 2: both device keys and one time keys
|
// situation 2: both device keys and one time keys
|
||||||
|
|
@ -333,11 +368,11 @@ func persistKeys(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if checkUpload(body, BODYONETIMEKEY) {
|
if checkUpload(body, BODYONETIMEKEY) {
|
||||||
if err = bothKeyProcess(ctx, body, userID, deviceID, database, deviceKeys); err != nil {
|
if err = persistBothKeys(ctx, body, userID, deviceID, database, deviceKeys); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err = dkeyProcess(ctx, userID, deviceID, database, deviceKeys); err != nil {
|
if err = persistDeviceKeys(ctx, userID, deviceID, database, deviceKeys); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -345,11 +380,11 @@ func persistKeys(
|
||||||
upnotify(userID)
|
upnotify(userID)
|
||||||
} else {
|
} else {
|
||||||
if checkUpload(body, BODYONETIMEKEY) {
|
if checkUpload(body, BODYONETIMEKEY) {
|
||||||
if err = otmKeyProcess(ctx, body, userID, deviceID, database); err != nil {
|
if err = persistOneTimeKeys(ctx, body, userID, deviceID, database); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return errors.New("failed to touch keys")
|
return errors.New("failed to persist keys")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
|
|
@ -454,56 +489,27 @@ func presetDeviceKeysQueryMap(
|
||||||
return deviceKeysQueryMap
|
return deviceKeysQueryMap
|
||||||
}
|
}
|
||||||
|
|
||||||
func bothKeyProcess(
|
func persistBothKeys(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
body *types.UploadEncryptSpecific,
|
body *types.UploadEncryptSpecific,
|
||||||
userID, deviceID string,
|
userID, deviceID string,
|
||||||
database *storage.Database,
|
database *storage.Database,
|
||||||
deviceKeys types.DeviceKeys,
|
deviceKeys types.DeviceKeys,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
// insert one time keys firstly
|
// insert one time keys
|
||||||
onetimeKeys := body.OneTimeKey
|
err = persistOneTimeKeys(ctx, body, userID, deviceID, database)
|
||||||
for alKeyID, val := range onetimeKeys.KeyString {
|
if err != nil {
|
||||||
al := (strings.Split(alKeyID, ":"))[0]
|
return
|
||||||
keyID := (strings.Split(alKeyID, ":"))[1]
|
|
||||||
keyInfo := val
|
|
||||||
keyStringTyp := ONETIMEKEYSTR
|
|
||||||
sig := ""
|
|
||||||
err = database.InsertKey(ctx, deviceID, userID, keyID, keyStringTyp, keyInfo, al, sig)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for alKeyID, val := range onetimeKeys.KeyObject {
|
|
||||||
al := (strings.Split(alKeyID, ":"))[0]
|
|
||||||
keyID := (strings.Split(alKeyID, ":"))[1]
|
|
||||||
keyInfo := val.Key
|
|
||||||
keyObjectTyp := ONETIMEKEYSTR
|
|
||||||
sig := val.Signature[userID][fmt.Sprintf("%s:%s", "ed25519", deviceID)]
|
|
||||||
err = database.InsertKey(ctx, deviceID, userID, keyID, keyObjectTyp, keyInfo, al, sig)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
// insert device keys
|
// insert device keys
|
||||||
keys := deviceKeys.Keys
|
err = persistDeviceKeys(ctx, userID, deviceID, database, deviceKeys)
|
||||||
sigs := deviceKeys.Signature
|
if err != nil {
|
||||||
for alDevice, key := range keys {
|
return
|
||||||
al := (strings.Split(alDevice, ":"))[0]
|
|
||||||
keyTyp := DEVICEKEYSTR
|
|
||||||
keyInfo := key
|
|
||||||
keyID := ""
|
|
||||||
sig := sigs[userID][fmt.Sprintf("%s:%s", "ed25519", deviceID)]
|
|
||||||
err = database.InsertKey(
|
|
||||||
ctx, deviceID, userID, keyID, keyTyp, keyInfo, al, sig)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func dkeyProcess(
|
func persistDeviceKeys(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
userID, deviceID string,
|
userID, deviceID string,
|
||||||
database *storage.Database,
|
database *storage.Database,
|
||||||
|
|
@ -522,7 +528,7 @@ func dkeyProcess(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func otmKeyProcess(
|
func persistOneTimeKeys(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
body *types.UploadEncryptSpecific,
|
body *types.UploadEncryptSpecific,
|
||||||
userID, deviceID string,
|
userID, deviceID string,
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
// Copyright 2018 Vector Creations Ltd
|
// Copyright FadeAce and Sumukha PK 2019
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
// you may not use this file except in compliance with the License.
|
// you may not use this file except in compliance with the License.
|
||||||
|
|
@ -17,6 +17,7 @@ package storage
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
"github.com/matrix-org/dendrite/common"
|
"github.com/matrix-org/dendrite/common"
|
||||||
"github.com/matrix-org/dendrite/encryptoapi/types"
|
"github.com/matrix-org/dendrite/encryptoapi/types"
|
||||||
|
|
@ -35,6 +36,15 @@ CREATE TABLE IF NOT EXISTS encrypt_keys (
|
||||||
signature TEXT NOT NULL
|
signature TEXT NOT NULL
|
||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
const changesTable = `
|
||||||
|
CREATE TABLE IF NOT EXISTS key_changes (
|
||||||
|
read_id INT NOT NULL,
|
||||||
|
user_id TEXT PRIMARY KEY NOT NULL,
|
||||||
|
neighbor_user_id TEXT NOT NULL,
|
||||||
|
status TEXT NOT NULL
|
||||||
|
);
|
||||||
|
`
|
||||||
|
|
||||||
const insertkeySQL = `
|
const insertkeySQL = `
|
||||||
INSERT INTO encrypt_keys (device_id, user_id, key_id, key_type, key_info, algorithm, signature)
|
INSERT INTO encrypt_keys (device_id, user_id, key_id, key_type, key_info, algorithm, signature)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||||
|
|
@ -43,11 +53,11 @@ const selectkeySQL = `
|
||||||
SELECT user_id, device_id, key_id, key_type, key_info, algorithm, signature FROM encrypt_keys
|
SELECT user_id, device_id, key_id, key_type, key_info, algorithm, signature FROM encrypt_keys
|
||||||
WHERE user_id = $1 AND device_id = $2
|
WHERE user_id = $1 AND device_id = $2
|
||||||
`
|
`
|
||||||
const deleteSinglekeySQL = `
|
const selectSinglekeySQL = `
|
||||||
SELECT user_id, device_id, key_id, key_type, key_info, algorithm, signature FROM encrypt_keys
|
SELECT user_id, device_id, key_id, key_type, key_info, algorithm, signature FROM encrypt_keys
|
||||||
WHERE user_id = $1 AND device_id = $2 AND algorithm = $3
|
WHERE user_id = $1 AND device_id = $2 AND algorithm = $3
|
||||||
`
|
`
|
||||||
const selectSinglekeySQL = `
|
const deleteSinglekeySQL = `
|
||||||
DELETE FROM encrypt_keys
|
DELETE FROM encrypt_keys
|
||||||
WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4
|
WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4
|
||||||
`
|
`
|
||||||
|
|
@ -63,6 +73,12 @@ const selectCountOneTimeKey = `
|
||||||
SELECT algorithm, COUNT(algorithm) FROM encrypt_keys WHERE user_id = $1 AND device_id = $2 AND key_type = 'one_time_key'
|
SELECT algorithm, COUNT(algorithm) FROM encrypt_keys WHERE user_id = $1 AND device_id = $2 AND key_type = 'one_time_key'
|
||||||
GROUP BY algorithm
|
GROUP BY algorithm
|
||||||
`
|
`
|
||||||
|
const insertChangesSQL = `
|
||||||
|
INSERT INTO changesTable (read_id, user_id) VALUES ($1, $2)
|
||||||
|
`
|
||||||
|
const selectChangesSQL = `
|
||||||
|
SELECT read_id, user_id
|
||||||
|
`
|
||||||
|
|
||||||
type keyStatements struct {
|
type keyStatements struct {
|
||||||
insertKeyStmt *sql.Stmt
|
insertKeyStmt *sql.Stmt
|
||||||
|
|
@ -72,6 +88,8 @@ type keyStatements struct {
|
||||||
selectSingleKeyStmt *sql.Stmt
|
selectSingleKeyStmt *sql.Stmt
|
||||||
deleteSingleKeyStmt *sql.Stmt
|
deleteSingleKeyStmt *sql.Stmt
|
||||||
selectCountOneTimeKeyStmt *sql.Stmt
|
selectCountOneTimeKeyStmt *sql.Stmt
|
||||||
|
insertChangesStmt *sql.Stmt
|
||||||
|
selectChangesStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *keyStatements) prepare(db *sql.DB) (err error) {
|
func (s *keyStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
|
@ -79,6 +97,10 @@ func (s *keyStatements) prepare(db *sql.DB) (err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
_, err = db.Exec(changesTable)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
if s.insertKeyStmt, err = db.Prepare(insertkeySQL); err != nil {
|
if s.insertKeyStmt, err = db.Prepare(insertkeySQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -91,19 +113,25 @@ func (s *keyStatements) prepare(db *sql.DB) (err error) {
|
||||||
if s.selectAllKeyStmt, err = db.Prepare(selectAllkeysSQL); err != nil {
|
if s.selectAllKeyStmt, err = db.Prepare(selectAllkeysSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if s.deleteSingleKeyStmt, err = db.Prepare(selectSinglekeySQL); err != nil {
|
if s.deleteSingleKeyStmt, err = db.Prepare(deleteSinglekeySQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if s.selectSingleKeyStmt, err = db.Prepare(deleteSinglekeySQL); err != nil {
|
if s.selectSingleKeyStmt, err = db.Prepare(selectSinglekeySQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if s.selectCountOneTimeKeyStmt, err = db.Prepare(selectCountOneTimeKey); err != nil {
|
if s.selectCountOneTimeKeyStmt, err = db.Prepare(selectCountOneTimeKey); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if s.insertChangesStmt, err = db.Prepare(insertChangesSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.selectChangesStmt, err = db.Prepare(selectChangesSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// insert keys
|
// insertKeys inserts keys
|
||||||
func (s *keyStatements) insertKey(
|
func (s *keyStatements) insertKey(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
deviceID, userID, keyID, keyTyp, keyInfo, algorithm, signature string,
|
deviceID, userID, keyID, keyTyp, keyInfo, algorithm, signature string,
|
||||||
|
|
@ -113,18 +141,18 @@ func (s *keyStatements) insertKey(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// select by user and device
|
// selectKey selects by user and device
|
||||||
func (s *keyStatements) selectKey(
|
func (s *keyStatements) selectKey(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
txn *sql.Tx,
|
txn *sql.Tx,
|
||||||
deviceID, userID string,
|
deviceID, userID string,
|
||||||
) ([]types.KeyHolder, error) {
|
) ([]types.KeyHolder, error) {
|
||||||
holders := []types.KeyHolder{}
|
|
||||||
stmt := common.TxStmt(txn, s.selectKeyStmt)
|
stmt := common.TxStmt(txn, s.selectKeyStmt)
|
||||||
rows, err := stmt.QueryContext(ctx, userID, deviceID)
|
rows, err := stmt.QueryContext(ctx, userID, deviceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
holders := []types.KeyHolder{}
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
single := &types.KeyHolder{}
|
single := &types.KeyHolder{}
|
||||||
if err = rows.Scan(
|
if err = rows.Scan(
|
||||||
|
|
@ -141,10 +169,13 @@ func (s *keyStatements) selectKey(
|
||||||
holders = append(holders, *single)
|
holders = append(holders, *single)
|
||||||
}
|
}
|
||||||
err = rows.Close()
|
err = rows.Close()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return holders, err
|
return holders, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// select single one for claim usage
|
// selectSingleKey selects single key for claim usage
|
||||||
func (s *keyStatements) selectSingleKey(
|
func (s *keyStatements) selectSingleKey(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
userID, deviceID, algorithm string,
|
userID, deviceID, algorithm string,
|
||||||
|
|
@ -170,7 +201,7 @@ func (s *keyStatements) selectSingleKey(
|
||||||
return holder, err
|
return holder, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// select details by given an array of devices
|
// selectInKeys selects details based on an array of devices
|
||||||
func (s *keyStatements) selectInKeys(
|
func (s *keyStatements) selectInKeys(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
userID string,
|
userID string,
|
||||||
|
|
@ -225,7 +256,7 @@ func injectKeyHolder(rows *sql.Rows, keyHolder []types.KeyHolder) (holders []typ
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// select by user and device
|
// selectOneTimeKeyCount selects by user and device
|
||||||
func (s *keyStatements) selectOneTimeKeyCount(
|
func (s *keyStatements) selectOneTimeKeyCount(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
userID, deviceID string,
|
userID, deviceID string,
|
||||||
|
|
@ -249,3 +280,59 @@ func (s *keyStatements) selectOneTimeKeyCount(
|
||||||
err = rows.Close()
|
err = rows.Close()
|
||||||
return holders, err
|
return holders, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// insertChanges inserts into the changes table
|
||||||
|
func (s *keyStatements) insertChanges(
|
||||||
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
readID int, userID string, status string,
|
||||||
|
) error {
|
||||||
|
stmt := common.TxStmt(txn, s.insertChangesStmt)
|
||||||
|
_, err := stmt.ExecContext(ctx, readID, userID, status)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// selectChanges returns data from the DB
|
||||||
|
func (s *keyStatements) selectChanges(
|
||||||
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
readID int, userID string,
|
||||||
|
) (types.KeyChanges, error) {
|
||||||
|
stmt := common.TxStmt(txn, s.selectChangesStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, readID, userID)
|
||||||
|
if err != nil {
|
||||||
|
return types.KeyChanges{}, err
|
||||||
|
}
|
||||||
|
keyChanges := newKeyChanges()
|
||||||
|
for rows.Next() {
|
||||||
|
var rID, uID, nUID, status string
|
||||||
|
if err = rows.Scan(
|
||||||
|
&rID,
|
||||||
|
&uID,
|
||||||
|
&nUID,
|
||||||
|
&status,
|
||||||
|
); err != nil {
|
||||||
|
return types.KeyChanges{}, err
|
||||||
|
}
|
||||||
|
if keyChanges.UserID == "" {
|
||||||
|
keyChanges.UserID = uID
|
||||||
|
}
|
||||||
|
if status == "changed" {
|
||||||
|
keyChanges.Changed = append(keyChanges.Changed, nUID)
|
||||||
|
} else if status == "left" {
|
||||||
|
keyChanges.Left = append(keyChanges.Left, nUID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err = rows.Close()
|
||||||
|
if err != nil {
|
||||||
|
return types.KeyChanges{}, err
|
||||||
|
}
|
||||||
|
return keyChanges, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newKeyChanges() types.KeyChanges {
|
||||||
|
return types.KeyChanges{
|
||||||
|
UserID: "",
|
||||||
|
NeighborUserID: "",
|
||||||
|
Changed: []string{},
|
||||||
|
Left: []string{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -17,9 +17,10 @@ package storage
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/common"
|
"github.com/matrix-org/dendrite/common"
|
||||||
"github.com/matrix-org/dendrite/encryptoapi/types"
|
"github.com/matrix-org/dendrite/encryptoapi/types"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Database represents a presence database.
|
// Database represents a presence database.
|
||||||
|
|
@ -129,3 +130,26 @@ func (d *Database) SyncOneTimeCount(
|
||||||
holder, err = d.keyStatements.selectOneTimeKeyCount(ctx, userID, deviceID)
|
holder, err = d.keyStatements.selectOneTimeKeyCount(ctx, userID, deviceID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InsertChanges inserts the changes to the DB
|
||||||
|
func (d *Database) InsertChanges(
|
||||||
|
ctx context.Context,
|
||||||
|
readID int, userID, status string,
|
||||||
|
) error {
|
||||||
|
err := common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
return d.keyStatements.insertChanges(ctx, txn, readID, userID, status)
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetKeyChanges gets the changed keys from the DB
|
||||||
|
func (d *Database) GetKeyChanges(
|
||||||
|
ctx context.Context,
|
||||||
|
readID int, userID string,
|
||||||
|
) (keyChanges types.KeyChanges, err error) {
|
||||||
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) (err error) {
|
||||||
|
keyChanges, err = d.keyStatements.selectChanges(ctx, txn, readID, userID)
|
||||||
|
return
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -31,3 +31,11 @@ type AlHolder struct {
|
||||||
DeviceID,
|
DeviceID,
|
||||||
SupportedAlgorithm string
|
SupportedAlgorithm string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// KeyChanges holds the changed keys data
|
||||||
|
type KeyChanges struct {
|
||||||
|
UserID string
|
||||||
|
NeighborUserID string
|
||||||
|
Changed []string
|
||||||
|
Left []string
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -80,7 +80,8 @@ func Setup(apiMux *mux.Router, srp *sync.RequestPool, syncDB *storage.SyncServer
|
||||||
// vars := mux.Vars(req)
|
// vars := mux.Vars(req)
|
||||||
// eventType := vars["eventType"]
|
// eventType := vars["eventType"]
|
||||||
// txnID := vars["txnId"]
|
// txnID := vars["txnId"]
|
||||||
// return SendToDevice(req, device.UserID, syncDB, deviceDB, eventType, txnID, notifier)
|
// roomID := vars["roomId"]
|
||||||
|
// return SendToDevice(req, device.UserID, roomID, syncDB, deviceDB, eventType, txnID, notifier)
|
||||||
// }),
|
// }),
|
||||||
// ).Methods(http.MethodPut, http.MethodOptions)
|
// ).Methods(http.MethodPut, http.MethodOptions)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,8 @@ package routing
|
||||||
|
|
||||||
// import (
|
// import (
|
||||||
// "encoding/json"
|
// "encoding/json"
|
||||||
|
// "net/http"
|
||||||
|
|
||||||
// "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
// "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
// "github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
|
// "github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
|
||||||
// "github.com/matrix-org/dendrite/clientapi/httputil"
|
// "github.com/matrix-org/dendrite/clientapi/httputil"
|
||||||
|
|
@ -10,13 +12,13 @@ package routing
|
||||||
// "github.com/matrix-org/dendrite/syncapi/types"
|
// "github.com/matrix-org/dendrite/syncapi/types"
|
||||||
// "github.com/matrix-org/gomatrixserverlib"
|
// "github.com/matrix-org/gomatrixserverlib"
|
||||||
// "github.com/matrix-org/util"
|
// "github.com/matrix-org/util"
|
||||||
// "net/http"
|
|
||||||
// )
|
// )
|
||||||
|
|
||||||
// // SendToDevice this is a function for calling process of send-to-device messages those bypassed DAG
|
// // SendToDevice this is a function for calling process of send-to-device messages those bypassed DAG
|
||||||
// func SendToDevice(
|
// func SendToDevice(
|
||||||
// req *http.Request,
|
// req *http.Request,
|
||||||
// sender string,
|
// sender string,
|
||||||
|
// roomID string,
|
||||||
// syncDB *storage.SyncServerDatasource,
|
// syncDB *storage.SyncServerDatasource,
|
||||||
// deviceDB *devices.Database,
|
// deviceDB *devices.Database,
|
||||||
// eventType, txnID string,
|
// eventType, txnID string,
|
||||||
|
|
@ -48,7 +50,7 @@ package routing
|
||||||
// Event: jsonBuffer,
|
// Event: jsonBuffer,
|
||||||
// EventTyp: eventType,
|
// EventTyp: eventType,
|
||||||
// }
|
// }
|
||||||
// var pos int64
|
// // var pos int64
|
||||||
|
|
||||||
// // wildcard all devices
|
// // wildcard all devices
|
||||||
// if device == "*" {
|
// if device == "*" {
|
||||||
|
|
@ -58,7 +60,8 @@ package routing
|
||||||
// deviceCollection, err = deviceDB.GetDevicesByLocalpart(ctx, localpart)
|
// deviceCollection, err = deviceDB.GetDevicesByLocalpart(ctx, localpart)
|
||||||
// for _, val := range deviceCollection {
|
// for _, val := range deviceCollection {
|
||||||
// pos, err = syncDB.InsertStdMessage(ctx, ev, txnID, uid, val.ID)
|
// pos, err = syncDB.InsertStdMessage(ctx, ev, txnID, uid, val.ID)
|
||||||
// notifier.OnNewEvent(nil, uid, types.StreamPosition(pos))
|
// // NEEDS MAJOR CHANGES
|
||||||
|
// // notifier.OnNewEvent(nil, roomID, uid, types.StreamPosition(pos))
|
||||||
// }
|
// }
|
||||||
// if err != nil {
|
// if err != nil {
|
||||||
// return util.JSONResponse{
|
// return util.JSONResponse{
|
||||||
|
|
@ -78,7 +81,8 @@ package routing
|
||||||
// JSON: struct{}{},
|
// JSON: struct{}{},
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
// notifier.OnNewEvent(nil, uid, types.StreamPosition(pos))
|
// // NEEDS MAJOR CHANGES
|
||||||
|
// // notifier.OnNewEvent(nil, roomID, uid, types.StreamPosition(pos))
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue