diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 5021871d9..2f8b93d07 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -214,7 +214,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P return nil } - if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil { + if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, req.ServerName, req.Localpart); err != nil { return err } @@ -412,7 +412,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil if !a.Config.Matrix.IsLocalServerName(domain) { return fmt.Errorf("cannot query profile of remote users (server name %s)", domain) } - prof, err := a.DB.GetProfileByLocalpart(ctx, local) + prof, err := a.DB.GetProfileByLocalpart(ctx, local, domain) if err != nil { if err == sql.ErrNoRows { return nil @@ -883,7 +883,7 @@ func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPush } func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error { - profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL) + profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.ServerName, req.AvatarURL) res.Profile = profile res.Changed = changed return err @@ -921,7 +921,7 @@ func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.Q } func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *api.PerformUpdateDisplayNameResponse) error { - profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName) + profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.ServerName, req.DisplayName) res.Profile = profile res.Changed = changed return err diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 4a0f66f85..d1b4e1b59 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -29,10 +29,10 @@ import ( ) type Profile interface { - GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) + GetProfileByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*authtypes.Profile, error) SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) - SetAvatarURL(ctx context.Context, localpart string, avatarURL string) (*authtypes.Profile, bool, error) - SetDisplayName(ctx context.Context, localpart string, displayName string) (*authtypes.Profile, bool, error) + SetAvatarURL(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error) + SetDisplayName(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, displayName string) (*authtypes.Profile, bool, error) } type Account interface { diff --git a/userapi/storage/postgres/profile_table.go b/userapi/storage/postgres/profile_table.go index 2511f32ca..28d24a4d0 100644 --- a/userapi/storage/postgres/profile_table.go +++ b/userapi/storage/postgres/profile_table.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" ) const profilesSchema = ` @@ -39,27 +40,27 @@ CREATE TABLE IF NOT EXISTS userapi_profiles ( ` const insertProfileSQL = "" + - "INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" + "INSERT INTO userapi_profiles(localpart, server_name, display_name, avatar_url) VALUES ($1, $2, $3, $4)" const selectProfileByLocalpartSQL = "" + - "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1" + "SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1 AND server_name = $2" const setAvatarURLSQL = "" + "UPDATE userapi_profiles AS new" + " SET avatar_url = $1" + " FROM userapi_profiles AS old" + - " WHERE new.localpart = $2" + + " WHERE new.localpart = $2 AND new.server_name = $3" + " RETURNING new.display_name, old.avatar_url <> new.avatar_url" const setDisplayNameSQL = "" + "UPDATE userapi_profiles AS new" + " SET display_name = $1" + " FROM userapi_profiles AS old" + - " WHERE new.localpart = $2" + + " WHERE new.localpart = $2 AND new.server_name = $3" + " RETURNING new.avatar_url, old.display_name <> new.display_name" const selectProfilesBySearchSQL = "" + - "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" + "SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" type profilesStatements struct { serverNoticesLocalpart string @@ -88,18 +89,20 @@ func NewPostgresProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables } func (s *profilesStatements) InsertProfile( - ctx context.Context, txn *sql.Tx, localpart string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, ) (err error) { - _, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "") + _, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, serverName, "", "") return } func (s *profilesStatements) SelectProfileByLocalpart( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) (*authtypes.Profile, error) { var profile authtypes.Profile - err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan( - &profile.Localpart, &profile.DisplayName, &profile.AvatarURL, + err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart, serverName).Scan( + &profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL, ) if err != nil { return nil, err @@ -108,28 +111,34 @@ func (s *profilesStatements) SelectProfileByLocalpart( } func (s *profilesStatements) SetAvatarURL( - ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + avatarURL string, ) (*authtypes.Profile, bool, error) { profile := &authtypes.Profile{ - Localpart: localpart, - AvatarURL: avatarURL, + Localpart: localpart, + ServerName: string(serverName), + AvatarURL: avatarURL, } var changed bool stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) - err := stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName, &changed) + err := stmt.QueryRowContext(ctx, avatarURL, localpart, serverName).Scan(&profile.DisplayName, &changed) return profile, changed, err } func (s *profilesStatements) SetDisplayName( - ctx context.Context, txn *sql.Tx, localpart string, displayName string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + displayName string, ) (*authtypes.Profile, bool, error) { profile := &authtypes.Profile{ Localpart: localpart, + ServerName: string(serverName), DisplayName: displayName, } var changed bool stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) - err := stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL, &changed) + err := stmt.QueryRowContext(ctx, displayName, localpart, serverName).Scan(&profile.AvatarURL, &changed) return profile, changed, err } @@ -147,7 +156,7 @@ func (s *profilesStatements) SelectProfilesBySearch( defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed") for rows.Next() { var profile authtypes.Profile - if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil { + if err := rows.Scan(&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL); err != nil { return nil, err } if profile.Localpart != s.serverNoticesLocalpart { diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 7f4ec8a1a..a1ce9bb39 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -87,18 +87,21 @@ func (d *Database) GetAccountByPassword( // GetProfileByLocalpart returns the profile associated with the given localpart. // Returns sql.ErrNoRows if no profile exists which matches the given localpart. func (d *Database) GetProfileByLocalpart( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) (*authtypes.Profile, error) { - return d.Profiles.SelectProfileByLocalpart(ctx, localpart) + return d.Profiles.SelectProfileByLocalpart(ctx, localpart, serverName) } // SetAvatarURL updates the avatar URL of the profile associated with the given // localpart. Returns an error if something went wrong with the SQL query func (d *Database) SetAvatarURL( - ctx context.Context, localpart string, avatarURL string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + avatarURL string, ) (profile *authtypes.Profile, changed bool, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - profile, changed, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL) + profile, changed, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, serverName, avatarURL) return err }) return @@ -107,10 +110,12 @@ func (d *Database) SetAvatarURL( // SetDisplayName updates the display name of the profile associated with the given // localpart. Returns an error if something went wrong with the SQL query func (d *Database) SetDisplayName( - ctx context.Context, localpart string, displayName string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + displayName string, ) (profile *authtypes.Profile, changed bool, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - profile, changed, err = d.Profiles.SetDisplayName(ctx, txn, localpart, displayName) + profile, changed, err = d.Profiles.SetDisplayName(ctx, txn, localpart, serverName, displayName) return err }) return @@ -175,7 +180,7 @@ func (d *Database) createAccount( if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, serverName, hash, appserviceID, accountType); err != nil { return nil, sqlutil.ErrUserExists } - if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil { + if err = d.Profiles.InsertProfile(ctx, txn, localpart, serverName); err != nil { return nil, err } pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName) diff --git a/userapi/storage/sqlite3/profile_table.go b/userapi/storage/sqlite3/profile_table.go index 679ffd32d..4f01f3813 100644 --- a/userapi/storage/sqlite3/profile_table.go +++ b/userapi/storage/sqlite3/profile_table.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" ) const profilesSchema = ` @@ -39,21 +40,21 @@ CREATE TABLE IF NOT EXISTS userapi_profiles ( ` const insertProfileSQL = "" + - "INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" + "INSERT INTO userapi_profiles(localpart, server_name, display_name, avatar_url) VALUES ($1, $2, $3, $4)" const selectProfileByLocalpartSQL = "" + - "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1" + "SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1 AND server_name = $2" const setAvatarURLSQL = "" + - "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" + + "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2 AND server_name = $3" + " RETURNING display_name" const setDisplayNameSQL = "" + - "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" + + "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2 AND server_name = $3" + " RETURNING avatar_url" const selectProfilesBySearchSQL = "" + - "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" + "SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" type profilesStatements struct { db *sql.DB @@ -84,18 +85,20 @@ func NewSQLiteProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables.P } func (s *profilesStatements) InsertProfile( - ctx context.Context, txn *sql.Tx, localpart string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, ) error { - _, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "") + _, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, serverName, "", "") return err } func (s *profilesStatements) SelectProfileByLocalpart( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) (*authtypes.Profile, error) { var profile authtypes.Profile - err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan( - &profile.Localpart, &profile.DisplayName, &profile.AvatarURL, + err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart, serverName).Scan( + &profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL, ) if err != nil { return nil, err @@ -104,13 +107,16 @@ func (s *profilesStatements) SelectProfileByLocalpart( } func (s *profilesStatements) SetAvatarURL( - ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + avatarURL string, ) (*authtypes.Profile, bool, error) { profile := &authtypes.Profile{ - Localpart: localpart, - AvatarURL: avatarURL, + Localpart: localpart, + ServerName: string(serverName), + AvatarURL: avatarURL, } - old, err := s.SelectProfileByLocalpart(ctx, localpart) + old, err := s.SelectProfileByLocalpart(ctx, localpart, serverName) if err != nil { return old, false, err } @@ -118,18 +124,20 @@ func (s *profilesStatements) SetAvatarURL( return old, false, nil } stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) - err = stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName) + err = stmt.QueryRowContext(ctx, avatarURL, localpart, serverName).Scan(&profile.DisplayName) return profile, true, err } func (s *profilesStatements) SetDisplayName( - ctx context.Context, txn *sql.Tx, localpart string, displayName string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + displayName string, ) (*authtypes.Profile, bool, error) { profile := &authtypes.Profile{ Localpart: localpart, DisplayName: displayName, } - old, err := s.SelectProfileByLocalpart(ctx, localpart) + old, err := s.SelectProfileByLocalpart(ctx, localpart, serverName) if err != nil { return old, false, err } @@ -137,7 +145,7 @@ func (s *profilesStatements) SetDisplayName( return old, false, nil } stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) - err = stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL) + err = stmt.QueryRowContext(ctx, displayName, localpart, serverName).Scan(&profile.AvatarURL) return profile, true, err } @@ -155,7 +163,7 @@ func (s *profilesStatements) SelectProfilesBySearch( defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed") for rows.Next() { var profile authtypes.Profile - if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil { + if err := rows.Scan(&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL); err != nil { return nil, err } if profile.Localpart != s.serverNoticesLocalpart { diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index c7bf7dc7f..668a84772 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -375,27 +375,27 @@ func Test_Profile(t *testing.T) { _, err = db.CreateAccount(ctx, aliceLocalpart, aliceDomain, "testing", "", api.AccountTypeAdmin) assert.NoError(t, err, "failed to create account") - gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart) + gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to get profile by localpart") wantProfile := &authtypes.Profile{Localpart: aliceLocalpart} assert.Equal(t, wantProfile, gotProfile) // set avatar & displayname wantProfile.DisplayName = "Alice" - gotProfile, changed, err := db.SetDisplayName(ctx, aliceLocalpart, "Alice") + gotProfile, changed, err := db.SetDisplayName(ctx, aliceLocalpart, aliceDomain, "Alice") assert.Equal(t, wantProfile, gotProfile) assert.NoError(t, err, "unable to set displayname") assert.True(t, changed) wantProfile.AvatarURL = "mxc://aliceAvatar" - gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar") + gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, aliceDomain, "mxc://aliceAvatar") assert.NoError(t, err, "unable to set avatar url") assert.Equal(t, wantProfile, gotProfile) assert.True(t, changed) // Setting the same avatar again doesn't change anything wantProfile.AvatarURL = "mxc://aliceAvatar" - gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar") + gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, aliceDomain, "mxc://aliceAvatar") assert.NoError(t, err, "unable to set avatar url") assert.Equal(t, wantProfile, gotProfile) assert.False(t, changed) diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index e5ee0daaf..2480ddf85 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -84,10 +84,10 @@ type OpenIDTable interface { } type ProfileTable interface { - InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error - SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) - SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (*authtypes.Profile, bool, error) - SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (*authtypes.Profile, bool, error) + InsertProfile(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName) error + SelectProfileByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*authtypes.Profile, error) + SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error) + SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, displayName string) (*authtypes.Profile, bool, error) SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) } diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 759e6fbc9..23acfa6f3 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -84,10 +84,10 @@ func TestQueryProfile(t *testing.T) { if err != nil { t.Fatalf("failed to make account: %s", err) } - if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", aliceAvatarURL); err != nil { + if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", serverName, aliceAvatarURL); err != nil { t.Fatalf("failed to set avatar url: %s", err) } - if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", aliceDisplayName); err != nil { + if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", serverName, aliceDisplayName); err != nil { t.Fatalf("failed to set display name: %s", err) }