Account data

This commit is contained in:
Neil Alexander 2022-11-04 12:09:49 +00:00
parent 7fd2c10975
commit d1c61f5f95
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
13 changed files with 217 additions and 64 deletions

View file

@ -491,7 +491,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *
} }
// Get accountdata to check if the event.Sender() is ignored by mem.LocalPart // Get accountdata to check if the event.Sender() is ignored by mem.LocalPart
data, err := s.db.GetAccountDataByType(ctx, mem.Localpart, "", "m.ignored_user_list") data, err := s.db.GetAccountDataByType(ctx, mem.Localpart, mem.Domain, "", "m.ignored_user_list")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -506,7 +506,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *
return nil, fmt.Errorf("user %s is ignored", sender) return nil, fmt.Errorf("user %s is ignored", sender)
} }
} }
ruleSets, err := s.db.QueryPushRules(ctx, mem.Localpart) ruleSets, err := s.db.QueryPushRules(ctx, mem.Localpart, mem.Domain)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -68,7 +68,7 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
if req.DataType == "" { if req.DataType == "" {
return fmt.Errorf("data type must not be empty") return fmt.Errorf("data type must not be empty")
} }
if err := a.DB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData); err != nil { if err := a.DB.SaveAccountData(ctx, local, domain, req.RoomID, req.DataType, req.AccountData); err != nil {
util.GetLogger(ctx).WithError(err).Error("a.DB.SaveAccountData failed") util.GetLogger(ctx).WithError(err).Error("a.DB.SaveAccountData failed")
return fmt.Errorf("failed to save account data: %w", err) return fmt.Errorf("failed to save account data: %w", err)
} }
@ -176,7 +176,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
serverName = a.Config.Matrix.ServerName serverName = a.Config.Matrix.ServerName
} }
// XXXX: Use the server name here // XXXX: Use the server name here
acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType) acc, err := a.DB.CreateAccount(ctx, req.Localpart, serverName, req.Password, req.AppServiceID, req.AccountType)
if err != nil { if err != nil {
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
switch req.OnConflict { switch req.OnConflict {
@ -476,7 +476,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
} }
if req.DataType != "" { if req.DataType != "" {
var data json.RawMessage var data json.RawMessage
data, err = a.DB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType) data, err = a.DB.GetAccountDataByType(ctx, local, domain, req.RoomID, req.DataType)
if err != nil { if err != nil {
return err return err
} }
@ -494,7 +494,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
} }
return nil return nil
} }
global, rooms, err := a.DB.GetAccountData(ctx, local) global, rooms, err := a.DB.GetAccountData(ctx, local, domain)
if err != nil { if err != nil {
return err return err
} }
@ -864,11 +864,11 @@ func (a *UserInternalAPI) PerformPushRulesPut(
} }
func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error { func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error {
localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID) localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil { if err != nil {
return fmt.Errorf("failed to split user ID %q for push rules", req.UserID) return fmt.Errorf("failed to split user ID %q for push rules", req.UserID)
} }
pushRules, err := a.DB.QueryPushRules(ctx, localpart) pushRules, err := a.DB.QueryPushRules(ctx, localpart, domain)
if err != nil { if err != nil {
return fmt.Errorf("failed to query push rules: %w", err) return fmt.Errorf("failed to query push rules: %w", err)
} }

View file

@ -39,7 +39,7 @@ type Account interface {
// CreateAccount makes a new account with the given login name and password, and creates an empty profile // CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the // for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, ErrUserExists. // account already exists, it will return nil, ErrUserExists.
CreateAccount(ctx context.Context, localpart string, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error) CreateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error)
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
GetNewNumericLocalpart(ctx context.Context) (int64, error) GetNewNumericLocalpart(ctx context.Context) (int64, error)
CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
@ -49,14 +49,14 @@ type Account interface {
} }
type AccountData interface { type AccountData interface {
SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error SaveAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string, content json.RawMessage) error
GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error) GetAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
// GetAccountDataByType returns account data matching a given // GetAccountDataByType returns account data matching a given
// localpart, room ID and type. // localpart, room ID and type.
// If no account data could be found, returns nil // If no account data could be found, returns nil
// Returns an error if there was an issue with the retrieval // Returns an error if there was an issue with the retrieval
GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error) GetAccountDataByType(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string) (data json.RawMessage, err error)
QueryPushRules(ctx context.Context, localpart string) (*pushrules.AccountRuleSets, error) QueryPushRules(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*pushrules.AccountRuleSets, error)
} }
type Device interface { type Device interface {

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
) )
const accountDataSchema = ` const accountDataSchema = `
@ -42,15 +43,15 @@ CREATE TABLE IF NOT EXISTS userapi_account_datas (
` `
const insertAccountDataSQL = ` const insertAccountDataSQL = `
INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4) INSERT INTO userapi_account_datas(localpart, server_name, room_id, type, content) VALUES($1, $2, $3, $4, $5)
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = EXCLUDED.content ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = EXCLUDED.content
` `
const selectAccountDataSQL = "" + const selectAccountDataSQL = "" +
"SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1" "SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2"
const selectAccountDataByTypeSQL = "" + const selectAccountDataByTypeSQL = "" +
"SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3" "SELECT content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND type = $4"
type accountDataStatements struct { type accountDataStatements struct {
insertAccountDataStmt *sql.Stmt insertAccountDataStmt *sql.Stmt
@ -72,21 +73,24 @@ func NewPostgresAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) {
} }
func (s *accountDataStatements) InsertAccountData( func (s *accountDataStatements) InsertAccountData(
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
roomID, dataType string, content json.RawMessage,
) (err error) { ) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt) stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt)
_, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content) _, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, content)
return return
} }
func (s *accountDataStatements) SelectAccountData( func (s *accountDataStatements) SelectAccountData(
ctx context.Context, localpart string, ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
) ( ) (
/* global */ map[string]json.RawMessage, /* global */ map[string]json.RawMessage,
/* rooms */ map[string]map[string]json.RawMessage, /* rooms */ map[string]map[string]json.RawMessage,
error, error,
) { ) {
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart) rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart, serverName)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -118,11 +122,13 @@ func (s *accountDataStatements) SelectAccountData(
} }
func (s *accountDataStatements) SelectAccountDataByType( func (s *accountDataStatements) SelectAccountDataByType(
ctx context.Context, localpart, roomID, dataType string, ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
roomID, dataType string,
) (data json.RawMessage, err error) { ) (data json.RawMessage, err error) {
var bytes []byte var bytes []byte
stmt := s.selectAccountDataByTypeStmt stmt := s.selectAccountDataByTypeStmt
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil { if err = stmt.QueryRowContext(ctx, localpart, serverName, roomID, dataType).Scan(&bytes); err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return nil, nil
} }

View file

@ -0,0 +1,59 @@
package deltas
import (
"context"
"database/sql"
"fmt"
"github.com/lib/pq"
"github.com/matrix-org/gomatrixserverlib"
)
var serverNamesTables = []string{
"userapi_accounts",
"userapi_account_datas",
"userapi_devices",
"userapi_notifications",
"userapi_openid_tokens",
"userapi_profiles",
"userapi_pushers",
"userapi_threepids",
}
// I know what you're thinking: you're wondering "why doesn't this use $1
// and pass variadic parameters to ExecContext?" — the answer is because
// PostgreSQL doesn't expect the table name to be specified as a substituted
// argument in that way so it results in a syntax error in the query.
func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error {
for _, table := range serverNamesTables {
q := fmt.Sprintf(
"ALTER TABLE IF EXISTS %s ADD COLUMN IF NOT EXISTS server_name TEXT NOT NULL DEFAULT '';",
pq.QuoteIdentifier(table),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("add server name to %q error: %w", table, err)
}
q = fmt.Sprintf(
"UPDATE %s SET server_name = %s WHERE server_name = '';",
pq.QuoteIdentifier(table), pq.QuoteLiteral(string(serverName)),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("write server names to %q error: %w", table, err)
}
}
return nil
}
func DownServerNames(ctx context.Context, tx *sql.Tx) error {
for _, table := range serverNamesTables {
q := fmt.Sprintf(
"ALTER TABLE IF EXISTS %s DELETE COLUMN server_name;",
pq.QuoteIdentifier(table),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("remove server name from %q error: %w", table, err)
}
}
return nil
}

View file

@ -15,6 +15,8 @@
package postgres package postgres
import ( import (
"context"
"database/sql"
"fmt" "fmt"
"time" "time"
@ -43,6 +45,13 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
Up: deltas.UpRenameTables, Up: deltas.UpRenameTables,
Down: deltas.DownRenameTables, Down: deltas.DownRenameTables,
}) })
m.AddMigrations(sqlutil.Migration{
Version: "userapi: server names",
Up: func(ctx context.Context, txn *sql.Tx) error {
return deltas.UpServerNames(ctx, txn, serverName)
},
Down: deltas.DownServerNames,
})
if err = m.Up(base.Context()); err != nil { if err = m.Up(base.Context()); err != nil {
return nil, err return nil, err
} }

View file

@ -132,7 +132,8 @@ func (d *Database) SetPassword(
// for this account. If no password is supplied, the account will be a passwordless account. If the // for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, ErrUserExists. // account already exists, it will return nil, ErrUserExists.
func (d *Database) CreateAccount( func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
plaintextPassword, appserviceID string, accountType api.AccountType,
) (acc *api.Account, err error) { ) (acc *api.Account, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
// For guest accounts, we create a new numeric local part // For guest accounts, we create a new numeric local part
@ -146,7 +147,7 @@ func (d *Database) CreateAccount(
plaintextPassword = "" plaintextPassword = ""
appserviceID = "" appserviceID = ""
} }
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType) acc, err = d.createAccount(ctx, txn, localpart, serverName, plaintextPassword, appserviceID, accountType)
return err return err
}) })
return return
@ -155,7 +156,9 @@ func (d *Database) CreateAccount(
// WARNING! This function assumes that the relevant mutexes have already // WARNING! This function assumes that the relevant mutexes have already
// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount). // been taken out by the caller (e.g. CreateAccount or CreateGuestAccount).
func (d *Database) createAccount( func (d *Database) createAccount(
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
plaintextPassword, appserviceID string, accountType api.AccountType,
) (*api.Account, error) { ) (*api.Account, error) {
var err error var err error
var account *api.Account var account *api.Account
@ -178,7 +181,7 @@ func (d *Database) createAccount(
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(prbs)); err != nil { if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, "", "m.push_rules", json.RawMessage(prbs)); err != nil {
return nil, err return nil, err
} }
return account, nil return account, nil
@ -186,9 +189,9 @@ func (d *Database) createAccount(
func (d *Database) QueryPushRules( func (d *Database) QueryPushRules(
ctx context.Context, ctx context.Context,
localpart string, localpart string, serverName gomatrixserverlib.ServerName,
) (*pushrules.AccountRuleSets, error) { ) (*pushrules.AccountRuleSets, error) {
data, err := d.AccountDatas.SelectAccountDataByType(ctx, localpart, "", "m.push_rules") data, err := d.AccountDatas.SelectAccountDataByType(ctx, localpart, serverName, "", "m.push_rules")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -202,7 +205,7 @@ func (d *Database) QueryPushRules(
return nil, fmt.Errorf("failed to marshal default push rules: %w", err) return nil, fmt.Errorf("failed to marshal default push rules: %w", err)
} }
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if dbErr := d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", prbs); dbErr != nil { if dbErr := d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, "", "m.push_rules", prbs); dbErr != nil {
return fmt.Errorf("failed to save default push rules: %w", dbErr) return fmt.Errorf("failed to save default push rules: %w", dbErr)
} }
return nil return nil
@ -225,22 +228,23 @@ func (d *Database) QueryPushRules(
// update the corresponding row with the new content // update the corresponding row with the new content
// Returns a SQL error if there was an issue with the insertion/update // Returns a SQL error if there was an issue with the insertion/update
func (d *Database) SaveAccountData( func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
roomID, dataType string, content json.RawMessage,
) error { ) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.AccountDatas.InsertAccountData(ctx, txn, localpart, roomID, dataType, content) return d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, roomID, dataType, content)
}) })
} }
// GetAccountData returns account data related to a given localpart // GetAccountData returns account data related to a given localpart
// If no account data could be found, returns an empty arrays // If no account data could be found, returns an empty arrays
// Returns an error if there was an issue with the retrieval // Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountData(ctx context.Context, localpart string) ( func (d *Database) GetAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (
global map[string]json.RawMessage, global map[string]json.RawMessage,
rooms map[string]map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage,
err error, err error,
) { ) {
return d.AccountDatas.SelectAccountData(ctx, localpart) return d.AccountDatas.SelectAccountData(ctx, localpart, serverName)
} }
// GetAccountDataByType returns account data matching a given // GetAccountDataByType returns account data matching a given
@ -248,10 +252,11 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) (
// If no account data could be found, returns nil // If no account data could be found, returns nil
// Returns an error if there was an issue with the retrieval // Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountDataByType( func (d *Database) GetAccountDataByType(
ctx context.Context, localpart, roomID, dataType string, ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
roomID, dataType string,
) (data json.RawMessage, err error) { ) (data json.RawMessage, err error) {
return d.AccountDatas.SelectAccountDataByType( return d.AccountDatas.SelectAccountDataByType(
ctx, localpart, roomID, dataType, ctx, localpart, serverName, roomID, dataType,
) )
} }

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
) )
const accountDataSchema = ` const accountDataSchema = `
@ -41,15 +42,15 @@ CREATE TABLE IF NOT EXISTS userapi_account_datas (
` `
const insertAccountDataSQL = ` const insertAccountDataSQL = `
INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4) INSERT INTO userapi_account_datas(localpart, server_name, room_id, type, content) VALUES($1, $2, $3, $4, $5)
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4 ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4
` `
const selectAccountDataSQL = "" + const selectAccountDataSQL = "" +
"SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1" "SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2"
const selectAccountDataByTypeSQL = "" + const selectAccountDataByTypeSQL = "" +
"SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3" "SELECT content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND type = $4"
type accountDataStatements struct { type accountDataStatements struct {
db *sql.DB db *sql.DB
@ -74,20 +75,23 @@ func NewSQLiteAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) {
} }
func (s *accountDataStatements) InsertAccountData( func (s *accountDataStatements) InsertAccountData(
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
roomID, dataType string, content json.RawMessage,
) error { ) error {
_, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content) _, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, serverName, roomID, dataType, content)
return err return err
} }
func (s *accountDataStatements) SelectAccountData( func (s *accountDataStatements) SelectAccountData(
ctx context.Context, localpart string, ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
) ( ) (
/* global */ map[string]json.RawMessage, /* global */ map[string]json.RawMessage,
/* rooms */ map[string]map[string]json.RawMessage, /* rooms */ map[string]map[string]json.RawMessage,
error, error,
) { ) {
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart) rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart, serverName)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -118,11 +122,13 @@ func (s *accountDataStatements) SelectAccountData(
} }
func (s *accountDataStatements) SelectAccountDataByType( func (s *accountDataStatements) SelectAccountDataByType(
ctx context.Context, localpart, roomID, dataType string, ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
roomID, dataType string,
) (data json.RawMessage, err error) { ) (data json.RawMessage, err error) {
var bytes []byte var bytes []byte
stmt := s.selectAccountDataByTypeStmt stmt := s.selectAccountDataByTypeStmt
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil { if err = stmt.QueryRowContext(ctx, localpart, serverName, roomID, dataType).Scan(&bytes); err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return nil, nil
} }

View file

@ -0,0 +1,59 @@
package deltas
import (
"context"
"database/sql"
"fmt"
"github.com/lib/pq"
"github.com/matrix-org/gomatrixserverlib"
)
var serverNamesTables = []string{
"userapi_accounts",
"userapi_account_datas",
"userapi_devices",
"userapi_notifications",
"userapi_openid_tokens",
"userapi_profiles",
"userapi_pushers",
"userapi_threepids",
}
// I know what you're thinking: you're wondering "why doesn't this use $1
// and pass variadic parameters to ExecContext?" — the answer is because
// PostgreSQL doesn't expect the table name to be specified as a substituted
// argument in that way so it results in a syntax error in the query.
func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error {
for _, table := range serverNamesTables {
q := fmt.Sprintf(
"ALTER TABLE IF EXISTS %s ADD COLUMN IF NOT EXISTS server_name TEXT NOT NULL DEFAULT '';",
pq.QuoteIdentifier(table),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("add server name to %q error: %w", table, err)
}
q = fmt.Sprintf(
"UPDATE %s SET server_name = %s WHERE server_name = '';",
pq.QuoteIdentifier(table), pq.QuoteLiteral(string(serverName)),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("write server names to %q error: %w", table, err)
}
}
return nil
}
func DownServerNames(ctx context.Context, tx *sql.Tx) error {
for _, table := range serverNamesTables {
q := fmt.Sprintf(
"ALTER TABLE IF EXISTS %s DELETE COLUMN server_name;",
pq.QuoteIdentifier(table),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("remove server name from %q error: %w", table, err)
}
}
return nil
}

View file

@ -15,6 +15,8 @@
package sqlite3 package sqlite3
import ( import (
"context"
"database/sql"
"fmt" "fmt"
"time" "time"
@ -41,6 +43,13 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
Up: deltas.UpRenameTables, Up: deltas.UpRenameTables,
Down: deltas.DownRenameTables, Down: deltas.DownRenameTables,
}) })
m.AddMigrations(sqlutil.Migration{
Version: "userapi: server names",
Up: func(ctx context.Context, txn *sql.Tx) error {
return deltas.UpServerNames(ctx, txn, serverName)
},
Down: deltas.DownServerNames,
})
if err = m.Up(base.Context()); err != nil { if err = m.Up(base.Context()); err != nil {
return nil, err return nil, err
} }

View file

@ -50,25 +50,25 @@ func Test_AccountData(t *testing.T) {
db, close := mustCreateDatabase(t, dbType) db, close := mustCreateDatabase(t, dbType)
defer close() defer close()
alice := test.NewUser(t) alice := test.NewUser(t)
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err) assert.NoError(t, err)
room := test.NewRoom(t, alice) room := test.NewRoom(t, alice)
events := room.Events() events := room.Events()
contentRoom := json.RawMessage(fmt.Sprintf(`{"event_id":"%s"}`, events[len(events)-1].EventID())) contentRoom := json.RawMessage(fmt.Sprintf(`{"event_id":"%s"}`, events[len(events)-1].EventID()))
err = db.SaveAccountData(ctx, localpart, room.ID, "m.fully_read", contentRoom) err = db.SaveAccountData(ctx, localpart, domain, room.ID, "m.fully_read", contentRoom)
assert.NoError(t, err, "unable to save account data") assert.NoError(t, err, "unable to save account data")
contentGlobal := json.RawMessage(fmt.Sprintf(`{"recent_rooms":["%s"]}`, room.ID)) contentGlobal := json.RawMessage(fmt.Sprintf(`{"recent_rooms":["%s"]}`, room.ID))
err = db.SaveAccountData(ctx, localpart, "", "im.vector.setting.breadcrumbs", contentGlobal) err = db.SaveAccountData(ctx, localpart, domain, "", "im.vector.setting.breadcrumbs", contentGlobal)
assert.NoError(t, err, "unable to save account data") assert.NoError(t, err, "unable to save account data")
accountData, err := db.GetAccountDataByType(ctx, localpart, room.ID, "m.fully_read") accountData, err := db.GetAccountDataByType(ctx, localpart, domain, room.ID, "m.fully_read")
assert.NoError(t, err, "unable to get account data by type") assert.NoError(t, err, "unable to get account data by type")
assert.Equal(t, contentRoom, accountData) assert.Equal(t, contentRoom, accountData)
globalData, roomData, err := db.GetAccountData(ctx, localpart) globalData, roomData, err := db.GetAccountData(ctx, localpart, domain)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, contentRoom, roomData[room.ID]["m.fully_read"]) assert.Equal(t, contentRoom, roomData[room.ID]["m.fully_read"])
assert.Equal(t, contentGlobal, globalData["im.vector.setting.breadcrumbs"]) assert.Equal(t, contentGlobal, globalData["im.vector.setting.breadcrumbs"])
@ -81,10 +81,10 @@ func Test_Accounts(t *testing.T) {
db, close := mustCreateDatabase(t, dbType) db, close := mustCreateDatabase(t, dbType)
defer close() defer close()
alice := test.NewUser(t) alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err) assert.NoError(t, err)
accAlice, err := db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin) accAlice, err := db.CreateAccount(ctx, aliceLocalpart, aliceDomain, "testing", "", api.AccountTypeAdmin)
assert.NoError(t, err, "failed to create account") assert.NoError(t, err, "failed to create account")
// verify the newly create account is the same as returned by CreateAccount // verify the newly create account is the same as returned by CreateAccount
var accGet *api.Account var accGet *api.Account
@ -108,7 +108,7 @@ func Test_Accounts(t *testing.T) {
first, err := db.GetNewNumericLocalpart(ctx) first, err := db.GetNewNumericLocalpart(ctx)
assert.NoError(t, err, "failed to get new numeric localpart") assert.NoError(t, err, "failed to get new numeric localpart")
// Create a new account to verify the numeric localpart is updated // Create a new account to verify the numeric localpart is updated
_, err = db.CreateAccount(ctx, "", "testing", "", api.AccountTypeGuest) _, err = db.CreateAccount(ctx, "", aliceDomain, "testing", "", api.AccountTypeGuest)
assert.NoError(t, err, "failed to create account") assert.NoError(t, err, "failed to create account")
second, err := db.GetNewNumericLocalpart(ctx) second, err := db.GetNewNumericLocalpart(ctx)
assert.NoError(t, err) assert.NoError(t, err)
@ -133,19 +133,19 @@ func Test_Accounts(t *testing.T) {
// create an empty localpart; this should never happen, but is required to test getting a numeric localpart // create an empty localpart; this should never happen, but is required to test getting a numeric localpart
// if there's already a user without a localpart in the database // if there's already a user without a localpart in the database
_, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeUser) _, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeUser)
assert.NoError(t, err) assert.NoError(t, err)
// test getting a numeric localpart, with an existing user without a localpart // test getting a numeric localpart, with an existing user without a localpart
_, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeGuest) _, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeGuest)
assert.NoError(t, err) assert.NoError(t, err)
// Create a user with a high numeric localpart, out of range for the Postgres integer (2147483647) type // Create a user with a high numeric localpart, out of range for the Postgres integer (2147483647) type
_, err = db.CreateAccount(ctx, "2147483650", "", "", api.AccountTypeUser) _, err = db.CreateAccount(ctx, "2147483650", aliceDomain, "", "", api.AccountTypeUser)
assert.NoError(t, err) assert.NoError(t, err)
// Now try to create a new guest user // Now try to create a new guest user
_, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeGuest) _, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeGuest)
assert.NoError(t, err) assert.NoError(t, err)
}) })
} }
@ -364,7 +364,7 @@ func Test_OpenID(t *testing.T) {
func Test_Profile(t *testing.T) { func Test_Profile(t *testing.T) {
alice := test.NewUser(t) alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err) assert.NoError(t, err)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
@ -372,7 +372,7 @@ func Test_Profile(t *testing.T) {
defer close() defer close()
// create account, which also creates a profile // create account, which also creates a profile
_, err = db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin) _, err = db.CreateAccount(ctx, aliceLocalpart, aliceDomain, "testing", "", api.AccountTypeAdmin)
assert.NoError(t, err, "failed to create account") assert.NoError(t, err, "failed to create account")
gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart) gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart)

View file

@ -28,9 +28,9 @@ import (
) )
type AccountDataTable interface { type AccountDataTable interface {
InsertAccountData(ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage) error InsertAccountData(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string, content json.RawMessage) error
SelectAccountData(ctx context.Context, localpart string) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error) SelectAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error)
SelectAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error) SelectAccountDataByType(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string) (data json.RawMessage, err error)
} }
type AccountsTable interface { type AccountsTable interface {

View file

@ -80,7 +80,7 @@ func TestQueryProfile(t *testing.T) {
// only one DBType, since userapi.AddInternalRoutes complains about multiple prometheus counters added // only one DBType, since userapi.AddInternalRoutes complains about multiple prometheus counters added
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, test.DBTypeSQLite) userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, test.DBTypeSQLite)
defer close() defer close()
_, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "", api.AccountTypeUser) _, err := accountDB.CreateAccount(context.TODO(), "alice", serverName, "foobar", "", api.AccountTypeUser)
if err != nil { if err != nil {
t.Fatalf("failed to make account: %s", err) t.Fatalf("failed to make account: %s", err)
} }
@ -164,7 +164,7 @@ func TestPasswordlessLoginFails(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
defer close() defer close()
_, err := accountDB.CreateAccount(ctx, "auser", "", "", api.AccountTypeAppService) _, err := accountDB.CreateAccount(ctx, "auser", serverName, "", "", api.AccountTypeAppService)
if err != nil { if err != nil {
t.Fatalf("failed to make account: %s", err) t.Fatalf("failed to make account: %s", err)
} }
@ -190,7 +190,7 @@ func TestLoginToken(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
defer close() defer close()
_, err := accountDB.CreateAccount(ctx, "auser", "apassword", "", api.AccountTypeUser) _, err := accountDB.CreateAccount(ctx, "auser", serverName, "apassword", "", api.AccountTypeUser)
if err != nil { if err != nil {
t.Fatalf("failed to make account: %s", err) t.Fatalf("failed to make account: %s", err)
} }