diff --git a/appservice/appservice.go b/appservice/appservice.go index 0c778b6ca..b3c28dbde 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -32,6 +32,7 @@ import ( "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" ) // AddInternalRoutes registers HTTP handlers for internal API calls @@ -74,7 +75,7 @@ func NewInternalAPI( // events to be sent out. for _, appservice := range base.Cfg.Derived.ApplicationServices { // Create bot account for this AS if it doesn't already exist - if err := generateAppServiceAccount(userAPI, appservice); err != nil { + if err := generateAppServiceAccount(userAPI, appservice, base.Cfg.Global.ServerName); err != nil { logrus.WithFields(logrus.Fields{ "appservice": appservice.ID, }).WithError(err).Panicf("failed to generate bot account for appservice") @@ -101,11 +102,13 @@ func NewInternalAPI( func generateAppServiceAccount( userAPI userapi.AppserviceUserAPI, as config.ApplicationService, + serverName gomatrixserverlib.ServerName, ) error { var accRes userapi.PerformAccountCreationResponse err := userAPI.PerformAccountCreation(context.Background(), &userapi.PerformAccountCreationRequest{ AccountType: userapi.AccountTypeAppService, Localpart: as.SenderLocalpart, + ServerName: serverName, AppServiceID: as.ID, OnConflict: userapi.ConflictUpdate, }, &accRes) @@ -115,6 +118,7 @@ func generateAppServiceAccount( var devRes userapi.PerformDeviceCreationResponse err = userAPI.PerformDeviceCreation(context.Background(), &userapi.PerformDeviceCreationRequest{ Localpart: as.SenderLocalpart, + ServerName: serverName, AccessToken: as.ASToken, DeviceID: &as.SenderLocalpart, DeviceDisplayName: &as.SenderLocalpart, diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 1c0de2f0c..300067243 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -175,7 +175,6 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P if serverName == "" { serverName = a.Config.Matrix.ServerName } - // XXXX: Use the server name here acc, err := a.DB.CreateAccount(ctx, req.Localpart, serverName, req.Password, req.AppServiceID, req.AccountType) if err != nil { if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists @@ -245,13 +244,12 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe serverName = a.Config.Matrix.ServerName } _ = serverName - // XXXX: Use the server name here util.GetLogger(ctx).WithFields(logrus.Fields{ "localpart": req.Localpart, "device_id": req.DeviceID, "display_name": req.DeviceDisplayName, }).Info("PerformDeviceCreation") - dev, err := a.DB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent) + dev, err := a.DB.CreateDevice(ctx, req.Localpart, serverName, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent) if err != nil { return err } diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 6413f88fb..d68a8b57d 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -70,7 +70,7 @@ type Device interface { // an error will be returned. // If no device ID is given one is generated. // Returns the device on success. - CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error) + CreateDevice(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error) UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr, userAgent string) error RemoveDevices(ctx context.Context, localpart string, devices []string) error diff --git a/userapi/storage/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go index 754880a06..9b30b5d8f 100644 --- a/userapi/storage/postgres/accounts_table.go +++ b/userapi/storage/postgres/accounts_table.go @@ -17,6 +17,7 @@ package postgres import ( "context" "database/sql" + "fmt" "time" "github.com/matrix-org/gomatrixserverlib" @@ -67,7 +68,7 @@ const selectPasswordHashSQL = "" + "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = FALSE" const selectNewNumericLocalpartSQL = "" + - "SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$' AND server_name = $2" + "SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$' AND server_name = $1" type accountsStatements struct { insertAccountStmt *sql.Stmt @@ -132,7 +133,7 @@ func (s *accountsStatements) InsertAccount( _, err = stmt.ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType) } if err != nil { - return nil, err + return nil, fmt.Errorf("insertAccountStmt: %w", err) } return &api.Account{ diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go index 3982d1f21..2b90a69b2 100644 --- a/userapi/storage/postgres/devices_table.go +++ b/userapi/storage/postgres/devices_table.go @@ -17,6 +17,7 @@ package postgres import ( "context" "database/sql" + "fmt" "time" "github.com/lib/pq" @@ -70,7 +71,7 @@ CREATE UNIQUE INDEX IF NOT EXISTS userapi_device_localpart_id_idx ON userapi_dev ` const insertDeviceSQL = "" + - "INSERT INTO userapi_devices(device_id, localpart, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" + + "INSERT INTO userapi_devices(device_id, localpart, server_name, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" + " RETURNING session_id" const selectDeviceByTokenSQL = "" + @@ -86,7 +87,7 @@ const updateDeviceNameSQL = "" + "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" const deleteDeviceSQL = "" + - "DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2" + "DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2 AND server_name = $3" const deleteDevicesByLocalpartSQL = "" + "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2" @@ -149,18 +150,19 @@ func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName // Returns an error if the user already has a device with the given device ID. // Returns the device on success. func (s *devicesStatements) InsertDevice( - ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, - displayName *string, ipAddr, userAgent string, + ctx context.Context, txn *sql.Tx, id string, + localpart string, serverName gomatrixserverlib.ServerName, + accessToken string, displayName *string, ipAddr, userAgent string, ) (*api.Device, error) { createdTimeMS := time.Now().UnixNano() / 1000000 var sessionID int64 stmt := sqlutil.TxStmt(txn, s.insertDeviceStmt) - if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil { - return nil, err + if err := stmt.QueryRowContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil { + return nil, fmt.Errorf("insertDeviceStmt: %w", err) } return &api.Device{ ID: id, - UserID: userutil.MakeUserID(localpart, s.serverName), + UserID: userutil.MakeUserID(localpart, serverName), AccessToken: accessToken, SessionID: sessionID, LastSeenTS: createdTimeMS, @@ -171,10 +173,11 @@ func (s *devicesStatements) InsertDevice( // deleteDevice removes a single device by id and user localpart. func (s *devicesStatements) DeleteDevice( - ctx context.Context, txn *sql.Tx, id, localpart string, + ctx context.Context, txn *sql.Tx, id string, + localpart string, serverName gomatrixserverlib.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) - _, err := stmt.ExecContext(ctx, id, localpart) + _, err := stmt.ExecContext(ctx, id, localpart, serverName) return err } diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index 4f72cd85f..45607b3f4 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -45,13 +45,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, Up: deltas.UpRenameTables, Down: deltas.DownRenameTables, }) - m.AddMigrations(sqlutil.Migration{ - Version: "userapi: server names", - Up: func(ctx context.Context, txn *sql.Tx) error { - return deltas.UpServerNames(ctx, txn, serverName) - }, - Down: deltas.DownServerNames, - }) if err = m.Up(base.Context()); err != nil { return nil, err } @@ -104,6 +97,19 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, fmt.Errorf("NewPostgresStatsTable: %w", err) } + + m = sqlutil.NewMigrator(db) + m.AddMigrations(sqlutil.Migration{ + Version: "userapi: server names", + Up: func(ctx context.Context, txn *sql.Tx) error { + return deltas.UpServerNames(ctx, txn, serverName) + }, + Down: deltas.DownServerNames, + }) + if err = m.Up(base.Context()); err != nil { + return nil, err + } + return &shared.Database{ AccountDatas: accountDataTable, Accounts: accountsTable, diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index f2cfe021d..88e821532 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -569,8 +569,8 @@ func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]ap // If no device ID is given one is generated. // Returns the device on success. func (d *Database) CreateDevice( - ctx context.Context, localpart string, deviceID *string, accessToken string, - displayName *string, ipAddr, userAgent string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string, ) (dev *api.Device, returnErr error) { if deviceID != nil { returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -580,7 +580,7 @@ func (d *Database) CreateDevice( return err } - dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent) + dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent) return err }) } else { @@ -595,7 +595,7 @@ func (d *Database) CreateDevice( returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { var err error - dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent) + dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent) return err }) if returnErr == nil { diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go index 69bb86060..832abf36d 100644 --- a/userapi/storage/sqlite3/devices_table.go +++ b/userapi/storage/sqlite3/devices_table.go @@ -52,8 +52,8 @@ CREATE TABLE IF NOT EXISTS userapi_devices ( ` const insertDeviceSQL = "" + - "INSERT INTO userapi_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" + - " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" + "INSERT INTO userapi_devices (device_id, localpart, server_name, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" + + " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)" const selectDevicesCountSQL = "" + "SELECT COUNT(access_token) FROM userapi_devices" @@ -136,8 +136,9 @@ func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) // Returns an error if the user already has a device with the given device ID. // Returns the device on success. func (s *devicesStatements) InsertDevice( - ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, - displayName *string, ipAddr, userAgent string, + ctx context.Context, txn *sql.Tx, id string, + localpart string, serverName gomatrixserverlib.ServerName, + accessToken string, displayName *string, ipAddr, userAgent string, ) (*api.Device, error) { createdTimeMS := time.Now().UnixNano() / 1000000 var sessionID int64 @@ -147,7 +148,7 @@ func (s *devicesStatements) InsertDevice( return nil, err } sessionID++ - if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil { + if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil { return nil, err } return &api.Device{ diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index ada077ce1..22798f029 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -152,7 +152,7 @@ func Test_Accounts(t *testing.T) { func Test_Devices(t *testing.T) { alice := test.NewUser(t) - localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) + localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) deviceID := util.RandomString(8) accessToken := util.RandomString(16) @@ -161,7 +161,7 @@ func Test_Devices(t *testing.T) { db, close := mustCreateDatabase(t, dbType) defer close() - deviceWithID, err := db.CreateDevice(ctx, localpart, &deviceID, accessToken, nil, "", "") + deviceWithID, err := db.CreateDevice(ctx, localpart, domain, &deviceID, accessToken, nil, "", "") assert.NoError(t, err, "unable to create deviceWithoutID") gotDevice, err := db.GetDeviceByID(ctx, localpart, deviceID) @@ -174,7 +174,7 @@ func Test_Devices(t *testing.T) { // create a device without existing device ID accessToken = util.RandomString(16) - deviceWithoutID, err := db.CreateDevice(ctx, localpart, nil, accessToken, nil, "", "") + deviceWithoutID, err := db.CreateDevice(ctx, localpart, domain, nil, accessToken, nil, "", "") assert.NoError(t, err, "unable to create deviceWithoutID") gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, deviceWithoutID.ID) assert.NoError(t, err, "unable to get device by id") @@ -213,7 +213,7 @@ func Test_Devices(t *testing.T) { // create one more device and remove the devices step by step newDeviceID := util.RandomString(16) accessToken = util.RandomString(16) - _, err = db.CreateDevice(ctx, localpart, &newDeviceID, accessToken, nil, "", "") + _, err = db.CreateDevice(ctx, localpart, domain, &newDeviceID, accessToken, nil, "", "") assert.NoError(t, err, "unable to create new device") devices, err = db.GetDevicesByLocalpart(ctx, localpart) diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 32b7c30c6..1747e9256 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -43,7 +43,7 @@ type AccountsTable interface { } type DevicesTable interface { - InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error) + InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error) DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string) error DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, devices []string) error DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) error diff --git a/userapi/storage/tables/stats_table_test.go b/userapi/storage/tables/stats_table_test.go index ae2915bc4..5ade9351c 100644 --- a/userapi/storage/tables/stats_table_test.go +++ b/userapi/storage/tables/stats_table_test.go @@ -79,6 +79,7 @@ func mustMakeAccountAndDevice( accDB tables.AccountsTable, devDB tables.DevicesTable, localpart string, + serverName gomatrixserverlib.ServerName, accType api.AccountType, userAgent string, ) { @@ -93,7 +94,7 @@ func mustMakeAccountAndDevice( if err != nil { t.Fatalf("unable to create account: %v", err) } - _, err = devDB.InsertDevice(ctx, nil, "deviceID", localpart, util.RandomString(16), nil, "", userAgent) + _, err = devDB.InsertDevice(ctx, nil, "deviceID", localpart, serverName, util.RandomString(16), nil, "", userAgent) if err != nil { t.Fatalf("unable to create device: %v", err) } @@ -150,12 +151,12 @@ func Test_UserStatistics(t *testing.T) { }) t.Run("Want Users", func(t *testing.T) { - mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user1", api.AccountTypeUser, "Element Android") - mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user2", api.AccountTypeUser, "Element iOS") - mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user3", api.AccountTypeUser, "Element web") - mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user4", api.AccountTypeGuest, "Element Electron") - mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user5", api.AccountTypeAdmin, "gecko") - mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user6", api.AccountTypeAppService, "gecko") + mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user1", "localhost", api.AccountTypeUser, "Element Android") + mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user2", "localhost", api.AccountTypeUser, "Element iOS") + mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user3", "localhost", api.AccountTypeUser, "Element web") + mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user4", "localhost", api.AccountTypeGuest, "Element Electron") + mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user5", "localhost", api.AccountTypeAdmin, "gecko") + mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user6", "localhost", api.AccountTypeAppService, "gecko") gotStats, _, err := statsDB.UserStatistics(ctx, nil) if err != nil { t.Fatalf("unexpected error: %v", err)