Add function to the sync API storage package for filtering shared users

This commit is contained in:
Neil Alexander 2022-07-15 13:09:56 +01:00
parent 69c86295f7
commit 4bf483a7f5
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
6 changed files with 73 additions and 1 deletions

View file

@ -215,6 +215,8 @@ func TrackChangedUsers(
return changed, left, nil
}
// filterSharedUsers takes a list of remote users whose keys have changed and filters
// it down to include only users who the requesting user shares a room with.
func filterSharedUsers(
ctx context.Context, rsAPI roomserverAPI.SyncRoomserverAPI, userID string, usersWithChangedKeys []string,
) (map[string]int, []string) {

View file

@ -54,6 +54,8 @@ type Database interface {
AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error)
// AllJoinedUsersInRoom returns a map of room ID to a list of all joined user IDs for a given room.
AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error)
// SharedUsers returns a subset of otherUserIDs that share a room with userID.
SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error)
// AllPeekingDevicesInRooms returns a map of room ID to a list of all peeking devices.
AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error)

View file

@ -107,6 +107,11 @@ const selectEventsWithEventIDsSQL = "" +
"SELECT event_id, added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" +
" FROM syncapi_current_room_state WHERE event_id = ANY($1)"
const selectSharedUsersSQL = "" +
"SELECT state_key FROM syncapi_current_room_state WHERE room_id = ANY(" +
" SELECT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" +
") AND state_key = ANY($2) AND membership='join';"
type currentRoomStateStatements struct {
upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt *sql.Stmt
@ -118,6 +123,7 @@ type currentRoomStateStatements struct {
selectJoinedUsersInRoomStmt *sql.Stmt
selectEventsWithEventIDsStmt *sql.Stmt
selectStateEventStmt *sql.Stmt
selectSharedUsersStmt *sql.Stmt
}
func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
@ -156,6 +162,9 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro
if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil {
return nil, err
}
if s.selectSharedUsersStmt, err = db.Prepare(selectSharedUsersSQL); err != nil {
return nil, err
}
return s, nil
}
@ -379,3 +388,24 @@ func (s *currentRoomStateStatements) SelectStateEvent(
}
return &ev, err
}
func (s *currentRoomStateStatements) SelectSharedUsers(
ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string,
) ([]string, error) {
stmt := sqlutil.TxStmt(txn, s.selectSharedUsersStmt)
rows, err := stmt.QueryContext(ctx, userID, otherUserIDs)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectSharedUsersStmt: rows.close() failed")
var stateKey string
result := make([]string, 0, len(otherUserIDs))
for rows.Next() {
if err := rows.Scan(&stateKey); err != nil {
return nil, err
}
result = append(result, stateKey)
}
return result, rows.Err()
}

View file

@ -176,6 +176,10 @@ func (d *Database) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]t
return d.Peeks.SelectPeekingDevices(ctx)
}
func (d *Database) SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error) {
return d.CurrentRoomState.SelectSharedUsers(ctx, nil, userID, otherUserIDs)
}
func (d *Database) GetStateEvent(
ctx context.Context, roomID, evType, stateKey string,
) (*gomatrixserverlib.HeaderedEvent, error) {

View file

@ -91,6 +91,11 @@ const selectEventsWithEventIDsSQL = "" +
"SELECT event_id, added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" +
" FROM syncapi_current_room_state WHERE event_id IN ($1)"
const selectSharedUsersSQL = "" +
"SELECT state_key FROM syncapi_current_room_state WHERE room_id = ANY(" +
" SELECT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" +
") AND state_key IN ($2) AND membership='join';"
type currentRoomStateStatements struct {
db *sql.DB
streamIDStatements *StreamIDStatements
@ -100,8 +105,9 @@ type currentRoomStateStatements struct {
selectRoomIDsWithMembershipStmt *sql.Stmt
selectRoomIDsWithAnyMembershipStmt *sql.Stmt
selectJoinedUsersStmt *sql.Stmt
//selectJoinedUsersInRoomStmt *sql.Stmt - prepared at runtime due to variadic
//selectJoinedUsersInRoomStmt *sql.Stmt - prepared at runtime due to variadic
selectStateEventStmt *sql.Stmt
//selectSharedUsersSQL *sql.Stmt - prepared at runtime due to variadic
}
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) {
@ -396,3 +402,29 @@ func (s *currentRoomStateStatements) SelectStateEvent(
}
return &ev, err
}
func (s *currentRoomStateStatements) SelectSharedUsers(
ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string,
) ([]string, error) {
query := strings.Replace(selectSharedUsersSQL, "($2)", sqlutil.QueryVariadicOffset(len(otherUserIDs), 1), 1)
stmt, err := s.db.Prepare(query)
if err != nil {
return nil, fmt.Errorf("SelectSharedUsers s.db.Prepare: %w", err)
}
defer internal.CloseAndLogIfError(ctx, stmt, "SelectSharedUsers: stmt.close() failed")
rows, err := sqlutil.TxStmt(txn, stmt).QueryContext(ctx, userID, otherUserIDs)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectSharedUsersStmt: rows.close() failed")
var stateKey string
result := make([]string, 0, len(otherUserIDs))
for rows.Next() {
if err := rows.Scan(&stateKey); err != nil {
return nil, err
}
result = append(result, stateKey)
}
return result, rows.Err()
}

View file

@ -104,6 +104,8 @@ type CurrentRoomState interface {
SelectJoinedUsers(ctx context.Context) (map[string][]string, error)
// SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room.
SelectJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error)
// SelectSharedUsers returns a subset of otherUserIDs that share a room with userID.
SelectSharedUsers(ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string) ([]string, error)
}
// BackwardsExtremities keeps track of backwards extremities for a room.