Add "changed" return value

Deduplicate code in the clientapi
This commit is contained in:
Till Faelligen 2022-10-21 09:22:00 +02:00
parent ce9b9cf87f
commit 51b30ab9d3
No known key found for this signature in database
GPG key ID: 3DF82D8AB9211D4E
9 changed files with 103 additions and 68 deletions

View file

@ -135,35 +135,17 @@ func SetAvatarURL(
util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetAvatarURL failed") util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetAvatarURL failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
// No need to build new membership events, since nothing changed
var roomsRes api.QueryRoomsForUserResponse if !setRes.Changed {
err = rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{
UserID: device.UserID,
WantMembership: "join",
}, &roomsRes)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed")
return jsonerror.InternalServerError()
}
events, err := buildMembershipEvents(
req.Context(), roomsRes.RoomIDs, *setRes.Profile, userID, cfg, evTime, rsAPI,
)
switch e := err.(type) {
case nil:
case gomatrixserverlib.BadJSONError:
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusOK,
JSON: jsonerror.BadJSON(e.Error()), JSON: struct{}{},
} }
default:
util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvents failed")
return jsonerror.InternalServerError()
} }
if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, true); err != nil { response, err := updateProfile(req.Context(), rsAPI, device, setRes.Profile, userID, cfg, evTime)
util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") if err != nil {
return jsonerror.InternalServerError() return response
} }
return util.JSONResponse{ return util.JSONResponse{
@ -245,19 +227,42 @@ func SetDisplayName(
util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetDisplayName failed") util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetDisplayName failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
// No need to build new membership events, since nothing changed
if !profileRes.Changed {
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
response, err := updateProfile(req.Context(), rsAPI, device, profileRes.Profile, userID, cfg, evTime)
if err != nil {
return response
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
func updateProfile(
ctx context.Context, rsAPI api.ClientRoomserverAPI, device *userapi.Device,
profile *authtypes.Profile,
userID string, cfg *config.ClientAPI, evTime time.Time,
) (util.JSONResponse, error) {
var res api.QueryRoomsForUserResponse var res api.QueryRoomsForUserResponse
err = rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{ err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{
UserID: device.UserID, UserID: device.UserID,
WantMembership: "join", WantMembership: "join",
}, &res) }, &res)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed") util.GetLogger(ctx).WithError(err).Error("QueryRoomsForUser failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError(), err
} }
events, err := buildMembershipEvents( events, err := buildMembershipEvents(
req.Context(), res.RoomIDs, *profileRes.Profile, userID, cfg, evTime, rsAPI, ctx, res.RoomIDs, *profile, userID, cfg, evTime, rsAPI,
) )
switch e := err.(type) { switch e := err.(type) {
case nil: case nil:
@ -265,21 +270,17 @@ func SetDisplayName(
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(e.Error()), JSON: jsonerror.BadJSON(e.Error()),
} }, e
default: default:
util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvents failed") util.GetLogger(ctx).WithError(err).Error("buildMembershipEvents failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError(), e
} }
if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, true); err != nil { if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, true); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") util.GetLogger(ctx).WithError(err).Error("SendEvents failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError(), err
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
} }
return util.JSONResponse{}, nil
} }
// getProfile gets the full profile of a user by querying the database or a // getProfile gets the full profile of a user by querying the database or a

View file

@ -581,6 +581,7 @@ type PerformSetAvatarURLRequest struct {
} }
type PerformSetAvatarURLResponse struct { type PerformSetAvatarURLResponse struct {
Profile *authtypes.Profile `json:"profile"` Profile *authtypes.Profile `json:"profile"`
Changed bool `json:"changed"`
} }
type QueryNumericLocalpartResponse struct { type QueryNumericLocalpartResponse struct {
@ -610,6 +611,7 @@ type PerformUpdateDisplayNameRequest struct {
type PerformUpdateDisplayNameResponse struct { type PerformUpdateDisplayNameResponse struct {
Profile *authtypes.Profile `json:"profile"` Profile *authtypes.Profile `json:"profile"`
Changed bool `json:"changed"`
} }
type QueryLocalpartForThreePIDRequest struct { type QueryLocalpartForThreePIDRequest struct {

View file

@ -170,7 +170,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
return nil return nil
} }
if _, err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil { if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
return err return err
} }
@ -813,8 +813,9 @@ func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPush
} }
func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error { func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error {
profile, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL) profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL)
res.Profile = profile res.Profile = profile
res.Changed = changed
return err return err
} }
@ -850,8 +851,9 @@ func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.Q
} }
func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *api.PerformUpdateDisplayNameResponse) error { func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *api.PerformUpdateDisplayNameResponse) error {
profile, err := a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName) profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName)
res.Profile = profile res.Profile = profile
res.Changed = changed
return err return err
} }

View file

@ -29,8 +29,8 @@ import (
type Profile interface { type Profile interface {
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) (*authtypes.Profile, error) SetAvatarURL(ctx context.Context, localpart string, avatarURL string) (*authtypes.Profile, bool, error)
SetDisplayName(ctx context.Context, localpart string, displayName string) (*authtypes.Profile, error) SetDisplayName(ctx context.Context, localpart string, displayName string) (*authtypes.Profile, bool, error)
} }
type Account interface { type Account interface {

View file

@ -44,12 +44,18 @@ const selectProfileByLocalpartSQL = "" +
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1" "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1"
const setAvatarURLSQL = "" + const setAvatarURLSQL = "" +
"UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" + "UPDATE userapi_profiles AS new" +
" RETURNING display_name" " SET avatar_url = $1" +
" FROM userapi_profiles AS old" +
" WHERE new.localpart = $2" +
" RETURNING new.display_name, old.avatar_url <> new.avatar_url"
const setDisplayNameSQL = "" + const setDisplayNameSQL = "" +
"UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" + "UPDATE userapi_profiles AS new" +
" RETURNING avatar_url" " SET display_name = $1" +
" FROM userapi_profiles AS old" +
" WHERE new.localpart = $2" +
" RETURNING new.avatar_url, old.display_name <> new.display_name"
const selectProfilesBySearchSQL = "" + const selectProfilesBySearchSQL = "" +
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
@ -102,26 +108,28 @@ func (s *profilesStatements) SelectProfileByLocalpart(
func (s *profilesStatements) SetAvatarURL( func (s *profilesStatements) SetAvatarURL(
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
) (*authtypes.Profile, error) { ) (*authtypes.Profile, bool, error) {
profile := &authtypes.Profile{ profile := &authtypes.Profile{
Localpart: localpart, Localpart: localpart,
AvatarURL: avatarURL, AvatarURL: avatarURL,
} }
var changed bool
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
err := stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName) err := stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName, &changed)
return profile, err return profile, changed, err
} }
func (s *profilesStatements) SetDisplayName( func (s *profilesStatements) SetDisplayName(
ctx context.Context, txn *sql.Tx, localpart string, displayName string, ctx context.Context, txn *sql.Tx, localpart string, displayName string,
) (*authtypes.Profile, error) { ) (*authtypes.Profile, bool, error) {
profile := &authtypes.Profile{ profile := &authtypes.Profile{
Localpart: localpart, Localpart: localpart,
DisplayName: displayName, DisplayName: displayName,
} }
var changed bool
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
err := stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL) err := stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL, &changed)
return profile, err return profile, changed, err
} }
func (s *profilesStatements) SelectProfilesBySearch( func (s *profilesStatements) SelectProfilesBySearch(

View file

@ -96,9 +96,9 @@ func (d *Database) GetProfileByLocalpart(
// localpart. Returns an error if something went wrong with the SQL query // localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetAvatarURL( func (d *Database) SetAvatarURL(
ctx context.Context, localpart string, avatarURL string, ctx context.Context, localpart string, avatarURL string,
) (profile *authtypes.Profile, err error) { ) (profile *authtypes.Profile, changed bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
profile, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL) profile, changed, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL)
return err return err
}) })
return return
@ -108,9 +108,9 @@ func (d *Database) SetAvatarURL(
// localpart. Returns an error if something went wrong with the SQL query // localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetDisplayName( func (d *Database) SetDisplayName(
ctx context.Context, localpart string, displayName string, ctx context.Context, localpart string, displayName string,
) (profile *authtypes.Profile, err error) { ) (profile *authtypes.Profile, changed bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
profile, err = d.Profiles.SetDisplayName(ctx, txn, localpart, displayName) profile, changed, err = d.Profiles.SetDisplayName(ctx, txn, localpart, displayName)
return err return err
}) })
return return

View file

@ -104,26 +104,40 @@ func (s *profilesStatements) SelectProfileByLocalpart(
func (s *profilesStatements) SetAvatarURL( func (s *profilesStatements) SetAvatarURL(
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
) (*authtypes.Profile, error) { ) (*authtypes.Profile, bool, error) {
profile := &authtypes.Profile{ profile := &authtypes.Profile{
Localpart: localpart, Localpart: localpart,
AvatarURL: avatarURL, AvatarURL: avatarURL,
} }
old, err := s.SelectProfileByLocalpart(ctx, localpart)
if err != nil {
return old, false, err
}
if old.AvatarURL == avatarURL {
return old, false, nil
}
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
err := stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName) err = stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName)
return profile, err return profile, true, err
} }
func (s *profilesStatements) SetDisplayName( func (s *profilesStatements) SetDisplayName(
ctx context.Context, txn *sql.Tx, localpart string, displayName string, ctx context.Context, txn *sql.Tx, localpart string, displayName string,
) (*authtypes.Profile, error) { ) (*authtypes.Profile, bool, error) {
profile := &authtypes.Profile{ profile := &authtypes.Profile{
Localpart: localpart, Localpart: localpart,
DisplayName: displayName, DisplayName: displayName,
} }
old, err := s.SelectProfileByLocalpart(ctx, localpart)
if err != nil {
return old, false, err
}
if old.DisplayName == displayName {
return old, false, nil
}
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
err := stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL) err = stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL)
return profile, err return profile, true, err
} }
func (s *profilesStatements) SelectProfilesBySearch( func (s *profilesStatements) SelectProfilesBySearch(

View file

@ -387,9 +387,17 @@ func Test_Profile(t *testing.T) {
assert.NoError(t, err, "unable to set displayname") assert.NoError(t, err, "unable to set displayname")
wantProfile.AvatarURL = "mxc://aliceAvatar" wantProfile.AvatarURL = "mxc://aliceAvatar"
gotProfile, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar") gotProfile, changed, err := db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
assert.NoError(t, err, "unable to set avatar url") assert.NoError(t, err, "unable to set avatar url")
assert.Equal(t, wantProfile, gotProfile) assert.Equal(t, wantProfile, gotProfile)
assert.True(t, changed)
// Setting the same avatar again doesn't change anything
wantProfile.AvatarURL = "mxc://aliceAvatar"
gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
assert.NoError(t, err, "unable to set avatar url")
assert.Equal(t, wantProfile, gotProfile)
assert.False(t, changed)
// search profiles // search profiles
searchRes, err := db.SearchProfiles(ctx, "Alice", 2) searchRes, err := db.SearchProfiles(ctx, "Alice", 2)

View file

@ -84,8 +84,8 @@ type OpenIDTable interface {
type ProfileTable interface { type ProfileTable interface {
InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error
SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (*authtypes.Profile, error) SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (*authtypes.Profile, bool, error)
SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (*authtypes.Profile, error) SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (*authtypes.Profile, bool, error)
SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
} }