More stuff
This commit is contained in:
parent
a0cc4c806c
commit
da7f2a7047
|
@ -32,6 +32,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/setup/base"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// AddInternalRoutes registers HTTP handlers for internal API calls
|
||||
|
@ -74,7 +75,7 @@ func NewInternalAPI(
|
|||
// events to be sent out.
|
||||
for _, appservice := range base.Cfg.Derived.ApplicationServices {
|
||||
// Create bot account for this AS if it doesn't already exist
|
||||
if err := generateAppServiceAccount(userAPI, appservice); err != nil {
|
||||
if err := generateAppServiceAccount(userAPI, appservice, base.Cfg.Global.ServerName); err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"appservice": appservice.ID,
|
||||
}).WithError(err).Panicf("failed to generate bot account for appservice")
|
||||
|
@ -101,11 +102,13 @@ func NewInternalAPI(
|
|||
func generateAppServiceAccount(
|
||||
userAPI userapi.AppserviceUserAPI,
|
||||
as config.ApplicationService,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
var accRes userapi.PerformAccountCreationResponse
|
||||
err := userAPI.PerformAccountCreation(context.Background(), &userapi.PerformAccountCreationRequest{
|
||||
AccountType: userapi.AccountTypeAppService,
|
||||
Localpart: as.SenderLocalpart,
|
||||
ServerName: serverName,
|
||||
AppServiceID: as.ID,
|
||||
OnConflict: userapi.ConflictUpdate,
|
||||
}, &accRes)
|
||||
|
@ -115,6 +118,7 @@ func generateAppServiceAccount(
|
|||
var devRes userapi.PerformDeviceCreationResponse
|
||||
err = userAPI.PerformDeviceCreation(context.Background(), &userapi.PerformDeviceCreationRequest{
|
||||
Localpart: as.SenderLocalpart,
|
||||
ServerName: serverName,
|
||||
AccessToken: as.ASToken,
|
||||
DeviceID: &as.SenderLocalpart,
|
||||
DeviceDisplayName: &as.SenderLocalpart,
|
||||
|
|
|
@ -175,7 +175,6 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
|
|||
if serverName == "" {
|
||||
serverName = a.Config.Matrix.ServerName
|
||||
}
|
||||
// XXXX: Use the server name here
|
||||
acc, err := a.DB.CreateAccount(ctx, req.Localpart, serverName, req.Password, req.AppServiceID, req.AccountType)
|
||||
if err != nil {
|
||||
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
|
||||
|
@ -245,13 +244,12 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
|
|||
serverName = a.Config.Matrix.ServerName
|
||||
}
|
||||
_ = serverName
|
||||
// XXXX: Use the server name here
|
||||
util.GetLogger(ctx).WithFields(logrus.Fields{
|
||||
"localpart": req.Localpart,
|
||||
"device_id": req.DeviceID,
|
||||
"display_name": req.DeviceDisplayName,
|
||||
}).Info("PerformDeviceCreation")
|
||||
dev, err := a.DB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
|
||||
dev, err := a.DB.CreateDevice(ctx, req.Localpart, serverName, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -70,7 +70,7 @@ type Device interface {
|
|||
// an error will be returned.
|
||||
// If no device ID is given one is generated.
|
||||
// Returns the device on success.
|
||||
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
|
||||
CreateDevice(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
|
||||
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
|
||||
UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr, userAgent string) error
|
||||
RemoveDevices(ctx context.Context, localpart string, devices []string) error
|
||||
|
|
|
@ -17,6 +17,7 @@ package postgres
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
@ -67,7 +68,7 @@ const selectPasswordHashSQL = "" +
|
|||
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = FALSE"
|
||||
|
||||
const selectNewNumericLocalpartSQL = "" +
|
||||
"SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$' AND server_name = $2"
|
||||
"SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$' AND server_name = $1"
|
||||
|
||||
type accountsStatements struct {
|
||||
insertAccountStmt *sql.Stmt
|
||||
|
@ -132,7 +133,7 @@ func (s *accountsStatements) InsertAccount(
|
|||
_, err = stmt.ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("insertAccountStmt: %w", err)
|
||||
}
|
||||
|
||||
return &api.Account{
|
||||
|
|
|
@ -17,6 +17,7 @@ package postgres
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
|
@ -70,7 +71,7 @@ CREATE UNIQUE INDEX IF NOT EXISTS userapi_device_localpart_id_idx ON userapi_dev
|
|||
`
|
||||
|
||||
const insertDeviceSQL = "" +
|
||||
"INSERT INTO userapi_devices(device_id, localpart, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" +
|
||||
"INSERT INTO userapi_devices(device_id, localpart, server_name, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" +
|
||||
" RETURNING session_id"
|
||||
|
||||
const selectDeviceByTokenSQL = "" +
|
||||
|
@ -86,7 +87,7 @@ const updateDeviceNameSQL = "" +
|
|||
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
||||
|
||||
const deleteDeviceSQL = "" +
|
||||
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2"
|
||||
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2 AND server_name = $3"
|
||||
|
||||
const deleteDevicesByLocalpartSQL = "" +
|
||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2"
|
||||
|
@ -149,18 +150,19 @@ func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName
|
|||
// Returns an error if the user already has a device with the given device ID.
|
||||
// Returns the device on success.
|
||||
func (s *devicesStatements) InsertDevice(
|
||||
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
|
||||
displayName *string, ipAddr, userAgent string,
|
||||
ctx context.Context, txn *sql.Tx, id string,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
accessToken string, displayName *string, ipAddr, userAgent string,
|
||||
) (*api.Device, error) {
|
||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
var sessionID int64
|
||||
stmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
|
||||
if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil {
|
||||
return nil, err
|
||||
if err := stmt.QueryRowContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil {
|
||||
return nil, fmt.Errorf("insertDeviceStmt: %w", err)
|
||||
}
|
||||
return &api.Device{
|
||||
ID: id,
|
||||
UserID: userutil.MakeUserID(localpart, s.serverName),
|
||||
UserID: userutil.MakeUserID(localpart, serverName),
|
||||
AccessToken: accessToken,
|
||||
SessionID: sessionID,
|
||||
LastSeenTS: createdTimeMS,
|
||||
|
@ -171,10 +173,11 @@ func (s *devicesStatements) InsertDevice(
|
|||
|
||||
// deleteDevice removes a single device by id and user localpart.
|
||||
func (s *devicesStatements) DeleteDevice(
|
||||
ctx context.Context, txn *sql.Tx, id, localpart string,
|
||||
ctx context.Context, txn *sql.Tx, id string,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
|
||||
_, err := stmt.ExecContext(ctx, id, localpart)
|
||||
_, err := stmt.ExecContext(ctx, id, localpart, serverName)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -45,13 +45,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
|||
Up: deltas.UpRenameTables,
|
||||
Down: deltas.DownRenameTables,
|
||||
})
|
||||
m.AddMigrations(sqlutil.Migration{
|
||||
Version: "userapi: server names",
|
||||
Up: func(ctx context.Context, txn *sql.Tx) error {
|
||||
return deltas.UpServerNames(ctx, txn, serverName)
|
||||
},
|
||||
Down: deltas.DownServerNames,
|
||||
})
|
||||
if err = m.Up(base.Context()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -104,6 +97,19 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("NewPostgresStatsTable: %w", err)
|
||||
}
|
||||
|
||||
m = sqlutil.NewMigrator(db)
|
||||
m.AddMigrations(sqlutil.Migration{
|
||||
Version: "userapi: server names",
|
||||
Up: func(ctx context.Context, txn *sql.Tx) error {
|
||||
return deltas.UpServerNames(ctx, txn, serverName)
|
||||
},
|
||||
Down: deltas.DownServerNames,
|
||||
})
|
||||
if err = m.Up(base.Context()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &shared.Database{
|
||||
AccountDatas: accountDataTable,
|
||||
Accounts: accountsTable,
|
||||
|
|
|
@ -569,8 +569,8 @@ func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]ap
|
|||
// If no device ID is given one is generated.
|
||||
// Returns the device on success.
|
||||
func (d *Database) CreateDevice(
|
||||
ctx context.Context, localpart string, deviceID *string, accessToken string,
|
||||
displayName *string, ipAddr, userAgent string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string,
|
||||
) (dev *api.Device, returnErr error) {
|
||||
if deviceID != nil {
|
||||
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
|
@ -580,7 +580,7 @@ func (d *Database) CreateDevice(
|
|||
return err
|
||||
}
|
||||
|
||||
dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
||||
dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent)
|
||||
return err
|
||||
})
|
||||
} else {
|
||||
|
@ -595,7 +595,7 @@ func (d *Database) CreateDevice(
|
|||
|
||||
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
var err error
|
||||
dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
||||
dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent)
|
||||
return err
|
||||
})
|
||||
if returnErr == nil {
|
||||
|
|
|
@ -52,8 +52,8 @@ CREATE TABLE IF NOT EXISTS userapi_devices (
|
|||
`
|
||||
|
||||
const insertDeviceSQL = "" +
|
||||
"INSERT INTO userapi_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
|
||||
"INSERT INTO userapi_devices (device_id, localpart, server_name, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"
|
||||
|
||||
const selectDevicesCountSQL = "" +
|
||||
"SELECT COUNT(access_token) FROM userapi_devices"
|
||||
|
@ -136,8 +136,9 @@ func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
|
|||
// Returns an error if the user already has a device with the given device ID.
|
||||
// Returns the device on success.
|
||||
func (s *devicesStatements) InsertDevice(
|
||||
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
|
||||
displayName *string, ipAddr, userAgent string,
|
||||
ctx context.Context, txn *sql.Tx, id string,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
accessToken string, displayName *string, ipAddr, userAgent string,
|
||||
) (*api.Device, error) {
|
||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
var sessionID int64
|
||||
|
@ -147,7 +148,7 @@ func (s *devicesStatements) InsertDevice(
|
|||
return nil, err
|
||||
}
|
||||
sessionID++
|
||||
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
|
||||
if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &api.Device{
|
||||
|
|
|
@ -152,7 +152,7 @@ func Test_Accounts(t *testing.T) {
|
|||
|
||||
func Test_Devices(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
deviceID := util.RandomString(8)
|
||||
accessToken := util.RandomString(16)
|
||||
|
@ -161,7 +161,7 @@ func Test_Devices(t *testing.T) {
|
|||
db, close := mustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
|
||||
deviceWithID, err := db.CreateDevice(ctx, localpart, &deviceID, accessToken, nil, "", "")
|
||||
deviceWithID, err := db.CreateDevice(ctx, localpart, domain, &deviceID, accessToken, nil, "", "")
|
||||
assert.NoError(t, err, "unable to create deviceWithoutID")
|
||||
|
||||
gotDevice, err := db.GetDeviceByID(ctx, localpart, deviceID)
|
||||
|
@ -174,7 +174,7 @@ func Test_Devices(t *testing.T) {
|
|||
|
||||
// create a device without existing device ID
|
||||
accessToken = util.RandomString(16)
|
||||
deviceWithoutID, err := db.CreateDevice(ctx, localpart, nil, accessToken, nil, "", "")
|
||||
deviceWithoutID, err := db.CreateDevice(ctx, localpart, domain, nil, accessToken, nil, "", "")
|
||||
assert.NoError(t, err, "unable to create deviceWithoutID")
|
||||
gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, deviceWithoutID.ID)
|
||||
assert.NoError(t, err, "unable to get device by id")
|
||||
|
@ -213,7 +213,7 @@ func Test_Devices(t *testing.T) {
|
|||
// create one more device and remove the devices step by step
|
||||
newDeviceID := util.RandomString(16)
|
||||
accessToken = util.RandomString(16)
|
||||
_, err = db.CreateDevice(ctx, localpart, &newDeviceID, accessToken, nil, "", "")
|
||||
_, err = db.CreateDevice(ctx, localpart, domain, &newDeviceID, accessToken, nil, "", "")
|
||||
assert.NoError(t, err, "unable to create new device")
|
||||
|
||||
devices, err = db.GetDevicesByLocalpart(ctx, localpart)
|
||||
|
|
|
@ -43,7 +43,7 @@ type AccountsTable interface {
|
|||
}
|
||||
|
||||
type DevicesTable interface {
|
||||
InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error)
|
||||
InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error)
|
||||
DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string) error
|
||||
DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, devices []string) error
|
||||
DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) error
|
||||
|
|
|
@ -79,6 +79,7 @@ func mustMakeAccountAndDevice(
|
|||
accDB tables.AccountsTable,
|
||||
devDB tables.DevicesTable,
|
||||
localpart string,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
accType api.AccountType,
|
||||
userAgent string,
|
||||
) {
|
||||
|
@ -93,7 +94,7 @@ func mustMakeAccountAndDevice(
|
|||
if err != nil {
|
||||
t.Fatalf("unable to create account: %v", err)
|
||||
}
|
||||
_, err = devDB.InsertDevice(ctx, nil, "deviceID", localpart, util.RandomString(16), nil, "", userAgent)
|
||||
_, err = devDB.InsertDevice(ctx, nil, "deviceID", localpart, serverName, util.RandomString(16), nil, "", userAgent)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create device: %v", err)
|
||||
}
|
||||
|
@ -150,12 +151,12 @@ func Test_UserStatistics(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("Want Users", func(t *testing.T) {
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user1", api.AccountTypeUser, "Element Android")
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user2", api.AccountTypeUser, "Element iOS")
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user3", api.AccountTypeUser, "Element web")
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user4", api.AccountTypeGuest, "Element Electron")
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user5", api.AccountTypeAdmin, "gecko")
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user6", api.AccountTypeAppService, "gecko")
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user1", "localhost", api.AccountTypeUser, "Element Android")
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user2", "localhost", api.AccountTypeUser, "Element iOS")
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user3", "localhost", api.AccountTypeUser, "Element web")
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user4", "localhost", api.AccountTypeGuest, "Element Electron")
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user5", "localhost", api.AccountTypeAdmin, "gecko")
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user6", "localhost", api.AccountTypeAppService, "gecko")
|
||||
gotStats, _, err := statsDB.UserStatistics(ctx, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
|
|
Loading…
Reference in a new issue