From 35142ec42f85f0902b7db26cbb4a0ccc3aceb538 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Tue, 6 Jun 2023 12:20:43 +0200 Subject: [PATCH] Add possibility to query userIDs for senderKeys --- roomserver/storage/interface.go | 1 + .../storage/postgres/user_room_keys_table.go | 37 ++++++++++- roomserver/storage/shared/storage.go | 29 +++++++++ roomserver/storage/shared/storage_test.go | 11 +++- .../storage/sqlite3/user_room_keys_table.go | 49 ++++++++++++++- roomserver/storage/tables/interface.go | 1 + .../tables/user_room_keys_table_test.go | 63 +++++++++++++------ 7 files changed, 168 insertions(+), 23 deletions(-) diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 1d9697fc5..699b42500 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -192,6 +192,7 @@ type Database interface { InsertUserRoomKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) SelectUserRoomKey(ctx context.Context, userNID types.EventStateKeyNID, roomNID types.RoomNID) (key ed25519.PrivateKey, err error) + SelectUserIDsForPublicKeys(ctx context.Context, publicKeys [][]byte) (map[string]string, error) } type RoomDatabase interface { diff --git a/roomserver/storage/postgres/user_room_keys_table.go b/roomserver/storage/postgres/user_room_keys_table.go index e60f6b6a5..4033bac53 100644 --- a/roomserver/storage/postgres/user_room_keys_table.go +++ b/roomserver/storage/postgres/user_room_keys_table.go @@ -20,9 +20,12 @@ import ( "database/sql" "errors" + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/sirupsen/logrus" ) const userRoomKeysSchema = ` @@ -30,20 +33,24 @@ CREATE TABLE IF NOT EXISTS roomserver_user_room_keys ( user_nid INTEGER NOT NULL, room_nid INTEGER NOT NULL, pseudo_id_key BYTEA NOT NULL, + pseudo_id_pub_key BYTEA NOT NULL, CONSTRAINT roomserver_user_room_keys_pk PRIMARY KEY (user_nid, room_nid) ); ` const insertUserRoomKeySQL = ` - INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_key) VALUES ($1, $2, $3) + INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_key, pseudo_id_pub_key) VALUES ($1, $2, $3, $4) ON CONFLICT ON CONSTRAINT roomserver_user_room_keys_pk DO UPDATE SET pseudo_id_key = roomserver_user_room_keys.pseudo_id_key RETURNING (pseudo_id_key) ` const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` +const selectUserNIDsSQL = `SELECT user_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE pseudo_id_pub_key = ANY($1)` + type userRoomKeysStatements struct { insertUserRoomKeyStmt *sql.Stmt selectUserRoomKeyStmt *sql.Stmt + selectUserIDsStmt *sql.Stmt } func CreateUserRoomKeysTable(db *sql.DB) error { @@ -56,6 +63,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) { return s, sqlutil.StatementList{ {&s.insertUserRoomKeyStmt, insertUserRoomKeySQL}, {&s.selectUserRoomKeyStmt, selectUserRoomKeySQL}, + {&s.selectUserIDsStmt, selectUserNIDsSQL}, }.Prepare(db) } @@ -67,7 +75,7 @@ func (s *userRoomKeysStatements) InsertUserRoomKey( key ed25519.PrivateKey, ) (result ed25519.PrivateKey, err error) { stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomKeyStmt) - err = stmt.QueryRowContext(ctx, userNID, roomNID, key).Scan(&result) + err = stmt.QueryRowContext(ctx, userNID, roomNID, key, key.Public()).Scan(&result) return result, err } @@ -85,3 +93,28 @@ func (s *userRoomKeysStatements) SelectUserRoomKey( } return result, err } + +func (s *userRoomKeysStatements) BulkSelectUserNIDs( + ctx context.Context, + txn *sql.Tx, + senderKeys [][]byte, +) (map[string]types.EventStateKeyNID, error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserIDsStmt) + logrus.Infof("%#v", senderKeys) + rows, err := stmt.QueryContext(ctx, pq.Array(senderKeys)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows") + + result := make(map[string]types.EventStateKeyNID, len(senderKeys)) + var publicKey []byte + var userNID types.EventStateKeyNID + for rows.Next() { + if err = rows.Scan(&userNID, &publicKey); err != nil { + return nil, err + } + result[string(publicKey)] = userNID + } + return result, rows.Err() +} diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 8dd340e78..71bca2de5 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -1619,6 +1619,35 @@ func (d *Database) SelectUserRoomKey(ctx context.Context, userNID types.EventSta return } +// SelectUserIDsForPublicKeys returns a map from senderKey -> userID +func (d *Database) SelectUserIDsForPublicKeys(ctx context.Context, publicKeys [][]byte) (result map[string]string, err error) { + result = make(map[string]string, len(publicKeys)) + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + var sErr error + var userNIDKeyMap map[string]types.EventStateKeyNID + userNIDKeyMap, sErr = d.UserRoomKeyTable.BulkSelectUserNIDs(ctx, txn, publicKeys) + if sErr != nil { + return sErr + } + nids := make([]types.EventStateKeyNID, 0, len(userNIDKeyMap)) + for _, nid := range userNIDKeyMap { + nids = append(nids, nid) + } + nidMAP, seErr := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, txn, nids) + if seErr != nil { + return err + } + + for publicKey, userNID := range userNIDKeyMap { + userID := nidMAP[userNID] + result[publicKey] = userID + } + + return nil + }) + return result, err +} + // FIXME TODO: Remove all this - horrible dupe with roomserver/state. Can't use the original impl because of circular loops // it should live in this package! diff --git a/roomserver/storage/shared/storage_test.go b/roomserver/storage/shared/storage_test.go index 7dc45aed4..c4dcabc98 100644 --- a/roomserver/storage/shared/storage_test.go +++ b/roomserver/storage/shared/storage_test.go @@ -115,11 +115,16 @@ func TestUserRoomKeys(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateRoomserverDatabase(t, dbType) defer close() - userNID := types.EventStateKeyNID(1) + roomNID := types.RoomNID(1) _, key, err := ed25519.GenerateKey(nil) assert.NoError(t, err) + // insert dummy event state keys + dummy := test.NewUser(t) + userNID, err := db.GetOrCreateEventStateKeyNID(ctx, &dummy.ID) + assert.NoError(t, err) + gotKey, err := db.InsertUserRoomKey(ctx, userNID, roomNID, key) assert.NoError(t, err) assert.Equal(t, gotKey, key) @@ -139,5 +144,9 @@ func TestUserRoomKeys(t *testing.T) { gotKey, err = db.SelectUserRoomKey(context.Background(), userNID, 2) assert.NoError(t, err) assert.Nil(t, gotKey) + + userIDs, err := db.SelectUserIDsForPublicKeys(ctx, [][]byte{key.Public().(ed25519.PublicKey)}) + assert.NoError(t, err) + assert.NotNil(t, userIDs) }) } diff --git a/roomserver/storage/sqlite3/user_room_keys_table.go b/roomserver/storage/sqlite3/user_room_keys_table.go index 05ab6d01d..300ae8411 100644 --- a/roomserver/storage/sqlite3/user_room_keys_table.go +++ b/roomserver/storage/sqlite3/user_room_keys_table.go @@ -19,7 +19,9 @@ import ( "crypto/ed25519" "database/sql" "errors" + "strings" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" @@ -30,20 +32,24 @@ CREATE TABLE IF NOT EXISTS roomserver_user_room_keys ( user_nid INTEGER NOT NULL, room_nid INTEGER NOT NULL, pseudo_id_key TEXT NOT NULL, + pseudo_id_pub_key TEXT NOT NULL, CONSTRAINT roomserver_user_room_keys_pk PRIMARY KEY (user_nid, room_nid) ); ` const insertUserRoomKeySQL = ` - INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_key) VALUES ($1, $2, $3) + INSERT INTO roomserver_user_room_keys (user_nid, room_nid, pseudo_id_key, pseudo_id_pub_key) VALUES ($1, $2, $3, $4) ON CONFLICT DO UPDATE SET pseudo_id_key = roomserver_user_room_keys.pseudo_id_key RETURNING (pseudo_id_key) ` const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` +const selectUserNIDsSQL = `SELECT user_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE pseudo_id_pub_key IN ($1)` + type userRoomKeysStatements struct { insertUserRoomKeyStmt *sql.Stmt selectUserRoomKeyStmt *sql.Stmt + selectUserIDsStmt *sql.Stmt } func CreateUserRoomKeysTable(db *sql.DB) error { @@ -56,6 +62,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) { return s, sqlutil.StatementList{ {&s.insertUserRoomKeyStmt, insertUserRoomKeySQL}, {&s.selectUserRoomKeyStmt, selectUserRoomKeySQL}, + {&s.selectUserIDsStmt, selectUserNIDsSQL}, //prepared at runtime }.Prepare(db) } @@ -67,7 +74,7 @@ func (s *userRoomKeysStatements) InsertUserRoomKey( key ed25519.PrivateKey, ) (result ed25519.PrivateKey, err error) { stmt := sqlutil.TxStmtContext(ctx, txn, s.insertUserRoomKeyStmt) - err = stmt.QueryRowContext(ctx, userNID, roomNID, key).Scan(&result) + err = stmt.QueryRowContext(ctx, userNID, roomNID, key, key.Public()).Scan(&result) return result, err } @@ -85,3 +92,41 @@ func (s *userRoomKeysStatements) SelectUserRoomKey( } return result, err } + +func (s *userRoomKeysStatements) BulkSelectUserNIDs( + ctx context.Context, + txn *sql.Tx, + senderKeys [][]byte, +) (map[string]types.EventStateKeyNID, error) { + + selectSQL := strings.Replace(selectUserNIDsSQL, "($1)", sqlutil.QueryVariadic(len(senderKeys)), 1) + selectStmt, err := txn.Prepare(selectSQL) + if err != nil { + return nil, err + } + + params := make([]interface{}, len(senderKeys)) + for i := range senderKeys { + params[i] = senderKeys[i] + } + + stmt := sqlutil.TxStmt(txn, selectStmt) + defer internal.CloseAndLogIfError(ctx, stmt, "failed to close transaction") + + rows, err := stmt.QueryContext(ctx, params...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows") + + result := make(map[string]types.EventStateKeyNID, len(senderKeys)) + var publicKey []byte + var userNID types.EventStateKeyNID + for rows.Next() { + if err = rows.Scan(&userNID, &publicKey); err != nil { + return nil, err + } + result[string(publicKey)] = userNID + } + return result, rows.Err() +} diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 8b25617f7..98e27b858 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -188,6 +188,7 @@ type Purge interface { type UserRoomKeys interface { InsertUserRoomKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (ed25519.PrivateKey, error) SelectUserRoomKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID) (ed25519.PrivateKey, error) + BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys [][]byte) (map[string]types.EventStateKeyNID, error) } // StrippedEvent represents a stripped event for returning extracted content values. diff --git a/roomserver/storage/tables/user_room_keys_table_test.go b/roomserver/storage/tables/user_room_keys_table_test.go index 0647813f5..f515460e3 100644 --- a/roomserver/storage/tables/user_room_keys_table_test.go +++ b/roomserver/storage/tables/user_room_keys_table_test.go @@ -3,6 +3,7 @@ package tables_test import ( "context" "crypto/ed25519" + "database/sql" "testing" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -15,7 +16,7 @@ import ( "github.com/stretchr/testify/assert" ) -func mustCreateUserRoomKeysTable(t *testing.T, dbType test.DBType) (tab tables.UserRoomKeys, close func()) { +func mustCreateUserRoomKeysTable(t *testing.T, dbType test.DBType) (tab tables.UserRoomKeys, db *sql.DB, close func()) { t.Helper() connStr, close := test.PrepareDBConnectionString(t, dbType) db, err := sqlutil.Open(&config.DatabaseOptions{ @@ -34,35 +35,61 @@ func mustCreateUserRoomKeysTable(t *testing.T, dbType test.DBType) (tab tables.U } assert.NoError(t, err) - return tab, close + return tab, db, close } func TestUserRoomKeysTable(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - tab, close := mustCreateUserRoomKeysTable(t, dbType) + tab, db, close := mustCreateUserRoomKeysTable(t, dbType) defer close() userNID := types.EventStateKeyNID(1) roomNID := types.RoomNID(1) _, key, err := ed25519.GenerateKey(nil) assert.NoError(t, err) - gotKey, err := tab.InsertUserRoomKey(context.Background(), nil, userNID, roomNID, key) - assert.NoError(t, err) - assert.Equal(t, gotKey, key) - // again, this shouldn't result in an error, but return the existing key - _, key2, err := ed25519.GenerateKey(nil) - assert.NoError(t, err) - gotKey, err = tab.InsertUserRoomKey(context.Background(), nil, userNID, roomNID, key2) - assert.NoError(t, err) - assert.Equal(t, gotKey, key) + err = sqlutil.WithTransaction(db, func(txn *sql.Tx) error { + var gotKey, key2, key3 ed25519.PrivateKey + gotKey, err = tab.InsertUserRoomKey(context.Background(), txn, userNID, roomNID, key) + assert.NoError(t, err) + assert.Equal(t, gotKey, key) - gotKey, err = tab.SelectUserRoomKey(context.Background(), nil, userNID, roomNID) - assert.NoError(t, err) - assert.Equal(t, key, gotKey) + // again, this shouldn't result in an error, but return the existing key + _, key2, err = ed25519.GenerateKey(nil) + assert.NoError(t, err) + gotKey, err = tab.InsertUserRoomKey(context.Background(), txn, userNID, roomNID, key2) + assert.NoError(t, err) + assert.Equal(t, gotKey, key) - // Key doesn't exist - gotKey, err = tab.SelectUserRoomKey(context.Background(), nil, userNID, 2) + // add another user + _, key3, err = ed25519.GenerateKey(nil) + assert.NoError(t, err) + userNID2 := types.EventStateKeyNID(2) + _, err = tab.InsertUserRoomKey(context.Background(), txn, userNID2, roomNID, key3) + assert.NoError(t, err) + + gotKey, err = tab.SelectUserRoomKey(context.Background(), txn, userNID, roomNID) + assert.NoError(t, err) + assert.Equal(t, key, gotKey) + + // Key doesn't exist + gotKey, err = tab.SelectUserRoomKey(context.Background(), txn, userNID, 2) + assert.NoError(t, err) + assert.Nil(t, gotKey) + + // query user NIDs for senderKeys + var gotKeys map[string]types.EventStateKeyNID + gotKeys, err = tab.BulkSelectUserNIDs(context.Background(), txn, [][]byte{key.Public().(ed25519.PublicKey), key3.Public().(ed25519.PublicKey)}) + assert.NoError(t, err) + assert.NotNil(t, gotKeys) + + wantKeys := map[string]types.EventStateKeyNID{ + string(key.Public().(ed25519.PublicKey)): userNID, + string(key3.Public().(ed25519.PublicKey)): userNID2, + } + assert.Equal(t, wantKeys, gotKeys) + return nil + }) assert.NoError(t, err) - assert.Nil(t, gotKey) + }) }