Add servername to account_profiles

This commit is contained in:
Till Faelligen 2022-06-09 08:22:50 +02:00
parent 3cdefcf765
commit 093cbe483a
11 changed files with 288 additions and 54 deletions

View file

@ -23,13 +23,14 @@ import (
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/dendrite/userapi/types"
"github.com/matrix-org/gomatrixserverlib"
) )
type Profile interface { type Profile interface {
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) GetProfileByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*authtypes.Profile, error)
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error SetAvatarURL(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) error
SetDisplayName(ctx context.Context, localpart string, displayName string) error SetDisplayName(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, displayName string) error
} }
type Account interface { type Account interface {

View file

@ -0,0 +1,36 @@
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
var serverName gomatrixserverlib.ServerName
func LoadProfilePrimaryKey(m *sqlutil.Migrations, s gomatrixserverlib.ServerName) {
serverName = s
m.AddMigration(UpProfilePrimaryKey, DownProfilePrimaryKey)
}
func UpProfilePrimaryKey(tx *sql.Tx) error {
_, err := tx.Exec(fmt.Sprintf(`ALTER TABLE account_profiles ADD COLUMN IF NOT EXISTS servername TEXT NOT NULL DEFAULT '%s';
ALTER TABLE account_profiles DROP CONSTRAINT account_profiles_pkey;
ALTER TABLE account_profiles ADD PRIMARY KEY (localpart, servername);
ALTER TABLE account_profiles ALTER COLUMN servername DROP DEFAULT;`, serverName))
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownProfilePrimaryKey(tx *sql.Tx) error {
_, err := tx.Exec(`ALTER TABLE account_profiles DROP COLUMN IF EXISTS servername;
ALTER TABLE account_profiles ADD PRIMARY KEY(localpart);`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -23,34 +23,38 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
) )
const profilesSchema = ` const profilesSchema = `
-- Stores data about accounts profiles. -- Stores data about accounts profiles.
CREATE TABLE IF NOT EXISTS account_profiles ( CREATE TABLE IF NOT EXISTS account_profiles (
-- The Matrix user ID localpart for this account -- The Matrix user ID localpart for this account
localpart TEXT NOT NULL PRIMARY KEY, localpart TEXT NOT NULL,
-- The server this user belongs to
servername TEXT NOT NULL,
-- The display name for this account -- The display name for this account
display_name TEXT, display_name TEXT,
-- The URL of the avatar for this account -- The URL of the avatar for this account
avatar_url TEXT avatar_url TEXT,
PRIMARY KEY (localpart, servername)
); );
` `
const insertProfileSQL = "" + const insertProfileSQL = "" +
"INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" "INSERT INTO account_profiles(localpart, display_name, avatar_url, servername) VALUES ($1, $2, $3, $4)"
const selectProfileByLocalpartSQL = "" + const selectProfileByLocalpartSQL = "" +
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1" "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1 AND servername = $2"
const setAvatarURLSQL = "" + const setAvatarURLSQL = "" +
"UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2" "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2 AND servername = $3"
const setDisplayNameSQL = "" + const setDisplayNameSQL = "" +
"UPDATE account_profiles SET display_name = $1 WHERE localpart = $2" "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2 AND servername = $3"
const selectProfilesBySearchSQL = "" + const selectProfilesBySearchSQL = "" +
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" "SELECT localpart, display_name, avatar_url, servername FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
type profilesStatements struct { type profilesStatements struct {
serverNoticesLocalpart string serverNoticesLocalpart string
@ -79,17 +83,17 @@ func NewPostgresProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables
} }
func (s *profilesStatements) InsertProfile( func (s *profilesStatements) InsertProfile(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName,
) (err error) { ) (err error) {
_, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "") _, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "", serverName)
return return
} }
func (s *profilesStatements) SelectProfileByLocalpart( func (s *profilesStatements) SelectProfileByLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (*authtypes.Profile, error) { ) (*authtypes.Profile, error) {
var profile authtypes.Profile profile := authtypes.Profile{ServerName: string(serverName)}
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan( err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart, serverName).Scan(
&profile.Localpart, &profile.DisplayName, &profile.AvatarURL, &profile.Localpart, &profile.DisplayName, &profile.AvatarURL,
) )
if err != nil { if err != nil {
@ -99,16 +103,16 @@ func (s *profilesStatements) SelectProfileByLocalpart(
} }
func (s *profilesStatements) SetAvatarURL( func (s *profilesStatements) SetAvatarURL(
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string,
) (err error) { ) (err error) {
_, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart) _, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart, serverName)
return return
} }
func (s *profilesStatements) SetDisplayName( func (s *profilesStatements) SetDisplayName(
ctx context.Context, txn *sql.Tx, localpart string, displayName string, ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, displayName string,
) (err error) { ) (err error) {
_, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart) _, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart, serverName)
return return
} }
@ -126,7 +130,7 @@ func (s *profilesStatements) SelectProfilesBySearch(
defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed")
for rows.Next() { for rows.Next() {
var profile authtypes.Profile var profile authtypes.Profile
if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil { if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL, &profile.ServerName); err != nil {
return nil, err return nil, err
} }
if profile.Localpart != s.serverNoticesLocalpart { if profile.Localpart != s.serverNoticesLocalpart {

View file

@ -43,9 +43,13 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
// preparing statements for columns that don't exist yet // preparing statements for columns that don't exist yet
return nil, err return nil, err
} }
if _, err = db.Exec(profilesSchema); err != nil {
return nil, err
}
deltas.LoadIsActive(m) deltas.LoadIsActive(m)
//deltas.LoadLastSeenTSIP(m) //deltas.LoadLastSeenTSIP(m)
deltas.LoadAddAccountType(m) deltas.LoadAddAccountType(m)
deltas.LoadProfilePrimaryKey(m, serverName)
if err = m.RunDeltas(db, dbProperties); err != nil { if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err return nil, err
} }

View file

@ -83,28 +83,28 @@ func (d *Database) GetAccountByPassword(
// GetProfileByLocalpart returns the profile associated with the given localpart. // GetProfileByLocalpart returns the profile associated with the given localpart.
// Returns sql.ErrNoRows if no profile exists which matches the given localpart. // Returns sql.ErrNoRows if no profile exists which matches the given localpart.
func (d *Database) GetProfileByLocalpart( func (d *Database) GetProfileByLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (*authtypes.Profile, error) { ) (*authtypes.Profile, error) {
return d.Profiles.SelectProfileByLocalpart(ctx, localpart) return d.Profiles.SelectProfileByLocalpart(ctx, localpart, serverName)
} }
// SetAvatarURL updates the avatar URL of the profile associated with the given // SetAvatarURL updates the avatar URL of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query // localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetAvatarURL( func (d *Database) SetAvatarURL(
ctx context.Context, localpart string, avatarURL string, ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string,
) error { ) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL) return d.Profiles.SetAvatarURL(ctx, txn, localpart, serverName, avatarURL)
}) })
} }
// SetDisplayName updates the display name of the profile associated with the given // SetDisplayName updates the display name of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query // localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetDisplayName( func (d *Database) SetDisplayName(
ctx context.Context, localpart string, displayName string, ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, displayName string,
) error { ) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Profiles.SetDisplayName(ctx, txn, localpart, displayName) return d.Profiles.SetDisplayName(ctx, txn, localpart, serverName, displayName)
}) })
} }
@ -163,7 +163,7 @@ func (d *Database) createAccount(
if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil { if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil {
return nil, sqlutil.ErrUserExists return nil, sqlutil.ErrUserExists
} }
if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil { if err = d.Profiles.InsertProfile(ctx, txn, localpart, d.ServerName); err != nil {
return nil, err return nil, err
} }
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName) pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName)

View file

@ -0,0 +1,60 @@
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
var serverName gomatrixserverlib.ServerName
func LoadProfilePrimaryKey(m *sqlutil.Migrations, s gomatrixserverlib.ServerName) {
serverName = s
m.AddMigration(UpProfilePrimaryKey, DownProfilePrimaryKey)
}
func UpProfilePrimaryKey(tx *sql.Tx) error {
_, err := tx.Exec(fmt.Sprintf(`
ALTER TABLE account_profiles RENAME TO account_profiles_tmp;
CREATE TABLE IF NOT EXISTS account_profiles (
localpart TEXT NOT NULL,
servername TEXT NOT NULL,
display_name TEXT,
avatar_url TEXT,
PRIMARY KEY (localpart, servername)
);
INSERT
INTO account_profiles (
localpart, servername, display_name, avatar_url
) SELECT
localpart, '%s', display_name, avatar_url
FROM account_profiles_tmp;
DROP TABLE account_profiles_tmp;`, serverName))
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownProfilePrimaryKey(tx *sql.Tx) error {
_, err := tx.Exec(`
ALTER TABLE account_profiles RENAME TO account_profiles_tmp;
CREATE TABLE IF NOT EXISTS account_profiles (
localpart TEXT NOT NULL PRIMARY KEY,
display_name TEXT,
avatar_url TEXT
);
INSERT
INTO account_profiles (
localpart, display_name, avatar_url
) SELECT
localpart, display_name, avatar_url
FROM account_profiles_tmp;
DROP TABLE account_profiles_tmp;`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -23,34 +23,38 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
) )
const profilesSchema = ` const profilesSchema = `
-- Stores data about accounts profiles. -- Stores data about accounts profiles.
CREATE TABLE IF NOT EXISTS account_profiles ( CREATE TABLE IF NOT EXISTS account_profiles (
-- The Matrix user ID localpart for this account -- The Matrix user ID localpart for this account
localpart TEXT NOT NULL PRIMARY KEY, localpart TEXT NOT NULL,
-- The server this user belongs to
servername TEXT NOT NULL,
-- The display name for this account -- The display name for this account
display_name TEXT, display_name TEXT,
-- The URL of the avatar for this account -- The URL of the avatar for this account
avatar_url TEXT avatar_url TEXT,
PRIMARY KEY (localpart, servername)
); );
` `
const insertProfileSQL = "" + const insertProfileSQL = "" +
"INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" "INSERT INTO account_profiles(localpart, display_name, avatar_url, servername) VALUES ($1, $2, $3, $4)"
const selectProfileByLocalpartSQL = "" + const selectProfileByLocalpartSQL = "" +
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1" "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1 AND servername = $2"
const setAvatarURLSQL = "" + const setAvatarURLSQL = "" +
"UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2" "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2 AND servername = $3"
const setDisplayNameSQL = "" + const setDisplayNameSQL = "" +
"UPDATE account_profiles SET display_name = $1 WHERE localpart = $2" "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2 AND servername = $3"
const selectProfilesBySearchSQL = "" + const selectProfilesBySearchSQL = "" +
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" "SELECT localpart, display_name, avatar_url, servername FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
type profilesStatements struct { type profilesStatements struct {
db *sql.DB db *sql.DB
@ -81,17 +85,17 @@ func NewSQLiteProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables.P
} }
func (s *profilesStatements) InsertProfile( func (s *profilesStatements) InsertProfile(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName,
) error { ) error {
_, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "") _, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "", serverName)
return err return err
} }
func (s *profilesStatements) SelectProfileByLocalpart( func (s *profilesStatements) SelectProfileByLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (*authtypes.Profile, error) { ) (*authtypes.Profile, error) {
var profile authtypes.Profile profile := authtypes.Profile{ServerName: string(serverName)}
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan( err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart, serverName).Scan(
&profile.Localpart, &profile.DisplayName, &profile.AvatarURL, &profile.Localpart, &profile.DisplayName, &profile.AvatarURL,
) )
if err != nil { if err != nil {
@ -101,18 +105,18 @@ func (s *profilesStatements) SelectProfileByLocalpart(
} }
func (s *profilesStatements) SetAvatarURL( func (s *profilesStatements) SetAvatarURL(
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string,
) (err error) { ) (err error) {
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
_, err = stmt.ExecContext(ctx, avatarURL, localpart) _, err = stmt.ExecContext(ctx, avatarURL, localpart, serverName)
return return
} }
func (s *profilesStatements) SetDisplayName( func (s *profilesStatements) SetDisplayName(
ctx context.Context, txn *sql.Tx, localpart string, displayName string, ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, displayName string,
) (err error) { ) (err error) {
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
_, err = stmt.ExecContext(ctx, displayName, localpart) _, err = stmt.ExecContext(ctx, displayName, localpart, serverName)
return return
} }
@ -130,7 +134,7 @@ func (s *profilesStatements) SelectProfilesBySearch(
defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed")
for rows.Next() { for rows.Next() {
var profile authtypes.Profile var profile authtypes.Profile
if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil { if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL, &profile.ServerName); err != nil {
return nil, err return nil, err
} }
if profile.Localpart != s.serverNoticesLocalpart { if profile.Localpart != s.serverNoticesLocalpart {

View file

@ -44,9 +44,13 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
// preparing statements for columns that don't exist yet // preparing statements for columns that don't exist yet
return nil, err return nil, err
} }
if _, err = db.Exec(profilesSchema); err != nil {
return nil, err
}
deltas.LoadIsActive(m) deltas.LoadIsActive(m)
//deltas.LoadLastSeenTSIP(m) //deltas.LoadLastSeenTSIP(m)
deltas.LoadAddAccountType(m) deltas.LoadAddAccountType(m)
deltas.LoadProfilePrimaryKey(m, serverName)
if err = m.RunDeltas(db, dbProperties); err != nil { if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err return nil, err
} }

View file

@ -22,6 +22,8 @@ import (
const loginTokenLifetime = time.Minute const loginTokenLifetime = time.Minute
const serverName = "example.com"
var ( var (
openIDLifetimeMS = time.Minute.Milliseconds() openIDLifetimeMS = time.Minute.Milliseconds()
ctx = context.Background() ctx = context.Background()
@ -31,7 +33,7 @@ func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, fun
connStr, close := test.PrepareDBConnectionString(t, dbType) connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := storage.NewUserAPIDatabase(nil, &config.DatabaseOptions{ db, err := storage.NewUserAPIDatabase(nil, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr), ConnectionString: config.DataSource(connStr),
}, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server") }, serverName, bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server")
if err != nil { if err != nil {
t.Fatalf("NewUserAPIDatabase returned %s", err) t.Fatalf("NewUserAPIDatabase returned %s", err)
} }
@ -370,20 +372,21 @@ func Test_Profile(t *testing.T) {
_, err = db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin) _, err = db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin)
assert.NoError(t, err, "failed to create account") assert.NoError(t, err, "failed to create account")
gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart) gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart, serverName)
assert.NoError(t, err, "unable to get profile by localpart") assert.NoError(t, err, "unable to get profile by localpart")
wantProfile := &authtypes.Profile{Localpart: aliceLocalpart} wantProfile := &authtypes.Profile{Localpart: aliceLocalpart, ServerName: string(serverName)}
assert.Equal(t, wantProfile, gotProfile) assert.Equal(t, wantProfile, gotProfile)
// set avatar & displayname // set avatar & displayname
wantProfile.DisplayName = "Alice" wantProfile.DisplayName = "Alice"
wantProfile.AvatarURL = "mxc://aliceAvatar" wantProfile.AvatarURL = "mxc://aliceAvatar"
err = db.SetDisplayName(ctx, aliceLocalpart, "Alice") wantProfile.ServerName = string(serverName)
err = db.SetDisplayName(ctx, aliceLocalpart, serverName, "Alice")
assert.NoError(t, err, "unable to set displayname") assert.NoError(t, err, "unable to set displayname")
err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar") err = db.SetAvatarURL(ctx, aliceLocalpart, serverName, "mxc://aliceAvatar")
assert.NoError(t, err, "unable to set avatar url") assert.NoError(t, err, "unable to set avatar url")
// verify profile // verify profile
gotProfile, err = db.GetProfileByLocalpart(ctx, aliceLocalpart) gotProfile, err = db.GetProfileByLocalpart(ctx, aliceLocalpart, serverName)
assert.NoError(t, err, "unable to get profile by localpart") assert.NoError(t, err, "unable to get profile by localpart")
assert.Equal(t, wantProfile, gotProfile) assert.Equal(t, wantProfile, gotProfile)

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/dendrite/userapi/types"
"github.com/matrix-org/gomatrixserverlib"
) )
type AccountDataTable interface { type AccountDataTable interface {
@ -82,10 +83,10 @@ type OpenIDTable interface {
} }
type ProfileTable interface { type ProfileTable interface {
InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error InsertProfile(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName) error
SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) SelectProfileByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*authtypes.Profile, error)
SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (err error) SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (err error)
SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (err error) SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, displayName string) (err error)
SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
} }

View file

@ -0,0 +1,117 @@
package tables_test
import (
"context"
"testing"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/userapi/storage/postgres"
"github.com/matrix-org/dendrite/userapi/storage/sqlite3"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
const serverNotice = "notice"
func mustCreateProfileTable(t *testing.T, dbType test.DBType) (tab tables.ProfileTable, close func()) {
var connStr string
connStr, close = test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
if err != nil {
t.Fatal(err)
}
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresProfilesTable(db, serverNotice)
case test.DBTypeSQLite:
tab, err = sqlite3.NewSQLiteProfilesTable(db, serverNotice)
}
if err != nil {
t.Fatalf("failed to create profiles table: %v", err)
}
return tab, close
}
func mustCreateProfile(t *testing.T, ctx context.Context, tab tables.ProfileTable, localPart string, serverName gomatrixserverlib.ServerName) {
if err := tab.InsertProfile(ctx, nil, localPart, serverName); err != nil {
t.Fatalf("failed to insert profile: %v", err)
}
}
func TestProfileTable(t *testing.T) {
ctx := context.Background()
serverName1 := gomatrixserverlib.ServerName("localhost")
serverName2 := gomatrixserverlib.ServerName("notlocalhost")
avatarURL := "newAvatarURL"
displayName := "newDisplayName"
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, close := mustCreateProfileTable(t, dbType)
defer close()
// Create serverNotice user
if err := tab.InsertProfile(ctx, nil, serverNotice, serverName1); err != nil {
t.Fatalf("failed to insert profile: %v", err)
}
// Ensure the a localpart is unique per serverName
if err := tab.InsertProfile(ctx, nil, serverNotice, serverName1); err == nil {
t.Fatalf("expected SQL insert to fail, but it didn't")
}
mustCreateProfile(t, ctx, tab, "dummy1", serverName1)
mustCreateProfile(t, ctx, tab, "dummy1", serverName2)
mustCreateProfile(t, ctx, tab, "testing", serverName2)
if err := tab.SetAvatarURL(ctx, nil, "dummy1", serverName1, avatarURL); err != nil {
t.Fatalf("failed to set avatar url: %v", err)
}
if err := tab.SetDisplayName(ctx, nil, "dummy1", serverName2, displayName); err != nil {
t.Fatalf("failed to set avatar url: %v", err)
}
// Verify dummy1 on serverName2 is as expected, just to test the function
dummy1, err := tab.SelectProfileByLocalpart(ctx, "dummy1", serverName2)
if err != nil {
t.Fatalf("failed to query profile by localpart: %v", err)
}
// Make sure that only dummy1 on serverName1 got the displayName changed and avatarURL is unchanged
if dummy1.AvatarURL == avatarURL {
t.Fatalf("expected avatarURL %s, got %s", avatarURL, dummy1.AvatarURL)
}
if dummy1.DisplayName != displayName {
t.Fatalf("expected displayName %s, got %s", displayName, dummy1.DisplayName)
}
searchRes, err := tab.SelectProfilesBySearch(ctx, "dummy", 10)
if err != nil {
t.Fatalf("unable to search profiles: %v", err)
}
// serverNotice user and testing should not be returned here, only the dummy users
if count := len(searchRes); count > 2 {
t.Fatalf("expected 2 results, got %d", count)
}
for _, profile := range searchRes {
if profile.Localpart != "dummy1" {
t.Fatalf("got unexpected localpart: %v", profile.Localpart)
}
// Make sure that only dummy1 on serverName1 got the avatarURL changed and displayName is unchanged
if gomatrixserverlib.ServerName(profile.ServerName) == serverName1 && profile.AvatarURL != avatarURL {
t.Fatalf("expected avatarURL %s, got %s", avatarURL, profile.AvatarURL)
}
if gomatrixserverlib.ServerName(profile.ServerName) == serverName1 && profile.DisplayName == displayName {
t.Fatalf("expected displayName %s, got %s", displayName, profile.DisplayName)
}
// Make sure that only dummy1 on serverName1 got the displayName changed and avatarURL is unchanged
if gomatrixserverlib.ServerName(profile.ServerName) == serverName2 && profile.AvatarURL == avatarURL {
t.Fatalf("expected avatarURL %s, got %s", avatarURL, profile.AvatarURL)
}
if gomatrixserverlib.ServerName(profile.ServerName) == serverName2 && profile.DisplayName != displayName {
t.Fatalf("expected displayName %s, got %s", displayName, profile.DisplayName)
}
}
})
}