diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 7812ceb0b..61e086c2a 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -108,17 +108,15 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform } } // store one-time keys - if err := a.db.StoreOneTimeKeys(ctx, key); err != nil { + counts, err := a.db.StoreOneTimeKeys(ctx, key) + if err != nil { res.KeyError(key.UserID, key.DeviceID, &api.KeyError{ Error: fmt.Sprintf("%s device %s : failed to store one-time keys: %s", key.UserID, key.DeviceID, err.Error()), }) + continue } // collect counts - res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, api.OneTimeKeysCount{ - DeviceID: key.DeviceID, - UserID: key.UserID, - KeyCount: nil, - }) + res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts) } } diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index 89b666d18..3697b1970 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -27,7 +27,7 @@ type Database interface { ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) // StoreOneTimeKeys persists the given one-time keys. - StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) error + StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` already then it will be replaced. DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error diff --git a/keyserver/storage/postgres/one_time_keys_table.go b/keyserver/storage/postgres/one_time_keys_table.go index b1f5e5e3f..b8aee72bd 100644 --- a/keyserver/storage/postgres/one_time_keys_table.go +++ b/keyserver/storage/postgres/one_time_keys_table.go @@ -49,10 +49,14 @@ const upsertKeysSQL = "" + const selectKeysSQL = "" + "SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2" +const selectKeysCountSQL = "" + + "SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm" + type oneTimeKeysStatements struct { - db *sql.DB - upsertKeysStmt *sql.Stmt - selectKeysStmt *sql.Stmt + db *sql.DB + upsertKeysStmt *sql.Stmt + selectKeysStmt *sql.Stmt + selectKeysCountStmt *sql.Stmt } func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { @@ -69,6 +73,9 @@ func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil { return nil, err } + if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil { + return nil, err + } return s, nil } @@ -100,9 +107,14 @@ func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, d return result, rows.Err() } -func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) error { +func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) { now := time.Now().Unix() - return sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error { + counts := &api.OneTimeKeysCount{ + DeviceID: keys.DeviceID, + UserID: keys.UserID, + KeyCount: make(map[string]int), + } + return counts, sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error { for keyIDWithAlgo, keyJSON := range keys.KeyJSON { algo, keyID := keys.Split(keyIDWithAlgo) _, err := txn.Stmt(s.upsertKeysStmt).ExecContext( @@ -112,6 +124,20 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api. return err } } - return nil + rows, err := txn.Stmt(s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID) + if err != nil { + return err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") + for rows.Next() { + var algorithm string + var count int + if err = rows.Scan(&algorithm, &count); err != nil { + return err + } + counts.KeyCount[algorithm] = count + } + + return rows.Err() }) } diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index cfc6f940f..28e1f4592 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -33,7 +33,7 @@ func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID str return d.OneTimeKeysTable.SelectOneTimeKeys(ctx, userID, deviceID, keyIDsWithAlgorithms) } -func (d *Database) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) error { +func (d *Database) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) { return d.OneTimeKeysTable.InsertOneTimeKeys(ctx, keys) } diff --git a/keyserver/storage/sqlite3/one_time_keys_table.go b/keyserver/storage/sqlite3/one_time_keys_table.go index 75fc376c3..0f77c8614 100644 --- a/keyserver/storage/sqlite3/one_time_keys_table.go +++ b/keyserver/storage/sqlite3/one_time_keys_table.go @@ -49,10 +49,14 @@ const upsertKeysSQL = "" + const selectKeysSQL = "" + "SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2" +const selectKeysCountSQL = "" + + "SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm" + type oneTimeKeysStatements struct { - db *sql.DB - upsertKeysStmt *sql.Stmt - selectKeysStmt *sql.Stmt + db *sql.DB + upsertKeysStmt *sql.Stmt + selectKeysStmt *sql.Stmt + selectKeysCountStmt *sql.Stmt } func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { @@ -69,6 +73,9 @@ func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil { return nil, err } + if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil { + return nil, err + } return s, nil } @@ -100,9 +107,14 @@ func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, d return result, rows.Err() } -func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) error { +func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) { now := time.Now().Unix() - return sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error { + counts := &api.OneTimeKeysCount{ + DeviceID: keys.DeviceID, + UserID: keys.UserID, + KeyCount: make(map[string]int), + } + return counts, sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error { for keyIDWithAlgo, keyJSON := range keys.KeyJSON { algo, keyID := keys.Split(keyIDWithAlgo) _, err := txn.Stmt(s.upsertKeysStmt).ExecContext( @@ -112,6 +124,20 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api. return err } } - return nil + rows, err := txn.Stmt(s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID) + if err != nil { + return err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") + for rows.Next() { + var algorithm string + var count int + if err = rows.Scan(&algorithm, &count); err != nil { + return err + } + counts.KeyCount[algorithm] = count + } + + return rows.Err() }) } diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index d2c854e78..20667ffb3 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -23,7 +23,7 @@ import ( type OneTimeKeys interface { SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) - InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) error + InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) } type DeviceKeys interface {