mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-07 06:03:09 -06:00
Add servername to account_profiles
This commit is contained in:
parent
3cdefcf765
commit
093cbe483a
|
|
@ -23,13 +23,14 @@ import (
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/userapi/types"
|
"github.com/matrix-org/dendrite/userapi/types"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Profile interface {
|
type Profile interface {
|
||||||
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
|
GetProfileByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*authtypes.Profile, error)
|
||||||
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
|
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
|
||||||
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error
|
SetAvatarURL(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) error
|
||||||
SetDisplayName(ctx context.Context, localpart string, displayName string) error
|
SetDisplayName(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, displayName string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type Account interface {
|
type Account interface {
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,36 @@
|
||||||
|
package deltas
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
var serverName gomatrixserverlib.ServerName
|
||||||
|
|
||||||
|
func LoadProfilePrimaryKey(m *sqlutil.Migrations, s gomatrixserverlib.ServerName) {
|
||||||
|
serverName = s
|
||||||
|
m.AddMigration(UpProfilePrimaryKey, DownProfilePrimaryKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func UpProfilePrimaryKey(tx *sql.Tx) error {
|
||||||
|
_, err := tx.Exec(fmt.Sprintf(`ALTER TABLE account_profiles ADD COLUMN IF NOT EXISTS servername TEXT NOT NULL DEFAULT '%s';
|
||||||
|
ALTER TABLE account_profiles DROP CONSTRAINT account_profiles_pkey;
|
||||||
|
ALTER TABLE account_profiles ADD PRIMARY KEY (localpart, servername);
|
||||||
|
ALTER TABLE account_profiles ALTER COLUMN servername DROP DEFAULT;`, serverName))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DownProfilePrimaryKey(tx *sql.Tx) error {
|
||||||
|
_, err := tx.Exec(`ALTER TABLE account_profiles DROP COLUMN IF EXISTS servername;
|
||||||
|
ALTER TABLE account_profiles ADD PRIMARY KEY(localpart);`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
@ -23,34 +23,38 @@ import (
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
const profilesSchema = `
|
const profilesSchema = `
|
||||||
-- Stores data about accounts profiles.
|
-- Stores data about accounts profiles.
|
||||||
CREATE TABLE IF NOT EXISTS account_profiles (
|
CREATE TABLE IF NOT EXISTS account_profiles (
|
||||||
-- The Matrix user ID localpart for this account
|
-- The Matrix user ID localpart for this account
|
||||||
localpart TEXT NOT NULL PRIMARY KEY,
|
localpart TEXT NOT NULL,
|
||||||
|
-- The server this user belongs to
|
||||||
|
servername TEXT NOT NULL,
|
||||||
-- The display name for this account
|
-- The display name for this account
|
||||||
display_name TEXT,
|
display_name TEXT,
|
||||||
-- The URL of the avatar for this account
|
-- The URL of the avatar for this account
|
||||||
avatar_url TEXT
|
avatar_url TEXT,
|
||||||
|
PRIMARY KEY (localpart, servername)
|
||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertProfileSQL = "" +
|
const insertProfileSQL = "" +
|
||||||
"INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
|
"INSERT INTO account_profiles(localpart, display_name, avatar_url, servername) VALUES ($1, $2, $3, $4)"
|
||||||
|
|
||||||
const selectProfileByLocalpartSQL = "" +
|
const selectProfileByLocalpartSQL = "" +
|
||||||
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1"
|
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1 AND servername = $2"
|
||||||
|
|
||||||
const setAvatarURLSQL = "" +
|
const setAvatarURLSQL = "" +
|
||||||
"UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2"
|
"UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2 AND servername = $3"
|
||||||
|
|
||||||
const setDisplayNameSQL = "" +
|
const setDisplayNameSQL = "" +
|
||||||
"UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
|
"UPDATE account_profiles SET display_name = $1 WHERE localpart = $2 AND servername = $3"
|
||||||
|
|
||||||
const selectProfilesBySearchSQL = "" +
|
const selectProfilesBySearchSQL = "" +
|
||||||
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
"SELECT localpart, display_name, avatar_url, servername FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
||||||
|
|
||||||
type profilesStatements struct {
|
type profilesStatements struct {
|
||||||
serverNoticesLocalpart string
|
serverNoticesLocalpart string
|
||||||
|
|
@ -79,17 +83,17 @@ func NewPostgresProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) InsertProfile(
|
func (s *profilesStatements) InsertProfile(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string,
|
ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
|
_, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "", serverName)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) SelectProfileByLocalpart(
|
func (s *profilesStatements) SelectProfileByLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (*authtypes.Profile, error) {
|
) (*authtypes.Profile, error) {
|
||||||
var profile authtypes.Profile
|
profile := authtypes.Profile{ServerName: string(serverName)}
|
||||||
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan(
|
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart, serverName).Scan(
|
||||||
&profile.Localpart, &profile.DisplayName, &profile.AvatarURL,
|
&profile.Localpart, &profile.DisplayName, &profile.AvatarURL,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -99,16 +103,16 @@ func (s *profilesStatements) SelectProfileByLocalpart(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) SetAvatarURL(
|
func (s *profilesStatements) SetAvatarURL(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
|
ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart)
|
_, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart, serverName)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) SetDisplayName(
|
func (s *profilesStatements) SetDisplayName(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
|
ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, displayName string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart)
|
_, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart, serverName)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -126,7 +130,7 @@ func (s *profilesStatements) SelectProfilesBySearch(
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed")
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var profile authtypes.Profile
|
var profile authtypes.Profile
|
||||||
if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil {
|
if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL, &profile.ServerName); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if profile.Localpart != s.serverNoticesLocalpart {
|
if profile.Localpart != s.serverNoticesLocalpart {
|
||||||
|
|
|
||||||
|
|
@ -43,9 +43,13 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
||||||
// preparing statements for columns that don't exist yet
|
// preparing statements for columns that don't exist yet
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if _, err = db.Exec(profilesSchema); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
deltas.LoadIsActive(m)
|
deltas.LoadIsActive(m)
|
||||||
//deltas.LoadLastSeenTSIP(m)
|
//deltas.LoadLastSeenTSIP(m)
|
||||||
deltas.LoadAddAccountType(m)
|
deltas.LoadAddAccountType(m)
|
||||||
|
deltas.LoadProfilePrimaryKey(m, serverName)
|
||||||
if err = m.RunDeltas(db, dbProperties); err != nil {
|
if err = m.RunDeltas(db, dbProperties); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -83,28 +83,28 @@ func (d *Database) GetAccountByPassword(
|
||||||
// GetProfileByLocalpart returns the profile associated with the given localpart.
|
// GetProfileByLocalpart returns the profile associated with the given localpart.
|
||||||
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
|
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
|
||||||
func (d *Database) GetProfileByLocalpart(
|
func (d *Database) GetProfileByLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (*authtypes.Profile, error) {
|
) (*authtypes.Profile, error) {
|
||||||
return d.Profiles.SelectProfileByLocalpart(ctx, localpart)
|
return d.Profiles.SelectProfileByLocalpart(ctx, localpart, serverName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetAvatarURL updates the avatar URL of the profile associated with the given
|
// SetAvatarURL updates the avatar URL of the profile associated with the given
|
||||||
// localpart. Returns an error if something went wrong with the SQL query
|
// localpart. Returns an error if something went wrong with the SQL query
|
||||||
func (d *Database) SetAvatarURL(
|
func (d *Database) SetAvatarURL(
|
||||||
ctx context.Context, localpart string, avatarURL string,
|
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string,
|
||||||
) error {
|
) error {
|
||||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
return d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL)
|
return d.Profiles.SetAvatarURL(ctx, txn, localpart, serverName, avatarURL)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetDisplayName updates the display name of the profile associated with the given
|
// SetDisplayName updates the display name of the profile associated with the given
|
||||||
// localpart. Returns an error if something went wrong with the SQL query
|
// localpart. Returns an error if something went wrong with the SQL query
|
||||||
func (d *Database) SetDisplayName(
|
func (d *Database) SetDisplayName(
|
||||||
ctx context.Context, localpart string, displayName string,
|
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, displayName string,
|
||||||
) error {
|
) error {
|
||||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
return d.Profiles.SetDisplayName(ctx, txn, localpart, displayName)
|
return d.Profiles.SetDisplayName(ctx, txn, localpart, serverName, displayName)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -163,7 +163,7 @@ func (d *Database) createAccount(
|
||||||
if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil {
|
if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil {
|
||||||
return nil, sqlutil.ErrUserExists
|
return nil, sqlutil.ErrUserExists
|
||||||
}
|
}
|
||||||
if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil {
|
if err = d.Profiles.InsertProfile(ctx, txn, localpart, d.ServerName); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName)
|
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,60 @@
|
||||||
|
package deltas
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
var serverName gomatrixserverlib.ServerName
|
||||||
|
|
||||||
|
func LoadProfilePrimaryKey(m *sqlutil.Migrations, s gomatrixserverlib.ServerName) {
|
||||||
|
serverName = s
|
||||||
|
m.AddMigration(UpProfilePrimaryKey, DownProfilePrimaryKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func UpProfilePrimaryKey(tx *sql.Tx) error {
|
||||||
|
_, err := tx.Exec(fmt.Sprintf(`
|
||||||
|
ALTER TABLE account_profiles RENAME TO account_profiles_tmp;
|
||||||
|
CREATE TABLE IF NOT EXISTS account_profiles (
|
||||||
|
localpart TEXT NOT NULL,
|
||||||
|
servername TEXT NOT NULL,
|
||||||
|
display_name TEXT,
|
||||||
|
avatar_url TEXT,
|
||||||
|
PRIMARY KEY (localpart, servername)
|
||||||
|
);
|
||||||
|
INSERT
|
||||||
|
INTO account_profiles (
|
||||||
|
localpart, servername, display_name, avatar_url
|
||||||
|
) SELECT
|
||||||
|
localpart, '%s', display_name, avatar_url
|
||||||
|
FROM account_profiles_tmp;
|
||||||
|
DROP TABLE account_profiles_tmp;`, serverName))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DownProfilePrimaryKey(tx *sql.Tx) error {
|
||||||
|
_, err := tx.Exec(`
|
||||||
|
ALTER TABLE account_profiles RENAME TO account_profiles_tmp;
|
||||||
|
CREATE TABLE IF NOT EXISTS account_profiles (
|
||||||
|
localpart TEXT NOT NULL PRIMARY KEY,
|
||||||
|
display_name TEXT,
|
||||||
|
avatar_url TEXT
|
||||||
|
);
|
||||||
|
INSERT
|
||||||
|
INTO account_profiles (
|
||||||
|
localpart, display_name, avatar_url
|
||||||
|
) SELECT
|
||||||
|
localpart, display_name, avatar_url
|
||||||
|
FROM account_profiles_tmp;
|
||||||
|
DROP TABLE account_profiles_tmp;`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
@ -23,34 +23,38 @@ import (
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
const profilesSchema = `
|
const profilesSchema = `
|
||||||
-- Stores data about accounts profiles.
|
-- Stores data about accounts profiles.
|
||||||
CREATE TABLE IF NOT EXISTS account_profiles (
|
CREATE TABLE IF NOT EXISTS account_profiles (
|
||||||
-- The Matrix user ID localpart for this account
|
-- The Matrix user ID localpart for this account
|
||||||
localpart TEXT NOT NULL PRIMARY KEY,
|
localpart TEXT NOT NULL,
|
||||||
|
-- The server this user belongs to
|
||||||
|
servername TEXT NOT NULL,
|
||||||
-- The display name for this account
|
-- The display name for this account
|
||||||
display_name TEXT,
|
display_name TEXT,
|
||||||
-- The URL of the avatar for this account
|
-- The URL of the avatar for this account
|
||||||
avatar_url TEXT
|
avatar_url TEXT,
|
||||||
|
PRIMARY KEY (localpart, servername)
|
||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertProfileSQL = "" +
|
const insertProfileSQL = "" +
|
||||||
"INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
|
"INSERT INTO account_profiles(localpart, display_name, avatar_url, servername) VALUES ($1, $2, $3, $4)"
|
||||||
|
|
||||||
const selectProfileByLocalpartSQL = "" +
|
const selectProfileByLocalpartSQL = "" +
|
||||||
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1"
|
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1 AND servername = $2"
|
||||||
|
|
||||||
const setAvatarURLSQL = "" +
|
const setAvatarURLSQL = "" +
|
||||||
"UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2"
|
"UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2 AND servername = $3"
|
||||||
|
|
||||||
const setDisplayNameSQL = "" +
|
const setDisplayNameSQL = "" +
|
||||||
"UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
|
"UPDATE account_profiles SET display_name = $1 WHERE localpart = $2 AND servername = $3"
|
||||||
|
|
||||||
const selectProfilesBySearchSQL = "" +
|
const selectProfilesBySearchSQL = "" +
|
||||||
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
"SELECT localpart, display_name, avatar_url, servername FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
||||||
|
|
||||||
type profilesStatements struct {
|
type profilesStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
|
@ -81,17 +85,17 @@ func NewSQLiteProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables.P
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) InsertProfile(
|
func (s *profilesStatements) InsertProfile(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string,
|
ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) error {
|
) error {
|
||||||
_, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
|
_, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "", serverName)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) SelectProfileByLocalpart(
|
func (s *profilesStatements) SelectProfileByLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (*authtypes.Profile, error) {
|
) (*authtypes.Profile, error) {
|
||||||
var profile authtypes.Profile
|
profile := authtypes.Profile{ServerName: string(serverName)}
|
||||||
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan(
|
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart, serverName).Scan(
|
||||||
&profile.Localpart, &profile.DisplayName, &profile.AvatarURL,
|
&profile.Localpart, &profile.DisplayName, &profile.AvatarURL,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -101,18 +105,18 @@ func (s *profilesStatements) SelectProfileByLocalpart(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) SetAvatarURL(
|
func (s *profilesStatements) SetAvatarURL(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
|
ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
|
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
|
||||||
_, err = stmt.ExecContext(ctx, avatarURL, localpart)
|
_, err = stmt.ExecContext(ctx, avatarURL, localpart, serverName)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) SetDisplayName(
|
func (s *profilesStatements) SetDisplayName(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
|
ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, displayName string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
|
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
|
||||||
_, err = stmt.ExecContext(ctx, displayName, localpart)
|
_, err = stmt.ExecContext(ctx, displayName, localpart, serverName)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -130,7 +134,7 @@ func (s *profilesStatements) SelectProfilesBySearch(
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed")
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var profile authtypes.Profile
|
var profile authtypes.Profile
|
||||||
if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil {
|
if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL, &profile.ServerName); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if profile.Localpart != s.serverNoticesLocalpart {
|
if profile.Localpart != s.serverNoticesLocalpart {
|
||||||
|
|
|
||||||
|
|
@ -44,9 +44,13 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
||||||
// preparing statements for columns that don't exist yet
|
// preparing statements for columns that don't exist yet
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if _, err = db.Exec(profilesSchema); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
deltas.LoadIsActive(m)
|
deltas.LoadIsActive(m)
|
||||||
//deltas.LoadLastSeenTSIP(m)
|
//deltas.LoadLastSeenTSIP(m)
|
||||||
deltas.LoadAddAccountType(m)
|
deltas.LoadAddAccountType(m)
|
||||||
|
deltas.LoadProfilePrimaryKey(m, serverName)
|
||||||
if err = m.RunDeltas(db, dbProperties); err != nil {
|
if err = m.RunDeltas(db, dbProperties); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,8 @@ import (
|
||||||
|
|
||||||
const loginTokenLifetime = time.Minute
|
const loginTokenLifetime = time.Minute
|
||||||
|
|
||||||
|
const serverName = "example.com"
|
||||||
|
|
||||||
var (
|
var (
|
||||||
openIDLifetimeMS = time.Minute.Milliseconds()
|
openIDLifetimeMS = time.Minute.Milliseconds()
|
||||||
ctx = context.Background()
|
ctx = context.Background()
|
||||||
|
|
@ -31,7 +33,7 @@ func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, fun
|
||||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||||
db, err := storage.NewUserAPIDatabase(nil, &config.DatabaseOptions{
|
db, err := storage.NewUserAPIDatabase(nil, &config.DatabaseOptions{
|
||||||
ConnectionString: config.DataSource(connStr),
|
ConnectionString: config.DataSource(connStr),
|
||||||
}, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server")
|
}, serverName, bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewUserAPIDatabase returned %s", err)
|
t.Fatalf("NewUserAPIDatabase returned %s", err)
|
||||||
}
|
}
|
||||||
|
|
@ -370,20 +372,21 @@ func Test_Profile(t *testing.T) {
|
||||||
_, err = db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin)
|
_, err = db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin)
|
||||||
assert.NoError(t, err, "failed to create account")
|
assert.NoError(t, err, "failed to create account")
|
||||||
|
|
||||||
gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart)
|
gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart, serverName)
|
||||||
assert.NoError(t, err, "unable to get profile by localpart")
|
assert.NoError(t, err, "unable to get profile by localpart")
|
||||||
wantProfile := &authtypes.Profile{Localpart: aliceLocalpart}
|
wantProfile := &authtypes.Profile{Localpart: aliceLocalpart, ServerName: string(serverName)}
|
||||||
assert.Equal(t, wantProfile, gotProfile)
|
assert.Equal(t, wantProfile, gotProfile)
|
||||||
|
|
||||||
// set avatar & displayname
|
// set avatar & displayname
|
||||||
wantProfile.DisplayName = "Alice"
|
wantProfile.DisplayName = "Alice"
|
||||||
wantProfile.AvatarURL = "mxc://aliceAvatar"
|
wantProfile.AvatarURL = "mxc://aliceAvatar"
|
||||||
err = db.SetDisplayName(ctx, aliceLocalpart, "Alice")
|
wantProfile.ServerName = string(serverName)
|
||||||
|
err = db.SetDisplayName(ctx, aliceLocalpart, serverName, "Alice")
|
||||||
assert.NoError(t, err, "unable to set displayname")
|
assert.NoError(t, err, "unable to set displayname")
|
||||||
err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
|
err = db.SetAvatarURL(ctx, aliceLocalpart, serverName, "mxc://aliceAvatar")
|
||||||
assert.NoError(t, err, "unable to set avatar url")
|
assert.NoError(t, err, "unable to set avatar url")
|
||||||
// verify profile
|
// verify profile
|
||||||
gotProfile, err = db.GetProfileByLocalpart(ctx, aliceLocalpart)
|
gotProfile, err = db.GetProfileByLocalpart(ctx, aliceLocalpart, serverName)
|
||||||
assert.NoError(t, err, "unable to get profile by localpart")
|
assert.NoError(t, err, "unable to get profile by localpart")
|
||||||
assert.Equal(t, wantProfile, gotProfile)
|
assert.Equal(t, wantProfile, gotProfile)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/types"
|
"github.com/matrix-org/dendrite/userapi/types"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
type AccountDataTable interface {
|
type AccountDataTable interface {
|
||||||
|
|
@ -82,10 +83,10 @@ type OpenIDTable interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type ProfileTable interface {
|
type ProfileTable interface {
|
||||||
InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error
|
InsertProfile(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName) error
|
||||||
SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
|
SelectProfileByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*authtypes.Profile, error)
|
||||||
SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (err error)
|
SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (err error)
|
||||||
SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (err error)
|
SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, displayName string) (err error)
|
||||||
SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
|
SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
117
userapi/storage/tables/profile_table_test.go
Normal file
117
userapi/storage/tables/profile_table_test.go
Normal file
|
|
@ -0,0 +1,117 @@
|
||||||
|
package tables_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/test"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/postgres"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/sqlite3"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
const serverNotice = "notice"
|
||||||
|
|
||||||
|
func mustCreateProfileTable(t *testing.T, dbType test.DBType) (tab tables.ProfileTable, close func()) {
|
||||||
|
var connStr string
|
||||||
|
connStr, close = test.PrepareDBConnectionString(t, dbType)
|
||||||
|
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||||
|
ConnectionString: config.DataSource(connStr),
|
||||||
|
}, sqlutil.NewExclusiveWriter())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
switch dbType {
|
||||||
|
case test.DBTypePostgres:
|
||||||
|
tab, err = postgres.NewPostgresProfilesTable(db, serverNotice)
|
||||||
|
case test.DBTypeSQLite:
|
||||||
|
tab, err = sqlite3.NewSQLiteProfilesTable(db, serverNotice)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create profiles table: %v", err)
|
||||||
|
}
|
||||||
|
return tab, close
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustCreateProfile(t *testing.T, ctx context.Context, tab tables.ProfileTable, localPart string, serverName gomatrixserverlib.ServerName) {
|
||||||
|
if err := tab.InsertProfile(ctx, nil, localPart, serverName); err != nil {
|
||||||
|
t.Fatalf("failed to insert profile: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProfileTable(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
serverName1 := gomatrixserverlib.ServerName("localhost")
|
||||||
|
serverName2 := gomatrixserverlib.ServerName("notlocalhost")
|
||||||
|
avatarURL := "newAvatarURL"
|
||||||
|
displayName := "newDisplayName"
|
||||||
|
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
tab, close := mustCreateProfileTable(t, dbType)
|
||||||
|
defer close()
|
||||||
|
// Create serverNotice user
|
||||||
|
if err := tab.InsertProfile(ctx, nil, serverNotice, serverName1); err != nil {
|
||||||
|
t.Fatalf("failed to insert profile: %v", err)
|
||||||
|
}
|
||||||
|
// Ensure the a localpart is unique per serverName
|
||||||
|
if err := tab.InsertProfile(ctx, nil, serverNotice, serverName1); err == nil {
|
||||||
|
t.Fatalf("expected SQL insert to fail, but it didn't")
|
||||||
|
}
|
||||||
|
|
||||||
|
mustCreateProfile(t, ctx, tab, "dummy1", serverName1)
|
||||||
|
mustCreateProfile(t, ctx, tab, "dummy1", serverName2)
|
||||||
|
mustCreateProfile(t, ctx, tab, "testing", serverName2)
|
||||||
|
|
||||||
|
if err := tab.SetAvatarURL(ctx, nil, "dummy1", serverName1, avatarURL); err != nil {
|
||||||
|
t.Fatalf("failed to set avatar url: %v", err)
|
||||||
|
}
|
||||||
|
if err := tab.SetDisplayName(ctx, nil, "dummy1", serverName2, displayName); err != nil {
|
||||||
|
t.Fatalf("failed to set avatar url: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify dummy1 on serverName2 is as expected, just to test the function
|
||||||
|
dummy1, err := tab.SelectProfileByLocalpart(ctx, "dummy1", serverName2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to query profile by localpart: %v", err)
|
||||||
|
}
|
||||||
|
// Make sure that only dummy1 on serverName1 got the displayName changed and avatarURL is unchanged
|
||||||
|
if dummy1.AvatarURL == avatarURL {
|
||||||
|
t.Fatalf("expected avatarURL %s, got %s", avatarURL, dummy1.AvatarURL)
|
||||||
|
}
|
||||||
|
if dummy1.DisplayName != displayName {
|
||||||
|
t.Fatalf("expected displayName %s, got %s", displayName, dummy1.DisplayName)
|
||||||
|
}
|
||||||
|
|
||||||
|
searchRes, err := tab.SelectProfilesBySearch(ctx, "dummy", 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to search profiles: %v", err)
|
||||||
|
}
|
||||||
|
// serverNotice user and testing should not be returned here, only the dummy users
|
||||||
|
if count := len(searchRes); count > 2 {
|
||||||
|
t.Fatalf("expected 2 results, got %d", count)
|
||||||
|
}
|
||||||
|
for _, profile := range searchRes {
|
||||||
|
if profile.Localpart != "dummy1" {
|
||||||
|
t.Fatalf("got unexpected localpart: %v", profile.Localpart)
|
||||||
|
}
|
||||||
|
// Make sure that only dummy1 on serverName1 got the avatarURL changed and displayName is unchanged
|
||||||
|
if gomatrixserverlib.ServerName(profile.ServerName) == serverName1 && profile.AvatarURL != avatarURL {
|
||||||
|
t.Fatalf("expected avatarURL %s, got %s", avatarURL, profile.AvatarURL)
|
||||||
|
}
|
||||||
|
if gomatrixserverlib.ServerName(profile.ServerName) == serverName1 && profile.DisplayName == displayName {
|
||||||
|
t.Fatalf("expected displayName %s, got %s", displayName, profile.DisplayName)
|
||||||
|
}
|
||||||
|
// Make sure that only dummy1 on serverName1 got the displayName changed and avatarURL is unchanged
|
||||||
|
if gomatrixserverlib.ServerName(profile.ServerName) == serverName2 && profile.AvatarURL == avatarURL {
|
||||||
|
t.Fatalf("expected avatarURL %s, got %s", avatarURL, profile.AvatarURL)
|
||||||
|
}
|
||||||
|
if gomatrixserverlib.ServerName(profile.ServerName) == serverName2 && profile.DisplayName != displayName {
|
||||||
|
t.Fatalf("expected displayName %s, got %s", displayName, profile.DisplayName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
Loading…
Reference in a new issue