diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 15018b038..57058748e 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -194,10 +194,10 @@ type Database interface { ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver, ) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error) - InsertUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) + InsertUserRoomPrivatePublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) SelectUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PrivateKey, err error) InsertUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PublicKey) (result ed25519.PublicKey, err error) - SelectUserIDsForPublicKeys(ctx context.Context, publicKeys [][]byte) (map[string]string, error) + SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (map[spec.RoomID]map[string]string, error) } type RoomDatabase interface { diff --git a/roomserver/storage/postgres/user_room_keys_table.go b/roomserver/storage/postgres/user_room_keys_table.go index f969cd246..5c5b64396 100644 --- a/roomserver/storage/postgres/user_room_keys_table.go +++ b/roomserver/storage/postgres/user_room_keys_table.go @@ -51,7 +51,7 @@ const insertUserRoomPublicKeySQL = ` const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` -const selectUserNIDsSQL = `SELECT user_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE pseudo_id_pub_key = ANY($1)` +const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE pseudo_id_pub_key = ANY($1)` type userRoomKeysStatements struct { insertUserRoomPrivateKeyStmt *sql.Stmt @@ -102,27 +102,31 @@ func (s *userRoomKeysStatements) SelectUserRoomPrivateKey( return result, err } -func (s *userRoomKeysStatements) BulkSelectUserNIDs( - ctx context.Context, - txn *sql.Tx, - senderKeys [][]byte, -) (map[string]types.EventStateKeyNID, error) { +func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) { stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserNIDsStmt) - rows, err := stmt.QueryContext(ctx, pq.Array(senderKeys)) + roomNIDs := make([]types.RoomNID, 0, len(senderKeys)) + var senders [][]byte + for roomNID := range senderKeys { + roomNIDs = append(roomNIDs, roomNID) + for _, key := range senderKeys[roomNID] { + senders = append(senders, key) + } + } + rows, err := stmt.QueryContext(ctx, pq.Array(senders)) if err != nil { return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows") - result := make(map[string]types.EventStateKeyNID, len(senderKeys)) + result := make(map[string]types.UserRoomKeyPair, len(senders)+len(roomNIDs)) var publicKey []byte - var userNID types.EventStateKeyNID + userRoomKeyPair := types.UserRoomKeyPair{} for rows.Next() { - if err = rows.Scan(&userNID, &publicKey); err != nil { + if err = rows.Scan(&userRoomKeyPair.EventStateKeyNID, &userRoomKeyPair.RoomNID, &publicKey); err != nil { return nil, err } - result[string(publicKey)] = userNID + result[string(publicKey)] = userRoomKeyPair } return result, rows.Err() } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 67756d34c..e8761a684 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -13,6 +13,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" + "github.com/sirupsen/logrus" "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/internal/caching" @@ -1616,7 +1617,7 @@ func (d *Database) UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventS // InsertUserRoomPrivateKey inserts a new user room key for the given user and room. // Returns the newly inserted private key or an existing private key. If there is // an error talking to the database, returns that error. -func (d *Database) InsertUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) { +func (d *Database) InsertUserRoomPrivatePublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PrivateKey) (result ed25519.PrivateKey, err error) { uID := userID.String() stateKeyNIDMap, sErr := d.eventStateKeyNIDs(ctx, nil, []string{uID}) if sErr != nil { @@ -1696,28 +1697,57 @@ func (d *Database) SelectUserRoomPrivateKey(ctx context.Context, userID spec.Use return } -// SelectUserIDsForPublicKeys returns a map from senderKey -> userID -func (d *Database) SelectUserIDsForPublicKeys(ctx context.Context, publicKeys [][]byte) (result map[string]string, err error) { - result = make(map[string]string, len(publicKeys)) +// SelectUserIDsForPublicKeys returns a map from roomID -> map from senderKey -> userID +func (d *Database) SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (result map[spec.RoomID]map[string]string, err error) { + result = make(map[spec.RoomID]map[string]string, len(publicKeys)) err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - var sErr error - var userNIDKeyMap map[string]types.EventStateKeyNID - userNIDKeyMap, sErr = d.UserRoomKeyTable.BulkSelectUserNIDs(ctx, txn, publicKeys) + + // map all roomIDs to roomNIDs + query := make(map[types.RoomNID][]ed25519.PublicKey) + rooms := make(map[types.RoomNID]spec.RoomID) + for roomID, keys := range publicKeys { + roomNID, ok := d.Cache.GetRoomServerRoomNID(roomID.String()) + if !ok { + roomInfo, rErr := d.roomInfo(ctx, txn, roomID.String()) + if rErr != nil { + return rErr + } + if roomInfo == nil { + logrus.Warnf("missing room info for %s, there will be missing users in the response", roomID.String()) + continue + } + roomNID = roomInfo.RoomNID + } + + query[roomNID] = keys + rooms[roomNID] = roomID + } + + // get the user room key pars + userRoomKeyPairMap, sErr := d.UserRoomKeyTable.BulkSelectUserNIDs(ctx, txn, query) if sErr != nil { return sErr } - nids := make([]types.EventStateKeyNID, 0, len(userNIDKeyMap)) - for _, nid := range userNIDKeyMap { - nids = append(nids, nid) + nids := make([]types.EventStateKeyNID, 0, len(userRoomKeyPairMap)) + for _, nid := range userRoomKeyPairMap { + nids = append(nids, nid.EventStateKeyNID) } + // get the userIDs nidMAP, seErr := d.EventStateKeys(ctx, nids) if seErr != nil { return seErr } - for publicKey, userNID := range userNIDKeyMap { - userID := nidMAP[userNID] - result[publicKey] = userID + // build the result map (roomID -> map publicKey -> userID) + for publicKey, userRoomKeyPair := range userRoomKeyPairMap { + userID := nidMAP[userRoomKeyPair.EventStateKeyNID] + roomID := rooms[userRoomKeyPair.RoomNID] + resMap, exists := result[roomID] + if !exists { + resMap = map[string]string{} + } + resMap[publicKey] = userID + result[roomID] = resMap } return nil diff --git a/roomserver/storage/shared/storage_test.go b/roomserver/storage/shared/storage_test.go index ce1c46c25..4fa451bcc 100644 --- a/roomserver/storage/shared/storage_test.go +++ b/roomserver/storage/shared/storage_test.go @@ -148,14 +148,14 @@ func TestUserRoomKeys(t *testing.T) { _, key, err := ed25519.GenerateKey(nil) assert.NoError(t, err) - gotKey, err := db.InsertUserRoomPrivateKey(ctx, *userID, *roomID, key) + gotKey, err := db.InsertUserRoomPrivatePublicKey(ctx, *userID, *roomID, key) assert.NoError(t, err) assert.Equal(t, gotKey, key) // again, this shouldn't result in an error, but return the existing key _, key2, err := ed25519.GenerateKey(nil) assert.NoError(t, err) - gotKey, err = db.InsertUserRoomPrivateKey(ctx, *userID, *roomID, key2) + gotKey, err = db.InsertUserRoomPrivatePublicKey(ctx, *userID, *roomID, key2) assert.NoError(t, err) assert.Equal(t, gotKey, key) @@ -169,9 +169,18 @@ func TestUserRoomKeys(t *testing.T) { assert.NoError(t, err) assert.Nil(t, gotKey) - userIDs, err := db.SelectUserIDsForPublicKeys(ctx, [][]byte{key.Public().(ed25519.PublicKey)}) + queryUserIDs := map[spec.RoomID][]ed25519.PublicKey{ + *roomID: {key.Public().(ed25519.PublicKey)}, + } + + userIDs, err := db.SelectUserIDsForPublicKeys(ctx, queryUserIDs) assert.NoError(t, err) - assert.NotNil(t, userIDs) + wantKeys := map[spec.RoomID]map[string]string{ + *roomID: { + string(key.Public().(ed25519.PublicKey)): userID.String(), + }, + } + assert.Equal(t, wantKeys, userIDs) // insert key that came in over federation var gotPublicKey, key4 ed255192.PublicKey @@ -186,7 +195,7 @@ func TestUserRoomKeys(t *testing.T) { assert.NoError(t, err) _, err = db.InsertUserRoomPublicKey(context.Background(), *userID, *reallyDoesNotExist, key4) assert.Error(t, err) - _, err = db.InsertUserRoomPrivateKey(context.Background(), *userID, *reallyDoesNotExist, key) + _, err = db.InsertUserRoomPrivatePublicKey(context.Background(), *userID, *reallyDoesNotExist, key) assert.Error(t, err) }) } diff --git a/roomserver/storage/sqlite3/user_room_keys_table.go b/roomserver/storage/sqlite3/user_room_keys_table.go index 155391a59..d334c616d 100644 --- a/roomserver/storage/sqlite3/user_room_keys_table.go +++ b/roomserver/storage/sqlite3/user_room_keys_table.go @@ -51,13 +51,13 @@ const insertUserRoomPublicKeySQL = ` const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2` -const selectUserNIDsSQL = `SELECT user_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE pseudo_id_pub_key IN ($1)` +const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid IN ($1) AND pseudo_id_pub_key IN ($2)` type userRoomKeysStatements struct { insertUserRoomPrivateKeyStmt *sql.Stmt insertUserRoomPublicKeyStmt *sql.Stmt selectUserRoomKeyStmt *sql.Stmt - selectUserNIDsStmt *sql.Stmt + //selectUserNIDsStmt *sql.Stmt //prepared at runtime } func CreateUserRoomKeysTable(db *sql.DB) error { @@ -71,7 +71,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) { {&s.insertUserRoomPrivateKeyStmt, insertUserRoomKeySQL}, {&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL}, {&s.selectUserRoomKeyStmt, selectUserRoomKeySQL}, - {&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime + //{&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime }.Prepare(db) } @@ -102,25 +102,30 @@ func (s *userRoomKeysStatements) SelectUserRoomPrivateKey( return result, err } -func (s *userRoomKeysStatements) BulkSelectUserNIDs( - ctx context.Context, - txn *sql.Tx, - senderKeys [][]byte, -) (map[string]types.EventStateKeyNID, error) { +func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) { + + roomNIDs := make([]any, 0, len(senderKeys)) + var senders []any + for roomNID := range senderKeys { + roomNIDs = append(roomNIDs, roomNID) + + for _, key := range senderKeys[roomNID] { + senders = append(senders, []byte(key)) + } + } + + selectSQL := strings.Replace(selectUserNIDsSQL, "($2)", sqlutil.QueryVariadicOffset(len(senders), len(senderKeys)), 1) + selectSQL = strings.Replace(selectSQL, "($1)", sqlutil.QueryVariadic(len(senderKeys)), 1) // replace $1 with the roomNIDs - selectSQL := strings.Replace(selectUserNIDsSQL, "($1)", sqlutil.QueryVariadic(len(senderKeys)), 1) selectStmt, err := txn.Prepare(selectSQL) if err != nil { return nil, err } - params := make([]interface{}, len(senderKeys)) - for i := range senderKeys { - params[i] = senderKeys[i] - } + params := append(roomNIDs, senders...) stmt := sqlutil.TxStmt(txn, selectStmt) - defer internal.CloseAndLogIfError(ctx, stmt, "failed to close transaction") + defer internal.CloseAndLogIfError(ctx, stmt, "failed to close statement") rows, err := stmt.QueryContext(ctx, params...) if err != nil { @@ -128,14 +133,14 @@ func (s *userRoomKeysStatements) BulkSelectUserNIDs( } defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows") - result := make(map[string]types.EventStateKeyNID, len(senderKeys)) + result := make(map[string]types.UserRoomKeyPair, len(params)) var publicKey []byte - var userNID types.EventStateKeyNID + userRoomKeyPair := types.UserRoomKeyPair{} for rows.Next() { - if err = rows.Scan(&userNID, &publicKey); err != nil { + if err = rows.Scan(&userRoomKeyPair.EventStateKeyNID, &userRoomKeyPair.RoomNID, &publicKey); err != nil { return nil, err } - result[string(publicKey)] = userNID + result[string(publicKey)] = userRoomKeyPair } return result, rows.Err() } diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 3064628ca..b2578021b 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -189,7 +189,7 @@ type UserRoomKeys interface { InsertUserRoomPrivateKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PrivateKey) (ed25519.PrivateKey, error) InsertUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (ed25519.PublicKey, error) SelectUserRoomPrivateKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID) (ed25519.PrivateKey, error) - BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys [][]byte) (map[string]types.EventStateKeyNID, error) + BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) } // StrippedEvent represents a stripped event for returning extracted content values. diff --git a/roomserver/storage/tables/user_room_keys_table_test.go b/roomserver/storage/tables/user_room_keys_table_test.go index 253f2b427..5b4e53a33 100644 --- a/roomserver/storage/tables/user_room_keys_table_test.go +++ b/roomserver/storage/tables/user_room_keys_table_test.go @@ -78,14 +78,18 @@ func TestUserRoomKeysTable(t *testing.T) { assert.Nil(t, gotKey) // query user NIDs for senderKeys - var gotKeys map[string]types.EventStateKeyNID - gotKeys, err = tab.BulkSelectUserNIDs(context.Background(), txn, [][]byte{key.Public().(ed25519.PublicKey), key3.Public().(ed25519.PublicKey)}) + var gotKeys map[string]types.UserRoomKeyPair + query := map[types.RoomNID][]ed25519.PublicKey{ + roomNID: {key.Public().(ed25519.PublicKey), key3.Public().(ed25519.PublicKey)}, + types.RoomNID(2): {key.Public().(ed25519.PublicKey), key3.Public().(ed25519.PublicKey)}, // doesn't exist + } + gotKeys, err = tab.BulkSelectUserNIDs(context.Background(), txn, query) assert.NoError(t, err) assert.NotNil(t, gotKeys) - wantKeys := map[string]types.EventStateKeyNID{ - string(key.Public().(ed25519.PublicKey)): userNID, - string(key3.Public().(ed25519.PublicKey)): userNID2, + wantKeys := map[string]types.UserRoomKeyPair{ + string(key.Public().(ed25519.PublicKey)): {RoomNID: roomNID, EventStateKeyNID: userNID}, + string(key3.Public().(ed25519.PublicKey)): {RoomNID: roomNID, EventStateKeyNID: userNID2}, } assert.Equal(t, wantKeys, gotKeys) diff --git a/roomserver/types/types.go b/roomserver/types/types.go index f57978ad5..45a3e25fc 100644 --- a/roomserver/types/types.go +++ b/roomserver/types/types.go @@ -44,6 +44,11 @@ type EventMetadata struct { RoomNID RoomNID } +type UserRoomKeyPair struct { + RoomNID RoomNID + EventStateKeyNID EventStateKeyNID +} + // StateSnapshotNID is a numeric ID for the state at an event. type StateSnapshotNID int64