mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-16 18:43:10 -06:00
Return the new profile when updating display_name or avatar_url
This commit is contained in:
parent
9041491201
commit
ce9b9cf87f
|
|
@ -19,6 +19,8 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
|
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
"github.com/matrix-org/dendrite/clientapi/httputil"
|
"github.com/matrix-org/dendrite/clientapi/httputil"
|
||||||
|
|
@ -27,7 +29,6 @@ import (
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrix"
|
"github.com/matrix-org/gomatrix"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
|
|
@ -126,20 +127,6 @@ func SetAvatarURL(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
res := &userapi.QueryProfileResponse{}
|
|
||||||
err = profileAPI.QueryProfile(req.Context(), &userapi.QueryProfileRequest{
|
|
||||||
UserID: userID,
|
|
||||||
}, res)
|
|
||||||
if err != nil {
|
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("profileAPI.QueryProfile failed")
|
|
||||||
return jsonerror.InternalServerError()
|
|
||||||
}
|
|
||||||
oldProfile := &authtypes.Profile{
|
|
||||||
Localpart: localpart,
|
|
||||||
DisplayName: res.DisplayName,
|
|
||||||
AvatarURL: res.AvatarURL,
|
|
||||||
}
|
|
||||||
|
|
||||||
setRes := &userapi.PerformSetAvatarURLResponse{}
|
setRes := &userapi.PerformSetAvatarURLResponse{}
|
||||||
if err = profileAPI.SetAvatarURL(req.Context(), &userapi.PerformSetAvatarURLRequest{
|
if err = profileAPI.SetAvatarURL(req.Context(), &userapi.PerformSetAvatarURLRequest{
|
||||||
Localpart: localpart,
|
Localpart: localpart,
|
||||||
|
|
@ -159,14 +146,8 @@ func SetAvatarURL(
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
||||||
newProfile := authtypes.Profile{
|
|
||||||
Localpart: localpart,
|
|
||||||
DisplayName: oldProfile.DisplayName,
|
|
||||||
AvatarURL: r.AvatarURL,
|
|
||||||
}
|
|
||||||
|
|
||||||
events, err := buildMembershipEvents(
|
events, err := buildMembershipEvents(
|
||||||
req.Context(), roomsRes.RoomIDs, newProfile, userID, cfg, evTime, rsAPI,
|
req.Context(), roomsRes.RoomIDs, *setRes.Profile, userID, cfg, evTime, rsAPI,
|
||||||
)
|
)
|
||||||
switch e := err.(type) {
|
switch e := err.(type) {
|
||||||
case nil:
|
case nil:
|
||||||
|
|
@ -255,24 +236,11 @@ func SetDisplayName(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pRes := &userapi.QueryProfileResponse{}
|
profileRes := &userapi.PerformUpdateDisplayNameResponse{}
|
||||||
err = profileAPI.QueryProfile(req.Context(), &userapi.QueryProfileRequest{
|
|
||||||
UserID: userID,
|
|
||||||
}, pRes)
|
|
||||||
if err != nil {
|
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("profileAPI.QueryProfile failed")
|
|
||||||
return jsonerror.InternalServerError()
|
|
||||||
}
|
|
||||||
oldProfile := &authtypes.Profile{
|
|
||||||
Localpart: localpart,
|
|
||||||
DisplayName: pRes.DisplayName,
|
|
||||||
AvatarURL: pRes.AvatarURL,
|
|
||||||
}
|
|
||||||
|
|
||||||
err = profileAPI.SetDisplayName(req.Context(), &userapi.PerformUpdateDisplayNameRequest{
|
err = profileAPI.SetDisplayName(req.Context(), &userapi.PerformUpdateDisplayNameRequest{
|
||||||
Localpart: localpart,
|
Localpart: localpart,
|
||||||
DisplayName: r.DisplayName,
|
DisplayName: r.DisplayName,
|
||||||
}, &struct{}{})
|
}, profileRes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetDisplayName failed")
|
util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetDisplayName failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
|
|
@ -288,14 +256,8 @@ func SetDisplayName(
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
||||||
newProfile := authtypes.Profile{
|
|
||||||
Localpart: localpart,
|
|
||||||
DisplayName: r.DisplayName,
|
|
||||||
AvatarURL: oldProfile.AvatarURL,
|
|
||||||
}
|
|
||||||
|
|
||||||
events, err := buildMembershipEvents(
|
events, err := buildMembershipEvents(
|
||||||
req.Context(), res.RoomIDs, newProfile, userID, cfg, evTime, rsAPI,
|
req.Context(), res.RoomIDs, *profileRes.Profile, userID, cfg, evTime, rsAPI,
|
||||||
)
|
)
|
||||||
switch e := err.(type) {
|
switch e := err.(type) {
|
||||||
case nil:
|
case nil:
|
||||||
|
|
|
||||||
|
|
@ -96,7 +96,7 @@ type ClientUserAPI interface {
|
||||||
PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error
|
PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error
|
||||||
PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error
|
PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error
|
||||||
SetAvatarURL(ctx context.Context, req *PerformSetAvatarURLRequest, res *PerformSetAvatarURLResponse) error
|
SetAvatarURL(ctx context.Context, req *PerformSetAvatarURLRequest, res *PerformSetAvatarURLResponse) error
|
||||||
SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *struct{}) error
|
SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *PerformUpdateDisplayNameResponse) error
|
||||||
QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error
|
QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error
|
||||||
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
|
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
|
||||||
PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error
|
PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error
|
||||||
|
|
@ -579,7 +579,9 @@ type Notification struct {
|
||||||
type PerformSetAvatarURLRequest struct {
|
type PerformSetAvatarURLRequest struct {
|
||||||
Localpart, AvatarURL string
|
Localpart, AvatarURL string
|
||||||
}
|
}
|
||||||
type PerformSetAvatarURLResponse struct{}
|
type PerformSetAvatarURLResponse struct {
|
||||||
|
Profile *authtypes.Profile `json:"profile"`
|
||||||
|
}
|
||||||
|
|
||||||
type QueryNumericLocalpartResponse struct {
|
type QueryNumericLocalpartResponse struct {
|
||||||
ID int64
|
ID int64
|
||||||
|
|
@ -606,6 +608,10 @@ type PerformUpdateDisplayNameRequest struct {
|
||||||
Localpart, DisplayName string
|
Localpart, DisplayName string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type PerformUpdateDisplayNameResponse struct {
|
||||||
|
Profile *authtypes.Profile `json:"profile"`
|
||||||
|
}
|
||||||
|
|
||||||
type QueryLocalpartForThreePIDRequest struct {
|
type QueryLocalpartForThreePIDRequest struct {
|
||||||
ThreePID, Medium string
|
ThreePID, Medium string
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -168,7 +168,7 @@ func (t *UserInternalAPITrace) QueryAccountAvailability(ctx context.Context, req
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *UserInternalAPITrace) SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *struct{}) error {
|
func (t *UserInternalAPITrace) SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *PerformUpdateDisplayNameResponse) error {
|
||||||
err := t.Impl.SetDisplayName(ctx, req, res)
|
err := t.Impl.SetDisplayName(ctx, req, res)
|
||||||
util.GetLogger(ctx).Infof("SetDisplayName req=%+v res=%+v", js(req), js(res))
|
util.GetLogger(ctx).Infof("SetDisplayName req=%+v res=%+v", js(req), js(res))
|
||||||
return err
|
return err
|
||||||
|
|
|
||||||
|
|
@ -170,7 +170,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
|
if _, err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -813,7 +813,9 @@ func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPush
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error {
|
func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error {
|
||||||
return a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL)
|
profile, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL)
|
||||||
|
res.Profile = profile
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error {
|
func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error {
|
||||||
|
|
@ -847,8 +849,10 @@ func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.Q
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, _ *struct{}) error {
|
func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *api.PerformUpdateDisplayNameResponse) error {
|
||||||
return a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName)
|
profile, err := a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName)
|
||||||
|
res.Profile = profile
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error {
|
func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error {
|
||||||
|
|
|
||||||
|
|
@ -388,7 +388,7 @@ func (h *httpUserInternalAPI) QueryAccountByPassword(
|
||||||
func (h *httpUserInternalAPI) SetDisplayName(
|
func (h *httpUserInternalAPI) SetDisplayName(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *api.PerformUpdateDisplayNameRequest,
|
request *api.PerformUpdateDisplayNameRequest,
|
||||||
response *struct{},
|
response *api.PerformUpdateDisplayNameResponse,
|
||||||
) error {
|
) error {
|
||||||
return httputil.CallInternalRPCAPI(
|
return httputil.CallInternalRPCAPI(
|
||||||
"SetDisplayName", h.apiURL+PerformSetDisplayNamePath,
|
"SetDisplayName", h.apiURL+PerformSetDisplayNamePath,
|
||||||
|
|
|
||||||
|
|
@ -29,8 +29,8 @@ import (
|
||||||
type Profile interface {
|
type Profile interface {
|
||||||
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
|
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
|
||||||
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
|
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
|
||||||
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error
|
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) (*authtypes.Profile, error)
|
||||||
SetDisplayName(ctx context.Context, localpart string, displayName string) error
|
SetDisplayName(ctx context.Context, localpart string, displayName string) (*authtypes.Profile, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Account interface {
|
type Account interface {
|
||||||
|
|
|
||||||
|
|
@ -44,10 +44,12 @@ const selectProfileByLocalpartSQL = "" +
|
||||||
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1"
|
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1"
|
||||||
|
|
||||||
const setAvatarURLSQL = "" +
|
const setAvatarURLSQL = "" +
|
||||||
"UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2"
|
"UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" +
|
||||||
|
" RETURNING display_name"
|
||||||
|
|
||||||
const setDisplayNameSQL = "" +
|
const setDisplayNameSQL = "" +
|
||||||
"UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2"
|
"UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" +
|
||||||
|
" RETURNING avatar_url"
|
||||||
|
|
||||||
const selectProfilesBySearchSQL = "" +
|
const selectProfilesBySearchSQL = "" +
|
||||||
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
||||||
|
|
@ -100,16 +102,26 @@ func (s *profilesStatements) SelectProfileByLocalpart(
|
||||||
|
|
||||||
func (s *profilesStatements) SetAvatarURL(
|
func (s *profilesStatements) SetAvatarURL(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
|
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
|
||||||
) (err error) {
|
) (*authtypes.Profile, error) {
|
||||||
_, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart)
|
profile := &authtypes.Profile{
|
||||||
return
|
Localpart: localpart,
|
||||||
|
AvatarURL: avatarURL,
|
||||||
|
}
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
|
||||||
|
err := stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName)
|
||||||
|
return profile, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) SetDisplayName(
|
func (s *profilesStatements) SetDisplayName(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
|
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
|
||||||
) (err error) {
|
) (*authtypes.Profile, error) {
|
||||||
_, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart)
|
profile := &authtypes.Profile{
|
||||||
return
|
Localpart: localpart,
|
||||||
|
DisplayName: displayName,
|
||||||
|
}
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
|
||||||
|
err := stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL)
|
||||||
|
return profile, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) SelectProfilesBySearch(
|
func (s *profilesStatements) SelectProfilesBySearch(
|
||||||
|
|
|
||||||
|
|
@ -96,20 +96,24 @@ func (d *Database) GetProfileByLocalpart(
|
||||||
// localpart. Returns an error if something went wrong with the SQL query
|
// localpart. Returns an error if something went wrong with the SQL query
|
||||||
func (d *Database) SetAvatarURL(
|
func (d *Database) SetAvatarURL(
|
||||||
ctx context.Context, localpart string, avatarURL string,
|
ctx context.Context, localpart string, avatarURL string,
|
||||||
) error {
|
) (profile *authtypes.Profile, err error) {
|
||||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
return d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL)
|
profile, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL)
|
||||||
|
return err
|
||||||
})
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetDisplayName updates the display name of the profile associated with the given
|
// SetDisplayName updates the display name of the profile associated with the given
|
||||||
// localpart. Returns an error if something went wrong with the SQL query
|
// localpart. Returns an error if something went wrong with the SQL query
|
||||||
func (d *Database) SetDisplayName(
|
func (d *Database) SetDisplayName(
|
||||||
ctx context.Context, localpart string, displayName string,
|
ctx context.Context, localpart string, displayName string,
|
||||||
) error {
|
) (profile *authtypes.Profile, err error) {
|
||||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
return d.Profiles.SetDisplayName(ctx, txn, localpart, displayName)
|
profile, err = d.Profiles.SetDisplayName(ctx, txn, localpart, displayName)
|
||||||
|
return err
|
||||||
})
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetPassword sets the account password to the given hash.
|
// SetPassword sets the account password to the given hash.
|
||||||
|
|
|
||||||
|
|
@ -44,10 +44,12 @@ const selectProfileByLocalpartSQL = "" +
|
||||||
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1"
|
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1"
|
||||||
|
|
||||||
const setAvatarURLSQL = "" +
|
const setAvatarURLSQL = "" +
|
||||||
"UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2"
|
"UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" +
|
||||||
|
" RETURNING display_name"
|
||||||
|
|
||||||
const setDisplayNameSQL = "" +
|
const setDisplayNameSQL = "" +
|
||||||
"UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2"
|
"UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" +
|
||||||
|
" RETURNING avatar_url"
|
||||||
|
|
||||||
const selectProfilesBySearchSQL = "" +
|
const selectProfilesBySearchSQL = "" +
|
||||||
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
||||||
|
|
@ -102,18 +104,26 @@ func (s *profilesStatements) SelectProfileByLocalpart(
|
||||||
|
|
||||||
func (s *profilesStatements) SetAvatarURL(
|
func (s *profilesStatements) SetAvatarURL(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
|
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
|
||||||
) (err error) {
|
) (*authtypes.Profile, error) {
|
||||||
|
profile := &authtypes.Profile{
|
||||||
|
Localpart: localpart,
|
||||||
|
AvatarURL: avatarURL,
|
||||||
|
}
|
||||||
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
|
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
|
||||||
_, err = stmt.ExecContext(ctx, avatarURL, localpart)
|
err := stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName)
|
||||||
return
|
return profile, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) SetDisplayName(
|
func (s *profilesStatements) SetDisplayName(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
|
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
|
||||||
) (err error) {
|
) (*authtypes.Profile, error) {
|
||||||
|
profile := &authtypes.Profile{
|
||||||
|
Localpart: localpart,
|
||||||
|
DisplayName: displayName,
|
||||||
|
}
|
||||||
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
|
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
|
||||||
_, err = stmt.ExecContext(ctx, displayName, localpart)
|
err := stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL)
|
||||||
return
|
return profile, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) SelectProfilesBySearch(
|
func (s *profilesStatements) SelectProfilesBySearch(
|
||||||
|
|
|
||||||
|
|
@ -382,14 +382,13 @@ func Test_Profile(t *testing.T) {
|
||||||
|
|
||||||
// set avatar & displayname
|
// set avatar & displayname
|
||||||
wantProfile.DisplayName = "Alice"
|
wantProfile.DisplayName = "Alice"
|
||||||
wantProfile.AvatarURL = "mxc://aliceAvatar"
|
gotProfile, err = db.SetDisplayName(ctx, aliceLocalpart, "Alice")
|
||||||
err = db.SetDisplayName(ctx, aliceLocalpart, "Alice")
|
assert.Equal(t, wantProfile, gotProfile)
|
||||||
assert.NoError(t, err, "unable to set displayname")
|
assert.NoError(t, err, "unable to set displayname")
|
||||||
err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
|
|
||||||
|
wantProfile.AvatarURL = "mxc://aliceAvatar"
|
||||||
|
gotProfile, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
|
||||||
assert.NoError(t, err, "unable to set avatar url")
|
assert.NoError(t, err, "unable to set avatar url")
|
||||||
// verify profile
|
|
||||||
gotProfile, err = db.GetProfileByLocalpart(ctx, aliceLocalpart)
|
|
||||||
assert.NoError(t, err, "unable to get profile by localpart")
|
|
||||||
assert.Equal(t, wantProfile, gotProfile)
|
assert.Equal(t, wantProfile, gotProfile)
|
||||||
|
|
||||||
// search profiles
|
// search profiles
|
||||||
|
|
|
||||||
|
|
@ -84,8 +84,8 @@ type OpenIDTable interface {
|
||||||
type ProfileTable interface {
|
type ProfileTable interface {
|
||||||
InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error
|
InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error
|
||||||
SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
|
SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
|
||||||
SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (err error)
|
SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (*authtypes.Profile, error)
|
||||||
SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (err error)
|
SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (*authtypes.Profile, error)
|
||||||
SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
|
SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue