Add possibility to query all user keys; Get all joined rooms
This commit is contained in:
parent
23cd7877a1
commit
bc8e83fd28
|
@ -128,7 +128,7 @@ const deleteMembershipSQL = "" +
|
|||
"DELETE FROM roomserver_membership WHERE room_nid = $1 AND target_nid = $2"
|
||||
|
||||
const selectRoomsWithMembershipSQL = "" +
|
||||
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false"
|
||||
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = ANY($2) and forgotten = false"
|
||||
|
||||
// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is
|
||||
// joined to. Since this information is used to populate the user directory, we will
|
||||
|
@ -347,10 +347,10 @@ func (s *membershipStatements) UpdateMembership(
|
|||
|
||||
func (s *membershipStatements) SelectRoomsWithMembership(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
userID types.EventStateKeyNID, membershipState tables.MembershipState,
|
||||
userIDs []types.EventStateKeyNID, membershipState tables.MembershipState,
|
||||
) ([]types.RoomNID, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt)
|
||||
rows, err := stmt.QueryContext(ctx, membershipState, userID)
|
||||
rows, err := stmt.QueryContext(ctx, membershipState, pq.Array(userIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -56,12 +56,15 @@ const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_use
|
|||
|
||||
const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid = ANY($1) AND pseudo_id_pub_key = ANY($2)`
|
||||
|
||||
const selectUserRoomKeysSQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND pseudo_id_key IS NOT NULL`
|
||||
|
||||
type userRoomKeysStatements struct {
|
||||
insertUserRoomPrivateKeyStmt *sql.Stmt
|
||||
insertUserRoomPublicKeyStmt *sql.Stmt
|
||||
selectUserRoomKeyStmt *sql.Stmt
|
||||
selectUserRoomPublicKeyStmt *sql.Stmt
|
||||
selectUserNIDsStmt *sql.Stmt
|
||||
selectUserRoomKeysStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func CreateUserRoomKeysTable(db *sql.DB) error {
|
||||
|
@ -77,6 +80,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) {
|
|||
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
|
||||
{&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL},
|
||||
{&s.selectUserNIDsStmt, selectUserNIDsSQL},
|
||||
{&s.selectUserRoomKeysStmt, selectUserRoomKeysSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
|
@ -150,3 +154,25 @@ func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sq
|
|||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *userRoomKeysStatements) SelectPrivateKeysForUserNID(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID) ([]ed25519.PrivateKey, error) {
|
||||
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomKeysStmt)
|
||||
|
||||
rows, err := stmt.QueryContext(ctx, userNID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows")
|
||||
|
||||
var result []ed25519.PrivateKey
|
||||
var pk ed25519.PrivateKey
|
||||
|
||||
for rows.Next() {
|
||||
if err = rows.Scan(&pk); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, pk)
|
||||
}
|
||||
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
|
|
@ -1361,14 +1361,38 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership
|
|||
default:
|
||||
return nil, fmt.Errorf("GetRoomsByMembership: invalid membership %s", membership)
|
||||
}
|
||||
stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, userID)
|
||||
stateKeyNID, err := d.EventStateKeyNIDs(ctx, []string{userID})
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err)
|
||||
}
|
||||
roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, nil, stateKeyNID, membershipState)
|
||||
|
||||
// get the pseudo IDs, if any, as otherwise we don't get the correct room list
|
||||
pseudoIDKeys, err := d.UserRoomKeyTable.SelectPrivateKeysForUserNID(ctx, nil, stateKeyNID[userID])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectPrivateKeysForUserNID: %w", err)
|
||||
}
|
||||
senderIDs := make([]string, len(pseudoIDKeys))
|
||||
var senderID spec.SenderID
|
||||
for _, key := range pseudoIDKeys {
|
||||
senderID = spec.SenderIDFromPseudoIDKey(key)
|
||||
senderIDs = append(senderIDs, string(senderID))
|
||||
}
|
||||
|
||||
stateKeyNIDMap, err := d.EventStateKeyNIDs(ctx, senderIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetRoomsByMembership: failed to EventStateKeyNIDs: %w", err)
|
||||
}
|
||||
|
||||
stateKeyNIDs := make([]types.EventStateKeyNID, 0, len(stateKeyNIDMap)+1)
|
||||
stateKeyNIDs = append(stateKeyNIDs, stateKeyNID[userID])
|
||||
for _, stateKeyNID := range stateKeyNIDMap {
|
||||
stateKeyNIDs = append(stateKeyNIDs, stateKeyNID)
|
||||
}
|
||||
|
||||
roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, nil, stateKeyNIDs, membershipState)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectRoomsWithMembership: %w", err)
|
||||
}
|
||||
|
|
|
@ -100,7 +100,7 @@ const updateMembershipForgetRoom = "" +
|
|||
" WHERE room_nid = $2 AND target_nid = $3"
|
||||
|
||||
const selectRoomsWithMembershipSQL = "" +
|
||||
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false"
|
||||
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid IN ($2) and forgotten = false"
|
||||
|
||||
// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is
|
||||
// joined to. Since this information is used to populate the user directory, we will
|
||||
|
@ -297,10 +297,27 @@ func (s *membershipStatements) UpdateMembership(
|
|||
}
|
||||
|
||||
func (s *membershipStatements) SelectRoomsWithMembership(
|
||||
ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState tables.MembershipState,
|
||||
ctx context.Context, txn *sql.Tx, userIDs []types.EventStateKeyNID, membershipState tables.MembershipState,
|
||||
) ([]types.RoomNID, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt)
|
||||
rows, err := stmt.QueryContext(ctx, membershipState, userID)
|
||||
|
||||
query := strings.Replace(selectRoomsWithMembershipSQL, "($2)", sqlutil.QueryVariadicOffset(len(userIDs), 1), 1)
|
||||
|
||||
var stmt *sql.Stmt
|
||||
var err error
|
||||
if txn != nil {
|
||||
stmt, err = txn.PrepareContext(ctx, query)
|
||||
} else {
|
||||
stmt, err = s.db.Prepare(query)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
params := make([]any, len(userIDs)+1)
|
||||
params[0] = membershipState
|
||||
for i, userID := range userIDs {
|
||||
params[i+1] = userID
|
||||
}
|
||||
rows, err := stmt.QueryContext(ctx, params...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -56,6 +56,8 @@ const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_use
|
|||
|
||||
const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid IN ($1) AND pseudo_id_pub_key IN ($2)`
|
||||
|
||||
const selectUserRoomKeysSQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND pseudo_id_key IS NOT NULL`
|
||||
|
||||
type userRoomKeysStatements struct {
|
||||
db *sql.DB
|
||||
insertUserRoomPrivateKeyStmt *sql.Stmt
|
||||
|
@ -63,6 +65,7 @@ type userRoomKeysStatements struct {
|
|||
selectUserRoomKeyStmt *sql.Stmt
|
||||
selectUserRoomPublicKeyStmt *sql.Stmt
|
||||
//selectUserNIDsStmt *sql.Stmt //prepared at runtime
|
||||
selectUserRoomKeysStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func CreateUserRoomKeysTable(db *sql.DB) error {
|
||||
|
@ -77,6 +80,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) {
|
|||
{&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL},
|
||||
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
|
||||
{&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL},
|
||||
{&s.selectUserRoomKeysStmt, selectUserRoomKeysSQL},
|
||||
//{&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
@ -165,3 +169,25 @@ func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sq
|
|||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *userRoomKeysStatements) SelectPrivateKeysForUserNID(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID) ([]ed25519.PrivateKey, error) {
|
||||
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomKeysStmt)
|
||||
|
||||
rows, err := stmt.QueryContext(ctx, userNID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows")
|
||||
|
||||
var result []ed25519.PrivateKey
|
||||
var pk ed25519.PrivateKey
|
||||
|
||||
for rows.Next() {
|
||||
if err = rows.Scan(&pk); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, pk)
|
||||
}
|
||||
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
|
|
@ -142,7 +142,7 @@ type Membership interface {
|
|||
SelectMembershipsFromRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error)
|
||||
SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
|
||||
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) (bool, error)
|
||||
SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
|
||||
SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userIDs []types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
|
||||
// SelectJoinedUsersSetForRooms returns how many times each of the given users appears across the given rooms.
|
||||
SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID, localOnly bool) (map[types.EventStateKeyNID]int, error)
|
||||
SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
|
||||
|
@ -198,6 +198,8 @@ type UserRoomKeys interface {
|
|||
// BulkSelectUserNIDs selects all userIDs for the requested senderKeys. Returns a map from publicKey -> types.UserRoomKeyPair.
|
||||
// If a senderKey can't be found, it is omitted in the result.
|
||||
BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error)
|
||||
// SelectRoomNIDs selects all roomNIDs for a specific user
|
||||
SelectPrivateKeysForUserNID(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID) ([]ed25519.PrivateKey, error)
|
||||
}
|
||||
|
||||
// StrippedEvent represents a stripped event for returning extracted content values.
|
||||
|
|
|
@ -99,12 +99,13 @@ func TestMembershipTable(t *testing.T) {
|
|||
assert.Equal(t, 10, len(members))
|
||||
|
||||
// Get correct user
|
||||
roomNIDs, err := tab.SelectRoomsWithMembership(ctx, nil, userNIDs[1], tables.MembershipStateLeaveOrBan)
|
||||
roomNIDs, err := tab.SelectRoomsWithMembership(ctx, nil, []types.EventStateKeyNID{userNIDs[1]}, tables.MembershipStateLeaveOrBan)
|
||||
t.Logf("XXX: %v", userNIDs[1:1])
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []types.RoomNID{1}, roomNIDs)
|
||||
|
||||
// User is not joined to room
|
||||
roomNIDs, err = tab.SelectRoomsWithMembership(ctx, nil, userNIDs[5], tables.MembershipStateJoin)
|
||||
roomNIDs, err = tab.SelectRoomsWithMembership(ctx, nil, []types.EventStateKeyNID{userNIDs[5]}, tables.MembershipStateJoin)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, len(roomNIDs))
|
||||
|
||||
|
|
|
@ -115,6 +115,11 @@ func TestUserRoomKeysTable(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
assert.Equal(t, key4, gotPublicKey)
|
||||
|
||||
// query rooms for a specific user
|
||||
var pks []ed25519.PrivateKey
|
||||
pks, err = tab.SelectPrivateKeysForUserNID(context.Background(), txn, userNID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []ed25519.PrivateKey{key}, pks)
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
|
Loading…
Reference in a new issue