mirror of
https://github.com/matrix-org/dendrite.git
synced 2024-11-26 16:21:55 -06:00
Add possibility to query all user keys; Get all joined rooms
This commit is contained in:
parent
23cd7877a1
commit
bc8e83fd28
|
@ -128,7 +128,7 @@ const deleteMembershipSQL = "" +
|
||||||
"DELETE FROM roomserver_membership WHERE room_nid = $1 AND target_nid = $2"
|
"DELETE FROM roomserver_membership WHERE room_nid = $1 AND target_nid = $2"
|
||||||
|
|
||||||
const selectRoomsWithMembershipSQL = "" +
|
const selectRoomsWithMembershipSQL = "" +
|
||||||
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false"
|
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = ANY($2) and forgotten = false"
|
||||||
|
|
||||||
// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is
|
// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is
|
||||||
// joined to. Since this information is used to populate the user directory, we will
|
// joined to. Since this information is used to populate the user directory, we will
|
||||||
|
@ -347,10 +347,10 @@ func (s *membershipStatements) UpdateMembership(
|
||||||
|
|
||||||
func (s *membershipStatements) SelectRoomsWithMembership(
|
func (s *membershipStatements) SelectRoomsWithMembership(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
userID types.EventStateKeyNID, membershipState tables.MembershipState,
|
userIDs []types.EventStateKeyNID, membershipState tables.MembershipState,
|
||||||
) ([]types.RoomNID, error) {
|
) ([]types.RoomNID, error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt)
|
stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt)
|
||||||
rows, err := stmt.QueryContext(ctx, membershipState, userID)
|
rows, err := stmt.QueryContext(ctx, membershipState, pq.Array(userIDs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -56,12 +56,15 @@ const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_use
|
||||||
|
|
||||||
const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid = ANY($1) AND pseudo_id_pub_key = ANY($2)`
|
const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid = ANY($1) AND pseudo_id_pub_key = ANY($2)`
|
||||||
|
|
||||||
|
const selectUserRoomKeysSQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND pseudo_id_key IS NOT NULL`
|
||||||
|
|
||||||
type userRoomKeysStatements struct {
|
type userRoomKeysStatements struct {
|
||||||
insertUserRoomPrivateKeyStmt *sql.Stmt
|
insertUserRoomPrivateKeyStmt *sql.Stmt
|
||||||
insertUserRoomPublicKeyStmt *sql.Stmt
|
insertUserRoomPublicKeyStmt *sql.Stmt
|
||||||
selectUserRoomKeyStmt *sql.Stmt
|
selectUserRoomKeyStmt *sql.Stmt
|
||||||
selectUserRoomPublicKeyStmt *sql.Stmt
|
selectUserRoomPublicKeyStmt *sql.Stmt
|
||||||
selectUserNIDsStmt *sql.Stmt
|
selectUserNIDsStmt *sql.Stmt
|
||||||
|
selectUserRoomKeysStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateUserRoomKeysTable(db *sql.DB) error {
|
func CreateUserRoomKeysTable(db *sql.DB) error {
|
||||||
|
@ -77,6 +80,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) {
|
||||||
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
|
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
|
||||||
{&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL},
|
{&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL},
|
||||||
{&s.selectUserNIDsStmt, selectUserNIDsSQL},
|
{&s.selectUserNIDsStmt, selectUserNIDsSQL},
|
||||||
|
{&s.selectUserRoomKeysStmt, selectUserRoomKeysSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -150,3 +154,25 @@ func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sq
|
||||||
}
|
}
|
||||||
return result, rows.Err()
|
return result, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *userRoomKeysStatements) SelectPrivateKeysForUserNID(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID) ([]ed25519.PrivateKey, error) {
|
||||||
|
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomKeysStmt)
|
||||||
|
|
||||||
|
rows, err := stmt.QueryContext(ctx, userNID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows")
|
||||||
|
|
||||||
|
var result []ed25519.PrivateKey
|
||||||
|
var pk ed25519.PrivateKey
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
if err = rows.Scan(&pk); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result = append(result, pk)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, rows.Err()
|
||||||
|
}
|
||||||
|
|
|
@ -1361,14 +1361,38 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("GetRoomsByMembership: invalid membership %s", membership)
|
return nil, fmt.Errorf("GetRoomsByMembership: invalid membership %s", membership)
|
||||||
}
|
}
|
||||||
stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, userID)
|
stateKeyNID, err := d.EventStateKeyNIDs(ctx, []string{userID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err)
|
return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err)
|
||||||
}
|
}
|
||||||
roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, nil, stateKeyNID, membershipState)
|
|
||||||
|
// get the pseudo IDs, if any, as otherwise we don't get the correct room list
|
||||||
|
pseudoIDKeys, err := d.UserRoomKeyTable.SelectPrivateKeysForUserNID(ctx, nil, stateKeyNID[userID])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectPrivateKeysForUserNID: %w", err)
|
||||||
|
}
|
||||||
|
senderIDs := make([]string, len(pseudoIDKeys))
|
||||||
|
var senderID spec.SenderID
|
||||||
|
for _, key := range pseudoIDKeys {
|
||||||
|
senderID = spec.SenderIDFromPseudoIDKey(key)
|
||||||
|
senderIDs = append(senderIDs, string(senderID))
|
||||||
|
}
|
||||||
|
|
||||||
|
stateKeyNIDMap, err := d.EventStateKeyNIDs(ctx, senderIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("GetRoomsByMembership: failed to EventStateKeyNIDs: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stateKeyNIDs := make([]types.EventStateKeyNID, 0, len(stateKeyNIDMap)+1)
|
||||||
|
stateKeyNIDs = append(stateKeyNIDs, stateKeyNID[userID])
|
||||||
|
for _, stateKeyNID := range stateKeyNIDMap {
|
||||||
|
stateKeyNIDs = append(stateKeyNIDs, stateKeyNID)
|
||||||
|
}
|
||||||
|
|
||||||
|
roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, nil, stateKeyNIDs, membershipState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectRoomsWithMembership: %w", err)
|
return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectRoomsWithMembership: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -100,7 +100,7 @@ const updateMembershipForgetRoom = "" +
|
||||||
" WHERE room_nid = $2 AND target_nid = $3"
|
" WHERE room_nid = $2 AND target_nid = $3"
|
||||||
|
|
||||||
const selectRoomsWithMembershipSQL = "" +
|
const selectRoomsWithMembershipSQL = "" +
|
||||||
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false"
|
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid IN ($2) and forgotten = false"
|
||||||
|
|
||||||
// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is
|
// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is
|
||||||
// joined to. Since this information is used to populate the user directory, we will
|
// joined to. Since this information is used to populate the user directory, we will
|
||||||
|
@ -297,10 +297,27 @@ func (s *membershipStatements) UpdateMembership(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) SelectRoomsWithMembership(
|
func (s *membershipStatements) SelectRoomsWithMembership(
|
||||||
ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState tables.MembershipState,
|
ctx context.Context, txn *sql.Tx, userIDs []types.EventStateKeyNID, membershipState tables.MembershipState,
|
||||||
) ([]types.RoomNID, error) {
|
) ([]types.RoomNID, error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt)
|
|
||||||
rows, err := stmt.QueryContext(ctx, membershipState, userID)
|
query := strings.Replace(selectRoomsWithMembershipSQL, "($2)", sqlutil.QueryVariadicOffset(len(userIDs), 1), 1)
|
||||||
|
|
||||||
|
var stmt *sql.Stmt
|
||||||
|
var err error
|
||||||
|
if txn != nil {
|
||||||
|
stmt, err = txn.PrepareContext(ctx, query)
|
||||||
|
} else {
|
||||||
|
stmt, err = s.db.Prepare(query)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
params := make([]any, len(userIDs)+1)
|
||||||
|
params[0] = membershipState
|
||||||
|
for i, userID := range userIDs {
|
||||||
|
params[i+1] = userID
|
||||||
|
}
|
||||||
|
rows, err := stmt.QueryContext(ctx, params...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -56,6 +56,8 @@ const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_use
|
||||||
|
|
||||||
const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid IN ($1) AND pseudo_id_pub_key IN ($2)`
|
const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid IN ($1) AND pseudo_id_pub_key IN ($2)`
|
||||||
|
|
||||||
|
const selectUserRoomKeysSQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND pseudo_id_key IS NOT NULL`
|
||||||
|
|
||||||
type userRoomKeysStatements struct {
|
type userRoomKeysStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
insertUserRoomPrivateKeyStmt *sql.Stmt
|
insertUserRoomPrivateKeyStmt *sql.Stmt
|
||||||
|
@ -63,6 +65,7 @@ type userRoomKeysStatements struct {
|
||||||
selectUserRoomKeyStmt *sql.Stmt
|
selectUserRoomKeyStmt *sql.Stmt
|
||||||
selectUserRoomPublicKeyStmt *sql.Stmt
|
selectUserRoomPublicKeyStmt *sql.Stmt
|
||||||
//selectUserNIDsStmt *sql.Stmt //prepared at runtime
|
//selectUserNIDsStmt *sql.Stmt //prepared at runtime
|
||||||
|
selectUserRoomKeysStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateUserRoomKeysTable(db *sql.DB) error {
|
func CreateUserRoomKeysTable(db *sql.DB) error {
|
||||||
|
@ -77,6 +80,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) {
|
||||||
{&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL},
|
{&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL},
|
||||||
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
|
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
|
||||||
{&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL},
|
{&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL},
|
||||||
|
{&s.selectUserRoomKeysStmt, selectUserRoomKeysSQL},
|
||||||
//{&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime
|
//{&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
@ -165,3 +169,25 @@ func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sq
|
||||||
}
|
}
|
||||||
return result, rows.Err()
|
return result, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *userRoomKeysStatements) SelectPrivateKeysForUserNID(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID) ([]ed25519.PrivateKey, error) {
|
||||||
|
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomKeysStmt)
|
||||||
|
|
||||||
|
rows, err := stmt.QueryContext(ctx, userNID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows")
|
||||||
|
|
||||||
|
var result []ed25519.PrivateKey
|
||||||
|
var pk ed25519.PrivateKey
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
if err = rows.Scan(&pk); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result = append(result, pk)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, rows.Err()
|
||||||
|
}
|
||||||
|
|
|
@ -142,7 +142,7 @@ type Membership interface {
|
||||||
SelectMembershipsFromRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error)
|
SelectMembershipsFromRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error)
|
||||||
SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
|
SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
|
||||||
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) (bool, error)
|
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) (bool, error)
|
||||||
SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
|
SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userIDs []types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
|
||||||
// SelectJoinedUsersSetForRooms returns how many times each of the given users appears across the given rooms.
|
// SelectJoinedUsersSetForRooms returns how many times each of the given users appears across the given rooms.
|
||||||
SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID, localOnly bool) (map[types.EventStateKeyNID]int, error)
|
SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID, localOnly bool) (map[types.EventStateKeyNID]int, error)
|
||||||
SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
|
SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
|
||||||
|
@ -198,6 +198,8 @@ type UserRoomKeys interface {
|
||||||
// BulkSelectUserNIDs selects all userIDs for the requested senderKeys. Returns a map from publicKey -> types.UserRoomKeyPair.
|
// BulkSelectUserNIDs selects all userIDs for the requested senderKeys. Returns a map from publicKey -> types.UserRoomKeyPair.
|
||||||
// If a senderKey can't be found, it is omitted in the result.
|
// If a senderKey can't be found, it is omitted in the result.
|
||||||
BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error)
|
BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error)
|
||||||
|
// SelectRoomNIDs selects all roomNIDs for a specific user
|
||||||
|
SelectPrivateKeysForUserNID(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID) ([]ed25519.PrivateKey, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// StrippedEvent represents a stripped event for returning extracted content values.
|
// StrippedEvent represents a stripped event for returning extracted content values.
|
||||||
|
|
|
@ -99,12 +99,13 @@ func TestMembershipTable(t *testing.T) {
|
||||||
assert.Equal(t, 10, len(members))
|
assert.Equal(t, 10, len(members))
|
||||||
|
|
||||||
// Get correct user
|
// Get correct user
|
||||||
roomNIDs, err := tab.SelectRoomsWithMembership(ctx, nil, userNIDs[1], tables.MembershipStateLeaveOrBan)
|
roomNIDs, err := tab.SelectRoomsWithMembership(ctx, nil, []types.EventStateKeyNID{userNIDs[1]}, tables.MembershipStateLeaveOrBan)
|
||||||
|
t.Logf("XXX: %v", userNIDs[1:1])
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, []types.RoomNID{1}, roomNIDs)
|
assert.Equal(t, []types.RoomNID{1}, roomNIDs)
|
||||||
|
|
||||||
// User is not joined to room
|
// User is not joined to room
|
||||||
roomNIDs, err = tab.SelectRoomsWithMembership(ctx, nil, userNIDs[5], tables.MembershipStateJoin)
|
roomNIDs, err = tab.SelectRoomsWithMembership(ctx, nil, []types.EventStateKeyNID{userNIDs[5]}, tables.MembershipStateJoin)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, 0, len(roomNIDs))
|
assert.Equal(t, 0, len(roomNIDs))
|
||||||
|
|
||||||
|
|
|
@ -115,6 +115,11 @@ func TestUserRoomKeysTable(t *testing.T) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, key4, gotPublicKey)
|
assert.Equal(t, key4, gotPublicKey)
|
||||||
|
|
||||||
|
// query rooms for a specific user
|
||||||
|
var pks []ed25519.PrivateKey
|
||||||
|
pks, err = tab.SelectPrivateKeysForUserNID(context.Background(), txn, userNID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []ed25519.PrivateKey{key}, pks)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
Loading…
Reference in a new issue