Account data
This commit is contained in:
parent
7fd2c10975
commit
d1c61f5f95
|
@ -491,7 +491,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *
|
|||
}
|
||||
|
||||
// Get accountdata to check if the event.Sender() is ignored by mem.LocalPart
|
||||
data, err := s.db.GetAccountDataByType(ctx, mem.Localpart, "", "m.ignored_user_list")
|
||||
data, err := s.db.GetAccountDataByType(ctx, mem.Localpart, mem.Domain, "", "m.ignored_user_list")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -506,7 +506,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *
|
|||
return nil, fmt.Errorf("user %s is ignored", sender)
|
||||
}
|
||||
}
|
||||
ruleSets, err := s.db.QueryPushRules(ctx, mem.Localpart)
|
||||
ruleSets, err := s.db.QueryPushRules(ctx, mem.Localpart, mem.Domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -68,7 +68,7 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
|
|||
if req.DataType == "" {
|
||||
return fmt.Errorf("data type must not be empty")
|
||||
}
|
||||
if err := a.DB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData); err != nil {
|
||||
if err := a.DB.SaveAccountData(ctx, local, domain, req.RoomID, req.DataType, req.AccountData); err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("a.DB.SaveAccountData failed")
|
||||
return fmt.Errorf("failed to save account data: %w", err)
|
||||
}
|
||||
|
@ -176,7 +176,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
|
|||
serverName = a.Config.Matrix.ServerName
|
||||
}
|
||||
// XXXX: Use the server name here
|
||||
acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
|
||||
acc, err := a.DB.CreateAccount(ctx, req.Localpart, serverName, req.Password, req.AppServiceID, req.AccountType)
|
||||
if err != nil {
|
||||
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
|
||||
switch req.OnConflict {
|
||||
|
@ -476,7 +476,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
|
|||
}
|
||||
if req.DataType != "" {
|
||||
var data json.RawMessage
|
||||
data, err = a.DB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType)
|
||||
data, err = a.DB.GetAccountDataByType(ctx, local, domain, req.RoomID, req.DataType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -494,7 +494,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
|
|||
}
|
||||
return nil
|
||||
}
|
||||
global, rooms, err := a.DB.GetAccountData(ctx, local)
|
||||
global, rooms, err := a.DB.GetAccountData(ctx, local, domain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -864,11 +864,11 @@ func (a *UserInternalAPI) PerformPushRulesPut(
|
|||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error {
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to split user ID %q for push rules", req.UserID)
|
||||
}
|
||||
pushRules, err := a.DB.QueryPushRules(ctx, localpart)
|
||||
pushRules, err := a.DB.QueryPushRules(ctx, localpart, domain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query push rules: %w", err)
|
||||
}
|
||||
|
|
|
@ -39,7 +39,7 @@ type Account interface {
|
|||
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
|
||||
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
||||
// account already exists, it will return nil, ErrUserExists.
|
||||
CreateAccount(ctx context.Context, localpart string, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error)
|
||||
CreateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error)
|
||||
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
|
||||
GetNewNumericLocalpart(ctx context.Context) (int64, error)
|
||||
CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
|
||||
|
@ -49,14 +49,14 @@ type Account interface {
|
|||
}
|
||||
|
||||
type AccountData interface {
|
||||
SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error
|
||||
GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
|
||||
SaveAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string, content json.RawMessage) error
|
||||
GetAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
|
||||
// GetAccountDataByType returns account data matching a given
|
||||
// localpart, room ID and type.
|
||||
// If no account data could be found, returns nil
|
||||
// Returns an error if there was an issue with the retrieval
|
||||
GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error)
|
||||
QueryPushRules(ctx context.Context, localpart string) (*pushrules.AccountRuleSets, error)
|
||||
GetAccountDataByType(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string) (data json.RawMessage, err error)
|
||||
QueryPushRules(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*pushrules.AccountRuleSets, error)
|
||||
}
|
||||
|
||||
type Device interface {
|
||||
|
|
|
@ -22,6 +22,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const accountDataSchema = `
|
||||
|
@ -42,15 +43,15 @@ CREATE TABLE IF NOT EXISTS userapi_account_datas (
|
|||
`
|
||||
|
||||
const insertAccountDataSQL = `
|
||||
INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
|
||||
INSERT INTO userapi_account_datas(localpart, server_name, room_id, type, content) VALUES($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = EXCLUDED.content
|
||||
`
|
||||
|
||||
const selectAccountDataSQL = "" +
|
||||
"SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1"
|
||||
"SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const selectAccountDataByTypeSQL = "" +
|
||||
"SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3"
|
||||
"SELECT content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND type = $4"
|
||||
|
||||
type accountDataStatements struct {
|
||||
insertAccountDataStmt *sql.Stmt
|
||||
|
@ -72,21 +73,24 @@ func NewPostgresAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) {
|
|||
}
|
||||
|
||||
func (s *accountDataStatements) InsertAccountData(
|
||||
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
roomID, dataType string, content json.RawMessage,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt)
|
||||
_, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content)
|
||||
_, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, content)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountDataStatements) SelectAccountData(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (
|
||||
/* global */ map[string]json.RawMessage,
|
||||
/* rooms */ map[string]map[string]json.RawMessage,
|
||||
error,
|
||||
) {
|
||||
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
|
||||
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart, serverName)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -118,11 +122,13 @@ func (s *accountDataStatements) SelectAccountData(
|
|||
}
|
||||
|
||||
func (s *accountDataStatements) SelectAccountDataByType(
|
||||
ctx context.Context, localpart, roomID, dataType string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
roomID, dataType string,
|
||||
) (data json.RawMessage, err error) {
|
||||
var bytes []byte
|
||||
stmt := s.selectAccountDataByTypeStmt
|
||||
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil {
|
||||
if err = stmt.QueryRowContext(ctx, localpart, serverName, roomID, dataType).Scan(&bytes); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
package deltas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
var serverNamesTables = []string{
|
||||
"userapi_accounts",
|
||||
"userapi_account_datas",
|
||||
"userapi_devices",
|
||||
"userapi_notifications",
|
||||
"userapi_openid_tokens",
|
||||
"userapi_profiles",
|
||||
"userapi_pushers",
|
||||
"userapi_threepids",
|
||||
}
|
||||
|
||||
// I know what you're thinking: you're wondering "why doesn't this use $1
|
||||
// and pass variadic parameters to ExecContext?" — the answer is because
|
||||
// PostgreSQL doesn't expect the table name to be specified as a substituted
|
||||
// argument in that way so it results in a syntax error in the query.
|
||||
|
||||
func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error {
|
||||
for _, table := range serverNamesTables {
|
||||
q := fmt.Sprintf(
|
||||
"ALTER TABLE IF EXISTS %s ADD COLUMN IF NOT EXISTS server_name TEXT NOT NULL DEFAULT '';",
|
||||
pq.QuoteIdentifier(table),
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("add server name to %q error: %w", table, err)
|
||||
}
|
||||
q = fmt.Sprintf(
|
||||
"UPDATE %s SET server_name = %s WHERE server_name = '';",
|
||||
pq.QuoteIdentifier(table), pq.QuoteLiteral(string(serverName)),
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("write server names to %q error: %w", table, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DownServerNames(ctx context.Context, tx *sql.Tx) error {
|
||||
for _, table := range serverNamesTables {
|
||||
q := fmt.Sprintf(
|
||||
"ALTER TABLE IF EXISTS %s DELETE COLUMN server_name;",
|
||||
pq.QuoteIdentifier(table),
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("remove server name from %q error: %w", table, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -15,6 +15,8 @@
|
|||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
|
@ -43,6 +45,13 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
|||
Up: deltas.UpRenameTables,
|
||||
Down: deltas.DownRenameTables,
|
||||
})
|
||||
m.AddMigrations(sqlutil.Migration{
|
||||
Version: "userapi: server names",
|
||||
Up: func(ctx context.Context, txn *sql.Tx) error {
|
||||
return deltas.UpServerNames(ctx, txn, serverName)
|
||||
},
|
||||
Down: deltas.DownServerNames,
|
||||
})
|
||||
if err = m.Up(base.Context()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -132,7 +132,8 @@ func (d *Database) SetPassword(
|
|||
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
||||
// account already exists, it will return nil, ErrUserExists.
|
||||
func (d *Database) CreateAccount(
|
||||
ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
plaintextPassword, appserviceID string, accountType api.AccountType,
|
||||
) (acc *api.Account, err error) {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
// For guest accounts, we create a new numeric local part
|
||||
|
@ -146,7 +147,7 @@ func (d *Database) CreateAccount(
|
|||
plaintextPassword = ""
|
||||
appserviceID = ""
|
||||
}
|
||||
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType)
|
||||
acc, err = d.createAccount(ctx, txn, localpart, serverName, plaintextPassword, appserviceID, accountType)
|
||||
return err
|
||||
})
|
||||
return
|
||||
|
@ -155,7 +156,9 @@ func (d *Database) CreateAccount(
|
|||
// WARNING! This function assumes that the relevant mutexes have already
|
||||
// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount).
|
||||
func (d *Database) createAccount(
|
||||
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
plaintextPassword, appserviceID string, accountType api.AccountType,
|
||||
) (*api.Account, error) {
|
||||
var err error
|
||||
var account *api.Account
|
||||
|
@ -178,7 +181,7 @@ func (d *Database) createAccount(
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(prbs)); err != nil {
|
||||
if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, "", "m.push_rules", json.RawMessage(prbs)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return account, nil
|
||||
|
@ -186,9 +189,9 @@ func (d *Database) createAccount(
|
|||
|
||||
func (d *Database) QueryPushRules(
|
||||
ctx context.Context,
|
||||
localpart string,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (*pushrules.AccountRuleSets, error) {
|
||||
data, err := d.AccountDatas.SelectAccountDataByType(ctx, localpart, "", "m.push_rules")
|
||||
data, err := d.AccountDatas.SelectAccountDataByType(ctx, localpart, serverName, "", "m.push_rules")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -202,7 +205,7 @@ func (d *Database) QueryPushRules(
|
|||
return nil, fmt.Errorf("failed to marshal default push rules: %w", err)
|
||||
}
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
if dbErr := d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", prbs); dbErr != nil {
|
||||
if dbErr := d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, "", "m.push_rules", prbs); dbErr != nil {
|
||||
return fmt.Errorf("failed to save default push rules: %w", dbErr)
|
||||
}
|
||||
return nil
|
||||
|
@ -225,22 +228,23 @@ func (d *Database) QueryPushRules(
|
|||
// update the corresponding row with the new content
|
||||
// Returns a SQL error if there was an issue with the insertion/update
|
||||
func (d *Database) SaveAccountData(
|
||||
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
roomID, dataType string, content json.RawMessage,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.AccountDatas.InsertAccountData(ctx, txn, localpart, roomID, dataType, content)
|
||||
return d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, roomID, dataType, content)
|
||||
})
|
||||
}
|
||||
|
||||
// GetAccountData returns account data related to a given localpart
|
||||
// If no account data could be found, returns an empty arrays
|
||||
// Returns an error if there was an issue with the retrieval
|
||||
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
|
||||
func (d *Database) GetAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (
|
||||
global map[string]json.RawMessage,
|
||||
rooms map[string]map[string]json.RawMessage,
|
||||
err error,
|
||||
) {
|
||||
return d.AccountDatas.SelectAccountData(ctx, localpart)
|
||||
return d.AccountDatas.SelectAccountData(ctx, localpart, serverName)
|
||||
}
|
||||
|
||||
// GetAccountDataByType returns account data matching a given
|
||||
|
@ -248,10 +252,11 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) (
|
|||
// If no account data could be found, returns nil
|
||||
// Returns an error if there was an issue with the retrieval
|
||||
func (d *Database) GetAccountDataByType(
|
||||
ctx context.Context, localpart, roomID, dataType string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
roomID, dataType string,
|
||||
) (data json.RawMessage, err error) {
|
||||
return d.AccountDatas.SelectAccountDataByType(
|
||||
ctx, localpart, roomID, dataType,
|
||||
ctx, localpart, serverName, roomID, dataType,
|
||||
)
|
||||
}
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const accountDataSchema = `
|
||||
|
@ -41,15 +42,15 @@ CREATE TABLE IF NOT EXISTS userapi_account_datas (
|
|||
`
|
||||
|
||||
const insertAccountDataSQL = `
|
||||
INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
|
||||
INSERT INTO userapi_account_datas(localpart, server_name, room_id, type, content) VALUES($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4
|
||||
`
|
||||
|
||||
const selectAccountDataSQL = "" +
|
||||
"SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1"
|
||||
"SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const selectAccountDataByTypeSQL = "" +
|
||||
"SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3"
|
||||
"SELECT content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND type = $4"
|
||||
|
||||
type accountDataStatements struct {
|
||||
db *sql.DB
|
||||
|
@ -74,20 +75,23 @@ func NewSQLiteAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) {
|
|||
}
|
||||
|
||||
func (s *accountDataStatements) InsertAccountData(
|
||||
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
roomID, dataType string, content json.RawMessage,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
|
||||
_, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, serverName, roomID, dataType, content)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *accountDataStatements) SelectAccountData(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (
|
||||
/* global */ map[string]json.RawMessage,
|
||||
/* rooms */ map[string]map[string]json.RawMessage,
|
||||
error,
|
||||
) {
|
||||
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
|
||||
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart, serverName)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -118,11 +122,13 @@ func (s *accountDataStatements) SelectAccountData(
|
|||
}
|
||||
|
||||
func (s *accountDataStatements) SelectAccountDataByType(
|
||||
ctx context.Context, localpart, roomID, dataType string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
roomID, dataType string,
|
||||
) (data json.RawMessage, err error) {
|
||||
var bytes []byte
|
||||
stmt := s.selectAccountDataByTypeStmt
|
||||
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil {
|
||||
if err = stmt.QueryRowContext(ctx, localpart, serverName, roomID, dataType).Scan(&bytes); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
package deltas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
var serverNamesTables = []string{
|
||||
"userapi_accounts",
|
||||
"userapi_account_datas",
|
||||
"userapi_devices",
|
||||
"userapi_notifications",
|
||||
"userapi_openid_tokens",
|
||||
"userapi_profiles",
|
||||
"userapi_pushers",
|
||||
"userapi_threepids",
|
||||
}
|
||||
|
||||
// I know what you're thinking: you're wondering "why doesn't this use $1
|
||||
// and pass variadic parameters to ExecContext?" — the answer is because
|
||||
// PostgreSQL doesn't expect the table name to be specified as a substituted
|
||||
// argument in that way so it results in a syntax error in the query.
|
||||
|
||||
func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error {
|
||||
for _, table := range serverNamesTables {
|
||||
q := fmt.Sprintf(
|
||||
"ALTER TABLE IF EXISTS %s ADD COLUMN IF NOT EXISTS server_name TEXT NOT NULL DEFAULT '';",
|
||||
pq.QuoteIdentifier(table),
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("add server name to %q error: %w", table, err)
|
||||
}
|
||||
q = fmt.Sprintf(
|
||||
"UPDATE %s SET server_name = %s WHERE server_name = '';",
|
||||
pq.QuoteIdentifier(table), pq.QuoteLiteral(string(serverName)),
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("write server names to %q error: %w", table, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DownServerNames(ctx context.Context, tx *sql.Tx) error {
|
||||
for _, table := range serverNamesTables {
|
||||
q := fmt.Sprintf(
|
||||
"ALTER TABLE IF EXISTS %s DELETE COLUMN server_name;",
|
||||
pq.QuoteIdentifier(table),
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("remove server name from %q error: %w", table, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -15,6 +15,8 @@
|
|||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
|
@ -41,6 +43,13 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
|||
Up: deltas.UpRenameTables,
|
||||
Down: deltas.DownRenameTables,
|
||||
})
|
||||
m.AddMigrations(sqlutil.Migration{
|
||||
Version: "userapi: server names",
|
||||
Up: func(ctx context.Context, txn *sql.Tx) error {
|
||||
return deltas.UpServerNames(ctx, txn, serverName)
|
||||
},
|
||||
Down: deltas.DownServerNames,
|
||||
})
|
||||
if err = m.Up(base.Context()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -50,25 +50,25 @@ func Test_AccountData(t *testing.T) {
|
|||
db, close := mustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
alice := test.NewUser(t)
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
room := test.NewRoom(t, alice)
|
||||
events := room.Events()
|
||||
|
||||
contentRoom := json.RawMessage(fmt.Sprintf(`{"event_id":"%s"}`, events[len(events)-1].EventID()))
|
||||
err = db.SaveAccountData(ctx, localpart, room.ID, "m.fully_read", contentRoom)
|
||||
err = db.SaveAccountData(ctx, localpart, domain, room.ID, "m.fully_read", contentRoom)
|
||||
assert.NoError(t, err, "unable to save account data")
|
||||
|
||||
contentGlobal := json.RawMessage(fmt.Sprintf(`{"recent_rooms":["%s"]}`, room.ID))
|
||||
err = db.SaveAccountData(ctx, localpart, "", "im.vector.setting.breadcrumbs", contentGlobal)
|
||||
err = db.SaveAccountData(ctx, localpart, domain, "", "im.vector.setting.breadcrumbs", contentGlobal)
|
||||
assert.NoError(t, err, "unable to save account data")
|
||||
|
||||
accountData, err := db.GetAccountDataByType(ctx, localpart, room.ID, "m.fully_read")
|
||||
accountData, err := db.GetAccountDataByType(ctx, localpart, domain, room.ID, "m.fully_read")
|
||||
assert.NoError(t, err, "unable to get account data by type")
|
||||
assert.Equal(t, contentRoom, accountData)
|
||||
|
||||
globalData, roomData, err := db.GetAccountData(ctx, localpart)
|
||||
globalData, roomData, err := db.GetAccountData(ctx, localpart, domain)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, contentRoom, roomData[room.ID]["m.fully_read"])
|
||||
assert.Equal(t, contentGlobal, globalData["im.vector.setting.breadcrumbs"])
|
||||
|
@ -81,10 +81,10 @@ func Test_Accounts(t *testing.T) {
|
|||
db, close := mustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
alice := test.NewUser(t)
|
||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
accAlice, err := db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin)
|
||||
accAlice, err := db.CreateAccount(ctx, aliceLocalpart, aliceDomain, "testing", "", api.AccountTypeAdmin)
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
// verify the newly create account is the same as returned by CreateAccount
|
||||
var accGet *api.Account
|
||||
|
@ -108,7 +108,7 @@ func Test_Accounts(t *testing.T) {
|
|||
first, err := db.GetNewNumericLocalpart(ctx)
|
||||
assert.NoError(t, err, "failed to get new numeric localpart")
|
||||
// Create a new account to verify the numeric localpart is updated
|
||||
_, err = db.CreateAccount(ctx, "", "testing", "", api.AccountTypeGuest)
|
||||
_, err = db.CreateAccount(ctx, "", aliceDomain, "testing", "", api.AccountTypeGuest)
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
second, err := db.GetNewNumericLocalpart(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
@ -133,19 +133,19 @@ func Test_Accounts(t *testing.T) {
|
|||
|
||||
// create an empty localpart; this should never happen, but is required to test getting a numeric localpart
|
||||
// if there's already a user without a localpart in the database
|
||||
_, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeUser)
|
||||
_, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeUser)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// test getting a numeric localpart, with an existing user without a localpart
|
||||
_, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeGuest)
|
||||
_, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeGuest)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create a user with a high numeric localpart, out of range for the Postgres integer (2147483647) type
|
||||
_, err = db.CreateAccount(ctx, "2147483650", "", "", api.AccountTypeUser)
|
||||
_, err = db.CreateAccount(ctx, "2147483650", aliceDomain, "", "", api.AccountTypeUser)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Now try to create a new guest user
|
||||
_, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeGuest)
|
||||
_, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeGuest)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
@ -364,7 +364,7 @@ func Test_OpenID(t *testing.T) {
|
|||
|
||||
func Test_Profile(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
|
@ -372,7 +372,7 @@ func Test_Profile(t *testing.T) {
|
|||
defer close()
|
||||
|
||||
// create account, which also creates a profile
|
||||
_, err = db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin)
|
||||
_, err = db.CreateAccount(ctx, aliceLocalpart, aliceDomain, "testing", "", api.AccountTypeAdmin)
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart)
|
||||
|
|
|
@ -28,9 +28,9 @@ import (
|
|||
)
|
||||
|
||||
type AccountDataTable interface {
|
||||
InsertAccountData(ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage) error
|
||||
SelectAccountData(ctx context.Context, localpart string) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error)
|
||||
SelectAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error)
|
||||
InsertAccountData(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string, content json.RawMessage) error
|
||||
SelectAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error)
|
||||
SelectAccountDataByType(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string) (data json.RawMessage, err error)
|
||||
}
|
||||
|
||||
type AccountsTable interface {
|
||||
|
|
|
@ -80,7 +80,7 @@ func TestQueryProfile(t *testing.T) {
|
|||
// only one DBType, since userapi.AddInternalRoutes complains about multiple prometheus counters added
|
||||
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, test.DBTypeSQLite)
|
||||
defer close()
|
||||
_, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "", api.AccountTypeUser)
|
||||
_, err := accountDB.CreateAccount(context.TODO(), "alice", serverName, "foobar", "", api.AccountTypeUser)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to make account: %s", err)
|
||||
}
|
||||
|
@ -164,7 +164,7 @@ func TestPasswordlessLoginFails(t *testing.T) {
|
|||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
|
||||
defer close()
|
||||
_, err := accountDB.CreateAccount(ctx, "auser", "", "", api.AccountTypeAppService)
|
||||
_, err := accountDB.CreateAccount(ctx, "auser", serverName, "", "", api.AccountTypeAppService)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to make account: %s", err)
|
||||
}
|
||||
|
@ -190,7 +190,7 @@ func TestLoginToken(t *testing.T) {
|
|||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
|
||||
defer close()
|
||||
_, err := accountDB.CreateAccount(ctx, "auser", "apassword", "", api.AccountTypeUser)
|
||||
_, err := accountDB.CreateAccount(ctx, "auser", serverName, "apassword", "", api.AccountTypeUser)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to make account: %s", err)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue