diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index e7b9a8500..39eee8210 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -585,6 +585,7 @@ func Setup( } return *SearchUserDirectory( req.Context(), + device, userAPI, stateAPI, cfg.Matrix.ServerName, diff --git a/clientapi/routing/userdirectory.go b/clientapi/routing/userdirectory.go index 774b0e96e..db81ffeae 100644 --- a/clientapi/routing/userdirectory.go +++ b/clientapi/routing/userdirectory.go @@ -32,6 +32,7 @@ type UserDirectoryResponse struct { func SearchUserDirectory( ctx context.Context, + device *userapi.Device, userAPI userapi.UserInternalAPI, stateAPI currentstateAPI.CurrentStateInternalAPI, serverName gomatrixserverlib.ServerName, @@ -81,6 +82,7 @@ func SearchUserDirectory( if len(results) <= limit { stateReq := ¤tstateAPI.QueryKnownUsersRequest{ + UserID: device.UserID, SearchString: searchString, Limit: limit - len(results), } diff --git a/currentstateserver/api/api.go b/currentstateserver/api/api.go index c4f4d8357..4ebe29683 100644 --- a/currentstateserver/api/api.go +++ b/currentstateserver/api/api.go @@ -92,6 +92,7 @@ type QueryCurrentStateResponse struct { } type QueryKnownUsersRequest struct { + UserID string `json:"user_id"` SearchString string `json:"search_string"` Limit int `json:"limit"` } diff --git a/currentstateserver/internal/api.go b/currentstateserver/internal/api.go index ff4093034..dc2554121 100644 --- a/currentstateserver/internal/api.go +++ b/currentstateserver/internal/api.go @@ -51,7 +51,7 @@ func (a *CurrentStateInternalAPI) QueryRoomsForUser(ctx context.Context, req *ap } func (a *CurrentStateInternalAPI) QueryKnownUsers(ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse) error { - users, err := a.DB.GetKnownUsers(ctx, req.SearchString, req.Limit) + users, err := a.DB.GetKnownUsers(ctx, req.UserID, req.SearchString, req.Limit) if err != nil { return err } diff --git a/currentstateserver/storage/interface.go b/currentstateserver/storage/interface.go index 81b73ee40..5a754b9ea 100644 --- a/currentstateserver/storage/interface.go +++ b/currentstateserver/storage/interface.go @@ -39,6 +39,6 @@ type Database interface { RedactEvent(ctx context.Context, redactedEventID string, redactedBecause gomatrixserverlib.HeaderedEvent) error // JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) - // GetKnownUsers searches all users that we know about. - GetKnownUsers(ctx context.Context, searchString string, limit int) ([]string, error) + // GetKnownUsers searches all users that userID knows about. + GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) } diff --git a/currentstateserver/storage/postgres/current_room_state_table.go b/currentstateserver/storage/postgres/current_room_state_table.go index e95f96119..c5414e449 100644 --- a/currentstateserver/storage/postgres/current_room_state_table.go +++ b/currentstateserver/storage/postgres/current_room_state_table.go @@ -83,7 +83,9 @@ const selectJoinedUsersSetForRoomsSQL = "" + " type = 'm.room.member' and content_value = 'join' GROUP BY state_key" const selectKnownUsersSQL = "" + - "SELECT DISTINCT state_key FROM currentstate_current_room_state WHERE type = 'm.room.member' AND state_key LIKE $1 LIMIT $2" + "SELECT DISTINCT state_key FROM currentstate_current_room_state WHERE room_id = ANY(" + + " SELECT DISTINCT room_id FROM currentstate_current_room_state WHERE state_key=$1 AND TYPE='m.room.member' AND content_value='join'" + + ") AND TYPE='m.room.member' AND content_value='join' AND state_key LIKE $2 LIMIT $3" type currentRoomStateStatements struct { upsertRoomStateStmt *sql.Stmt @@ -304,8 +306,8 @@ func (s *currentRoomStateStatements) SelectBulkStateContent( return strippedEvents, rows.Err() } -func (s *currentRoomStateStatements) SelectKnownUsers(ctx context.Context, searchString string, limit int) ([]string, error) { - rows, err := s.selectKnownUsersStmt.QueryContext(ctx, fmt.Sprintf("%%%s%%", searchString), limit) +func (s *currentRoomStateStatements) SelectKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) { + rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) if err == sql.ErrNoRows { return nil, nil } diff --git a/currentstateserver/storage/shared/storage.go b/currentstateserver/storage/shared/storage.go index 40cc94549..bd4329a7d 100644 --- a/currentstateserver/storage/shared/storage.go +++ b/currentstateserver/storage/shared/storage.go @@ -90,6 +90,6 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) return d.CurrentRoomState.SelectJoinedUsersSetForRooms(ctx, roomIDs) } -func (d *Database) GetKnownUsers(ctx context.Context, searchString string, limit int) ([]string, error) { - return d.CurrentRoomState.SelectKnownUsers(ctx, searchString, limit) +func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) { + return d.CurrentRoomState.SelectKnownUsers(ctx, userID, searchString, limit) } diff --git a/currentstateserver/storage/sqlite3/current_room_state_table.go b/currentstateserver/storage/sqlite3/current_room_state_table.go index d3bc86dde..418b4079b 100644 --- a/currentstateserver/storage/sqlite3/current_room_state_table.go +++ b/currentstateserver/storage/sqlite3/current_room_state_table.go @@ -71,7 +71,9 @@ const selectJoinedUsersSetForRoomsSQL = "" + "SELECT state_key, COUNT(room_id) FROM currentstate_current_room_state WHERE room_id IN ($1) AND type = 'm.room.member' and content_value = 'join' GROUP BY state_key" const selectKnownUsersSQL = "" + - "SELECT DISTINCT state_key FROM currentstate_current_room_state WHERE type = 'm.room.member' AND state_key LIKE $1 LIMIT $2" + "SELECT DISTINCT state_key FROM currentstate_current_room_state WHERE room_id IN (" + + " SELECT DISTINCT room_id FROM currentstate_current_room_state WHERE state_key=$1 AND TYPE='m.room.member' AND content_value='join'" + + ") AND TYPE='m.room.member' AND content_value='join' AND state_key LIKE $2 LIMIT $3" type currentRoomStateStatements struct { db *sql.DB @@ -324,8 +326,8 @@ func (s *currentRoomStateStatements) SelectBulkStateContent( return strippedEvents, rows.Err() } -func (s *currentRoomStateStatements) SelectKnownUsers(ctx context.Context, searchString string, limit int) ([]string, error) { - rows, err := s.selectKnownUsersStmt.QueryContext(ctx, fmt.Sprintf("%%%s%%", searchString), limit) +func (s *currentRoomStateStatements) SelectKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) { + rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) if err == sql.ErrNoRows { return nil, nil } diff --git a/currentstateserver/storage/tables/interface.go b/currentstateserver/storage/tables/interface.go index 817ee3885..6290e7b3d 100644 --- a/currentstateserver/storage/tables/interface.go +++ b/currentstateserver/storage/tables/interface.go @@ -39,8 +39,8 @@ type CurrentRoomState interface { // SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the // counts of how many rooms they are joined. SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) (map[string]int, error) - // SelectKnownUsers searches all users that we know about. - SelectKnownUsers(ctx context.Context, searchString string, limit int) ([]string, error) + // SelectKnownUsers searches all users that userID knows about. + SelectKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) } // StrippedEvent represents a stripped event for returning extracted content values.