More stuff

This commit is contained in:
Neil Alexander 2022-11-04 15:11:28 +00:00
parent a0cc4c806c
commit da7f2a7047
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
11 changed files with 58 additions and 44 deletions

View file

@ -32,6 +32,7 @@ import (
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
)
// AddInternalRoutes registers HTTP handlers for internal API calls
@ -74,7 +75,7 @@ func NewInternalAPI(
// events to be sent out.
for _, appservice := range base.Cfg.Derived.ApplicationServices {
// Create bot account for this AS if it doesn't already exist
if err := generateAppServiceAccount(userAPI, appservice); err != nil {
if err := generateAppServiceAccount(userAPI, appservice, base.Cfg.Global.ServerName); err != nil {
logrus.WithFields(logrus.Fields{
"appservice": appservice.ID,
}).WithError(err).Panicf("failed to generate bot account for appservice")
@ -101,11 +102,13 @@ func NewInternalAPI(
func generateAppServiceAccount(
userAPI userapi.AppserviceUserAPI,
as config.ApplicationService,
serverName gomatrixserverlib.ServerName,
) error {
var accRes userapi.PerformAccountCreationResponse
err := userAPI.PerformAccountCreation(context.Background(), &userapi.PerformAccountCreationRequest{
AccountType: userapi.AccountTypeAppService,
Localpart: as.SenderLocalpart,
ServerName: serverName,
AppServiceID: as.ID,
OnConflict: userapi.ConflictUpdate,
}, &accRes)
@ -115,6 +118,7 @@ func generateAppServiceAccount(
var devRes userapi.PerformDeviceCreationResponse
err = userAPI.PerformDeviceCreation(context.Background(), &userapi.PerformDeviceCreationRequest{
Localpart: as.SenderLocalpart,
ServerName: serverName,
AccessToken: as.ASToken,
DeviceID: &as.SenderLocalpart,
DeviceDisplayName: &as.SenderLocalpart,

View file

@ -175,7 +175,6 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
if serverName == "" {
serverName = a.Config.Matrix.ServerName
}
// XXXX: Use the server name here
acc, err := a.DB.CreateAccount(ctx, req.Localpart, serverName, req.Password, req.AppServiceID, req.AccountType)
if err != nil {
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
@ -245,13 +244,12 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
serverName = a.Config.Matrix.ServerName
}
_ = serverName
// XXXX: Use the server name here
util.GetLogger(ctx).WithFields(logrus.Fields{
"localpart": req.Localpart,
"device_id": req.DeviceID,
"display_name": req.DeviceDisplayName,
}).Info("PerformDeviceCreation")
dev, err := a.DB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
dev, err := a.DB.CreateDevice(ctx, req.Localpart, serverName, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
if err != nil {
return err
}

View file

@ -70,7 +70,7 @@ type Device interface {
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
CreateDevice(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr, userAgent string) error
RemoveDevices(ctx context.Context, localpart string, devices []string) error

View file

@ -17,6 +17,7 @@ package postgres
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/matrix-org/gomatrixserverlib"
@ -67,7 +68,7 @@ const selectPasswordHashSQL = "" +
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = FALSE"
const selectNewNumericLocalpartSQL = "" +
"SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$' AND server_name = $2"
"SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$' AND server_name = $1"
type accountsStatements struct {
insertAccountStmt *sql.Stmt
@ -132,7 +133,7 @@ func (s *accountsStatements) InsertAccount(
_, err = stmt.ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType)
}
if err != nil {
return nil, err
return nil, fmt.Errorf("insertAccountStmt: %w", err)
}
return &api.Account{

View file

@ -17,6 +17,7 @@ package postgres
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/lib/pq"
@ -70,7 +71,7 @@ CREATE UNIQUE INDEX IF NOT EXISTS userapi_device_localpart_id_idx ON userapi_dev
`
const insertDeviceSQL = "" +
"INSERT INTO userapi_devices(device_id, localpart, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" +
"INSERT INTO userapi_devices(device_id, localpart, server_name, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" +
" RETURNING session_id"
const selectDeviceByTokenSQL = "" +
@ -86,7 +87,7 @@ const updateDeviceNameSQL = "" +
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
const deleteDeviceSQL = "" +
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2"
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2 AND server_name = $3"
const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2"
@ -149,18 +150,19 @@ func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName
// Returns an error if the user already has a device with the given device ID.
// Returns the device on success.
func (s *devicesStatements) InsertDevice(
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
displayName *string, ipAddr, userAgent string,
ctx context.Context, txn *sql.Tx, id string,
localpart string, serverName gomatrixserverlib.ServerName,
accessToken string, displayName *string, ipAddr, userAgent string,
) (*api.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64
stmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil {
return nil, err
if err := stmt.QueryRowContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil {
return nil, fmt.Errorf("insertDeviceStmt: %w", err)
}
return &api.Device{
ID: id,
UserID: userutil.MakeUserID(localpart, s.serverName),
UserID: userutil.MakeUserID(localpart, serverName),
AccessToken: accessToken,
SessionID: sessionID,
LastSeenTS: createdTimeMS,
@ -171,10 +173,11 @@ func (s *devicesStatements) InsertDevice(
// deleteDevice removes a single device by id and user localpart.
func (s *devicesStatements) DeleteDevice(
ctx context.Context, txn *sql.Tx, id, localpart string,
ctx context.Context, txn *sql.Tx, id string,
localpart string, serverName gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
_, err := stmt.ExecContext(ctx, id, localpart)
_, err := stmt.ExecContext(ctx, id, localpart, serverName)
return err
}

View file

@ -45,13 +45,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
Up: deltas.UpRenameTables,
Down: deltas.DownRenameTables,
})
m.AddMigrations(sqlutil.Migration{
Version: "userapi: server names",
Up: func(ctx context.Context, txn *sql.Tx) error {
return deltas.UpServerNames(ctx, txn, serverName)
},
Down: deltas.DownServerNames,
})
if err = m.Up(base.Context()); err != nil {
return nil, err
}
@ -104,6 +97,19 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil {
return nil, fmt.Errorf("NewPostgresStatsTable: %w", err)
}
m = sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "userapi: server names",
Up: func(ctx context.Context, txn *sql.Tx) error {
return deltas.UpServerNames(ctx, txn, serverName)
},
Down: deltas.DownServerNames,
})
if err = m.Up(base.Context()); err != nil {
return nil, err
}
return &shared.Database{
AccountDatas: accountDataTable,
Accounts: accountsTable,

View file

@ -569,8 +569,8 @@ func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]ap
// If no device ID is given one is generated.
// Returns the device on success.
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string, ipAddr, userAgent string,
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string,
) (dev *api.Device, returnErr error) {
if deviceID != nil {
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
@ -580,7 +580,7 @@ func (d *Database) CreateDevice(
return err
}
dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent)
return err
})
} else {
@ -595,7 +595,7 @@ func (d *Database) CreateDevice(
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var err error
dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent)
return err
})
if returnErr == nil {

View file

@ -52,8 +52,8 @@ CREATE TABLE IF NOT EXISTS userapi_devices (
`
const insertDeviceSQL = "" +
"INSERT INTO userapi_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
"INSERT INTO userapi_devices (device_id, localpart, server_name, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"
const selectDevicesCountSQL = "" +
"SELECT COUNT(access_token) FROM userapi_devices"
@ -136,8 +136,9 @@ func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
// Returns an error if the user already has a device with the given device ID.
// Returns the device on success.
func (s *devicesStatements) InsertDevice(
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
displayName *string, ipAddr, userAgent string,
ctx context.Context, txn *sql.Tx, id string,
localpart string, serverName gomatrixserverlib.ServerName,
accessToken string, displayName *string, ipAddr, userAgent string,
) (*api.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64
@ -147,7 +148,7 @@ func (s *devicesStatements) InsertDevice(
return nil, err
}
sessionID++
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
return nil, err
}
return &api.Device{

View file

@ -152,7 +152,7 @@ func Test_Accounts(t *testing.T) {
func Test_Devices(t *testing.T) {
alice := test.NewUser(t)
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
deviceID := util.RandomString(8)
accessToken := util.RandomString(16)
@ -161,7 +161,7 @@ func Test_Devices(t *testing.T) {
db, close := mustCreateDatabase(t, dbType)
defer close()
deviceWithID, err := db.CreateDevice(ctx, localpart, &deviceID, accessToken, nil, "", "")
deviceWithID, err := db.CreateDevice(ctx, localpart, domain, &deviceID, accessToken, nil, "", "")
assert.NoError(t, err, "unable to create deviceWithoutID")
gotDevice, err := db.GetDeviceByID(ctx, localpart, deviceID)
@ -174,7 +174,7 @@ func Test_Devices(t *testing.T) {
// create a device without existing device ID
accessToken = util.RandomString(16)
deviceWithoutID, err := db.CreateDevice(ctx, localpart, nil, accessToken, nil, "", "")
deviceWithoutID, err := db.CreateDevice(ctx, localpart, domain, nil, accessToken, nil, "", "")
assert.NoError(t, err, "unable to create deviceWithoutID")
gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, deviceWithoutID.ID)
assert.NoError(t, err, "unable to get device by id")
@ -213,7 +213,7 @@ func Test_Devices(t *testing.T) {
// create one more device and remove the devices step by step
newDeviceID := util.RandomString(16)
accessToken = util.RandomString(16)
_, err = db.CreateDevice(ctx, localpart, &newDeviceID, accessToken, nil, "", "")
_, err = db.CreateDevice(ctx, localpart, domain, &newDeviceID, accessToken, nil, "", "")
assert.NoError(t, err, "unable to create new device")
devices, err = db.GetDevicesByLocalpart(ctx, localpart)

View file

@ -43,7 +43,7 @@ type AccountsTable interface {
}
type DevicesTable interface {
InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error)
InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error)
DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string) error
DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, devices []string) error
DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) error

View file

@ -79,6 +79,7 @@ func mustMakeAccountAndDevice(
accDB tables.AccountsTable,
devDB tables.DevicesTable,
localpart string,
serverName gomatrixserverlib.ServerName,
accType api.AccountType,
userAgent string,
) {
@ -93,7 +94,7 @@ func mustMakeAccountAndDevice(
if err != nil {
t.Fatalf("unable to create account: %v", err)
}
_, err = devDB.InsertDevice(ctx, nil, "deviceID", localpart, util.RandomString(16), nil, "", userAgent)
_, err = devDB.InsertDevice(ctx, nil, "deviceID", localpart, serverName, util.RandomString(16), nil, "", userAgent)
if err != nil {
t.Fatalf("unable to create device: %v", err)
}
@ -150,12 +151,12 @@ func Test_UserStatistics(t *testing.T) {
})
t.Run("Want Users", func(t *testing.T) {
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user1", api.AccountTypeUser, "Element Android")
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user2", api.AccountTypeUser, "Element iOS")
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user3", api.AccountTypeUser, "Element web")
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user4", api.AccountTypeGuest, "Element Electron")
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user5", api.AccountTypeAdmin, "gecko")
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user6", api.AccountTypeAppService, "gecko")
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user1", "localhost", api.AccountTypeUser, "Element Android")
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user2", "localhost", api.AccountTypeUser, "Element iOS")
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user3", "localhost", api.AccountTypeUser, "Element web")
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user4", "localhost", api.AccountTypeGuest, "Element Electron")
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user5", "localhost", api.AccountTypeAdmin, "gecko")
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user6", "localhost", api.AccountTypeAppService, "gecko")
gotStats, _, err := statsDB.UserStatistics(ctx, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)