mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-21 05:43:09 -06:00
Return OTK counts when inserting new keys
This commit is contained in:
parent
17017eefc1
commit
07cbfce2d0
|
|
@ -108,17 +108,15 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// store one-time keys
|
// store one-time keys
|
||||||
if err := a.db.StoreOneTimeKeys(ctx, key); err != nil {
|
counts, err := a.db.StoreOneTimeKeys(ctx, key)
|
||||||
|
if err != nil {
|
||||||
res.KeyError(key.UserID, key.DeviceID, &api.KeyError{
|
res.KeyError(key.UserID, key.DeviceID, &api.KeyError{
|
||||||
Error: fmt.Sprintf("%s device %s : failed to store one-time keys: %s", key.UserID, key.DeviceID, err.Error()),
|
Error: fmt.Sprintf("%s device %s : failed to store one-time keys: %s", key.UserID, key.DeviceID, err.Error()),
|
||||||
})
|
})
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
// collect counts
|
// collect counts
|
||||||
res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, api.OneTimeKeysCount{
|
res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts)
|
||||||
DeviceID: key.DeviceID,
|
|
||||||
UserID: key.UserID,
|
|
||||||
KeyCount: nil,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ type Database interface {
|
||||||
ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
|
ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
|
||||||
|
|
||||||
// StoreOneTimeKeys persists the given one-time keys.
|
// StoreOneTimeKeys persists the given one-time keys.
|
||||||
StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) error
|
StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
|
||||||
|
|
||||||
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` already then it will be replaced.
|
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` already then it will be replaced.
|
||||||
DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error
|
DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error
|
||||||
|
|
|
||||||
|
|
@ -49,10 +49,14 @@ const upsertKeysSQL = "" +
|
||||||
const selectKeysSQL = "" +
|
const selectKeysSQL = "" +
|
||||||
"SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2"
|
"SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2"
|
||||||
|
|
||||||
|
const selectKeysCountSQL = "" +
|
||||||
|
"SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm"
|
||||||
|
|
||||||
type oneTimeKeysStatements struct {
|
type oneTimeKeysStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
upsertKeysStmt *sql.Stmt
|
upsertKeysStmt *sql.Stmt
|
||||||
selectKeysStmt *sql.Stmt
|
selectKeysStmt *sql.Stmt
|
||||||
|
selectKeysCountStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
|
func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
|
||||||
|
|
@ -69,6 +73,9 @@ func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
|
||||||
if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil {
|
if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -100,9 +107,14 @@ func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, d
|
||||||
return result, rows.Err()
|
return result, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) error {
|
func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) {
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
return sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error {
|
counts := &api.OneTimeKeysCount{
|
||||||
|
DeviceID: keys.DeviceID,
|
||||||
|
UserID: keys.UserID,
|
||||||
|
KeyCount: make(map[string]int),
|
||||||
|
}
|
||||||
|
return counts, sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error {
|
||||||
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
||||||
algo, keyID := keys.Split(keyIDWithAlgo)
|
algo, keyID := keys.Split(keyIDWithAlgo)
|
||||||
_, err := txn.Stmt(s.upsertKeysStmt).ExecContext(
|
_, err := txn.Stmt(s.upsertKeysStmt).ExecContext(
|
||||||
|
|
@ -112,6 +124,20 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
rows, err := txn.Stmt(s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
|
||||||
|
for rows.Next() {
|
||||||
|
var algorithm string
|
||||||
|
var count int
|
||||||
|
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
counts.KeyCount[algorithm] = count
|
||||||
|
}
|
||||||
|
|
||||||
|
return rows.Err()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID str
|
||||||
return d.OneTimeKeysTable.SelectOneTimeKeys(ctx, userID, deviceID, keyIDsWithAlgorithms)
|
return d.OneTimeKeysTable.SelectOneTimeKeys(ctx, userID, deviceID, keyIDsWithAlgorithms)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) error {
|
func (d *Database) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) {
|
||||||
return d.OneTimeKeysTable.InsertOneTimeKeys(ctx, keys)
|
return d.OneTimeKeysTable.InsertOneTimeKeys(ctx, keys)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -49,10 +49,14 @@ const upsertKeysSQL = "" +
|
||||||
const selectKeysSQL = "" +
|
const selectKeysSQL = "" +
|
||||||
"SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2"
|
"SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2"
|
||||||
|
|
||||||
|
const selectKeysCountSQL = "" +
|
||||||
|
"SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm"
|
||||||
|
|
||||||
type oneTimeKeysStatements struct {
|
type oneTimeKeysStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
upsertKeysStmt *sql.Stmt
|
upsertKeysStmt *sql.Stmt
|
||||||
selectKeysStmt *sql.Stmt
|
selectKeysStmt *sql.Stmt
|
||||||
|
selectKeysCountStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
|
func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
|
||||||
|
|
@ -69,6 +73,9 @@ func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
|
||||||
if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil {
|
if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -100,9 +107,14 @@ func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, d
|
||||||
return result, rows.Err()
|
return result, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) error {
|
func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) {
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
return sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error {
|
counts := &api.OneTimeKeysCount{
|
||||||
|
DeviceID: keys.DeviceID,
|
||||||
|
UserID: keys.UserID,
|
||||||
|
KeyCount: make(map[string]int),
|
||||||
|
}
|
||||||
|
return counts, sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error {
|
||||||
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
||||||
algo, keyID := keys.Split(keyIDWithAlgo)
|
algo, keyID := keys.Split(keyIDWithAlgo)
|
||||||
_, err := txn.Stmt(s.upsertKeysStmt).ExecContext(
|
_, err := txn.Stmt(s.upsertKeysStmt).ExecContext(
|
||||||
|
|
@ -112,6 +124,20 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
rows, err := txn.Stmt(s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
|
||||||
|
for rows.Next() {
|
||||||
|
var algorithm string
|
||||||
|
var count int
|
||||||
|
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
counts.KeyCount[algorithm] = count
|
||||||
|
}
|
||||||
|
|
||||||
|
return rows.Err()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ import (
|
||||||
|
|
||||||
type OneTimeKeys interface {
|
type OneTimeKeys interface {
|
||||||
SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
|
SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
|
||||||
InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) error
|
InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type DeviceKeys interface {
|
type DeviceKeys interface {
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue