diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index f7cd1810a..0f68c8b9f 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -23,13 +23,14 @@ import ( "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/types" + "github.com/matrix-org/gomatrixserverlib" ) type Profile interface { - GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) + GetProfileByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*authtypes.Profile, error) SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) - SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error - SetDisplayName(ctx context.Context, localpart string, displayName string) error + SetAvatarURL(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) error + SetDisplayName(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, displayName string) error } type Account interface { diff --git a/userapi/storage/postgres/deltas/2022060815233800_userprofile_servername.go b/userapi/storage/postgres/deltas/2022060815233800_userprofile_servername.go new file mode 100644 index 000000000..0f9a769d6 --- /dev/null +++ b/userapi/storage/postgres/deltas/2022060815233800_userprofile_servername.go @@ -0,0 +1,36 @@ +package deltas + +import ( + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +var serverName gomatrixserverlib.ServerName + +func LoadProfilePrimaryKey(m *sqlutil.Migrations, s gomatrixserverlib.ServerName) { + serverName = s + m.AddMigration(UpProfilePrimaryKey, DownProfilePrimaryKey) +} + +func UpProfilePrimaryKey(tx *sql.Tx) error { + _, err := tx.Exec(fmt.Sprintf(`ALTER TABLE account_profiles ADD COLUMN IF NOT EXISTS servername TEXT NOT NULL DEFAULT '%s'; + ALTER TABLE account_profiles DROP CONSTRAINT account_profiles_pkey; + ALTER TABLE account_profiles ADD PRIMARY KEY (localpart, servername); + ALTER TABLE account_profiles ALTER COLUMN servername DROP DEFAULT;`, serverName)) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownProfilePrimaryKey(tx *sql.Tx) error { + _, err := tx.Exec(`ALTER TABLE account_profiles DROP COLUMN IF EXISTS servername; + ALTER TABLE account_profiles ADD PRIMARY KEY(localpart);`) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/userapi/storage/postgres/profile_table.go b/userapi/storage/postgres/profile_table.go index 6d336eb8e..1fab35cfe 100644 --- a/userapi/storage/postgres/profile_table.go +++ b/userapi/storage/postgres/profile_table.go @@ -23,34 +23,38 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" ) const profilesSchema = ` -- Stores data about accounts profiles. CREATE TABLE IF NOT EXISTS account_profiles ( -- The Matrix user ID localpart for this account - localpart TEXT NOT NULL PRIMARY KEY, + localpart TEXT NOT NULL, + -- The server this user belongs to + servername TEXT NOT NULL, -- The display name for this account display_name TEXT, -- The URL of the avatar for this account - avatar_url TEXT + avatar_url TEXT, + PRIMARY KEY (localpart, servername) ); ` const insertProfileSQL = "" + - "INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" + "INSERT INTO account_profiles(localpart, display_name, avatar_url, servername) VALUES ($1, $2, $3, $4)" const selectProfileByLocalpartSQL = "" + - "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1" + "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1 AND servername = $2" const setAvatarURLSQL = "" + - "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2" + "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2 AND servername = $3" const setDisplayNameSQL = "" + - "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2" + "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2 AND servername = $3" const selectProfilesBySearchSQL = "" + - "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" + "SELECT localpart, display_name, avatar_url, servername FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" type profilesStatements struct { serverNoticesLocalpart string @@ -79,17 +83,17 @@ func NewPostgresProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables } func (s *profilesStatements) InsertProfile( - ctx context.Context, txn *sql.Tx, localpart string, + ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, ) (err error) { - _, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "") + _, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "", serverName) return } func (s *profilesStatements) SelectProfileByLocalpart( - ctx context.Context, localpart string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) (*authtypes.Profile, error) { - var profile authtypes.Profile - err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan( + profile := authtypes.Profile{ServerName: string(serverName)} + err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart, serverName).Scan( &profile.Localpart, &profile.DisplayName, &profile.AvatarURL, ) if err != nil { @@ -99,16 +103,16 @@ func (s *profilesStatements) SelectProfileByLocalpart( } func (s *profilesStatements) SetAvatarURL( - ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, + ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string, ) (err error) { - _, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart) + _, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart, serverName) return } func (s *profilesStatements) SetDisplayName( - ctx context.Context, txn *sql.Tx, localpart string, displayName string, + ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, displayName string, ) (err error) { - _, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart) + _, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart, serverName) return } @@ -126,7 +130,7 @@ func (s *profilesStatements) SelectProfilesBySearch( defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed") for rows.Next() { var profile authtypes.Profile - if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil { + if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL, &profile.ServerName); err != nil { return nil, err } if profile.Localpart != s.serverNoticesLocalpart { diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index b9afb5a56..c70122d65 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -43,9 +43,13 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, // preparing statements for columns that don't exist yet return nil, err } + if _, err = db.Exec(profilesSchema); err != nil { + return nil, err + } deltas.LoadIsActive(m) //deltas.LoadLastSeenTSIP(m) deltas.LoadAddAccountType(m) + deltas.LoadProfilePrimaryKey(m, serverName) if err = m.RunDeltas(db, dbProperties); err != nil { return nil, err } diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 0cf713dac..310cf826f 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -83,28 +83,28 @@ func (d *Database) GetAccountByPassword( // GetProfileByLocalpart returns the profile associated with the given localpart. // Returns sql.ErrNoRows if no profile exists which matches the given localpart. func (d *Database) GetProfileByLocalpart( - ctx context.Context, localpart string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) (*authtypes.Profile, error) { - return d.Profiles.SelectProfileByLocalpart(ctx, localpart) + return d.Profiles.SelectProfileByLocalpart(ctx, localpart, serverName) } // SetAvatarURL updates the avatar URL of the profile associated with the given // localpart. Returns an error if something went wrong with the SQL query func (d *Database) SetAvatarURL( - ctx context.Context, localpart string, avatarURL string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL) + return d.Profiles.SetAvatarURL(ctx, txn, localpart, serverName, avatarURL) }) } // SetDisplayName updates the display name of the profile associated with the given // localpart. Returns an error if something went wrong with the SQL query func (d *Database) SetDisplayName( - ctx context.Context, localpart string, displayName string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, displayName string, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.Profiles.SetDisplayName(ctx, txn, localpart, displayName) + return d.Profiles.SetDisplayName(ctx, txn, localpart, serverName, displayName) }) } @@ -163,7 +163,7 @@ func (d *Database) createAccount( if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil { return nil, sqlutil.ErrUserExists } - if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil { + if err = d.Profiles.InsertProfile(ctx, txn, localpart, d.ServerName); err != nil { return nil, err } pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName) diff --git a/userapi/storage/sqlite3/deltas/2022060815233800_userprofile_servername.go b/userapi/storage/sqlite3/deltas/2022060815233800_userprofile_servername.go new file mode 100644 index 000000000..529ac5799 --- /dev/null +++ b/userapi/storage/sqlite3/deltas/2022060815233800_userprofile_servername.go @@ -0,0 +1,60 @@ +package deltas + +import ( + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +var serverName gomatrixserverlib.ServerName + +func LoadProfilePrimaryKey(m *sqlutil.Migrations, s gomatrixserverlib.ServerName) { + serverName = s + m.AddMigration(UpProfilePrimaryKey, DownProfilePrimaryKey) +} + +func UpProfilePrimaryKey(tx *sql.Tx) error { + _, err := tx.Exec(fmt.Sprintf(` + ALTER TABLE account_profiles RENAME TO account_profiles_tmp; + CREATE TABLE IF NOT EXISTS account_profiles ( + localpart TEXT NOT NULL, + servername TEXT NOT NULL, + display_name TEXT, + avatar_url TEXT, + PRIMARY KEY (localpart, servername) + ); + INSERT + INTO account_profiles ( + localpart, servername, display_name, avatar_url + ) SELECT + localpart, '%s', display_name, avatar_url + FROM account_profiles_tmp; + DROP TABLE account_profiles_tmp;`, serverName)) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownProfilePrimaryKey(tx *sql.Tx) error { + _, err := tx.Exec(` + ALTER TABLE account_profiles RENAME TO account_profiles_tmp; + CREATE TABLE IF NOT EXISTS account_profiles ( + localpart TEXT NOT NULL PRIMARY KEY, + display_name TEXT, + avatar_url TEXT + ); + INSERT + INTO account_profiles ( + localpart, display_name, avatar_url + ) SELECT + localpart, display_name, avatar_url + FROM account_profiles_tmp; + DROP TABLE account_profiles_tmp;`) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/userapi/storage/sqlite3/profile_table.go b/userapi/storage/sqlite3/profile_table.go index 3050ff4b5..7995c4f7c 100644 --- a/userapi/storage/sqlite3/profile_table.go +++ b/userapi/storage/sqlite3/profile_table.go @@ -23,34 +23,38 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" ) const profilesSchema = ` -- Stores data about accounts profiles. CREATE TABLE IF NOT EXISTS account_profiles ( -- The Matrix user ID localpart for this account - localpart TEXT NOT NULL PRIMARY KEY, + localpart TEXT NOT NULL, + -- The server this user belongs to + servername TEXT NOT NULL, -- The display name for this account display_name TEXT, -- The URL of the avatar for this account - avatar_url TEXT + avatar_url TEXT, + PRIMARY KEY (localpart, servername) ); ` const insertProfileSQL = "" + - "INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" + "INSERT INTO account_profiles(localpart, display_name, avatar_url, servername) VALUES ($1, $2, $3, $4)" const selectProfileByLocalpartSQL = "" + - "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1" + "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1 AND servername = $2" const setAvatarURLSQL = "" + - "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2" + "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2 AND servername = $3" const setDisplayNameSQL = "" + - "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2" + "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2 AND servername = $3" const selectProfilesBySearchSQL = "" + - "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" + "SELECT localpart, display_name, avatar_url, servername FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" type profilesStatements struct { db *sql.DB @@ -81,17 +85,17 @@ func NewSQLiteProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables.P } func (s *profilesStatements) InsertProfile( - ctx context.Context, txn *sql.Tx, localpart string, + ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, ) error { - _, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "") + _, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "", serverName) return err } func (s *profilesStatements) SelectProfileByLocalpart( - ctx context.Context, localpart string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) (*authtypes.Profile, error) { - var profile authtypes.Profile - err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan( + profile := authtypes.Profile{ServerName: string(serverName)} + err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart, serverName).Scan( &profile.Localpart, &profile.DisplayName, &profile.AvatarURL, ) if err != nil { @@ -101,18 +105,18 @@ func (s *profilesStatements) SelectProfileByLocalpart( } func (s *profilesStatements) SetAvatarURL( - ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, + ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string, ) (err error) { stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) - _, err = stmt.ExecContext(ctx, avatarURL, localpart) + _, err = stmt.ExecContext(ctx, avatarURL, localpart, serverName) return } func (s *profilesStatements) SetDisplayName( - ctx context.Context, txn *sql.Tx, localpart string, displayName string, + ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, displayName string, ) (err error) { stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) - _, err = stmt.ExecContext(ctx, displayName, localpart) + _, err = stmt.ExecContext(ctx, displayName, localpart, serverName) return } @@ -130,7 +134,7 @@ func (s *profilesStatements) SelectProfilesBySearch( defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed") for rows.Next() { var profile authtypes.Profile - if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil { + if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL, &profile.ServerName); err != nil { return nil, err } if profile.Localpart != s.serverNoticesLocalpart { diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index a822f687d..cacf7e1b5 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -44,9 +44,13 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, // preparing statements for columns that don't exist yet return nil, err } + if _, err = db.Exec(profilesSchema); err != nil { + return nil, err + } deltas.LoadIsActive(m) //deltas.LoadLastSeenTSIP(m) deltas.LoadAddAccountType(m) + deltas.LoadProfilePrimaryKey(m, serverName) if err = m.RunDeltas(db, dbProperties); err != nil { return nil, err } diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index a26097338..0bfce8c7b 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -22,6 +22,8 @@ import ( const loginTokenLifetime = time.Minute +const serverName = "example.com" + var ( openIDLifetimeMS = time.Minute.Milliseconds() ctx = context.Background() @@ -31,7 +33,7 @@ func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, fun connStr, close := test.PrepareDBConnectionString(t, dbType) db, err := storage.NewUserAPIDatabase(nil, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), - }, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server") + }, serverName, bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server") if err != nil { t.Fatalf("NewUserAPIDatabase returned %s", err) } @@ -370,20 +372,21 @@ func Test_Profile(t *testing.T) { _, err = db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin) assert.NoError(t, err, "failed to create account") - gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart) + gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart, serverName) assert.NoError(t, err, "unable to get profile by localpart") - wantProfile := &authtypes.Profile{Localpart: aliceLocalpart} + wantProfile := &authtypes.Profile{Localpart: aliceLocalpart, ServerName: string(serverName)} assert.Equal(t, wantProfile, gotProfile) // set avatar & displayname wantProfile.DisplayName = "Alice" wantProfile.AvatarURL = "mxc://aliceAvatar" - err = db.SetDisplayName(ctx, aliceLocalpart, "Alice") + wantProfile.ServerName = string(serverName) + err = db.SetDisplayName(ctx, aliceLocalpart, serverName, "Alice") assert.NoError(t, err, "unable to set displayname") - err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar") + err = db.SetAvatarURL(ctx, aliceLocalpart, serverName, "mxc://aliceAvatar") assert.NoError(t, err, "unable to set avatar url") // verify profile - gotProfile, err = db.GetProfileByLocalpart(ctx, aliceLocalpart) + gotProfile, err = db.GetProfileByLocalpart(ctx, aliceLocalpart, serverName) assert.NoError(t, err, "unable to get profile by localpart") assert.Equal(t, wantProfile, gotProfile) diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 2fe955670..092540ae1 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/types" + "github.com/matrix-org/gomatrixserverlib" ) type AccountDataTable interface { @@ -82,10 +83,10 @@ type OpenIDTable interface { } type ProfileTable interface { - InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error - SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) - SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (err error) - SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (err error) + InsertProfile(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName) error + SelectProfileByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*authtypes.Profile, error) + SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (err error) + SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, displayName string) (err error) SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) } diff --git a/userapi/storage/tables/profile_table_test.go b/userapi/storage/tables/profile_table_test.go new file mode 100644 index 000000000..d8aebd279 --- /dev/null +++ b/userapi/storage/tables/profile_table_test.go @@ -0,0 +1,117 @@ +package tables_test + +import ( + "context" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/userapi/storage/postgres" + "github.com/matrix-org/dendrite/userapi/storage/sqlite3" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" +) + +const serverNotice = "notice" + +func mustCreateProfileTable(t *testing.T, dbType test.DBType) (tab tables.ProfileTable, close func()) { + var connStr string + connStr, close = test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + if err != nil { + t.Fatal(err) + } + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresProfilesTable(db, serverNotice) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSQLiteProfilesTable(db, serverNotice) + } + if err != nil { + t.Fatalf("failed to create profiles table: %v", err) + } + return tab, close +} + +func mustCreateProfile(t *testing.T, ctx context.Context, tab tables.ProfileTable, localPart string, serverName gomatrixserverlib.ServerName) { + if err := tab.InsertProfile(ctx, nil, localPart, serverName); err != nil { + t.Fatalf("failed to insert profile: %v", err) + } +} + +func TestProfileTable(t *testing.T) { + ctx := context.Background() + serverName1 := gomatrixserverlib.ServerName("localhost") + serverName2 := gomatrixserverlib.ServerName("notlocalhost") + avatarURL := "newAvatarURL" + displayName := "newDisplayName" + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, close := mustCreateProfileTable(t, dbType) + defer close() + // Create serverNotice user + if err := tab.InsertProfile(ctx, nil, serverNotice, serverName1); err != nil { + t.Fatalf("failed to insert profile: %v", err) + } + // Ensure the a localpart is unique per serverName + if err := tab.InsertProfile(ctx, nil, serverNotice, serverName1); err == nil { + t.Fatalf("expected SQL insert to fail, but it didn't") + } + + mustCreateProfile(t, ctx, tab, "dummy1", serverName1) + mustCreateProfile(t, ctx, tab, "dummy1", serverName2) + mustCreateProfile(t, ctx, tab, "testing", serverName2) + + if err := tab.SetAvatarURL(ctx, nil, "dummy1", serverName1, avatarURL); err != nil { + t.Fatalf("failed to set avatar url: %v", err) + } + if err := tab.SetDisplayName(ctx, nil, "dummy1", serverName2, displayName); err != nil { + t.Fatalf("failed to set avatar url: %v", err) + } + + // Verify dummy1 on serverName2 is as expected, just to test the function + dummy1, err := tab.SelectProfileByLocalpart(ctx, "dummy1", serverName2) + if err != nil { + t.Fatalf("failed to query profile by localpart: %v", err) + } + // Make sure that only dummy1 on serverName1 got the displayName changed and avatarURL is unchanged + if dummy1.AvatarURL == avatarURL { + t.Fatalf("expected avatarURL %s, got %s", avatarURL, dummy1.AvatarURL) + } + if dummy1.DisplayName != displayName { + t.Fatalf("expected displayName %s, got %s", displayName, dummy1.DisplayName) + } + + searchRes, err := tab.SelectProfilesBySearch(ctx, "dummy", 10) + if err != nil { + t.Fatalf("unable to search profiles: %v", err) + } + // serverNotice user and testing should not be returned here, only the dummy users + if count := len(searchRes); count > 2 { + t.Fatalf("expected 2 results, got %d", count) + } + for _, profile := range searchRes { + if profile.Localpart != "dummy1" { + t.Fatalf("got unexpected localpart: %v", profile.Localpart) + } + // Make sure that only dummy1 on serverName1 got the avatarURL changed and displayName is unchanged + if gomatrixserverlib.ServerName(profile.ServerName) == serverName1 && profile.AvatarURL != avatarURL { + t.Fatalf("expected avatarURL %s, got %s", avatarURL, profile.AvatarURL) + } + if gomatrixserverlib.ServerName(profile.ServerName) == serverName1 && profile.DisplayName == displayName { + t.Fatalf("expected displayName %s, got %s", displayName, profile.DisplayName) + } + // Make sure that only dummy1 on serverName1 got the displayName changed and avatarURL is unchanged + if gomatrixserverlib.ServerName(profile.ServerName) == serverName2 && profile.AvatarURL == avatarURL { + t.Fatalf("expected avatarURL %s, got %s", avatarURL, profile.AvatarURL) + } + if gomatrixserverlib.ServerName(profile.ServerName) == serverName2 && profile.DisplayName != displayName { + t.Fatalf("expected displayName %s, got %s", displayName, profile.DisplayName) + } + } + + }) +}