More stuff
This commit is contained in:
parent
a0cc4c806c
commit
da7f2a7047
|
@ -32,6 +32,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/setup/base"
|
"github.com/matrix-org/dendrite/setup/base"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AddInternalRoutes registers HTTP handlers for internal API calls
|
// AddInternalRoutes registers HTTP handlers for internal API calls
|
||||||
|
@ -74,7 +75,7 @@ func NewInternalAPI(
|
||||||
// events to be sent out.
|
// events to be sent out.
|
||||||
for _, appservice := range base.Cfg.Derived.ApplicationServices {
|
for _, appservice := range base.Cfg.Derived.ApplicationServices {
|
||||||
// Create bot account for this AS if it doesn't already exist
|
// Create bot account for this AS if it doesn't already exist
|
||||||
if err := generateAppServiceAccount(userAPI, appservice); err != nil {
|
if err := generateAppServiceAccount(userAPI, appservice, base.Cfg.Global.ServerName); err != nil {
|
||||||
logrus.WithFields(logrus.Fields{
|
logrus.WithFields(logrus.Fields{
|
||||||
"appservice": appservice.ID,
|
"appservice": appservice.ID,
|
||||||
}).WithError(err).Panicf("failed to generate bot account for appservice")
|
}).WithError(err).Panicf("failed to generate bot account for appservice")
|
||||||
|
@ -101,11 +102,13 @@ func NewInternalAPI(
|
||||||
func generateAppServiceAccount(
|
func generateAppServiceAccount(
|
||||||
userAPI userapi.AppserviceUserAPI,
|
userAPI userapi.AppserviceUserAPI,
|
||||||
as config.ApplicationService,
|
as config.ApplicationService,
|
||||||
|
serverName gomatrixserverlib.ServerName,
|
||||||
) error {
|
) error {
|
||||||
var accRes userapi.PerformAccountCreationResponse
|
var accRes userapi.PerformAccountCreationResponse
|
||||||
err := userAPI.PerformAccountCreation(context.Background(), &userapi.PerformAccountCreationRequest{
|
err := userAPI.PerformAccountCreation(context.Background(), &userapi.PerformAccountCreationRequest{
|
||||||
AccountType: userapi.AccountTypeAppService,
|
AccountType: userapi.AccountTypeAppService,
|
||||||
Localpart: as.SenderLocalpart,
|
Localpart: as.SenderLocalpart,
|
||||||
|
ServerName: serverName,
|
||||||
AppServiceID: as.ID,
|
AppServiceID: as.ID,
|
||||||
OnConflict: userapi.ConflictUpdate,
|
OnConflict: userapi.ConflictUpdate,
|
||||||
}, &accRes)
|
}, &accRes)
|
||||||
|
@ -115,6 +118,7 @@ func generateAppServiceAccount(
|
||||||
var devRes userapi.PerformDeviceCreationResponse
|
var devRes userapi.PerformDeviceCreationResponse
|
||||||
err = userAPI.PerformDeviceCreation(context.Background(), &userapi.PerformDeviceCreationRequest{
|
err = userAPI.PerformDeviceCreation(context.Background(), &userapi.PerformDeviceCreationRequest{
|
||||||
Localpart: as.SenderLocalpart,
|
Localpart: as.SenderLocalpart,
|
||||||
|
ServerName: serverName,
|
||||||
AccessToken: as.ASToken,
|
AccessToken: as.ASToken,
|
||||||
DeviceID: &as.SenderLocalpart,
|
DeviceID: &as.SenderLocalpart,
|
||||||
DeviceDisplayName: &as.SenderLocalpart,
|
DeviceDisplayName: &as.SenderLocalpart,
|
||||||
|
|
|
@ -175,7 +175,6 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
|
||||||
if serverName == "" {
|
if serverName == "" {
|
||||||
serverName = a.Config.Matrix.ServerName
|
serverName = a.Config.Matrix.ServerName
|
||||||
}
|
}
|
||||||
// XXXX: Use the server name here
|
|
||||||
acc, err := a.DB.CreateAccount(ctx, req.Localpart, serverName, req.Password, req.AppServiceID, req.AccountType)
|
acc, err := a.DB.CreateAccount(ctx, req.Localpart, serverName, req.Password, req.AppServiceID, req.AccountType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
|
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
|
||||||
|
@ -245,13 +244,12 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
|
||||||
serverName = a.Config.Matrix.ServerName
|
serverName = a.Config.Matrix.ServerName
|
||||||
}
|
}
|
||||||
_ = serverName
|
_ = serverName
|
||||||
// XXXX: Use the server name here
|
|
||||||
util.GetLogger(ctx).WithFields(logrus.Fields{
|
util.GetLogger(ctx).WithFields(logrus.Fields{
|
||||||
"localpart": req.Localpart,
|
"localpart": req.Localpart,
|
||||||
"device_id": req.DeviceID,
|
"device_id": req.DeviceID,
|
||||||
"display_name": req.DeviceDisplayName,
|
"display_name": req.DeviceDisplayName,
|
||||||
}).Info("PerformDeviceCreation")
|
}).Info("PerformDeviceCreation")
|
||||||
dev, err := a.DB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
|
dev, err := a.DB.CreateDevice(ctx, req.Localpart, serverName, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -70,7 +70,7 @@ type Device interface {
|
||||||
// an error will be returned.
|
// an error will be returned.
|
||||||
// If no device ID is given one is generated.
|
// If no device ID is given one is generated.
|
||||||
// Returns the device on success.
|
// Returns the device on success.
|
||||||
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
|
CreateDevice(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
|
||||||
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
|
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
|
||||||
UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr, userAgent string) error
|
UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr, userAgent string) error
|
||||||
RemoveDevices(ctx context.Context, localpart string, devices []string) error
|
RemoveDevices(ctx context.Context, localpart string, devices []string) error
|
||||||
|
|
|
@ -17,6 +17,7 @@ package postgres
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
@ -67,7 +68,7 @@ const selectPasswordHashSQL = "" +
|
||||||
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = FALSE"
|
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = FALSE"
|
||||||
|
|
||||||
const selectNewNumericLocalpartSQL = "" +
|
const selectNewNumericLocalpartSQL = "" +
|
||||||
"SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$' AND server_name = $2"
|
"SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$' AND server_name = $1"
|
||||||
|
|
||||||
type accountsStatements struct {
|
type accountsStatements struct {
|
||||||
insertAccountStmt *sql.Stmt
|
insertAccountStmt *sql.Stmt
|
||||||
|
@ -132,7 +133,7 @@ func (s *accountsStatements) InsertAccount(
|
||||||
_, err = stmt.ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType)
|
_, err = stmt.ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("insertAccountStmt: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &api.Account{
|
return &api.Account{
|
||||||
|
|
|
@ -17,6 +17,7 @@ package postgres
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
|
@ -70,7 +71,7 @@ CREATE UNIQUE INDEX IF NOT EXISTS userapi_device_localpart_id_idx ON userapi_dev
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertDeviceSQL = "" +
|
const insertDeviceSQL = "" +
|
||||||
"INSERT INTO userapi_devices(device_id, localpart, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" +
|
"INSERT INTO userapi_devices(device_id, localpart, server_name, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" +
|
||||||
" RETURNING session_id"
|
" RETURNING session_id"
|
||||||
|
|
||||||
const selectDeviceByTokenSQL = "" +
|
const selectDeviceByTokenSQL = "" +
|
||||||
|
@ -86,7 +87,7 @@ const updateDeviceNameSQL = "" +
|
||||||
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
||||||
|
|
||||||
const deleteDeviceSQL = "" +
|
const deleteDeviceSQL = "" +
|
||||||
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2"
|
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2 AND server_name = $3"
|
||||||
|
|
||||||
const deleteDevicesByLocalpartSQL = "" +
|
const deleteDevicesByLocalpartSQL = "" +
|
||||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2"
|
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2"
|
||||||
|
@ -149,18 +150,19 @@ func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName
|
||||||
// Returns an error if the user already has a device with the given device ID.
|
// Returns an error if the user already has a device with the given device ID.
|
||||||
// Returns the device on success.
|
// Returns the device on success.
|
||||||
func (s *devicesStatements) InsertDevice(
|
func (s *devicesStatements) InsertDevice(
|
||||||
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
|
ctx context.Context, txn *sql.Tx, id string,
|
||||||
displayName *string, ipAddr, userAgent string,
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
accessToken string, displayName *string, ipAddr, userAgent string,
|
||||||
) (*api.Device, error) {
|
) (*api.Device, error) {
|
||||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||||
var sessionID int64
|
var sessionID int64
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
|
stmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
|
||||||
if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil {
|
if err := stmt.QueryRowContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("insertDeviceStmt: %w", err)
|
||||||
}
|
}
|
||||||
return &api.Device{
|
return &api.Device{
|
||||||
ID: id,
|
ID: id,
|
||||||
UserID: userutil.MakeUserID(localpart, s.serverName),
|
UserID: userutil.MakeUserID(localpart, serverName),
|
||||||
AccessToken: accessToken,
|
AccessToken: accessToken,
|
||||||
SessionID: sessionID,
|
SessionID: sessionID,
|
||||||
LastSeenTS: createdTimeMS,
|
LastSeenTS: createdTimeMS,
|
||||||
|
@ -171,10 +173,11 @@ func (s *devicesStatements) InsertDevice(
|
||||||
|
|
||||||
// deleteDevice removes a single device by id and user localpart.
|
// deleteDevice removes a single device by id and user localpart.
|
||||||
func (s *devicesStatements) DeleteDevice(
|
func (s *devicesStatements) DeleteDevice(
|
||||||
ctx context.Context, txn *sql.Tx, id, localpart string,
|
ctx context.Context, txn *sql.Tx, id string,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
|
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
|
||||||
_, err := stmt.ExecContext(ctx, id, localpart)
|
_, err := stmt.ExecContext(ctx, id, localpart, serverName)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -45,13 +45,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
||||||
Up: deltas.UpRenameTables,
|
Up: deltas.UpRenameTables,
|
||||||
Down: deltas.DownRenameTables,
|
Down: deltas.DownRenameTables,
|
||||||
})
|
})
|
||||||
m.AddMigrations(sqlutil.Migration{
|
|
||||||
Version: "userapi: server names",
|
|
||||||
Up: func(ctx context.Context, txn *sql.Tx) error {
|
|
||||||
return deltas.UpServerNames(ctx, txn, serverName)
|
|
||||||
},
|
|
||||||
Down: deltas.DownServerNames,
|
|
||||||
})
|
|
||||||
if err = m.Up(base.Context()); err != nil {
|
if err = m.Up(base.Context()); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -104,6 +97,19 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("NewPostgresStatsTable: %w", err)
|
return nil, fmt.Errorf("NewPostgresStatsTable: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m = sqlutil.NewMigrator(db)
|
||||||
|
m.AddMigrations(sqlutil.Migration{
|
||||||
|
Version: "userapi: server names",
|
||||||
|
Up: func(ctx context.Context, txn *sql.Tx) error {
|
||||||
|
return deltas.UpServerNames(ctx, txn, serverName)
|
||||||
|
},
|
||||||
|
Down: deltas.DownServerNames,
|
||||||
|
})
|
||||||
|
if err = m.Up(base.Context()); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return &shared.Database{
|
return &shared.Database{
|
||||||
AccountDatas: accountDataTable,
|
AccountDatas: accountDataTable,
|
||||||
Accounts: accountsTable,
|
Accounts: accountsTable,
|
||||||
|
|
|
@ -569,8 +569,8 @@ func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]ap
|
||||||
// If no device ID is given one is generated.
|
// If no device ID is given one is generated.
|
||||||
// Returns the device on success.
|
// Returns the device on success.
|
||||||
func (d *Database) CreateDevice(
|
func (d *Database) CreateDevice(
|
||||||
ctx context.Context, localpart string, deviceID *string, accessToken string,
|
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
displayName *string, ipAddr, userAgent string,
|
deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string,
|
||||||
) (dev *api.Device, returnErr error) {
|
) (dev *api.Device, returnErr error) {
|
||||||
if deviceID != nil {
|
if deviceID != nil {
|
||||||
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
@ -580,7 +580,7 @@ func (d *Database) CreateDevice(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
|
@ -595,7 +595,7 @@ func (d *Database) CreateDevice(
|
||||||
|
|
||||||
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
var err error
|
var err error
|
||||||
dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
if returnErr == nil {
|
if returnErr == nil {
|
||||||
|
|
|
@ -52,8 +52,8 @@ CREATE TABLE IF NOT EXISTS userapi_devices (
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertDeviceSQL = "" +
|
const insertDeviceSQL = "" +
|
||||||
"INSERT INTO userapi_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
|
"INSERT INTO userapi_devices (device_id, localpart, server_name, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
|
||||||
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
|
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"
|
||||||
|
|
||||||
const selectDevicesCountSQL = "" +
|
const selectDevicesCountSQL = "" +
|
||||||
"SELECT COUNT(access_token) FROM userapi_devices"
|
"SELECT COUNT(access_token) FROM userapi_devices"
|
||||||
|
@ -136,8 +136,9 @@ func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
|
||||||
// Returns an error if the user already has a device with the given device ID.
|
// Returns an error if the user already has a device with the given device ID.
|
||||||
// Returns the device on success.
|
// Returns the device on success.
|
||||||
func (s *devicesStatements) InsertDevice(
|
func (s *devicesStatements) InsertDevice(
|
||||||
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
|
ctx context.Context, txn *sql.Tx, id string,
|
||||||
displayName *string, ipAddr, userAgent string,
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
accessToken string, displayName *string, ipAddr, userAgent string,
|
||||||
) (*api.Device, error) {
|
) (*api.Device, error) {
|
||||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||||
var sessionID int64
|
var sessionID int64
|
||||||
|
@ -147,7 +148,7 @@ func (s *devicesStatements) InsertDevice(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
sessionID++
|
sessionID++
|
||||||
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
|
if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &api.Device{
|
return &api.Device{
|
||||||
|
|
|
@ -152,7 +152,7 @@ func Test_Accounts(t *testing.T) {
|
||||||
|
|
||||||
func Test_Devices(t *testing.T) {
|
func Test_Devices(t *testing.T) {
|
||||||
alice := test.NewUser(t)
|
alice := test.NewUser(t)
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
deviceID := util.RandomString(8)
|
deviceID := util.RandomString(8)
|
||||||
accessToken := util.RandomString(16)
|
accessToken := util.RandomString(16)
|
||||||
|
@ -161,7 +161,7 @@ func Test_Devices(t *testing.T) {
|
||||||
db, close := mustCreateDatabase(t, dbType)
|
db, close := mustCreateDatabase(t, dbType)
|
||||||
defer close()
|
defer close()
|
||||||
|
|
||||||
deviceWithID, err := db.CreateDevice(ctx, localpart, &deviceID, accessToken, nil, "", "")
|
deviceWithID, err := db.CreateDevice(ctx, localpart, domain, &deviceID, accessToken, nil, "", "")
|
||||||
assert.NoError(t, err, "unable to create deviceWithoutID")
|
assert.NoError(t, err, "unable to create deviceWithoutID")
|
||||||
|
|
||||||
gotDevice, err := db.GetDeviceByID(ctx, localpart, deviceID)
|
gotDevice, err := db.GetDeviceByID(ctx, localpart, deviceID)
|
||||||
|
@ -174,7 +174,7 @@ func Test_Devices(t *testing.T) {
|
||||||
|
|
||||||
// create a device without existing device ID
|
// create a device without existing device ID
|
||||||
accessToken = util.RandomString(16)
|
accessToken = util.RandomString(16)
|
||||||
deviceWithoutID, err := db.CreateDevice(ctx, localpart, nil, accessToken, nil, "", "")
|
deviceWithoutID, err := db.CreateDevice(ctx, localpart, domain, nil, accessToken, nil, "", "")
|
||||||
assert.NoError(t, err, "unable to create deviceWithoutID")
|
assert.NoError(t, err, "unable to create deviceWithoutID")
|
||||||
gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, deviceWithoutID.ID)
|
gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, deviceWithoutID.ID)
|
||||||
assert.NoError(t, err, "unable to get device by id")
|
assert.NoError(t, err, "unable to get device by id")
|
||||||
|
@ -213,7 +213,7 @@ func Test_Devices(t *testing.T) {
|
||||||
// create one more device and remove the devices step by step
|
// create one more device and remove the devices step by step
|
||||||
newDeviceID := util.RandomString(16)
|
newDeviceID := util.RandomString(16)
|
||||||
accessToken = util.RandomString(16)
|
accessToken = util.RandomString(16)
|
||||||
_, err = db.CreateDevice(ctx, localpart, &newDeviceID, accessToken, nil, "", "")
|
_, err = db.CreateDevice(ctx, localpart, domain, &newDeviceID, accessToken, nil, "", "")
|
||||||
assert.NoError(t, err, "unable to create new device")
|
assert.NoError(t, err, "unable to create new device")
|
||||||
|
|
||||||
devices, err = db.GetDevicesByLocalpart(ctx, localpart)
|
devices, err = db.GetDevicesByLocalpart(ctx, localpart)
|
||||||
|
|
|
@ -43,7 +43,7 @@ type AccountsTable interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type DevicesTable interface {
|
type DevicesTable interface {
|
||||||
InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error)
|
InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error)
|
||||||
DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string) error
|
DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string) error
|
||||||
DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, devices []string) error
|
DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, devices []string) error
|
||||||
DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) error
|
DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) error
|
||||||
|
|
|
@ -79,6 +79,7 @@ func mustMakeAccountAndDevice(
|
||||||
accDB tables.AccountsTable,
|
accDB tables.AccountsTable,
|
||||||
devDB tables.DevicesTable,
|
devDB tables.DevicesTable,
|
||||||
localpart string,
|
localpart string,
|
||||||
|
serverName gomatrixserverlib.ServerName,
|
||||||
accType api.AccountType,
|
accType api.AccountType,
|
||||||
userAgent string,
|
userAgent string,
|
||||||
) {
|
) {
|
||||||
|
@ -93,7 +94,7 @@ func mustMakeAccountAndDevice(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to create account: %v", err)
|
t.Fatalf("unable to create account: %v", err)
|
||||||
}
|
}
|
||||||
_, err = devDB.InsertDevice(ctx, nil, "deviceID", localpart, util.RandomString(16), nil, "", userAgent)
|
_, err = devDB.InsertDevice(ctx, nil, "deviceID", localpart, serverName, util.RandomString(16), nil, "", userAgent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to create device: %v", err)
|
t.Fatalf("unable to create device: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -150,12 +151,12 @@ func Test_UserStatistics(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Want Users", func(t *testing.T) {
|
t.Run("Want Users", func(t *testing.T) {
|
||||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user1", api.AccountTypeUser, "Element Android")
|
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user1", "localhost", api.AccountTypeUser, "Element Android")
|
||||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user2", api.AccountTypeUser, "Element iOS")
|
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user2", "localhost", api.AccountTypeUser, "Element iOS")
|
||||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user3", api.AccountTypeUser, "Element web")
|
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user3", "localhost", api.AccountTypeUser, "Element web")
|
||||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user4", api.AccountTypeGuest, "Element Electron")
|
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user4", "localhost", api.AccountTypeGuest, "Element Electron")
|
||||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user5", api.AccountTypeAdmin, "gecko")
|
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user5", "localhost", api.AccountTypeAdmin, "gecko")
|
||||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user6", api.AccountTypeAppService, "gecko")
|
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user6", "localhost", api.AccountTypeAppService, "gecko")
|
||||||
gotStats, _, err := statsDB.UserStatistics(ctx, nil)
|
gotStats, _, err := statsDB.UserStatistics(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
|
Loading…
Reference in a new issue