diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index 835a43b2d..0b22e87e8 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -128,7 +128,7 @@ const deleteMembershipSQL = "" + "DELETE FROM roomserver_membership WHERE room_nid = $1 AND target_nid = $2" const selectRoomsWithMembershipSQL = "" + - "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false" + "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = ANY($2) and forgotten = false" // selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is // joined to. Since this information is used to populate the user directory, we will @@ -347,10 +347,10 @@ func (s *membershipStatements) UpdateMembership( func (s *membershipStatements) SelectRoomsWithMembership( ctx context.Context, txn *sql.Tx, - userID types.EventStateKeyNID, membershipState tables.MembershipState, + userIDs []types.EventStateKeyNID, membershipState tables.MembershipState, ) ([]types.RoomNID, error) { stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt) - rows, err := stmt.QueryContext(ctx, membershipState, userID) + rows, err := stmt.QueryContext(ctx, membershipState, pq.Array(userIDs)) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/user_room_keys_table.go b/roomserver/storage/postgres/user_room_keys_table.go index 202b0abc1..6bba87c01 100644 --- a/roomserver/storage/postgres/user_room_keys_table.go +++ b/roomserver/storage/postgres/user_room_keys_table.go @@ -56,12 +56,15 @@ const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_use const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid = ANY($1) AND pseudo_id_pub_key = ANY($2)` +const selectUserRoomKeysSQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND pseudo_id_key IS NOT NULL` + type userRoomKeysStatements struct { insertUserRoomPrivateKeyStmt *sql.Stmt insertUserRoomPublicKeyStmt *sql.Stmt selectUserRoomKeyStmt *sql.Stmt selectUserRoomPublicKeyStmt *sql.Stmt selectUserNIDsStmt *sql.Stmt + selectUserRoomKeysStmt *sql.Stmt } func CreateUserRoomKeysTable(db *sql.DB) error { @@ -77,6 +80,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) { {&s.selectUserRoomKeyStmt, selectUserRoomKeySQL}, {&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL}, {&s.selectUserNIDsStmt, selectUserNIDsSQL}, + {&s.selectUserRoomKeysStmt, selectUserRoomKeysSQL}, }.Prepare(db) } @@ -150,3 +154,25 @@ func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sq } return result, rows.Err() } + +func (s *userRoomKeysStatements) SelectPrivateKeysForUserNID(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID) ([]ed25519.PrivateKey, error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomKeysStmt) + + rows, err := stmt.QueryContext(ctx, userNID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows") + + var result []ed25519.PrivateKey + var pk ed25519.PrivateKey + + for rows.Next() { + if err = rows.Scan(&pk); err != nil { + return nil, err + } + result = append(result, pk) + } + + return result, rows.Err() +} diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index fc3ace6a6..92c9751d3 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -1361,14 +1361,38 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership default: return nil, fmt.Errorf("GetRoomsByMembership: invalid membership %s", membership) } - stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, userID) + stateKeyNID, err := d.EventStateKeyNIDs(ctx, []string{userID}) if err != nil { if err == sql.ErrNoRows { return nil, nil } return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err) } - roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, nil, stateKeyNID, membershipState) + + // get the pseudo IDs, if any, as otherwise we don't get the correct room list + pseudoIDKeys, err := d.UserRoomKeyTable.SelectPrivateKeysForUserNID(ctx, nil, stateKeyNID[userID]) + if err != nil { + return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectPrivateKeysForUserNID: %w", err) + } + senderIDs := make([]string, len(pseudoIDKeys)) + var senderID spec.SenderID + for _, key := range pseudoIDKeys { + senderID = spec.SenderIDFromPseudoIDKey(key) + senderIDs = append(senderIDs, string(senderID)) + } + + stateKeyNIDMap, err := d.EventStateKeyNIDs(ctx, senderIDs) + if err != nil { + return nil, fmt.Errorf("GetRoomsByMembership: failed to EventStateKeyNIDs: %w", err) + } + + stateKeyNIDs := make([]types.EventStateKeyNID, 0, len(stateKeyNIDMap)+1) + stateKeyNIDs = append(stateKeyNIDs, stateKeyNID[userID]) + for _, stateKeyNID := range stateKeyNIDMap { + stateKeyNIDs = append(stateKeyNIDs, stateKeyNID) + } + + roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, nil, stateKeyNIDs, membershipState) if err != nil { return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectRoomsWithMembership: %w", err) } diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 977788d50..834f908d9 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -100,7 +100,7 @@ const updateMembershipForgetRoom = "" + " WHERE room_nid = $2 AND target_nid = $3" const selectRoomsWithMembershipSQL = "" + - "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false" + "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid IN ($2) and forgotten = false" // selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is // joined to. Since this information is used to populate the user directory, we will @@ -297,10 +297,27 @@ func (s *membershipStatements) UpdateMembership( } func (s *membershipStatements) SelectRoomsWithMembership( - ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState tables.MembershipState, + ctx context.Context, txn *sql.Tx, userIDs []types.EventStateKeyNID, membershipState tables.MembershipState, ) ([]types.RoomNID, error) { - stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt) - rows, err := stmt.QueryContext(ctx, membershipState, userID) + + query := strings.Replace(selectRoomsWithMembershipSQL, "($2)", sqlutil.QueryVariadicOffset(len(userIDs), 1), 1) + + var stmt *sql.Stmt + var err error + if txn != nil { + stmt, err = txn.PrepareContext(ctx, query) + } else { + stmt, err = s.db.Prepare(query) + } + if err != nil { + return nil, err + } + params := make([]any, len(userIDs)+1) + params[0] = membershipState + for i, userID := range userIDs { + params[i+1] = userID + } + rows, err := stmt.QueryContext(ctx, params...) if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/user_room_keys_table.go b/roomserver/storage/sqlite3/user_room_keys_table.go index 5d6ddc9a8..56c2432b2 100644 --- a/roomserver/storage/sqlite3/user_room_keys_table.go +++ b/roomserver/storage/sqlite3/user_room_keys_table.go @@ -56,6 +56,8 @@ const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_use const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid IN ($1) AND pseudo_id_pub_key IN ($2)` +const selectUserRoomKeysSQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND pseudo_id_key IS NOT NULL` + type userRoomKeysStatements struct { db *sql.DB insertUserRoomPrivateKeyStmt *sql.Stmt @@ -63,6 +65,7 @@ type userRoomKeysStatements struct { selectUserRoomKeyStmt *sql.Stmt selectUserRoomPublicKeyStmt *sql.Stmt //selectUserNIDsStmt *sql.Stmt //prepared at runtime + selectUserRoomKeysStmt *sql.Stmt } func CreateUserRoomKeysTable(db *sql.DB) error { @@ -77,6 +80,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) { {&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL}, {&s.selectUserRoomKeyStmt, selectUserRoomKeySQL}, {&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL}, + {&s.selectUserRoomKeysStmt, selectUserRoomKeysSQL}, //{&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime }.Prepare(db) } @@ -165,3 +169,25 @@ func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sq } return result, rows.Err() } + +func (s *userRoomKeysStatements) SelectPrivateKeysForUserNID(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID) ([]ed25519.PrivateKey, error) { + stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomKeysStmt) + + rows, err := stmt.QueryContext(ctx, userNID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows") + + var result []ed25519.PrivateKey + var pk ed25519.PrivateKey + + for rows.Next() { + if err = rows.Scan(&pk); err != nil { + return nil, err + } + result = append(result, pk) + } + + return result, rows.Err() +} diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 445c1223f..58a89b4d9 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -142,7 +142,7 @@ type Membership interface { SelectMembershipsFromRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error) SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) (bool, error) - SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error) + SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userIDs []types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error) // SelectJoinedUsersSetForRooms returns how many times each of the given users appears across the given rooms. SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID, localOnly bool) (map[types.EventStateKeyNID]int, error) SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) @@ -198,6 +198,8 @@ type UserRoomKeys interface { // BulkSelectUserNIDs selects all userIDs for the requested senderKeys. Returns a map from publicKey -> types.UserRoomKeyPair. // If a senderKey can't be found, it is omitted in the result. BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) + // SelectRoomNIDs selects all roomNIDs for a specific user + SelectPrivateKeysForUserNID(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID) ([]ed25519.PrivateKey, error) } // StrippedEvent represents a stripped event for returning extracted content values. diff --git a/roomserver/storage/tables/membership_table_test.go b/roomserver/storage/tables/membership_table_test.go index c4524ee44..37fb4a47c 100644 --- a/roomserver/storage/tables/membership_table_test.go +++ b/roomserver/storage/tables/membership_table_test.go @@ -99,12 +99,13 @@ func TestMembershipTable(t *testing.T) { assert.Equal(t, 10, len(members)) // Get correct user - roomNIDs, err := tab.SelectRoomsWithMembership(ctx, nil, userNIDs[1], tables.MembershipStateLeaveOrBan) + roomNIDs, err := tab.SelectRoomsWithMembership(ctx, nil, []types.EventStateKeyNID{userNIDs[1]}, tables.MembershipStateLeaveOrBan) + t.Logf("XXX: %v", userNIDs[1:1]) assert.NoError(t, err) assert.Equal(t, []types.RoomNID{1}, roomNIDs) // User is not joined to room - roomNIDs, err = tab.SelectRoomsWithMembership(ctx, nil, userNIDs[5], tables.MembershipStateJoin) + roomNIDs, err = tab.SelectRoomsWithMembership(ctx, nil, []types.EventStateKeyNID{userNIDs[5]}, tables.MembershipStateJoin) assert.NoError(t, err) assert.Equal(t, 0, len(roomNIDs)) diff --git a/roomserver/storage/tables/user_room_keys_table_test.go b/roomserver/storage/tables/user_room_keys_table_test.go index 2809771b4..4ca3bea7c 100644 --- a/roomserver/storage/tables/user_room_keys_table_test.go +++ b/roomserver/storage/tables/user_room_keys_table_test.go @@ -115,6 +115,11 @@ func TestUserRoomKeysTable(t *testing.T) { assert.NoError(t, err) assert.Equal(t, key4, gotPublicKey) + // query rooms for a specific user + var pks []ed25519.PrivateKey + pks, err = tab.SelectPrivateKeysForUserNID(context.Background(), txn, userNID) + assert.NoError(t, err) + assert.Equal(t, []ed25519.PrivateKey{key}, pks) return nil }) assert.NoError(t, err)