Add user_agent to database

This commit is contained in:
Till Faelligen 2020-10-02 17:16:14 +02:00
parent d5fef50690
commit 57c12aacc0
9 changed files with 26 additions and 18 deletions

View file

@ -193,6 +193,8 @@ type PerformDeviceCreationRequest struct {
DeviceDisplayName *string
// IP address of this device
IPAddr string
// Useragent for this device
UserAgent string
}
// PerformDeviceCreationResponse is the response for PerformDeviceCreation

View file

@ -113,7 +113,7 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
"device_id": req.DeviceID,
"display_name": req.DeviceDisplayName,
}).Info("PerformDeviceCreation")
dev, err := a.DeviceDB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr)
dev, err := a.DeviceDB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
if err != nil {
return err
}

View file

@ -31,7 +31,7 @@ type Database interface {
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr string) (dev *api.Device, returnErr error)
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
RemoveDevice(ctx context.Context, deviceID, localpart string) error
RemoveDevices(ctx context.Context, localpart string, devices []string) error

View file

@ -2,10 +2,12 @@
-- +goose StatementBegin
ALTER TABLE device_devices ADD COLUMN last_seen_ts BIGINT NOT NULL;
ALTER TABLE device_devices ADD COLUMN ip TEXT;
ALTER TABLE device_devices ADD COLUMN user_agent TEXT;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
ALTER TABLE device_devices DROP COLUMN last_seen_ts;
ALTER TABLE device_devices DROP COLUMN ip;
ALTER TABLE device_devices DROP COLUMN user_agent;
-- +goose StatementEnd

View file

@ -55,7 +55,9 @@ CREATE TABLE IF NOT EXISTS device_devices (
-- The time the device was last used, as a unix timestamp (ms resolution).
last_seen_ts BIGINT NOT NULL,
-- The last seen IP address of this device
ip TEXT
ip TEXT,
-- User agent of this device
user_agent TEXT
-- TODO: device keys, device display names, token restrictions (if 3rd-party OAuth app)
);
@ -65,7 +67,7 @@ CREATE UNIQUE INDEX IF NOT EXISTS device_localpart_id_idx ON device_devices(loca
`
const insertDeviceSQL = "" +
"INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name, last_seen_ts, ip) VALUES ($1, $2, $3, $4, $5, $6, $7)" +
"INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" +
" RETURNING session_id"
const selectDeviceByTokenSQL = "" +
@ -153,12 +155,12 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
// Returns the device on success.
func (s *devicesStatements) insertDevice(
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
displayName *string, ipAddr string,
displayName *string, ipAddr, userAgent string,
) (*api.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64
stmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr).Scan(&sessionID); err != nil {
if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil {
return nil, err
}
return &api.Device{

View file

@ -83,7 +83,7 @@ func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]ap
// Returns the device on success.
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string, ipAddr string,
displayName *string, ipAddr, userAgent string,
) (dev *api.Device, returnErr error) {
if deviceID != nil {
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
@ -93,7 +93,7 @@ func (d *Database) CreateDevice(
return err
}
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr)
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
} else {
@ -108,7 +108,7 @@ func (d *Database) CreateDevice(
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr)
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
if returnErr == nil {

View file

@ -10,13 +10,14 @@ CREATE TABLE device_devices (
display_name TEXT,
last_seen_ts BIGINT,
ip TEXT,
user_agent TEXT,
UNIQUE (localpart, device_id)
);
INSERT
INTO device_devices (
access_token, session_id, device_id, localpart, created_ts, display_name, last_seen_ts, ip
access_token, session_id, device_id, localpart, created_ts, display_name, last_seen_ts, ip, user_agent
) SELECT
access_token, session_id, device_id, localpart, created_ts, display_name, created_ts, ''
access_token, session_id, device_id, localpart, created_ts, display_name, created_ts, '', ''
FROM device_devices_tmp;
DROP TABLE device_devices_tmp;
-- +goose StatementEnd

View file

@ -42,14 +42,15 @@ CREATE TABLE IF NOT EXISTS device_devices (
display_name TEXT,
last_seen_ts BIGINT,
ip TEXT,
user_agent TEXT,
UNIQUE (localpart, device_id)
);
`
const insertDeviceSQL = "" +
"INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip)" +
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
"INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
const selectDevicesCountSQL = "" +
"SELECT COUNT(access_token) FROM device_devices"
@ -143,7 +144,7 @@ func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server go
// Returns the device on success.
func (s *devicesStatements) insertDevice(
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
displayName *string, ipAddr string,
displayName *string, ipAddr, userAgent string,
) (*api.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64
@ -153,7 +154,7 @@ func (s *devicesStatements) insertDevice(
return nil, err
}
sessionID++
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr); err != nil {
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
return nil, err
}
return &api.Device{

View file

@ -87,7 +87,7 @@ func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]ap
// Returns the device on success.
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string, ipAddr string,
displayName *string, ipAddr, userAgent string,
) (dev *api.Device, returnErr error) {
if deviceID != nil {
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
@ -97,7 +97,7 @@ func (d *Database) CreateDevice(
return err
}
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr)
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
} else {
@ -112,7 +112,7 @@ func (d *Database) CreateDevice(
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr)
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
if returnErr == nil {