From 51b30ab9d365929617fafcedda6fc521b4418576 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Fri, 21 Oct 2022 09:22:00 +0200 Subject: [PATCH] Add "changed" return value Deduplicate code in the clientapi --- clientapi/routing/profile.go | 81 ++++++++++++----------- userapi/api/api.go | 2 + userapi/internal/api.go | 8 ++- userapi/storage/interface.go | 4 +- userapi/storage/postgres/profile_table.go | 28 +++++--- userapi/storage/shared/storage.go | 8 +-- userapi/storage/sqlite3/profile_table.go | 26 ++++++-- userapi/storage/storage_test.go | 10 ++- userapi/storage/tables/interface.go | 4 +- 9 files changed, 103 insertions(+), 68 deletions(-) diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index afa7a1a24..c9647eb1b 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -135,35 +135,17 @@ func SetAvatarURL( util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetAvatarURL failed") return jsonerror.InternalServerError() } - - var roomsRes api.QueryRoomsForUserResponse - err = rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{ - UserID: device.UserID, - WantMembership: "join", - }, &roomsRes) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed") - return jsonerror.InternalServerError() - } - - events, err := buildMembershipEvents( - req.Context(), roomsRes.RoomIDs, *setRes.Profile, userID, cfg, evTime, rsAPI, - ) - switch e := err.(type) { - case nil: - case gomatrixserverlib.BadJSONError: + // No need to build new membership events, since nothing changed + if !setRes.Changed { return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(e.Error()), + Code: http.StatusOK, + JSON: struct{}{}, } - default: - util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvents failed") - return jsonerror.InternalServerError() } - if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, true); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") - return jsonerror.InternalServerError() + response, err := updateProfile(req.Context(), rsAPI, device, setRes.Profile, userID, cfg, evTime) + if err != nil { + return response } return util.JSONResponse{ @@ -245,19 +227,42 @@ func SetDisplayName( util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetDisplayName failed") return jsonerror.InternalServerError() } + // No need to build new membership events, since nothing changed + if !profileRes.Changed { + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } + } + response, err := updateProfile(req.Context(), rsAPI, device, profileRes.Profile, userID, cfg, evTime) + if err != nil { + return response + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } +} + +func updateProfile( + ctx context.Context, rsAPI api.ClientRoomserverAPI, device *userapi.Device, + profile *authtypes.Profile, + userID string, cfg *config.ClientAPI, evTime time.Time, +) (util.JSONResponse, error) { var res api.QueryRoomsForUserResponse - err = rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{ + err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ UserID: device.UserID, WantMembership: "join", }, &res) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed") - return jsonerror.InternalServerError() + util.GetLogger(ctx).WithError(err).Error("QueryRoomsForUser failed") + return jsonerror.InternalServerError(), err } events, err := buildMembershipEvents( - req.Context(), res.RoomIDs, *profileRes.Profile, userID, cfg, evTime, rsAPI, + ctx, res.RoomIDs, *profile, userID, cfg, evTime, rsAPI, ) switch e := err.(type) { case nil: @@ -265,21 +270,17 @@ func SetDisplayName( return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(e.Error()), - } + }, e default: - util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvents failed") - return jsonerror.InternalServerError() + util.GetLogger(ctx).WithError(err).Error("buildMembershipEvents failed") + return jsonerror.InternalServerError(), e } - if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, true); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") - return jsonerror.InternalServerError() - } - - return util.JSONResponse{ - Code: http.StatusOK, - JSON: struct{}{}, + if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, true); err != nil { + util.GetLogger(ctx).WithError(err).Error("SendEvents failed") + return jsonerror.InternalServerError(), err } + return util.JSONResponse{}, nil } // getProfile gets the full profile of a user by querying the database or a diff --git a/userapi/api/api.go b/userapi/api/api.go index 226799fd4..eef29144a 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -581,6 +581,7 @@ type PerformSetAvatarURLRequest struct { } type PerformSetAvatarURLResponse struct { Profile *authtypes.Profile `json:"profile"` + Changed bool `json:"changed"` } type QueryNumericLocalpartResponse struct { @@ -610,6 +611,7 @@ type PerformUpdateDisplayNameRequest struct { type PerformUpdateDisplayNameResponse struct { Profile *authtypes.Profile `json:"profile"` + Changed bool `json:"changed"` } type QueryLocalpartForThreePIDRequest struct { diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 5f9cef396..63044eedb 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -170,7 +170,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P return nil } - if _, err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil { + if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil { return err } @@ -813,8 +813,9 @@ func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPush } func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error { - profile, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL) + profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL) res.Profile = profile + res.Changed = changed return err } @@ -850,8 +851,9 @@ func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.Q } func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *api.PerformUpdateDisplayNameResponse) error { - profile, err := a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName) + profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName) res.Profile = profile + res.Changed = changed return err } diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index f897c023b..fb12b53af 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -29,8 +29,8 @@ import ( type Profile interface { GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) - SetAvatarURL(ctx context.Context, localpart string, avatarURL string) (*authtypes.Profile, error) - SetDisplayName(ctx context.Context, localpart string, displayName string) (*authtypes.Profile, error) + SetAvatarURL(ctx context.Context, localpart string, avatarURL string) (*authtypes.Profile, bool, error) + SetDisplayName(ctx context.Context, localpart string, displayName string) (*authtypes.Profile, bool, error) } type Account interface { diff --git a/userapi/storage/postgres/profile_table.go b/userapi/storage/postgres/profile_table.go index 749ced1c0..2753b23d9 100644 --- a/userapi/storage/postgres/profile_table.go +++ b/userapi/storage/postgres/profile_table.go @@ -44,12 +44,18 @@ const selectProfileByLocalpartSQL = "" + "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1" const setAvatarURLSQL = "" + - "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" + - " RETURNING display_name" + "UPDATE userapi_profiles AS new" + + " SET avatar_url = $1" + + " FROM userapi_profiles AS old" + + " WHERE new.localpart = $2" + + " RETURNING new.display_name, old.avatar_url <> new.avatar_url" const setDisplayNameSQL = "" + - "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" + - " RETURNING avatar_url" + "UPDATE userapi_profiles AS new" + + " SET display_name = $1" + + " FROM userapi_profiles AS old" + + " WHERE new.localpart = $2" + + " RETURNING new.avatar_url, old.display_name <> new.display_name" const selectProfilesBySearchSQL = "" + "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" @@ -102,26 +108,28 @@ func (s *profilesStatements) SelectProfileByLocalpart( func (s *profilesStatements) SetAvatarURL( ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, -) (*authtypes.Profile, error) { +) (*authtypes.Profile, bool, error) { profile := &authtypes.Profile{ Localpart: localpart, AvatarURL: avatarURL, } + var changed bool stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) - err := stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName) - return profile, err + err := stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName, &changed) + return profile, changed, err } func (s *profilesStatements) SetDisplayName( ctx context.Context, txn *sql.Tx, localpart string, displayName string, -) (*authtypes.Profile, error) { +) (*authtypes.Profile, bool, error) { profile := &authtypes.Profile{ Localpart: localpart, DisplayName: displayName, } + var changed bool stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) - err := stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL) - return profile, err + err := stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL, &changed) + return profile, changed, err } func (s *profilesStatements) SelectProfilesBySearch( diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index ad1b55728..f8b6ad311 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -96,9 +96,9 @@ func (d *Database) GetProfileByLocalpart( // localpart. Returns an error if something went wrong with the SQL query func (d *Database) SetAvatarURL( ctx context.Context, localpart string, avatarURL string, -) (profile *authtypes.Profile, err error) { +) (profile *authtypes.Profile, changed bool, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - profile, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL) + profile, changed, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL) return err }) return @@ -108,9 +108,9 @@ func (d *Database) SetAvatarURL( // localpart. Returns an error if something went wrong with the SQL query func (d *Database) SetDisplayName( ctx context.Context, localpart string, displayName string, -) (profile *authtypes.Profile, err error) { +) (profile *authtypes.Profile, changed bool, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - profile, err = d.Profiles.SetDisplayName(ctx, txn, localpart, displayName) + profile, changed, err = d.Profiles.SetDisplayName(ctx, txn, localpart, displayName) return err }) return diff --git a/userapi/storage/sqlite3/profile_table.go b/userapi/storage/sqlite3/profile_table.go index 63d7ae4b3..b6130a1e3 100644 --- a/userapi/storage/sqlite3/profile_table.go +++ b/userapi/storage/sqlite3/profile_table.go @@ -104,26 +104,40 @@ func (s *profilesStatements) SelectProfileByLocalpart( func (s *profilesStatements) SetAvatarURL( ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, -) (*authtypes.Profile, error) { +) (*authtypes.Profile, bool, error) { profile := &authtypes.Profile{ Localpart: localpart, AvatarURL: avatarURL, } + old, err := s.SelectProfileByLocalpart(ctx, localpart) + if err != nil { + return old, false, err + } + if old.AvatarURL == avatarURL { + return old, false, nil + } stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) - err := stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName) - return profile, err + err = stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName) + return profile, true, err } func (s *profilesStatements) SetDisplayName( ctx context.Context, txn *sql.Tx, localpart string, displayName string, -) (*authtypes.Profile, error) { +) (*authtypes.Profile, bool, error) { profile := &authtypes.Profile{ Localpart: localpart, DisplayName: displayName, } + old, err := s.SelectProfileByLocalpart(ctx, localpart) + if err != nil { + return old, false, err + } + if old.DisplayName == displayName { + return old, false, nil + } stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) - err := stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL) - return profile, err + err = stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL) + return profile, true, err } func (s *profilesStatements) SelectProfilesBySearch( diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index d72ebc94c..db58cbe50 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -387,9 +387,17 @@ func Test_Profile(t *testing.T) { assert.NoError(t, err, "unable to set displayname") wantProfile.AvatarURL = "mxc://aliceAvatar" - gotProfile, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar") + gotProfile, changed, err := db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar") assert.NoError(t, err, "unable to set avatar url") assert.Equal(t, wantProfile, gotProfile) + assert.True(t, changed) + + // Setting the same avatar again doesn't change anything + wantProfile.AvatarURL = "mxc://aliceAvatar" + gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar") + assert.NoError(t, err, "unable to set avatar url") + assert.Equal(t, wantProfile, gotProfile) + assert.False(t, changed) // search profiles searchRes, err := db.SearchProfiles(ctx, "Alice", 2) diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 47a254cfc..1b239e442 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -84,8 +84,8 @@ type OpenIDTable interface { type ProfileTable interface { InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) - SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (*authtypes.Profile, error) - SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (*authtypes.Profile, error) + SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (*authtypes.Profile, bool, error) + SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (*authtypes.Profile, bool, error) SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) }