diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index 0685c7352..afa7a1a24 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -19,6 +19,8 @@ import ( "net/http" "time" + "github.com/matrix-org/gomatrixserverlib" + appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" @@ -27,7 +29,6 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrix" "github.com/matrix-org/util" @@ -126,20 +127,6 @@ func SetAvatarURL( } } - res := &userapi.QueryProfileResponse{} - err = profileAPI.QueryProfile(req.Context(), &userapi.QueryProfileRequest{ - UserID: userID, - }, res) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("profileAPI.QueryProfile failed") - return jsonerror.InternalServerError() - } - oldProfile := &authtypes.Profile{ - Localpart: localpart, - DisplayName: res.DisplayName, - AvatarURL: res.AvatarURL, - } - setRes := &userapi.PerformSetAvatarURLResponse{} if err = profileAPI.SetAvatarURL(req.Context(), &userapi.PerformSetAvatarURLRequest{ Localpart: localpart, @@ -159,14 +146,8 @@ func SetAvatarURL( return jsonerror.InternalServerError() } - newProfile := authtypes.Profile{ - Localpart: localpart, - DisplayName: oldProfile.DisplayName, - AvatarURL: r.AvatarURL, - } - events, err := buildMembershipEvents( - req.Context(), roomsRes.RoomIDs, newProfile, userID, cfg, evTime, rsAPI, + req.Context(), roomsRes.RoomIDs, *setRes.Profile, userID, cfg, evTime, rsAPI, ) switch e := err.(type) { case nil: @@ -255,24 +236,11 @@ func SetDisplayName( } } - pRes := &userapi.QueryProfileResponse{} - err = profileAPI.QueryProfile(req.Context(), &userapi.QueryProfileRequest{ - UserID: userID, - }, pRes) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("profileAPI.QueryProfile failed") - return jsonerror.InternalServerError() - } - oldProfile := &authtypes.Profile{ - Localpart: localpart, - DisplayName: pRes.DisplayName, - AvatarURL: pRes.AvatarURL, - } - + profileRes := &userapi.PerformUpdateDisplayNameResponse{} err = profileAPI.SetDisplayName(req.Context(), &userapi.PerformUpdateDisplayNameRequest{ Localpart: localpart, DisplayName: r.DisplayName, - }, &struct{}{}) + }, profileRes) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetDisplayName failed") return jsonerror.InternalServerError() @@ -288,14 +256,8 @@ func SetDisplayName( return jsonerror.InternalServerError() } - newProfile := authtypes.Profile{ - Localpart: localpart, - DisplayName: r.DisplayName, - AvatarURL: oldProfile.AvatarURL, - } - events, err := buildMembershipEvents( - req.Context(), res.RoomIDs, newProfile, userID, cfg, evTime, rsAPI, + req.Context(), res.RoomIDs, *profileRes.Profile, userID, cfg, evTime, rsAPI, ) switch e := err.(type) { case nil: diff --git a/userapi/api/api.go b/userapi/api/api.go index 66ee9c7c8..226799fd4 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -96,7 +96,7 @@ type ClientUserAPI interface { PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error SetAvatarURL(ctx context.Context, req *PerformSetAvatarURLRequest, res *PerformSetAvatarURLResponse) error - SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *struct{}) error + SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *PerformUpdateDisplayNameResponse) error QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error @@ -579,7 +579,9 @@ type Notification struct { type PerformSetAvatarURLRequest struct { Localpart, AvatarURL string } -type PerformSetAvatarURLResponse struct{} +type PerformSetAvatarURLResponse struct { + Profile *authtypes.Profile `json:"profile"` +} type QueryNumericLocalpartResponse struct { ID int64 @@ -606,6 +608,10 @@ type PerformUpdateDisplayNameRequest struct { Localpart, DisplayName string } +type PerformUpdateDisplayNameResponse struct { + Profile *authtypes.Profile `json:"profile"` +} + type QueryLocalpartForThreePIDRequest struct { ThreePID, Medium string } diff --git a/userapi/api/api_trace.go b/userapi/api/api_trace.go index 7e2f69615..90834f7e3 100644 --- a/userapi/api/api_trace.go +++ b/userapi/api/api_trace.go @@ -168,7 +168,7 @@ func (t *UserInternalAPITrace) QueryAccountAvailability(ctx context.Context, req return err } -func (t *UserInternalAPITrace) SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *struct{}) error { +func (t *UserInternalAPITrace) SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *PerformUpdateDisplayNameResponse) error { err := t.Impl.SetDisplayName(ctx, req, res) util.GetLogger(ctx).Infof("SetDisplayName req=%+v res=%+v", js(req), js(res)) return err diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 2f7795dfe..5f9cef396 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -170,7 +170,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P return nil } - if err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil { + if _, err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil { return err } @@ -813,7 +813,9 @@ func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPush } func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error { - return a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL) + profile, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL) + res.Profile = profile + return err } func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error { @@ -847,8 +849,10 @@ func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.Q } } -func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, _ *struct{}) error { - return a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName) +func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *api.PerformUpdateDisplayNameResponse) error { + profile, err := a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName) + res.Profile = profile + return err } func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error { diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index a375d6caa..aa5d46d9f 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -388,7 +388,7 @@ func (h *httpUserInternalAPI) QueryAccountByPassword( func (h *httpUserInternalAPI) SetDisplayName( ctx context.Context, request *api.PerformUpdateDisplayNameRequest, - response *struct{}, + response *api.PerformUpdateDisplayNameResponse, ) error { return httputil.CallInternalRPCAPI( "SetDisplayName", h.apiURL+PerformSetDisplayNamePath, diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 02efe7afe..f897c023b 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -29,8 +29,8 @@ import ( type Profile interface { GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) - SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error - SetDisplayName(ctx context.Context, localpart string, displayName string) error + SetAvatarURL(ctx context.Context, localpart string, avatarURL string) (*authtypes.Profile, error) + SetDisplayName(ctx context.Context, localpart string, displayName string) (*authtypes.Profile, error) } type Account interface { diff --git a/userapi/storage/postgres/profile_table.go b/userapi/storage/postgres/profile_table.go index f686127be..749ced1c0 100644 --- a/userapi/storage/postgres/profile_table.go +++ b/userapi/storage/postgres/profile_table.go @@ -44,10 +44,12 @@ const selectProfileByLocalpartSQL = "" + "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1" const setAvatarURLSQL = "" + - "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" + "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" + + " RETURNING display_name" const setDisplayNameSQL = "" + - "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" + "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" + + " RETURNING avatar_url" const selectProfilesBySearchSQL = "" + "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" @@ -100,16 +102,26 @@ func (s *profilesStatements) SelectProfileByLocalpart( func (s *profilesStatements) SetAvatarURL( ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, -) (err error) { - _, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart) - return +) (*authtypes.Profile, error) { + profile := &authtypes.Profile{ + Localpart: localpart, + AvatarURL: avatarURL, + } + stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) + err := stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName) + return profile, err } func (s *profilesStatements) SetDisplayName( ctx context.Context, txn *sql.Tx, localpart string, displayName string, -) (err error) { - _, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart) - return +) (*authtypes.Profile, error) { + profile := &authtypes.Profile{ + Localpart: localpart, + DisplayName: displayName, + } + stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) + err := stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL) + return profile, err } func (s *profilesStatements) SelectProfilesBySearch( diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 4e28f7b5a..ad1b55728 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -96,20 +96,24 @@ func (d *Database) GetProfileByLocalpart( // localpart. Returns an error if something went wrong with the SQL query func (d *Database) SetAvatarURL( ctx context.Context, localpart string, avatarURL string, -) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL) +) (profile *authtypes.Profile, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + profile, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL) + return err }) + return } // SetDisplayName updates the display name of the profile associated with the given // localpart. Returns an error if something went wrong with the SQL query func (d *Database) SetDisplayName( ctx context.Context, localpart string, displayName string, -) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.Profiles.SetDisplayName(ctx, txn, localpart, displayName) +) (profile *authtypes.Profile, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + profile, err = d.Profiles.SetDisplayName(ctx, txn, localpart, displayName) + return err }) + return } // SetPassword sets the account password to the given hash. diff --git a/userapi/storage/sqlite3/profile_table.go b/userapi/storage/sqlite3/profile_table.go index 267daf044..63d7ae4b3 100644 --- a/userapi/storage/sqlite3/profile_table.go +++ b/userapi/storage/sqlite3/profile_table.go @@ -44,10 +44,12 @@ const selectProfileByLocalpartSQL = "" + "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1" const setAvatarURLSQL = "" + - "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" + "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" + + " RETURNING display_name" const setDisplayNameSQL = "" + - "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" + "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" + + " RETURNING avatar_url" const selectProfilesBySearchSQL = "" + "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" @@ -102,18 +104,26 @@ func (s *profilesStatements) SelectProfileByLocalpart( func (s *profilesStatements) SetAvatarURL( ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, -) (err error) { +) (*authtypes.Profile, error) { + profile := &authtypes.Profile{ + Localpart: localpart, + AvatarURL: avatarURL, + } stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) - _, err = stmt.ExecContext(ctx, avatarURL, localpart) - return + err := stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName) + return profile, err } func (s *profilesStatements) SetDisplayName( ctx context.Context, txn *sql.Tx, localpart string, displayName string, -) (err error) { +) (*authtypes.Profile, error) { + profile := &authtypes.Profile{ + Localpart: localpart, + DisplayName: displayName, + } stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) - _, err = stmt.ExecContext(ctx, displayName, localpart) - return + err := stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL) + return profile, err } func (s *profilesStatements) SelectProfilesBySearch( diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index 8e5b32b6a..d72ebc94c 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -382,14 +382,13 @@ func Test_Profile(t *testing.T) { // set avatar & displayname wantProfile.DisplayName = "Alice" - wantProfile.AvatarURL = "mxc://aliceAvatar" - err = db.SetDisplayName(ctx, aliceLocalpart, "Alice") + gotProfile, err = db.SetDisplayName(ctx, aliceLocalpart, "Alice") + assert.Equal(t, wantProfile, gotProfile) assert.NoError(t, err, "unable to set displayname") - err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar") + + wantProfile.AvatarURL = "mxc://aliceAvatar" + gotProfile, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar") assert.NoError(t, err, "unable to set avatar url") - // verify profile - gotProfile, err = db.GetProfileByLocalpart(ctx, aliceLocalpart) - assert.NoError(t, err, "unable to get profile by localpart") assert.Equal(t, wantProfile, gotProfile) // search profiles diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index cc4287997..47a254cfc 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -84,8 +84,8 @@ type OpenIDTable interface { type ProfileTable interface { InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) - SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (err error) - SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (err error) + SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (*authtypes.Profile, error) + SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (*authtypes.Profile, error) SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) }