Return OTK counts when inserting new keys

This commit is contained in:
Kegan Dougal 2020-07-15 10:37:32 +01:00
parent 17017eefc1
commit 07cbfce2d0
6 changed files with 71 additions and 21 deletions

View file

@ -108,17 +108,15 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform
} }
} }
// store one-time keys // store one-time keys
if err := a.db.StoreOneTimeKeys(ctx, key); err != nil { counts, err := a.db.StoreOneTimeKeys(ctx, key)
if err != nil {
res.KeyError(key.UserID, key.DeviceID, &api.KeyError{ res.KeyError(key.UserID, key.DeviceID, &api.KeyError{
Error: fmt.Sprintf("%s device %s : failed to store one-time keys: %s", key.UserID, key.DeviceID, err.Error()), Error: fmt.Sprintf("%s device %s : failed to store one-time keys: %s", key.UserID, key.DeviceID, err.Error()),
}) })
continue
} }
// collect counts // collect counts
res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, api.OneTimeKeysCount{ res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts)
DeviceID: key.DeviceID,
UserID: key.UserID,
KeyCount: nil,
})
} }
} }

View file

@ -27,7 +27,7 @@ type Database interface {
ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
// StoreOneTimeKeys persists the given one-time keys. // StoreOneTimeKeys persists the given one-time keys.
StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) error StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` already then it will be replaced. // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` already then it will be replaced.
DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error

View file

@ -49,10 +49,14 @@ const upsertKeysSQL = "" +
const selectKeysSQL = "" + const selectKeysSQL = "" +
"SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2" "SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2"
const selectKeysCountSQL = "" +
"SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm"
type oneTimeKeysStatements struct { type oneTimeKeysStatements struct {
db *sql.DB db *sql.DB
upsertKeysStmt *sql.Stmt upsertKeysStmt *sql.Stmt
selectKeysStmt *sql.Stmt selectKeysStmt *sql.Stmt
selectKeysCountStmt *sql.Stmt
} }
func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
@ -69,6 +73,9 @@ func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil { if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil {
return nil, err return nil, err
} }
if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
@ -100,9 +107,14 @@ func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, d
return result, rows.Err() return result, rows.Err()
} }
func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) error { func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) {
now := time.Now().Unix() now := time.Now().Unix()
return sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error { counts := &api.OneTimeKeysCount{
DeviceID: keys.DeviceID,
UserID: keys.UserID,
KeyCount: make(map[string]int),
}
return counts, sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error {
for keyIDWithAlgo, keyJSON := range keys.KeyJSON { for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
algo, keyID := keys.Split(keyIDWithAlgo) algo, keyID := keys.Split(keyIDWithAlgo)
_, err := txn.Stmt(s.upsertKeysStmt).ExecContext( _, err := txn.Stmt(s.upsertKeysStmt).ExecContext(
@ -112,6 +124,20 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.
return err return err
} }
} }
return nil rows, err := txn.Stmt(s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
if err != nil {
return err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
for rows.Next() {
var algorithm string
var count int
if err = rows.Scan(&algorithm, &count); err != nil {
return err
}
counts.KeyCount[algorithm] = count
}
return rows.Err()
}) })
} }

View file

@ -33,7 +33,7 @@ func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID str
return d.OneTimeKeysTable.SelectOneTimeKeys(ctx, userID, deviceID, keyIDsWithAlgorithms) return d.OneTimeKeysTable.SelectOneTimeKeys(ctx, userID, deviceID, keyIDsWithAlgorithms)
} }
func (d *Database) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) error { func (d *Database) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) {
return d.OneTimeKeysTable.InsertOneTimeKeys(ctx, keys) return d.OneTimeKeysTable.InsertOneTimeKeys(ctx, keys)
} }

View file

@ -49,10 +49,14 @@ const upsertKeysSQL = "" +
const selectKeysSQL = "" + const selectKeysSQL = "" +
"SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2" "SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2"
const selectKeysCountSQL = "" +
"SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm"
type oneTimeKeysStatements struct { type oneTimeKeysStatements struct {
db *sql.DB db *sql.DB
upsertKeysStmt *sql.Stmt upsertKeysStmt *sql.Stmt
selectKeysStmt *sql.Stmt selectKeysStmt *sql.Stmt
selectKeysCountStmt *sql.Stmt
} }
func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
@ -69,6 +73,9 @@ func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil { if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil {
return nil, err return nil, err
} }
if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
@ -100,9 +107,14 @@ func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, d
return result, rows.Err() return result, rows.Err()
} }
func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) error { func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) {
now := time.Now().Unix() now := time.Now().Unix()
return sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error { counts := &api.OneTimeKeysCount{
DeviceID: keys.DeviceID,
UserID: keys.UserID,
KeyCount: make(map[string]int),
}
return counts, sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error {
for keyIDWithAlgo, keyJSON := range keys.KeyJSON { for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
algo, keyID := keys.Split(keyIDWithAlgo) algo, keyID := keys.Split(keyIDWithAlgo)
_, err := txn.Stmt(s.upsertKeysStmt).ExecContext( _, err := txn.Stmt(s.upsertKeysStmt).ExecContext(
@ -112,6 +124,20 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.
return err return err
} }
} }
return nil rows, err := txn.Stmt(s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
if err != nil {
return err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
for rows.Next() {
var algorithm string
var count int
if err = rows.Scan(&algorithm, &count); err != nil {
return err
}
counts.KeyCount[algorithm] = count
}
return rows.Err()
}) })
} }

View file

@ -23,7 +23,7 @@ import (
type OneTimeKeys interface { type OneTimeKeys interface {
SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) error InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
} }
type DeviceKeys interface { type DeviceKeys interface {