Add possibility to query all user keys; Get all joined rooms

This commit is contained in:
Till Faelligen 2023-06-29 11:33:28 +02:00
parent 23cd7877a1
commit bc8e83fd28
No known key found for this signature in database
GPG key ID: ACCDC9606D472758
8 changed files with 113 additions and 12 deletions

View file

@ -128,7 +128,7 @@ const deleteMembershipSQL = "" +
"DELETE FROM roomserver_membership WHERE room_nid = $1 AND target_nid = $2"
const selectRoomsWithMembershipSQL = "" +
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false"
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = ANY($2) and forgotten = false"
// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is
// joined to. Since this information is used to populate the user directory, we will
@ -347,10 +347,10 @@ func (s *membershipStatements) UpdateMembership(
func (s *membershipStatements) SelectRoomsWithMembership(
ctx context.Context, txn *sql.Tx,
userID types.EventStateKeyNID, membershipState tables.MembershipState,
userIDs []types.EventStateKeyNID, membershipState tables.MembershipState,
) ([]types.RoomNID, error) {
stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt)
rows, err := stmt.QueryContext(ctx, membershipState, userID)
rows, err := stmt.QueryContext(ctx, membershipState, pq.Array(userIDs))
if err != nil {
return nil, err
}

View file

@ -56,12 +56,15 @@ const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_use
const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid = ANY($1) AND pseudo_id_pub_key = ANY($2)`
const selectUserRoomKeysSQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND pseudo_id_key IS NOT NULL`
type userRoomKeysStatements struct {
insertUserRoomPrivateKeyStmt *sql.Stmt
insertUserRoomPublicKeyStmt *sql.Stmt
selectUserRoomKeyStmt *sql.Stmt
selectUserRoomPublicKeyStmt *sql.Stmt
selectUserNIDsStmt *sql.Stmt
selectUserRoomKeysStmt *sql.Stmt
}
func CreateUserRoomKeysTable(db *sql.DB) error {
@ -77,6 +80,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) {
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
{&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL},
{&s.selectUserNIDsStmt, selectUserNIDsSQL},
{&s.selectUserRoomKeysStmt, selectUserRoomKeysSQL},
}.Prepare(db)
}
@ -150,3 +154,25 @@ func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sq
}
return result, rows.Err()
}
func (s *userRoomKeysStatements) SelectPrivateKeysForUserNID(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID) ([]ed25519.PrivateKey, error) {
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomKeysStmt)
rows, err := stmt.QueryContext(ctx, userNID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows")
var result []ed25519.PrivateKey
var pk ed25519.PrivateKey
for rows.Next() {
if err = rows.Scan(&pk); err != nil {
return nil, err
}
result = append(result, pk)
}
return result, rows.Err()
}

View file

@ -1361,14 +1361,38 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership
default:
return nil, fmt.Errorf("GetRoomsByMembership: invalid membership %s", membership)
}
stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, userID)
stateKeyNID, err := d.EventStateKeyNIDs(ctx, []string{userID})
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err)
}
roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, nil, stateKeyNID, membershipState)
// get the pseudo IDs, if any, as otherwise we don't get the correct room list
pseudoIDKeys, err := d.UserRoomKeyTable.SelectPrivateKeysForUserNID(ctx, nil, stateKeyNID[userID])
if err != nil {
return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectPrivateKeysForUserNID: %w", err)
}
senderIDs := make([]string, len(pseudoIDKeys))
var senderID spec.SenderID
for _, key := range pseudoIDKeys {
senderID = spec.SenderIDFromPseudoIDKey(key)
senderIDs = append(senderIDs, string(senderID))
}
stateKeyNIDMap, err := d.EventStateKeyNIDs(ctx, senderIDs)
if err != nil {
return nil, fmt.Errorf("GetRoomsByMembership: failed to EventStateKeyNIDs: %w", err)
}
stateKeyNIDs := make([]types.EventStateKeyNID, 0, len(stateKeyNIDMap)+1)
stateKeyNIDs = append(stateKeyNIDs, stateKeyNID[userID])
for _, stateKeyNID := range stateKeyNIDMap {
stateKeyNIDs = append(stateKeyNIDs, stateKeyNID)
}
roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, nil, stateKeyNIDs, membershipState)
if err != nil {
return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectRoomsWithMembership: %w", err)
}

View file

@ -100,7 +100,7 @@ const updateMembershipForgetRoom = "" +
" WHERE room_nid = $2 AND target_nid = $3"
const selectRoomsWithMembershipSQL = "" +
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false"
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid IN ($2) and forgotten = false"
// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is
// joined to. Since this information is used to populate the user directory, we will
@ -297,10 +297,27 @@ func (s *membershipStatements) UpdateMembership(
}
func (s *membershipStatements) SelectRoomsWithMembership(
ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState tables.MembershipState,
ctx context.Context, txn *sql.Tx, userIDs []types.EventStateKeyNID, membershipState tables.MembershipState,
) ([]types.RoomNID, error) {
stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt)
rows, err := stmt.QueryContext(ctx, membershipState, userID)
query := strings.Replace(selectRoomsWithMembershipSQL, "($2)", sqlutil.QueryVariadicOffset(len(userIDs), 1), 1)
var stmt *sql.Stmt
var err error
if txn != nil {
stmt, err = txn.PrepareContext(ctx, query)
} else {
stmt, err = s.db.Prepare(query)
}
if err != nil {
return nil, err
}
params := make([]any, len(userIDs)+1)
params[0] = membershipState
for i, userID := range userIDs {
params[i+1] = userID
}
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
return nil, err
}

View file

@ -56,6 +56,8 @@ const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_use
const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid IN ($1) AND pseudo_id_pub_key IN ($2)`
const selectUserRoomKeysSQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND pseudo_id_key IS NOT NULL`
type userRoomKeysStatements struct {
db *sql.DB
insertUserRoomPrivateKeyStmt *sql.Stmt
@ -63,6 +65,7 @@ type userRoomKeysStatements struct {
selectUserRoomKeyStmt *sql.Stmt
selectUserRoomPublicKeyStmt *sql.Stmt
//selectUserNIDsStmt *sql.Stmt //prepared at runtime
selectUserRoomKeysStmt *sql.Stmt
}
func CreateUserRoomKeysTable(db *sql.DB) error {
@ -77,6 +80,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) {
{&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL},
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
{&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL},
{&s.selectUserRoomKeysStmt, selectUserRoomKeysSQL},
//{&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime
}.Prepare(db)
}
@ -165,3 +169,25 @@ func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sq
}
return result, rows.Err()
}
func (s *userRoomKeysStatements) SelectPrivateKeysForUserNID(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID) ([]ed25519.PrivateKey, error) {
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomKeysStmt)
rows, err := stmt.QueryContext(ctx, userNID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows")
var result []ed25519.PrivateKey
var pk ed25519.PrivateKey
for rows.Next() {
if err = rows.Scan(&pk); err != nil {
return nil, err
}
result = append(result, pk)
}
return result, rows.Err()
}

View file

@ -142,7 +142,7 @@ type Membership interface {
SelectMembershipsFromRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error)
SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) (bool, error)
SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userIDs []types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
// SelectJoinedUsersSetForRooms returns how many times each of the given users appears across the given rooms.
SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID, localOnly bool) (map[types.EventStateKeyNID]int, error)
SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
@ -198,6 +198,8 @@ type UserRoomKeys interface {
// BulkSelectUserNIDs selects all userIDs for the requested senderKeys. Returns a map from publicKey -> types.UserRoomKeyPair.
// If a senderKey can't be found, it is omitted in the result.
BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error)
// SelectRoomNIDs selects all roomNIDs for a specific user
SelectPrivateKeysForUserNID(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID) ([]ed25519.PrivateKey, error)
}
// StrippedEvent represents a stripped event for returning extracted content values.

View file

@ -99,12 +99,13 @@ func TestMembershipTable(t *testing.T) {
assert.Equal(t, 10, len(members))
// Get correct user
roomNIDs, err := tab.SelectRoomsWithMembership(ctx, nil, userNIDs[1], tables.MembershipStateLeaveOrBan)
roomNIDs, err := tab.SelectRoomsWithMembership(ctx, nil, []types.EventStateKeyNID{userNIDs[1]}, tables.MembershipStateLeaveOrBan)
t.Logf("XXX: %v", userNIDs[1:1])
assert.NoError(t, err)
assert.Equal(t, []types.RoomNID{1}, roomNIDs)
// User is not joined to room
roomNIDs, err = tab.SelectRoomsWithMembership(ctx, nil, userNIDs[5], tables.MembershipStateJoin)
roomNIDs, err = tab.SelectRoomsWithMembership(ctx, nil, []types.EventStateKeyNID{userNIDs[5]}, tables.MembershipStateJoin)
assert.NoError(t, err)
assert.Equal(t, 0, len(roomNIDs))

View file

@ -115,6 +115,11 @@ func TestUserRoomKeysTable(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, key4, gotPublicKey)
// query rooms for a specific user
var pks []ed25519.PrivateKey
pks, err = tab.SelectPrivateKeysForUserNID(context.Background(), txn, userNID)
assert.NoError(t, err)
assert.Equal(t, []ed25519.PrivateKey{key}, pks)
return nil
})
assert.NoError(t, err)