User directory

This commit is contained in:
Neil Alexander 2020-07-27 15:29:48 +01:00
parent 83f038e12b
commit aad94989be
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
21 changed files with 374 additions and 8 deletions

View file

@ -16,7 +16,14 @@ package authtypes
// Profile represents the profile for a Matrix account. // Profile represents the profile for a Matrix account.
type Profile struct { type Profile struct {
Localpart string Localpart string `json:"local_part"`
DisplayName string DisplayName string `json:"display_name"`
AvatarURL string AvatarURL string `json:"avatar_url"`
}
// FullyQualifiedProfile represents the profile for a Matrix account.
type FullyQualifiedProfile struct {
UserID string `json:"user_id"`
DisplayName string `json:"display_name"`
AvatarURL string `json:"avatar_url"`
} }

View file

@ -574,6 +574,26 @@ func Setup(
}), }),
).Methods(http.MethodGet) ).Methods(http.MethodGet)
r0mux.Handle("/user_directory/search",
httputil.MakeAuthAPI("userdirectory_search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
postContent := struct {
SearchString string `json:"search_term"`
Limit int `json:"limit"`
}{}
if err := json.NewDecoder(req.Body).Decode(&postContent); err != nil {
return util.ErrorResponse(err)
}
return *SearchUserDirectory(
req.Context(),
userAPI,
stateAPI,
cfg.Matrix.ServerName,
postContent.SearchString,
postContent.Limit,
)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/members", r0mux.Handle("/rooms/{roomID}/members",
httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) vars, err := httputil.URLDecodeMapValues(mux.Vars(req))

View file

@ -0,0 +1,113 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"context"
"fmt"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
type UserDirectoryResponse struct {
Results []authtypes.FullyQualifiedProfile `json:"results"`
Limited bool `json:"limited"`
}
func SearchUserDirectory(
ctx context.Context,
userAPI userapi.UserInternalAPI,
stateAPI currentstateAPI.CurrentStateInternalAPI,
serverName gomatrixserverlib.ServerName,
searchString string,
limit int,
) *util.JSONResponse {
if limit < 10 {
limit = 10
}
results := map[string]authtypes.FullyQualifiedProfile{}
response := &UserDirectoryResponse{
Results: []authtypes.FullyQualifiedProfile{},
Limited: false,
}
// First start searching local users.
userReq := &userapi.QuerySearchProfilesRequest{
SearchString: searchString,
Limit: limit,
}
userRes := &userapi.QuerySearchProfilesResponse{}
if err := userAPI.QuerySearchProfiles(ctx, userReq, userRes); err != nil {
errRes := util.ErrorResponse(fmt.Errorf("userAPI.QuerySearchProfiles: %w", err))
return &errRes
}
for _, user := range userRes.Profiles {
if len(results) == limit {
response.Limited = true
break
}
userID := fmt.Sprintf("@%s:%s", user.Localpart, serverName)
if _, ok := results[userID]; !ok {
results[userID] = authtypes.FullyQualifiedProfile{
UserID: userID,
DisplayName: user.DisplayName,
AvatarURL: user.AvatarURL,
}
}
}
// Then, if we have enough room left in the response,
// start searching for known users from joined rooms.
if len(results) <= limit {
stateReq := &currentstateAPI.QueryKnownUsersRequest{
SearchString: searchString,
Limit: limit - len(results),
}
stateRes := &currentstateAPI.QueryKnownUsersResponse{}
if err := stateAPI.QueryKnownUsers(ctx, stateReq, stateRes); err != nil {
errRes := util.ErrorResponse(fmt.Errorf("stateAPI.QueryKnownUsers: %w", err))
return &errRes
}
for _, user := range stateRes.Users {
if len(results) == limit {
response.Limited = true
break
}
if _, ok := results[user.UserID]; !ok {
results[user.UserID] = user
}
}
}
for _, result := range results {
response.Results = append(response.Results, result)
}
return &util.JSONResponse{
Code: 200,
JSON: response,
}
}

View file

@ -20,6 +20,7 @@ import (
"fmt" "fmt"
"strings" "strings"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -33,6 +34,8 @@ type CurrentStateInternalAPI interface {
QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error
// QuerySharedUsers returns a list of users who share at least 1 room in common with the given user. // QuerySharedUsers returns a list of users who share at least 1 room in common with the given user.
QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error
// QueryKnownUsers returns a list of users that we know about from our joined rooms.
QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error
} }
type QuerySharedUsersRequest struct { type QuerySharedUsersRequest struct {
@ -88,6 +91,15 @@ type QueryCurrentStateResponse struct {
StateEvents map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent StateEvents map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent
} }
type QueryKnownUsersRequest struct {
SearchString string `json:"search_string"`
Limit int `json:"limit"`
}
type QueryKnownUsersResponse struct {
Users []authtypes.FullyQualifiedProfile `json:"profiles"`
}
// MarshalJSON stringifies the StateKeyTuple keys so they can be sent over the wire in HTTP API mode. // MarshalJSON stringifies the StateKeyTuple keys so they can be sent over the wire in HTTP API mode.
func (r *QueryCurrentStateResponse) MarshalJSON() ([]byte, error) { func (r *QueryCurrentStateResponse) MarshalJSON() ([]byte, error) {
se := make(map[string]*gomatrixserverlib.HeaderedEvent, len(r.StateEvents)) se := make(map[string]*gomatrixserverlib.HeaderedEvent, len(r.StateEvents))

View file

@ -17,6 +17,7 @@ package internal
import ( import (
"context" "context"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/currentstateserver/api" "github.com/matrix-org/dendrite/currentstateserver/api"
"github.com/matrix-org/dendrite/currentstateserver/storage" "github.com/matrix-org/dendrite/currentstateserver/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -49,6 +50,19 @@ func (a *CurrentStateInternalAPI) QueryRoomsForUser(ctx context.Context, req *ap
return nil return nil
} }
func (a *CurrentStateInternalAPI) QueryKnownUsers(ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse) error {
users, err := a.DB.GetKnownUsers(ctx, req.SearchString, req.Limit)
if err != nil {
return err
}
for _, user := range users {
res.Users = append(res.Users, authtypes.FullyQualifiedProfile{
UserID: user,
})
}
return nil
}
func (a *CurrentStateInternalAPI) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkStateContentRequest, res *api.QueryBulkStateContentResponse) error { func (a *CurrentStateInternalAPI) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkStateContentRequest, res *api.QueryBulkStateContentResponse) error {
events, err := a.DB.GetBulkStateContent(ctx, req.RoomIDs, req.StateTuples, req.AllowWildcards) events, err := a.DB.GetBulkStateContent(ctx, req.RoomIDs, req.StateTuples, req.AllowWildcards)
if err != nil { if err != nil {

View file

@ -30,6 +30,7 @@ const (
QueryRoomsForUserPath = "/currentstateserver/queryRoomsForUser" QueryRoomsForUserPath = "/currentstateserver/queryRoomsForUser"
QueryBulkStateContentPath = "/currentstateserver/queryBulkStateContent" QueryBulkStateContentPath = "/currentstateserver/queryBulkStateContent"
QuerySharedUsersPath = "/currentstateserver/querySharedUsers" QuerySharedUsersPath = "/currentstateserver/querySharedUsers"
QueryKnownUsersPath = "/currentstateserver/queryKnownUsers"
) )
// NewCurrentStateAPIClient creates a CurrentStateInternalAPI implemented by talking to a HTTP POST API. // NewCurrentStateAPIClient creates a CurrentStateInternalAPI implemented by talking to a HTTP POST API.
@ -97,3 +98,13 @@ func (h *httpCurrentStateInternalAPI) QuerySharedUsers(
apiURL := h.apiURL + QuerySharedUsersPath apiURL := h.apiURL + QuerySharedUsersPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
} }
func (h *httpCurrentStateInternalAPI) QueryKnownUsers(
ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKnownUsers")
defer span.Finish()
apiURL := h.apiURL + QueryKnownUsersPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}

View file

@ -77,4 +77,17 @@ func AddRoutes(internalAPIMux *mux.Router, intAPI api.CurrentStateInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )
internalAPIMux.Handle(QuerySharedUsersPath,
httputil.MakeInternalAPI("queryKnownUsers", func(req *http.Request) util.JSONResponse {
request := api.QueryKnownUsersRequest{}
response := api.QueryKnownUsersResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := intAPI.QueryKnownUsers(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
} }

View file

@ -39,4 +39,6 @@ type Database interface {
RedactEvent(ctx context.Context, redactedEventID string, redactedBecause gomatrixserverlib.HeaderedEvent) error RedactEvent(ctx context.Context, redactedEventID string, redactedBecause gomatrixserverlib.HeaderedEvent) error
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. // JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error)
// GetKnownUsers searches all users that we know about.
GetKnownUsers(ctx context.Context, searchString string, limit int) ([]string, error)
} }

View file

@ -18,6 +18,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/currentstateserver/storage/tables" "github.com/matrix-org/dendrite/currentstateserver/storage/tables"
@ -81,6 +82,9 @@ const selectJoinedUsersSetForRoomsSQL = "" +
"SELECT state_key, COUNT(room_id) FROM currentstate_current_room_state WHERE room_id = ANY($1) AND" + "SELECT state_key, COUNT(room_id) FROM currentstate_current_room_state WHERE room_id = ANY($1) AND" +
" type = 'm.room.member' and content_value = 'join' GROUP BY state_key" " type = 'm.room.member' and content_value = 'join' GROUP BY state_key"
const selectKnownUsersSQL = "" +
"SELECT DISTINCT state_key FROM currentstate_current_room_state WHERE type = 'm.room.member' AND state_key LIKE $1 LIMIT $2"
type currentRoomStateStatements struct { type currentRoomStateStatements struct {
upsertRoomStateStmt *sql.Stmt upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt
@ -90,6 +94,7 @@ type currentRoomStateStatements struct {
selectBulkStateContentStmt *sql.Stmt selectBulkStateContentStmt *sql.Stmt
selectBulkStateContentWildStmt *sql.Stmt selectBulkStateContentWildStmt *sql.Stmt
selectJoinedUsersSetForRoomsStmt *sql.Stmt selectJoinedUsersSetForRoomsStmt *sql.Stmt
selectKnownUsersStmt *sql.Stmt
} }
func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) { func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
@ -122,6 +127,9 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro
if s.selectJoinedUsersSetForRoomsStmt, err = db.Prepare(selectJoinedUsersSetForRoomsSQL); err != nil { if s.selectJoinedUsersSetForRoomsStmt, err = db.Prepare(selectJoinedUsersSetForRoomsSQL); err != nil {
return nil, err return nil, err
} }
if s.selectKnownUsersStmt, err = db.Prepare(selectKnownUsersSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
@ -295,3 +303,22 @@ func (s *currentRoomStateStatements) SelectBulkStateContent(
} }
return strippedEvents, rows.Err() return strippedEvents, rows.Err()
} }
func (s *currentRoomStateStatements) SelectKnownUsers(ctx context.Context, searchString string, limit int) ([]string, error) {
rows, err := s.selectKnownUsersStmt.QueryContext(ctx, fmt.Sprintf("%%%s%%", searchString), limit)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
result := []string{}
for rows.Next() {
var userID string
if err := rows.Scan(&userID); err != nil {
return nil, err
}
result = append(result, userID)
}
return result, rows.Err()
}

View file

@ -89,3 +89,7 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) { func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) {
return d.CurrentRoomState.SelectJoinedUsersSetForRooms(ctx, roomIDs) return d.CurrentRoomState.SelectJoinedUsersSetForRooms(ctx, roomIDs)
} }
func (d *Database) GetKnownUsers(ctx context.Context, searchString string, limit int) ([]string, error) {
return d.CurrentRoomState.SelectKnownUsers(ctx, searchString, limit)
}

View file

@ -18,6 +18,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt"
"strings" "strings"
"github.com/matrix-org/dendrite/currentstateserver/storage/tables" "github.com/matrix-org/dendrite/currentstateserver/storage/tables"
@ -69,6 +70,9 @@ const selectBulkStateContentWildSQL = "" +
const selectJoinedUsersSetForRoomsSQL = "" + const selectJoinedUsersSetForRoomsSQL = "" +
"SELECT state_key, COUNT(room_id) FROM currentstate_current_room_state WHERE room_id IN ($1) AND type = 'm.room.member' and content_value = 'join' GROUP BY state_key" "SELECT state_key, COUNT(room_id) FROM currentstate_current_room_state WHERE room_id IN ($1) AND type = 'm.room.member' and content_value = 'join' GROUP BY state_key"
const selectKnownUsersSQL = "" +
"SELECT DISTINCT state_key FROM currentstate_current_room_state WHERE type = 'm.room.member' AND state_key LIKE $1 LIMIT $2"
type currentRoomStateStatements struct { type currentRoomStateStatements struct {
db *sql.DB db *sql.DB
writer *sqlutil.TransactionWriter writer *sqlutil.TransactionWriter
@ -77,6 +81,7 @@ type currentRoomStateStatements struct {
selectRoomIDsWithMembershipStmt *sql.Stmt selectRoomIDsWithMembershipStmt *sql.Stmt
selectStateEventStmt *sql.Stmt selectStateEventStmt *sql.Stmt
selectJoinedUsersSetForRoomsStmt *sql.Stmt selectJoinedUsersSetForRoomsStmt *sql.Stmt
selectKnownUsersStmt *sql.Stmt
} }
func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) { func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
@ -103,6 +108,9 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error)
if s.selectJoinedUsersSetForRoomsStmt, err = db.Prepare(selectJoinedUsersSetForRoomsSQL); err != nil { if s.selectJoinedUsersSetForRoomsStmt, err = db.Prepare(selectJoinedUsersSetForRoomsSQL); err != nil {
return nil, err return nil, err
} }
if s.selectKnownUsersStmt, err = db.Prepare(selectKnownUsersSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
@ -315,3 +323,22 @@ func (s *currentRoomStateStatements) SelectBulkStateContent(
} }
return strippedEvents, rows.Err() return strippedEvents, rows.Err()
} }
func (s *currentRoomStateStatements) SelectKnownUsers(ctx context.Context, searchString string, limit int) ([]string, error) {
rows, err := s.selectKnownUsersStmt.QueryContext(ctx, fmt.Sprintf("%%%s%%", searchString), limit)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
result := []string{}
for rows.Next() {
var userID string
if err := rows.Scan(&userID); err != nil {
return nil, err
}
result = append(result, userID)
}
return result, rows.Err()
}

View file

@ -39,6 +39,8 @@ type CurrentRoomState interface {
// SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the // SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the
// counts of how many rooms they are joined. // counts of how many rooms they are joined.
SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) (map[string]int, error) SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) (map[string]int, error)
// SelectKnownUsers searches all users that we know about.
SelectKnownUsers(ctx context.Context, searchString string, limit int) ([]string, error)
} }
// StrippedEvent represents a stripped event for returning extracted content values. // StrippedEvent represents a stripped event for returning extracted content values.

View file

@ -18,6 +18,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -31,6 +32,7 @@ type UserInternalAPI interface {
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error
QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error
} }
// InputAccountDataRequest is the request for InputAccountData // InputAccountDataRequest is the request for InputAccountData
@ -112,6 +114,20 @@ type QueryProfileResponse struct {
AvatarURL string AvatarURL string
} }
// QuerySearchProfilesRequest is the request for QueryProfile
type QuerySearchProfilesRequest struct {
// The search string to match
SearchString string
// How many results to return
Limit int
}
// QuerySearchProfilesResponse is the response for QuerySearchProfilesRequest
type QuerySearchProfilesResponse struct {
// Profiles matching the search
Profiles []authtypes.Profile
}
// PerformAccountCreationRequest is the request for PerformAccountCreation // PerformAccountCreationRequest is the request for PerformAccountCreation
type PerformAccountCreationRequest struct { type PerformAccountCreationRequest struct {
AccountType AccountType // Required: whether this is a guest or user account AccountType AccountType // Required: whether this is a guest or user account

View file

@ -125,6 +125,15 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
return nil return nil
} }
func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error {
profiles, err := a.AccountDB.SearchProfiles(ctx, req.SearchString, req.Limit)
if err != nil {
return err
}
res.Profiles = profiles
return nil
}
func (a *UserInternalAPI) QueryDeviceInfos(ctx context.Context, req *api.QueryDeviceInfosRequest, res *api.QueryDeviceInfosResponse) error { func (a *UserInternalAPI) QueryDeviceInfos(ctx context.Context, req *api.QueryDeviceInfosRequest, res *api.QueryDeviceInfosResponse) error {
devices, err := a.DeviceDB.GetDevicesByID(ctx, req.DeviceIDs) devices, err := a.DeviceDB.GetDevicesByID(ctx, req.DeviceIDs)
if err != nil { if err != nil {

View file

@ -31,11 +31,12 @@ const (
PerformDeviceCreationPath = "/userapi/performDeviceCreation" PerformDeviceCreationPath = "/userapi/performDeviceCreation"
PerformAccountCreationPath = "/userapi/performAccountCreation" PerformAccountCreationPath = "/userapi/performAccountCreation"
QueryProfilePath = "/userapi/queryProfile" QueryProfilePath = "/userapi/queryProfile"
QueryAccessTokenPath = "/userapi/queryAccessToken" QueryAccessTokenPath = "/userapi/queryAccessToken"
QueryDevicesPath = "/userapi/queryDevices" QueryDevicesPath = "/userapi/queryDevices"
QueryAccountDataPath = "/userapi/queryAccountData" QueryAccountDataPath = "/userapi/queryAccountData"
QueryDeviceInfosPath = "/userapi/queryDeviceInfos" QueryDeviceInfosPath = "/userapi/queryDeviceInfos"
QuerySearchProfilesPath = "/userapi/querySearchProfiles"
) )
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. // NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
@ -141,3 +142,11 @@ func (h *httpUserInternalAPI) QueryAccountData(ctx context.Context, req *api.Que
apiURL := h.apiURL + QueryAccountDataPath apiURL := h.apiURL + QueryAccountDataPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
} }
func (h *httpUserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySearchProfiles")
defer span.Finish()
apiURL := h.apiURL + QuerySearchProfilesPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}

View file

@ -117,4 +117,17 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )
internalAPIMux.Handle(QueryDeviceInfosPath,
httputil.MakeInternalAPI("querySearchProfiles", func(req *http.Request) util.JSONResponse {
request := api.QuerySearchProfilesRequest{}
response := api.QuerySearchProfilesResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QuerySearchProfiles(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
} }

View file

@ -49,6 +49,7 @@ type Database interface {
GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error) GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
} }
// Err3PIDInUse is the error returned when trying to save an association involving // Err3PIDInUse is the error returned when trying to save an association involving

View file

@ -17,6 +17,7 @@ package postgres
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
) )
@ -45,11 +46,15 @@ const setAvatarURLSQL = "" +
const setDisplayNameSQL = "" + const setDisplayNameSQL = "" +
"UPDATE account_profiles SET display_name = $1 WHERE localpart = $2" "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
const selectProfilesBySearchSQL = "" +
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
type profilesStatements struct { type profilesStatements struct {
insertProfileStmt *sql.Stmt insertProfileStmt *sql.Stmt
selectProfileByLocalpartStmt *sql.Stmt selectProfileByLocalpartStmt *sql.Stmt
setAvatarURLStmt *sql.Stmt setAvatarURLStmt *sql.Stmt
setDisplayNameStmt *sql.Stmt setDisplayNameStmt *sql.Stmt
selectProfilesBySearchStmt *sql.Stmt
} }
func (s *profilesStatements) prepare(db *sql.DB) (err error) { func (s *profilesStatements) prepare(db *sql.DB) (err error) {
@ -69,6 +74,9 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
if s.setDisplayNameStmt, err = db.Prepare(setDisplayNameSQL); err != nil { if s.setDisplayNameStmt, err = db.Prepare(setDisplayNameSQL); err != nil {
return return
} }
if s.selectProfilesBySearchStmt, err = db.Prepare(selectProfilesBySearchSQL); err != nil {
return
}
return return
} }
@ -105,3 +113,21 @@ func (s *profilesStatements) setDisplayName(
_, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart) _, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart)
return return
} }
func (s *profilesStatements) selectProfilesBySearch(
ctx context.Context, searchString string, limit int,
) ([]authtypes.Profile, error) {
var profiles []authtypes.Profile
rows, err := s.selectProfilesBySearchStmt.QueryContext(ctx, fmt.Sprintf("%%%s%%", searchString), limit)
if err != nil {
return nil, err
}
for rows.Next() {
var profile authtypes.Profile
if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil {
return nil, err
}
profiles = append(profiles, profile)
}
return profiles, nil
}

View file

@ -298,3 +298,10 @@ func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
) (*api.Account, error) { ) (*api.Account, error) {
return d.accounts.selectAccountByLocalpart(ctx, localpart) return d.accounts.selectAccountByLocalpart(ctx, localpart)
} }
// SearchProfiles returns all profiles where the provided localpart or display name
// match any part of the profiles in the database.
func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int,
) ([]authtypes.Profile, error) {
return d.profiles.selectProfilesBySearch(ctx, searchString, limit)
}

View file

@ -17,6 +17,7 @@ package sqlite3
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
@ -46,6 +47,9 @@ const setAvatarURLSQL = "" +
const setDisplayNameSQL = "" + const setDisplayNameSQL = "" +
"UPDATE account_profiles SET display_name = $1 WHERE localpart = $2" "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
const selectProfilesBySearchSQL = "" +
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
type profilesStatements struct { type profilesStatements struct {
db *sql.DB db *sql.DB
writer *sqlutil.TransactionWriter writer *sqlutil.TransactionWriter
@ -53,6 +57,7 @@ type profilesStatements struct {
selectProfileByLocalpartStmt *sql.Stmt selectProfileByLocalpartStmt *sql.Stmt
setAvatarURLStmt *sql.Stmt setAvatarURLStmt *sql.Stmt
setDisplayNameStmt *sql.Stmt setDisplayNameStmt *sql.Stmt
selectProfilesBySearchStmt *sql.Stmt
} }
func (s *profilesStatements) prepare(db *sql.DB) (err error) { func (s *profilesStatements) prepare(db *sql.DB) (err error) {
@ -74,6 +79,9 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
if s.setDisplayNameStmt, err = db.Prepare(setDisplayNameSQL); err != nil { if s.setDisplayNameStmt, err = db.Prepare(setDisplayNameSQL); err != nil {
return return
} }
if s.selectProfilesBySearchStmt, err = db.Prepare(selectProfilesBySearchSQL); err != nil {
return
}
return return
} }
@ -112,3 +120,21 @@ func (s *profilesStatements) setDisplayName(
_, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart) _, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart)
return return
} }
func (s *profilesStatements) selectProfilesBySearch(
ctx context.Context, searchString string, limit int,
) ([]authtypes.Profile, error) {
var profiles []authtypes.Profile
rows, err := s.selectProfilesBySearchStmt.QueryContext(ctx, fmt.Sprintf("%%%s%%", searchString), limit)
if err != nil {
return nil, err
}
for rows.Next() {
var profile authtypes.Profile
if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil {
return nil, err
}
profiles = append(profiles, profile)
}
return profiles, nil
}

View file

@ -343,3 +343,10 @@ func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
) (*api.Account, error) { ) (*api.Account, error) {
return d.accounts.selectAccountByLocalpart(ctx, localpart) return d.accounts.selectAccountByLocalpart(ctx, localpart)
} }
// SearchProfiles returns all profiles where the provided localpart or display name
// match any part of the profiles in the database.
func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int,
) ([]authtypes.Profile, error) {
return d.profiles.selectProfilesBySearch(ctx, searchString, limit)
}