From aad94989becb52ef0890a7dfd2b0e895b4bd3abe Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 27 Jul 2020 15:29:48 +0100 Subject: [PATCH] User directory --- clientapi/auth/authtypes/profile.go | 13 +- clientapi/routing/routing.go | 20 ++++ clientapi/routing/userdirectory.go | 113 ++++++++++++++++++ currentstateserver/api/api.go | 12 ++ currentstateserver/internal/api.go | 14 +++ currentstateserver/inthttp/client.go | 11 ++ currentstateserver/inthttp/server.go | 13 ++ currentstateserver/storage/interface.go | 2 + .../postgres/current_room_state_table.go | 27 +++++ currentstateserver/storage/shared/storage.go | 4 + .../sqlite3/current_room_state_table.go | 27 +++++ .../storage/tables/interface.go | 2 + userapi/api/api.go | 16 +++ userapi/internal/api.go | 9 ++ userapi/inthttp/client.go | 19 ++- userapi/inthttp/server.go | 13 ++ userapi/storage/accounts/interface.go | 1 + .../accounts/postgres/profile_table.go | 26 ++++ userapi/storage/accounts/postgres/storage.go | 7 ++ .../storage/accounts/sqlite3/profile_table.go | 26 ++++ userapi/storage/accounts/sqlite3/storage.go | 7 ++ 21 files changed, 374 insertions(+), 8 deletions(-) create mode 100644 clientapi/routing/userdirectory.go diff --git a/clientapi/auth/authtypes/profile.go b/clientapi/auth/authtypes/profile.go index 0bc49658b..902850bc0 100644 --- a/clientapi/auth/authtypes/profile.go +++ b/clientapi/auth/authtypes/profile.go @@ -16,7 +16,14 @@ package authtypes // Profile represents the profile for a Matrix account. type Profile struct { - Localpart string - DisplayName string - AvatarURL string + Localpart string `json:"local_part"` + DisplayName string `json:"display_name"` + AvatarURL string `json:"avatar_url"` +} + +// FullyQualifiedProfile represents the profile for a Matrix account. +type FullyQualifiedProfile struct { + UserID string `json:"user_id"` + DisplayName string `json:"display_name"` + AvatarURL string `json:"avatar_url"` } diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 311f64d1b..e7b9a8500 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -574,6 +574,26 @@ func Setup( }), ).Methods(http.MethodGet) + r0mux.Handle("/user_directory/search", + httputil.MakeAuthAPI("userdirectory_search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + postContent := struct { + SearchString string `json:"search_term"` + Limit int `json:"limit"` + }{} + if err := json.NewDecoder(req.Body).Decode(&postContent); err != nil { + return util.ErrorResponse(err) + } + return *SearchUserDirectory( + req.Context(), + userAPI, + stateAPI, + cfg.Matrix.ServerName, + postContent.SearchString, + postContent.Limit, + ) + }), + ).Methods(http.MethodPost, http.MethodOptions) + r0mux.Handle("/rooms/{roomID}/members", httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) diff --git a/clientapi/routing/userdirectory.go b/clientapi/routing/userdirectory.go new file mode 100644 index 000000000..774b0e96e --- /dev/null +++ b/clientapi/routing/userdirectory.go @@ -0,0 +1,113 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "context" + "fmt" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +type UserDirectoryResponse struct { + Results []authtypes.FullyQualifiedProfile `json:"results"` + Limited bool `json:"limited"` +} + +func SearchUserDirectory( + ctx context.Context, + userAPI userapi.UserInternalAPI, + stateAPI currentstateAPI.CurrentStateInternalAPI, + serverName gomatrixserverlib.ServerName, + searchString string, + limit int, +) *util.JSONResponse { + if limit < 10 { + limit = 10 + } + + results := map[string]authtypes.FullyQualifiedProfile{} + response := &UserDirectoryResponse{ + Results: []authtypes.FullyQualifiedProfile{}, + Limited: false, + } + + // First start searching local users. + + userReq := &userapi.QuerySearchProfilesRequest{ + SearchString: searchString, + Limit: limit, + } + userRes := &userapi.QuerySearchProfilesResponse{} + if err := userAPI.QuerySearchProfiles(ctx, userReq, userRes); err != nil { + errRes := util.ErrorResponse(fmt.Errorf("userAPI.QuerySearchProfiles: %w", err)) + return &errRes + } + + for _, user := range userRes.Profiles { + if len(results) == limit { + response.Limited = true + break + } + + userID := fmt.Sprintf("@%s:%s", user.Localpart, serverName) + if _, ok := results[userID]; !ok { + results[userID] = authtypes.FullyQualifiedProfile{ + UserID: userID, + DisplayName: user.DisplayName, + AvatarURL: user.AvatarURL, + } + } + } + + // Then, if we have enough room left in the response, + // start searching for known users from joined rooms. + + if len(results) <= limit { + stateReq := ¤tstateAPI.QueryKnownUsersRequest{ + SearchString: searchString, + Limit: limit - len(results), + } + stateRes := ¤tstateAPI.QueryKnownUsersResponse{} + if err := stateAPI.QueryKnownUsers(ctx, stateReq, stateRes); err != nil { + errRes := util.ErrorResponse(fmt.Errorf("stateAPI.QueryKnownUsers: %w", err)) + return &errRes + } + + for _, user := range stateRes.Users { + if len(results) == limit { + response.Limited = true + break + } + + if _, ok := results[user.UserID]; !ok { + results[user.UserID] = user + } + } + } + + for _, result := range results { + response.Results = append(response.Results, result) + } + + return &util.JSONResponse{ + Code: 200, + JSON: response, + } +} diff --git a/currentstateserver/api/api.go b/currentstateserver/api/api.go index b778acb21..c4f4d8357 100644 --- a/currentstateserver/api/api.go +++ b/currentstateserver/api/api.go @@ -20,6 +20,7 @@ import ( "fmt" "strings" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/gomatrixserverlib" ) @@ -33,6 +34,8 @@ type CurrentStateInternalAPI interface { QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error // QuerySharedUsers returns a list of users who share at least 1 room in common with the given user. QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error + // QueryKnownUsers returns a list of users that we know about from our joined rooms. + QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error } type QuerySharedUsersRequest struct { @@ -88,6 +91,15 @@ type QueryCurrentStateResponse struct { StateEvents map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent } +type QueryKnownUsersRequest struct { + SearchString string `json:"search_string"` + Limit int `json:"limit"` +} + +type QueryKnownUsersResponse struct { + Users []authtypes.FullyQualifiedProfile `json:"profiles"` +} + // MarshalJSON stringifies the StateKeyTuple keys so they can be sent over the wire in HTTP API mode. func (r *QueryCurrentStateResponse) MarshalJSON() ([]byte, error) { se := make(map[string]*gomatrixserverlib.HeaderedEvent, len(r.StateEvents)) diff --git a/currentstateserver/internal/api.go b/currentstateserver/internal/api.go index c581c524c..ff4093034 100644 --- a/currentstateserver/internal/api.go +++ b/currentstateserver/internal/api.go @@ -17,6 +17,7 @@ package internal import ( "context" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/currentstateserver/api" "github.com/matrix-org/dendrite/currentstateserver/storage" "github.com/matrix-org/gomatrixserverlib" @@ -49,6 +50,19 @@ func (a *CurrentStateInternalAPI) QueryRoomsForUser(ctx context.Context, req *ap return nil } +func (a *CurrentStateInternalAPI) QueryKnownUsers(ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse) error { + users, err := a.DB.GetKnownUsers(ctx, req.SearchString, req.Limit) + if err != nil { + return err + } + for _, user := range users { + res.Users = append(res.Users, authtypes.FullyQualifiedProfile{ + UserID: user, + }) + } + return nil +} + func (a *CurrentStateInternalAPI) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkStateContentRequest, res *api.QueryBulkStateContentResponse) error { events, err := a.DB.GetBulkStateContent(ctx, req.RoomIDs, req.StateTuples, req.AllowWildcards) if err != nil { diff --git a/currentstateserver/inthttp/client.go b/currentstateserver/inthttp/client.go index cce881fff..37d289eaf 100644 --- a/currentstateserver/inthttp/client.go +++ b/currentstateserver/inthttp/client.go @@ -30,6 +30,7 @@ const ( QueryRoomsForUserPath = "/currentstateserver/queryRoomsForUser" QueryBulkStateContentPath = "/currentstateserver/queryBulkStateContent" QuerySharedUsersPath = "/currentstateserver/querySharedUsers" + QueryKnownUsersPath = "/currentstateserver/queryKnownUsers" ) // NewCurrentStateAPIClient creates a CurrentStateInternalAPI implemented by talking to a HTTP POST API. @@ -97,3 +98,13 @@ func (h *httpCurrentStateInternalAPI) QuerySharedUsers( apiURL := h.apiURL + QuerySharedUsersPath return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) } + +func (h *httpCurrentStateInternalAPI) QueryKnownUsers( + ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKnownUsers") + defer span.Finish() + + apiURL := h.apiURL + QueryKnownUsersPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} diff --git a/currentstateserver/inthttp/server.go b/currentstateserver/inthttp/server.go index f4e93dcdf..aee900e06 100644 --- a/currentstateserver/inthttp/server.go +++ b/currentstateserver/inthttp/server.go @@ -77,4 +77,17 @@ func AddRoutes(internalAPIMux *mux.Router, intAPI api.CurrentStateInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(QuerySharedUsersPath, + httputil.MakeInternalAPI("queryKnownUsers", func(req *http.Request) util.JSONResponse { + request := api.QueryKnownUsersRequest{} + response := api.QueryKnownUsersResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := intAPI.QueryKnownUsers(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) } diff --git a/currentstateserver/storage/interface.go b/currentstateserver/storage/interface.go index 8deaa3484..81b73ee40 100644 --- a/currentstateserver/storage/interface.go +++ b/currentstateserver/storage/interface.go @@ -39,4 +39,6 @@ type Database interface { RedactEvent(ctx context.Context, redactedEventID string, redactedBecause gomatrixserverlib.HeaderedEvent) error // JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) + // GetKnownUsers searches all users that we know about. + GetKnownUsers(ctx context.Context, searchString string, limit int) ([]string, error) } diff --git a/currentstateserver/storage/postgres/current_room_state_table.go b/currentstateserver/storage/postgres/current_room_state_table.go index 294f757cf..e95f96119 100644 --- a/currentstateserver/storage/postgres/current_room_state_table.go +++ b/currentstateserver/storage/postgres/current_room_state_table.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "encoding/json" + "fmt" "github.com/lib/pq" "github.com/matrix-org/dendrite/currentstateserver/storage/tables" @@ -81,6 +82,9 @@ const selectJoinedUsersSetForRoomsSQL = "" + "SELECT state_key, COUNT(room_id) FROM currentstate_current_room_state WHERE room_id = ANY($1) AND" + " type = 'm.room.member' and content_value = 'join' GROUP BY state_key" +const selectKnownUsersSQL = "" + + "SELECT DISTINCT state_key FROM currentstate_current_room_state WHERE type = 'm.room.member' AND state_key LIKE $1 LIMIT $2" + type currentRoomStateStatements struct { upsertRoomStateStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt @@ -90,6 +94,7 @@ type currentRoomStateStatements struct { selectBulkStateContentStmt *sql.Stmt selectBulkStateContentWildStmt *sql.Stmt selectJoinedUsersSetForRoomsStmt *sql.Stmt + selectKnownUsersStmt *sql.Stmt } func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) { @@ -122,6 +127,9 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro if s.selectJoinedUsersSetForRoomsStmt, err = db.Prepare(selectJoinedUsersSetForRoomsSQL); err != nil { return nil, err } + if s.selectKnownUsersStmt, err = db.Prepare(selectKnownUsersSQL); err != nil { + return nil, err + } return s, nil } @@ -295,3 +303,22 @@ func (s *currentRoomStateStatements) SelectBulkStateContent( } return strippedEvents, rows.Err() } + +func (s *currentRoomStateStatements) SelectKnownUsers(ctx context.Context, searchString string, limit int) ([]string, error) { + rows, err := s.selectKnownUsersStmt.QueryContext(ctx, fmt.Sprintf("%%%s%%", searchString), limit) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, err + } + result := []string{} + for rows.Next() { + var userID string + if err := rows.Scan(&userID); err != nil { + return nil, err + } + result = append(result, userID) + } + return result, rows.Err() +} diff --git a/currentstateserver/storage/shared/storage.go b/currentstateserver/storage/shared/storage.go index dac38790d..40cc94549 100644 --- a/currentstateserver/storage/shared/storage.go +++ b/currentstateserver/storage/shared/storage.go @@ -89,3 +89,7 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) { return d.CurrentRoomState.SelectJoinedUsersSetForRooms(ctx, roomIDs) } + +func (d *Database) GetKnownUsers(ctx context.Context, searchString string, limit int) ([]string, error) { + return d.CurrentRoomState.SelectKnownUsers(ctx, searchString, limit) +} diff --git a/currentstateserver/storage/sqlite3/current_room_state_table.go b/currentstateserver/storage/sqlite3/current_room_state_table.go index 5706fa35c..d3bc86dde 100644 --- a/currentstateserver/storage/sqlite3/current_room_state_table.go +++ b/currentstateserver/storage/sqlite3/current_room_state_table.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "encoding/json" + "fmt" "strings" "github.com/matrix-org/dendrite/currentstateserver/storage/tables" @@ -69,6 +70,9 @@ const selectBulkStateContentWildSQL = "" + const selectJoinedUsersSetForRoomsSQL = "" + "SELECT state_key, COUNT(room_id) FROM currentstate_current_room_state WHERE room_id IN ($1) AND type = 'm.room.member' and content_value = 'join' GROUP BY state_key" +const selectKnownUsersSQL = "" + + "SELECT DISTINCT state_key FROM currentstate_current_room_state WHERE type = 'm.room.member' AND state_key LIKE $1 LIMIT $2" + type currentRoomStateStatements struct { db *sql.DB writer *sqlutil.TransactionWriter @@ -77,6 +81,7 @@ type currentRoomStateStatements struct { selectRoomIDsWithMembershipStmt *sql.Stmt selectStateEventStmt *sql.Stmt selectJoinedUsersSetForRoomsStmt *sql.Stmt + selectKnownUsersStmt *sql.Stmt } func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) { @@ -103,6 +108,9 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) if s.selectJoinedUsersSetForRoomsStmt, err = db.Prepare(selectJoinedUsersSetForRoomsSQL); err != nil { return nil, err } + if s.selectKnownUsersStmt, err = db.Prepare(selectKnownUsersSQL); err != nil { + return nil, err + } return s, nil } @@ -315,3 +323,22 @@ func (s *currentRoomStateStatements) SelectBulkStateContent( } return strippedEvents, rows.Err() } + +func (s *currentRoomStateStatements) SelectKnownUsers(ctx context.Context, searchString string, limit int) ([]string, error) { + rows, err := s.selectKnownUsersStmt.QueryContext(ctx, fmt.Sprintf("%%%s%%", searchString), limit) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, err + } + result := []string{} + for rows.Next() { + var userID string + if err := rows.Scan(&userID); err != nil { + return nil, err + } + result = append(result, userID) + } + return result, rows.Err() +} diff --git a/currentstateserver/storage/tables/interface.go b/currentstateserver/storage/tables/interface.go index 121bf4fdf..817ee3885 100644 --- a/currentstateserver/storage/tables/interface.go +++ b/currentstateserver/storage/tables/interface.go @@ -39,6 +39,8 @@ type CurrentRoomState interface { // SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the // counts of how many rooms they are joined. SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) (map[string]int, error) + // SelectKnownUsers searches all users that we know about. + SelectKnownUsers(ctx context.Context, searchString string, limit int) ([]string, error) } // StrippedEvent represents a stripped event for returning extracted content values. diff --git a/userapi/api/api.go b/userapi/api/api.go index bd0773f87..5791403ff 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -18,6 +18,7 @@ import ( "context" "encoding/json" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/gomatrixserverlib" ) @@ -31,6 +32,7 @@ type UserInternalAPI interface { QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error + QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error } // InputAccountDataRequest is the request for InputAccountData @@ -112,6 +114,20 @@ type QueryProfileResponse struct { AvatarURL string } +// QuerySearchProfilesRequest is the request for QueryProfile +type QuerySearchProfilesRequest struct { + // The search string to match + SearchString string + // How many results to return + Limit int +} + +// QuerySearchProfilesResponse is the response for QuerySearchProfilesRequest +type QuerySearchProfilesResponse struct { + // Profiles matching the search + Profiles []authtypes.Profile +} + // PerformAccountCreationRequest is the request for PerformAccountCreation type PerformAccountCreationRequest struct { AccountType AccountType // Required: whether this is a guest or user account diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 2de8f9607..5b1541967 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -125,6 +125,15 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil return nil } +func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error { + profiles, err := a.AccountDB.SearchProfiles(ctx, req.SearchString, req.Limit) + if err != nil { + return err + } + res.Profiles = profiles + return nil +} + func (a *UserInternalAPI) QueryDeviceInfos(ctx context.Context, req *api.QueryDeviceInfosRequest, res *api.QueryDeviceInfosResponse) error { devices, err := a.DeviceDB.GetDevicesByID(ctx, req.DeviceIDs) if err != nil { diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index b2b42823f..3e1ac0662 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -31,11 +31,12 @@ const ( PerformDeviceCreationPath = "/userapi/performDeviceCreation" PerformAccountCreationPath = "/userapi/performAccountCreation" - QueryProfilePath = "/userapi/queryProfile" - QueryAccessTokenPath = "/userapi/queryAccessToken" - QueryDevicesPath = "/userapi/queryDevices" - QueryAccountDataPath = "/userapi/queryAccountData" - QueryDeviceInfosPath = "/userapi/queryDeviceInfos" + QueryProfilePath = "/userapi/queryProfile" + QueryAccessTokenPath = "/userapi/queryAccessToken" + QueryDevicesPath = "/userapi/queryDevices" + QueryAccountDataPath = "/userapi/queryAccountData" + QueryDeviceInfosPath = "/userapi/queryDeviceInfos" + QuerySearchProfilesPath = "/userapi/querySearchProfiles" ) // NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. @@ -141,3 +142,11 @@ func (h *httpUserInternalAPI) QueryAccountData(ctx context.Context, req *api.Que apiURL := h.apiURL + QueryAccountDataPath return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) } + +func (h *httpUserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySearchProfiles") + defer span.Finish() + + apiURL := h.apiURL + QuerySearchProfilesPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index d8e151ad4..d29f4d442 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -117,4 +117,17 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(QueryDeviceInfosPath, + httputil.MakeInternalAPI("querySearchProfiles", func(req *http.Request) util.JSONResponse { + request := api.QuerySearchProfilesRequest{} + response := api.QuerySearchProfilesResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.QuerySearchProfiles(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) } diff --git a/userapi/storage/accounts/interface.go b/userapi/storage/accounts/interface.go index 6f6caf111..86b91e603 100644 --- a/userapi/storage/accounts/interface.go +++ b/userapi/storage/accounts/interface.go @@ -49,6 +49,7 @@ type Database interface { GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) + SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) } // Err3PIDInUse is the error returned when trying to save an association involving diff --git a/userapi/storage/accounts/postgres/profile_table.go b/userapi/storage/accounts/postgres/profile_table.go index d2cbeb8e6..a26221f9e 100644 --- a/userapi/storage/accounts/postgres/profile_table.go +++ b/userapi/storage/accounts/postgres/profile_table.go @@ -17,6 +17,7 @@ package postgres import ( "context" "database/sql" + "fmt" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) @@ -45,11 +46,15 @@ const setAvatarURLSQL = "" + const setDisplayNameSQL = "" + "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2" +const selectProfilesBySearchSQL = "" + + "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" + type profilesStatements struct { insertProfileStmt *sql.Stmt selectProfileByLocalpartStmt *sql.Stmt setAvatarURLStmt *sql.Stmt setDisplayNameStmt *sql.Stmt + selectProfilesBySearchStmt *sql.Stmt } func (s *profilesStatements) prepare(db *sql.DB) (err error) { @@ -69,6 +74,9 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) { if s.setDisplayNameStmt, err = db.Prepare(setDisplayNameSQL); err != nil { return } + if s.selectProfilesBySearchStmt, err = db.Prepare(selectProfilesBySearchSQL); err != nil { + return + } return } @@ -105,3 +113,21 @@ func (s *profilesStatements) setDisplayName( _, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart) return } + +func (s *profilesStatements) selectProfilesBySearch( + ctx context.Context, searchString string, limit int, +) ([]authtypes.Profile, error) { + var profiles []authtypes.Profile + rows, err := s.selectProfilesBySearchStmt.QueryContext(ctx, fmt.Sprintf("%%%s%%", searchString), limit) + if err != nil { + return nil, err + } + for rows.Next() { + var profile authtypes.Profile + if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil { + return nil, err + } + profiles = append(profiles, profile) + } + return profiles, nil +} diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go index c76b92f10..f56fb6d81 100644 --- a/userapi/storage/accounts/postgres/storage.go +++ b/userapi/storage/accounts/postgres/storage.go @@ -298,3 +298,10 @@ func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, ) (*api.Account, error) { return d.accounts.selectAccountByLocalpart(ctx, localpart) } + +// SearchProfiles returns all profiles where the provided localpart or display name +// match any part of the profiles in the database. +func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int, +) ([]authtypes.Profile, error) { + return d.profiles.selectProfilesBySearch(ctx, searchString, limit) +} diff --git a/userapi/storage/accounts/sqlite3/profile_table.go b/userapi/storage/accounts/sqlite3/profile_table.go index 68cea516d..b7576e688 100644 --- a/userapi/storage/accounts/sqlite3/profile_table.go +++ b/userapi/storage/accounts/sqlite3/profile_table.go @@ -17,6 +17,7 @@ package sqlite3 import ( "context" "database/sql" + "fmt" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -46,6 +47,9 @@ const setAvatarURLSQL = "" + const setDisplayNameSQL = "" + "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2" +const selectProfilesBySearchSQL = "" + + "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" + type profilesStatements struct { db *sql.DB writer *sqlutil.TransactionWriter @@ -53,6 +57,7 @@ type profilesStatements struct { selectProfileByLocalpartStmt *sql.Stmt setAvatarURLStmt *sql.Stmt setDisplayNameStmt *sql.Stmt + selectProfilesBySearchStmt *sql.Stmt } func (s *profilesStatements) prepare(db *sql.DB) (err error) { @@ -74,6 +79,9 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) { if s.setDisplayNameStmt, err = db.Prepare(setDisplayNameSQL); err != nil { return } + if s.selectProfilesBySearchStmt, err = db.Prepare(selectProfilesBySearchSQL); err != nil { + return + } return } @@ -112,3 +120,21 @@ func (s *profilesStatements) setDisplayName( _, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart) return } + +func (s *profilesStatements) selectProfilesBySearch( + ctx context.Context, searchString string, limit int, +) ([]authtypes.Profile, error) { + var profiles []authtypes.Profile + rows, err := s.selectProfilesBySearchStmt.QueryContext(ctx, fmt.Sprintf("%%%s%%", searchString), limit) + if err != nil { + return nil, err + } + for rows.Next() { + var profile authtypes.Profile + if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil { + return nil, err + } + profiles = append(profiles, profile) + } + return profiles, nil +} diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index 2d09090fc..722390148 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -343,3 +343,10 @@ func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, ) (*api.Account, error) { return d.accounts.selectAccountByLocalpart(ctx, localpart) } + +// SearchProfiles returns all profiles where the provided localpart or display name +// match any part of the profiles in the database. +func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int, +) ([]authtypes.Profile, error) { + return d.profiles.selectProfilesBySearch(ctx, searchString, limit) +}