mirror of
https://github.com/matrix-org/dendrite.git
synced 2024-12-03 11:41:56 -06:00
Profiles tables
This commit is contained in:
parent
2e74be1bf2
commit
76ac6dbdf1
|
@ -214,7 +214,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
|
if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, req.ServerName, req.Localpart); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -412,7 +412,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
|
||||||
if !a.Config.Matrix.IsLocalServerName(domain) {
|
if !a.Config.Matrix.IsLocalServerName(domain) {
|
||||||
return fmt.Errorf("cannot query profile of remote users (server name %s)", domain)
|
return fmt.Errorf("cannot query profile of remote users (server name %s)", domain)
|
||||||
}
|
}
|
||||||
prof, err := a.DB.GetProfileByLocalpart(ctx, local)
|
prof, err := a.DB.GetProfileByLocalpart(ctx, local, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil
|
return nil
|
||||||
|
@ -883,7 +883,7 @@ func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPush
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error {
|
func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error {
|
||||||
profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL)
|
profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.ServerName, req.AvatarURL)
|
||||||
res.Profile = profile
|
res.Profile = profile
|
||||||
res.Changed = changed
|
res.Changed = changed
|
||||||
return err
|
return err
|
||||||
|
@ -921,7 +921,7 @@ func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.Q
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *api.PerformUpdateDisplayNameResponse) error {
|
func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *api.PerformUpdateDisplayNameResponse) error {
|
||||||
profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName)
|
profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.ServerName, req.DisplayName)
|
||||||
res.Profile = profile
|
res.Profile = profile
|
||||||
res.Changed = changed
|
res.Changed = changed
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -29,10 +29,10 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Profile interface {
|
type Profile interface {
|
||||||
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
|
GetProfileByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*authtypes.Profile, error)
|
||||||
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
|
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
|
||||||
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) (*authtypes.Profile, bool, error)
|
SetAvatarURL(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error)
|
||||||
SetDisplayName(ctx context.Context, localpart string, displayName string) (*authtypes.Profile, bool, error)
|
SetDisplayName(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, displayName string) (*authtypes.Profile, bool, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Account interface {
|
type Account interface {
|
||||||
|
|
|
@ -23,6 +23,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
const profilesSchema = `
|
const profilesSchema = `
|
||||||
|
@ -39,27 +40,27 @@ CREATE TABLE IF NOT EXISTS userapi_profiles (
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertProfileSQL = "" +
|
const insertProfileSQL = "" +
|
||||||
"INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
|
"INSERT INTO userapi_profiles(localpart, server_name, display_name, avatar_url) VALUES ($1, $2, $3, $4)"
|
||||||
|
|
||||||
const selectProfileByLocalpartSQL = "" +
|
const selectProfileByLocalpartSQL = "" +
|
||||||
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1"
|
"SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1 AND server_name = $2"
|
||||||
|
|
||||||
const setAvatarURLSQL = "" +
|
const setAvatarURLSQL = "" +
|
||||||
"UPDATE userapi_profiles AS new" +
|
"UPDATE userapi_profiles AS new" +
|
||||||
" SET avatar_url = $1" +
|
" SET avatar_url = $1" +
|
||||||
" FROM userapi_profiles AS old" +
|
" FROM userapi_profiles AS old" +
|
||||||
" WHERE new.localpart = $2" +
|
" WHERE new.localpart = $2 AND new.server_name = $3" +
|
||||||
" RETURNING new.display_name, old.avatar_url <> new.avatar_url"
|
" RETURNING new.display_name, old.avatar_url <> new.avatar_url"
|
||||||
|
|
||||||
const setDisplayNameSQL = "" +
|
const setDisplayNameSQL = "" +
|
||||||
"UPDATE userapi_profiles AS new" +
|
"UPDATE userapi_profiles AS new" +
|
||||||
" SET display_name = $1" +
|
" SET display_name = $1" +
|
||||||
" FROM userapi_profiles AS old" +
|
" FROM userapi_profiles AS old" +
|
||||||
" WHERE new.localpart = $2" +
|
" WHERE new.localpart = $2 AND new.server_name = $3" +
|
||||||
" RETURNING new.avatar_url, old.display_name <> new.display_name"
|
" RETURNING new.avatar_url, old.display_name <> new.display_name"
|
||||||
|
|
||||||
const selectProfilesBySearchSQL = "" +
|
const selectProfilesBySearchSQL = "" +
|
||||||
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
"SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
||||||
|
|
||||||
type profilesStatements struct {
|
type profilesStatements struct {
|
||||||
serverNoticesLocalpart string
|
serverNoticesLocalpart string
|
||||||
|
@ -88,18 +89,20 @@ func NewPostgresProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) InsertProfile(
|
func (s *profilesStatements) InsertProfile(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
|
_, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, serverName, "", "")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) SelectProfileByLocalpart(
|
func (s *profilesStatements) SelectProfileByLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (*authtypes.Profile, error) {
|
) (*authtypes.Profile, error) {
|
||||||
var profile authtypes.Profile
|
var profile authtypes.Profile
|
||||||
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan(
|
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart, serverName).Scan(
|
||||||
&profile.Localpart, &profile.DisplayName, &profile.AvatarURL,
|
&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -108,28 +111,34 @@ func (s *profilesStatements) SelectProfileByLocalpart(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) SetAvatarURL(
|
func (s *profilesStatements) SetAvatarURL(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
avatarURL string,
|
||||||
) (*authtypes.Profile, bool, error) {
|
) (*authtypes.Profile, bool, error) {
|
||||||
profile := &authtypes.Profile{
|
profile := &authtypes.Profile{
|
||||||
Localpart: localpart,
|
Localpart: localpart,
|
||||||
|
ServerName: string(serverName),
|
||||||
AvatarURL: avatarURL,
|
AvatarURL: avatarURL,
|
||||||
}
|
}
|
||||||
var changed bool
|
var changed bool
|
||||||
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
|
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
|
||||||
err := stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName, &changed)
|
err := stmt.QueryRowContext(ctx, avatarURL, localpart, serverName).Scan(&profile.DisplayName, &changed)
|
||||||
return profile, changed, err
|
return profile, changed, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) SetDisplayName(
|
func (s *profilesStatements) SetDisplayName(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
displayName string,
|
||||||
) (*authtypes.Profile, bool, error) {
|
) (*authtypes.Profile, bool, error) {
|
||||||
profile := &authtypes.Profile{
|
profile := &authtypes.Profile{
|
||||||
Localpart: localpart,
|
Localpart: localpart,
|
||||||
|
ServerName: string(serverName),
|
||||||
DisplayName: displayName,
|
DisplayName: displayName,
|
||||||
}
|
}
|
||||||
var changed bool
|
var changed bool
|
||||||
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
|
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
|
||||||
err := stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL, &changed)
|
err := stmt.QueryRowContext(ctx, displayName, localpart, serverName).Scan(&profile.AvatarURL, &changed)
|
||||||
return profile, changed, err
|
return profile, changed, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -147,7 +156,7 @@ func (s *profilesStatements) SelectProfilesBySearch(
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed")
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var profile authtypes.Profile
|
var profile authtypes.Profile
|
||||||
if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil {
|
if err := rows.Scan(&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if profile.Localpart != s.serverNoticesLocalpart {
|
if profile.Localpart != s.serverNoticesLocalpart {
|
||||||
|
|
|
@ -87,18 +87,21 @@ func (d *Database) GetAccountByPassword(
|
||||||
// GetProfileByLocalpart returns the profile associated with the given localpart.
|
// GetProfileByLocalpart returns the profile associated with the given localpart.
|
||||||
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
|
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
|
||||||
func (d *Database) GetProfileByLocalpart(
|
func (d *Database) GetProfileByLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (*authtypes.Profile, error) {
|
) (*authtypes.Profile, error) {
|
||||||
return d.Profiles.SelectProfileByLocalpart(ctx, localpart)
|
return d.Profiles.SelectProfileByLocalpart(ctx, localpart, serverName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetAvatarURL updates the avatar URL of the profile associated with the given
|
// SetAvatarURL updates the avatar URL of the profile associated with the given
|
||||||
// localpart. Returns an error if something went wrong with the SQL query
|
// localpart. Returns an error if something went wrong with the SQL query
|
||||||
func (d *Database) SetAvatarURL(
|
func (d *Database) SetAvatarURL(
|
||||||
ctx context.Context, localpart string, avatarURL string,
|
ctx context.Context,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
avatarURL string,
|
||||||
) (profile *authtypes.Profile, changed bool, err error) {
|
) (profile *authtypes.Profile, changed bool, err error) {
|
||||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
profile, changed, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL)
|
profile, changed, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, serverName, avatarURL)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
|
@ -107,10 +110,12 @@ func (d *Database) SetAvatarURL(
|
||||||
// SetDisplayName updates the display name of the profile associated with the given
|
// SetDisplayName updates the display name of the profile associated with the given
|
||||||
// localpart. Returns an error if something went wrong with the SQL query
|
// localpart. Returns an error if something went wrong with the SQL query
|
||||||
func (d *Database) SetDisplayName(
|
func (d *Database) SetDisplayName(
|
||||||
ctx context.Context, localpart string, displayName string,
|
ctx context.Context,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
displayName string,
|
||||||
) (profile *authtypes.Profile, changed bool, err error) {
|
) (profile *authtypes.Profile, changed bool, err error) {
|
||||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
profile, changed, err = d.Profiles.SetDisplayName(ctx, txn, localpart, displayName)
|
profile, changed, err = d.Profiles.SetDisplayName(ctx, txn, localpart, serverName, displayName)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
|
@ -175,7 +180,7 @@ func (d *Database) createAccount(
|
||||||
if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, serverName, hash, appserviceID, accountType); err != nil {
|
if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, serverName, hash, appserviceID, accountType); err != nil {
|
||||||
return nil, sqlutil.ErrUserExists
|
return nil, sqlutil.ErrUserExists
|
||||||
}
|
}
|
||||||
if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil {
|
if err = d.Profiles.InsertProfile(ctx, txn, localpart, serverName); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName)
|
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName)
|
||||||
|
|
|
@ -23,6 +23,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
const profilesSchema = `
|
const profilesSchema = `
|
||||||
|
@ -39,21 +40,21 @@ CREATE TABLE IF NOT EXISTS userapi_profiles (
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertProfileSQL = "" +
|
const insertProfileSQL = "" +
|
||||||
"INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
|
"INSERT INTO userapi_profiles(localpart, server_name, display_name, avatar_url) VALUES ($1, $2, $3, $4)"
|
||||||
|
|
||||||
const selectProfileByLocalpartSQL = "" +
|
const selectProfileByLocalpartSQL = "" +
|
||||||
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1"
|
"SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1 AND server_name = $2"
|
||||||
|
|
||||||
const setAvatarURLSQL = "" +
|
const setAvatarURLSQL = "" +
|
||||||
"UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" +
|
"UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2 AND server_name = $3" +
|
||||||
" RETURNING display_name"
|
" RETURNING display_name"
|
||||||
|
|
||||||
const setDisplayNameSQL = "" +
|
const setDisplayNameSQL = "" +
|
||||||
"UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" +
|
"UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2 AND server_name = $3" +
|
||||||
" RETURNING avatar_url"
|
" RETURNING avatar_url"
|
||||||
|
|
||||||
const selectProfilesBySearchSQL = "" +
|
const selectProfilesBySearchSQL = "" +
|
||||||
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
"SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
||||||
|
|
||||||
type profilesStatements struct {
|
type profilesStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
@ -84,18 +85,20 @@ func NewSQLiteProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables.P
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) InsertProfile(
|
func (s *profilesStatements) InsertProfile(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) error {
|
) error {
|
||||||
_, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
|
_, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, serverName, "", "")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) SelectProfileByLocalpart(
|
func (s *profilesStatements) SelectProfileByLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (*authtypes.Profile, error) {
|
) (*authtypes.Profile, error) {
|
||||||
var profile authtypes.Profile
|
var profile authtypes.Profile
|
||||||
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan(
|
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart, serverName).Scan(
|
||||||
&profile.Localpart, &profile.DisplayName, &profile.AvatarURL,
|
&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -104,13 +107,16 @@ func (s *profilesStatements) SelectProfileByLocalpart(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) SetAvatarURL(
|
func (s *profilesStatements) SetAvatarURL(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
avatarURL string,
|
||||||
) (*authtypes.Profile, bool, error) {
|
) (*authtypes.Profile, bool, error) {
|
||||||
profile := &authtypes.Profile{
|
profile := &authtypes.Profile{
|
||||||
Localpart: localpart,
|
Localpart: localpart,
|
||||||
|
ServerName: string(serverName),
|
||||||
AvatarURL: avatarURL,
|
AvatarURL: avatarURL,
|
||||||
}
|
}
|
||||||
old, err := s.SelectProfileByLocalpart(ctx, localpart)
|
old, err := s.SelectProfileByLocalpart(ctx, localpart, serverName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return old, false, err
|
return old, false, err
|
||||||
}
|
}
|
||||||
|
@ -118,18 +124,20 @@ func (s *profilesStatements) SetAvatarURL(
|
||||||
return old, false, nil
|
return old, false, nil
|
||||||
}
|
}
|
||||||
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
|
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
|
||||||
err = stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName)
|
err = stmt.QueryRowContext(ctx, avatarURL, localpart, serverName).Scan(&profile.DisplayName)
|
||||||
return profile, true, err
|
return profile, true, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) SetDisplayName(
|
func (s *profilesStatements) SetDisplayName(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
displayName string,
|
||||||
) (*authtypes.Profile, bool, error) {
|
) (*authtypes.Profile, bool, error) {
|
||||||
profile := &authtypes.Profile{
|
profile := &authtypes.Profile{
|
||||||
Localpart: localpart,
|
Localpart: localpart,
|
||||||
DisplayName: displayName,
|
DisplayName: displayName,
|
||||||
}
|
}
|
||||||
old, err := s.SelectProfileByLocalpart(ctx, localpart)
|
old, err := s.SelectProfileByLocalpart(ctx, localpart, serverName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return old, false, err
|
return old, false, err
|
||||||
}
|
}
|
||||||
|
@ -137,7 +145,7 @@ func (s *profilesStatements) SetDisplayName(
|
||||||
return old, false, nil
|
return old, false, nil
|
||||||
}
|
}
|
||||||
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
|
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
|
||||||
err = stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL)
|
err = stmt.QueryRowContext(ctx, displayName, localpart, serverName).Scan(&profile.AvatarURL)
|
||||||
return profile, true, err
|
return profile, true, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -155,7 +163,7 @@ func (s *profilesStatements) SelectProfilesBySearch(
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed")
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var profile authtypes.Profile
|
var profile authtypes.Profile
|
||||||
if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil {
|
if err := rows.Scan(&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if profile.Localpart != s.serverNoticesLocalpart {
|
if profile.Localpart != s.serverNoticesLocalpart {
|
||||||
|
|
|
@ -375,27 +375,27 @@ func Test_Profile(t *testing.T) {
|
||||||
_, err = db.CreateAccount(ctx, aliceLocalpart, aliceDomain, "testing", "", api.AccountTypeAdmin)
|
_, err = db.CreateAccount(ctx, aliceLocalpart, aliceDomain, "testing", "", api.AccountTypeAdmin)
|
||||||
assert.NoError(t, err, "failed to create account")
|
assert.NoError(t, err, "failed to create account")
|
||||||
|
|
||||||
gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart)
|
gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart, aliceDomain)
|
||||||
assert.NoError(t, err, "unable to get profile by localpart")
|
assert.NoError(t, err, "unable to get profile by localpart")
|
||||||
wantProfile := &authtypes.Profile{Localpart: aliceLocalpart}
|
wantProfile := &authtypes.Profile{Localpart: aliceLocalpart}
|
||||||
assert.Equal(t, wantProfile, gotProfile)
|
assert.Equal(t, wantProfile, gotProfile)
|
||||||
|
|
||||||
// set avatar & displayname
|
// set avatar & displayname
|
||||||
wantProfile.DisplayName = "Alice"
|
wantProfile.DisplayName = "Alice"
|
||||||
gotProfile, changed, err := db.SetDisplayName(ctx, aliceLocalpart, "Alice")
|
gotProfile, changed, err := db.SetDisplayName(ctx, aliceLocalpart, aliceDomain, "Alice")
|
||||||
assert.Equal(t, wantProfile, gotProfile)
|
assert.Equal(t, wantProfile, gotProfile)
|
||||||
assert.NoError(t, err, "unable to set displayname")
|
assert.NoError(t, err, "unable to set displayname")
|
||||||
assert.True(t, changed)
|
assert.True(t, changed)
|
||||||
|
|
||||||
wantProfile.AvatarURL = "mxc://aliceAvatar"
|
wantProfile.AvatarURL = "mxc://aliceAvatar"
|
||||||
gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
|
gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, aliceDomain, "mxc://aliceAvatar")
|
||||||
assert.NoError(t, err, "unable to set avatar url")
|
assert.NoError(t, err, "unable to set avatar url")
|
||||||
assert.Equal(t, wantProfile, gotProfile)
|
assert.Equal(t, wantProfile, gotProfile)
|
||||||
assert.True(t, changed)
|
assert.True(t, changed)
|
||||||
|
|
||||||
// Setting the same avatar again doesn't change anything
|
// Setting the same avatar again doesn't change anything
|
||||||
wantProfile.AvatarURL = "mxc://aliceAvatar"
|
wantProfile.AvatarURL = "mxc://aliceAvatar"
|
||||||
gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
|
gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, aliceDomain, "mxc://aliceAvatar")
|
||||||
assert.NoError(t, err, "unable to set avatar url")
|
assert.NoError(t, err, "unable to set avatar url")
|
||||||
assert.Equal(t, wantProfile, gotProfile)
|
assert.Equal(t, wantProfile, gotProfile)
|
||||||
assert.False(t, changed)
|
assert.False(t, changed)
|
||||||
|
|
|
@ -84,10 +84,10 @@ type OpenIDTable interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type ProfileTable interface {
|
type ProfileTable interface {
|
||||||
InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error
|
InsertProfile(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName) error
|
||||||
SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
|
SelectProfileByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*authtypes.Profile, error)
|
||||||
SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (*authtypes.Profile, bool, error)
|
SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error)
|
||||||
SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (*authtypes.Profile, bool, error)
|
SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, displayName string) (*authtypes.Profile, bool, error)
|
||||||
SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
|
SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -84,10 +84,10 @@ func TestQueryProfile(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to make account: %s", err)
|
t.Fatalf("failed to make account: %s", err)
|
||||||
}
|
}
|
||||||
if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", aliceAvatarURL); err != nil {
|
if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", serverName, aliceAvatarURL); err != nil {
|
||||||
t.Fatalf("failed to set avatar url: %s", err)
|
t.Fatalf("failed to set avatar url: %s", err)
|
||||||
}
|
}
|
||||||
if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", aliceDisplayName); err != nil {
|
if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", serverName, aliceDisplayName); err != nil {
|
||||||
t.Fatalf("failed to set display name: %s", err)
|
t.Fatalf("failed to set display name: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue