Return the new profile when updating display_name or avatar_url

This commit is contained in:
Till Faelligen 2022-10-21 07:45:11 +02:00
parent 9041491201
commit ce9b9cf87f
No known key found for this signature in database
GPG key ID: 3DF82D8AB9211D4E
11 changed files with 81 additions and 84 deletions

View file

@ -19,6 +19,8 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/matrix-org/gomatrixserverlib"
appserviceAPI "github.com/matrix-org/dendrite/appservice/api" appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/httputil"
@ -27,7 +29,6 @@ import (
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrix"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -126,20 +127,6 @@ func SetAvatarURL(
} }
} }
res := &userapi.QueryProfileResponse{}
err = profileAPI.QueryProfile(req.Context(), &userapi.QueryProfileRequest{
UserID: userID,
}, res)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("profileAPI.QueryProfile failed")
return jsonerror.InternalServerError()
}
oldProfile := &authtypes.Profile{
Localpart: localpart,
DisplayName: res.DisplayName,
AvatarURL: res.AvatarURL,
}
setRes := &userapi.PerformSetAvatarURLResponse{} setRes := &userapi.PerformSetAvatarURLResponse{}
if err = profileAPI.SetAvatarURL(req.Context(), &userapi.PerformSetAvatarURLRequest{ if err = profileAPI.SetAvatarURL(req.Context(), &userapi.PerformSetAvatarURLRequest{
Localpart: localpart, Localpart: localpart,
@ -159,14 +146,8 @@ func SetAvatarURL(
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
newProfile := authtypes.Profile{
Localpart: localpart,
DisplayName: oldProfile.DisplayName,
AvatarURL: r.AvatarURL,
}
events, err := buildMembershipEvents( events, err := buildMembershipEvents(
req.Context(), roomsRes.RoomIDs, newProfile, userID, cfg, evTime, rsAPI, req.Context(), roomsRes.RoomIDs, *setRes.Profile, userID, cfg, evTime, rsAPI,
) )
switch e := err.(type) { switch e := err.(type) {
case nil: case nil:
@ -255,24 +236,11 @@ func SetDisplayName(
} }
} }
pRes := &userapi.QueryProfileResponse{} profileRes := &userapi.PerformUpdateDisplayNameResponse{}
err = profileAPI.QueryProfile(req.Context(), &userapi.QueryProfileRequest{
UserID: userID,
}, pRes)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("profileAPI.QueryProfile failed")
return jsonerror.InternalServerError()
}
oldProfile := &authtypes.Profile{
Localpart: localpart,
DisplayName: pRes.DisplayName,
AvatarURL: pRes.AvatarURL,
}
err = profileAPI.SetDisplayName(req.Context(), &userapi.PerformUpdateDisplayNameRequest{ err = profileAPI.SetDisplayName(req.Context(), &userapi.PerformUpdateDisplayNameRequest{
Localpart: localpart, Localpart: localpart,
DisplayName: r.DisplayName, DisplayName: r.DisplayName,
}, &struct{}{}) }, profileRes)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetDisplayName failed") util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetDisplayName failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -288,14 +256,8 @@ func SetDisplayName(
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
newProfile := authtypes.Profile{
Localpart: localpart,
DisplayName: r.DisplayName,
AvatarURL: oldProfile.AvatarURL,
}
events, err := buildMembershipEvents( events, err := buildMembershipEvents(
req.Context(), res.RoomIDs, newProfile, userID, cfg, evTime, rsAPI, req.Context(), res.RoomIDs, *profileRes.Profile, userID, cfg, evTime, rsAPI,
) )
switch e := err.(type) { switch e := err.(type) {
case nil: case nil:

View file

@ -96,7 +96,7 @@ type ClientUserAPI interface {
PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error
PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error
SetAvatarURL(ctx context.Context, req *PerformSetAvatarURLRequest, res *PerformSetAvatarURLResponse) error SetAvatarURL(ctx context.Context, req *PerformSetAvatarURLRequest, res *PerformSetAvatarURLResponse) error
SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *struct{}) error SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *PerformUpdateDisplayNameResponse) error
QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error
@ -579,7 +579,9 @@ type Notification struct {
type PerformSetAvatarURLRequest struct { type PerformSetAvatarURLRequest struct {
Localpart, AvatarURL string Localpart, AvatarURL string
} }
type PerformSetAvatarURLResponse struct{} type PerformSetAvatarURLResponse struct {
Profile *authtypes.Profile `json:"profile"`
}
type QueryNumericLocalpartResponse struct { type QueryNumericLocalpartResponse struct {
ID int64 ID int64
@ -606,6 +608,10 @@ type PerformUpdateDisplayNameRequest struct {
Localpart, DisplayName string Localpart, DisplayName string
} }
type PerformUpdateDisplayNameResponse struct {
Profile *authtypes.Profile `json:"profile"`
}
type QueryLocalpartForThreePIDRequest struct { type QueryLocalpartForThreePIDRequest struct {
ThreePID, Medium string ThreePID, Medium string
} }

View file

@ -168,7 +168,7 @@ func (t *UserInternalAPITrace) QueryAccountAvailability(ctx context.Context, req
return err return err
} }
func (t *UserInternalAPITrace) SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *struct{}) error { func (t *UserInternalAPITrace) SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *PerformUpdateDisplayNameResponse) error {
err := t.Impl.SetDisplayName(ctx, req, res) err := t.Impl.SetDisplayName(ctx, req, res)
util.GetLogger(ctx).Infof("SetDisplayName req=%+v res=%+v", js(req), js(res)) util.GetLogger(ctx).Infof("SetDisplayName req=%+v res=%+v", js(req), js(res))
return err return err

View file

@ -170,7 +170,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
return nil return nil
} }
if err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil { if _, err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
return err return err
} }
@ -813,7 +813,9 @@ func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPush
} }
func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error { func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error {
return a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL) profile, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL)
res.Profile = profile
return err
} }
func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error { func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error {
@ -847,8 +849,10 @@ func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.Q
} }
} }
func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, _ *struct{}) error { func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *api.PerformUpdateDisplayNameResponse) error {
return a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName) profile, err := a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName)
res.Profile = profile
return err
} }
func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error { func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error {

View file

@ -388,7 +388,7 @@ func (h *httpUserInternalAPI) QueryAccountByPassword(
func (h *httpUserInternalAPI) SetDisplayName( func (h *httpUserInternalAPI) SetDisplayName(
ctx context.Context, ctx context.Context,
request *api.PerformUpdateDisplayNameRequest, request *api.PerformUpdateDisplayNameRequest,
response *struct{}, response *api.PerformUpdateDisplayNameResponse,
) error { ) error {
return httputil.CallInternalRPCAPI( return httputil.CallInternalRPCAPI(
"SetDisplayName", h.apiURL+PerformSetDisplayNamePath, "SetDisplayName", h.apiURL+PerformSetDisplayNamePath,

View file

@ -29,8 +29,8 @@ import (
type Profile interface { type Profile interface {
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error SetAvatarURL(ctx context.Context, localpart string, avatarURL string) (*authtypes.Profile, error)
SetDisplayName(ctx context.Context, localpart string, displayName string) error SetDisplayName(ctx context.Context, localpart string, displayName string) (*authtypes.Profile, error)
} }
type Account interface { type Account interface {

View file

@ -44,10 +44,12 @@ const selectProfileByLocalpartSQL = "" +
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1" "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1"
const setAvatarURLSQL = "" + const setAvatarURLSQL = "" +
"UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" +
" RETURNING display_name"
const setDisplayNameSQL = "" + const setDisplayNameSQL = "" +
"UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" +
" RETURNING avatar_url"
const selectProfilesBySearchSQL = "" + const selectProfilesBySearchSQL = "" +
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
@ -100,16 +102,26 @@ func (s *profilesStatements) SelectProfileByLocalpart(
func (s *profilesStatements) SetAvatarURL( func (s *profilesStatements) SetAvatarURL(
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
) (err error) { ) (*authtypes.Profile, error) {
_, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart) profile := &authtypes.Profile{
return Localpart: localpart,
AvatarURL: avatarURL,
}
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
err := stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName)
return profile, err
} }
func (s *profilesStatements) SetDisplayName( func (s *profilesStatements) SetDisplayName(
ctx context.Context, txn *sql.Tx, localpart string, displayName string, ctx context.Context, txn *sql.Tx, localpart string, displayName string,
) (err error) { ) (*authtypes.Profile, error) {
_, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart) profile := &authtypes.Profile{
return Localpart: localpart,
DisplayName: displayName,
}
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
err := stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL)
return profile, err
} }
func (s *profilesStatements) SelectProfilesBySearch( func (s *profilesStatements) SelectProfilesBySearch(

View file

@ -96,20 +96,24 @@ func (d *Database) GetProfileByLocalpart(
// localpart. Returns an error if something went wrong with the SQL query // localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetAvatarURL( func (d *Database) SetAvatarURL(
ctx context.Context, localpart string, avatarURL string, ctx context.Context, localpart string, avatarURL string,
) error { ) (profile *authtypes.Profile, err error) {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL) profile, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL)
return err
}) })
return
} }
// SetDisplayName updates the display name of the profile associated with the given // SetDisplayName updates the display name of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query // localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetDisplayName( func (d *Database) SetDisplayName(
ctx context.Context, localpart string, displayName string, ctx context.Context, localpart string, displayName string,
) error { ) (profile *authtypes.Profile, err error) {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Profiles.SetDisplayName(ctx, txn, localpart, displayName) profile, err = d.Profiles.SetDisplayName(ctx, txn, localpart, displayName)
return err
}) })
return
} }
// SetPassword sets the account password to the given hash. // SetPassword sets the account password to the given hash.

View file

@ -44,10 +44,12 @@ const selectProfileByLocalpartSQL = "" +
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1" "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1"
const setAvatarURLSQL = "" + const setAvatarURLSQL = "" +
"UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" +
" RETURNING display_name"
const setDisplayNameSQL = "" + const setDisplayNameSQL = "" +
"UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" +
" RETURNING avatar_url"
const selectProfilesBySearchSQL = "" + const selectProfilesBySearchSQL = "" +
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
@ -102,18 +104,26 @@ func (s *profilesStatements) SelectProfileByLocalpart(
func (s *profilesStatements) SetAvatarURL( func (s *profilesStatements) SetAvatarURL(
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
) (err error) { ) (*authtypes.Profile, error) {
profile := &authtypes.Profile{
Localpart: localpart,
AvatarURL: avatarURL,
}
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
_, err = stmt.ExecContext(ctx, avatarURL, localpart) err := stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName)
return return profile, err
} }
func (s *profilesStatements) SetDisplayName( func (s *profilesStatements) SetDisplayName(
ctx context.Context, txn *sql.Tx, localpart string, displayName string, ctx context.Context, txn *sql.Tx, localpart string, displayName string,
) (err error) { ) (*authtypes.Profile, error) {
profile := &authtypes.Profile{
Localpart: localpart,
DisplayName: displayName,
}
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
_, err = stmt.ExecContext(ctx, displayName, localpart) err := stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL)
return return profile, err
} }
func (s *profilesStatements) SelectProfilesBySearch( func (s *profilesStatements) SelectProfilesBySearch(

View file

@ -382,14 +382,13 @@ func Test_Profile(t *testing.T) {
// set avatar & displayname // set avatar & displayname
wantProfile.DisplayName = "Alice" wantProfile.DisplayName = "Alice"
wantProfile.AvatarURL = "mxc://aliceAvatar" gotProfile, err = db.SetDisplayName(ctx, aliceLocalpart, "Alice")
err = db.SetDisplayName(ctx, aliceLocalpart, "Alice") assert.Equal(t, wantProfile, gotProfile)
assert.NoError(t, err, "unable to set displayname") assert.NoError(t, err, "unable to set displayname")
err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
wantProfile.AvatarURL = "mxc://aliceAvatar"
gotProfile, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
assert.NoError(t, err, "unable to set avatar url") assert.NoError(t, err, "unable to set avatar url")
// verify profile
gotProfile, err = db.GetProfileByLocalpart(ctx, aliceLocalpart)
assert.NoError(t, err, "unable to get profile by localpart")
assert.Equal(t, wantProfile, gotProfile) assert.Equal(t, wantProfile, gotProfile)
// search profiles // search profiles

View file

@ -84,8 +84,8 @@ type OpenIDTable interface {
type ProfileTable interface { type ProfileTable interface {
InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error
SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (err error) SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (*authtypes.Profile, error)
SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (err error) SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (*authtypes.Profile, error)
SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
} }