Extend device_devices table (#1471)

* Add last_used_ts and IP to database

* Add migrations

* Rename column
Prepare statements

* Add interface method and implement it

Signed-off-by: Till Faelligen <tfaelligen@gmail.com>

* Rename struct fields

* Add user_agent to database

* Add userAgent to registration calls

* Add missing "IF NOT EXISTS"

* use txn writer

* Add UserAgent to Device

Co-authored-by: Kegsay <kegan@matrix.org>
This commit is contained in:
S7evinK 2020-10-09 10:17:23 +02:00 committed by GitHub
parent c4c8bfd027
commit 1cd525ef0d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 154 additions and 24 deletions

View file

@ -79,7 +79,7 @@ func Login(
return *authErr return *authErr
} }
// make a device/access token // make a device/access token
return completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login) return completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login, req.RemoteAddr, req.UserAgent())
} }
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusMethodNotAllowed, Code: http.StatusMethodNotAllowed,
@ -89,6 +89,7 @@ func Login(
func completeAuth( func completeAuth(
ctx context.Context, serverName gomatrixserverlib.ServerName, userAPI userapi.UserInternalAPI, login *auth.Login, ctx context.Context, serverName gomatrixserverlib.ServerName, userAPI userapi.UserInternalAPI, login *auth.Login,
ipAddr, userAgent string,
) util.JSONResponse { ) util.JSONResponse {
token, err := auth.GenerateAccessToken() token, err := auth.GenerateAccessToken()
if err != nil { if err != nil {
@ -108,6 +109,8 @@ func completeAuth(
DeviceID: login.DeviceID, DeviceID: login.DeviceID,
AccessToken: token, AccessToken: token,
Localpart: localpart, Localpart: localpart,
IPAddr: ipAddr,
UserAgent: userAgent,
}, &performRes) }, &performRes)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{

View file

@ -543,6 +543,8 @@ func handleGuestRegistration(
Localpart: res.Account.Localpart, Localpart: res.Account.Localpart,
DeviceDisplayName: r.InitialDisplayName, DeviceDisplayName: r.InitialDisplayName,
AccessToken: token, AccessToken: token,
IPAddr: req.RemoteAddr,
UserAgent: req.UserAgent(),
}, &devRes) }, &devRes)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
@ -691,7 +693,7 @@ func handleApplicationServiceRegistration(
// Don't need to worry about appending to registration stages as // Don't need to worry about appending to registration stages as
// application service registration is entirely separate. // application service registration is entirely separate.
return completeRegistration( return completeRegistration(
req.Context(), userAPI, r.Username, "", appserviceID, req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(),
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, r.InhibitLogin, r.InitialDisplayName, r.DeviceID,
) )
} }
@ -710,7 +712,7 @@ func checkAndCompleteFlow(
if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) { if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) {
// This flow was completed, registration can continue // This flow was completed, registration can continue
return completeRegistration( return completeRegistration(
req.Context(), userAPI, r.Username, r.Password, "", req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(),
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, r.InhibitLogin, r.InitialDisplayName, r.DeviceID,
) )
} }
@ -762,10 +764,10 @@ func LegacyRegister(
return util.MessageResponse(http.StatusForbidden, "HMAC incorrect") return util.MessageResponse(http.StatusForbidden, "HMAC incorrect")
} }
return completeRegistration(req.Context(), userAPI, r.Username, r.Password, "", false, nil, nil) return completeRegistration(req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), false, nil, nil)
case authtypes.LoginTypeDummy: case authtypes.LoginTypeDummy:
// there is nothing to do // there is nothing to do
return completeRegistration(req.Context(), userAPI, r.Username, r.Password, "", false, nil, nil) return completeRegistration(req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), false, nil, nil)
default: default:
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusNotImplemented, Code: http.StatusNotImplemented,
@ -812,7 +814,7 @@ func parseAndValidateLegacyLogin(req *http.Request, r *legacyRegisterRequest) *u
func completeRegistration( func completeRegistration(
ctx context.Context, ctx context.Context,
userAPI userapi.UserInternalAPI, userAPI userapi.UserInternalAPI,
username, password, appserviceID string, username, password, appserviceID, ipAddr, userAgent string,
inhibitLogin eventutil.WeakBoolean, inhibitLogin eventutil.WeakBoolean,
displayName, deviceID *string, displayName, deviceID *string,
) util.JSONResponse { ) util.JSONResponse {
@ -880,6 +882,8 @@ func completeRegistration(
AccessToken: token, AccessToken: token,
DeviceDisplayName: displayName, DeviceDisplayName: displayName,
DeviceID: deviceID, DeviceID: deviceID,
IPAddr: ipAddr,
UserAgent: userAgent,
}, &devRes) }, &devRes)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{

View file

@ -92,7 +92,7 @@ func main() {
} }
device, err := deviceDB.CreateDevice( device, err := deviceDB.CreateDevice(
context.Background(), *username, nil, *accessToken, nil, context.Background(), *username, nil, *accessToken, nil, "127.0.0.1", "",
) )
if err != nil { if err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())

View file

@ -192,6 +192,10 @@ type PerformDeviceCreationRequest struct {
DeviceID *string DeviceID *string
// optional: if nil no display name will be associated with this device. // optional: if nil no display name will be associated with this device.
DeviceDisplayName *string DeviceDisplayName *string
// IP address of this device
IPAddr string
// Useragent for this device
UserAgent string
} }
// PerformDeviceCreationResponse is the response for PerformDeviceCreation // PerformDeviceCreationResponse is the response for PerformDeviceCreation
@ -222,6 +226,9 @@ type Device struct {
// associated with access tokens. // associated with access tokens.
SessionID int64 SessionID int64
DisplayName string DisplayName string
LastSeenTS int64
LastSeenIP string
UserAgent string
} }
// Account represents a Matrix account on this home server. // Account represents a Matrix account on this home server.

View file

@ -113,7 +113,7 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
"device_id": req.DeviceID, "device_id": req.DeviceID,
"display_name": req.DeviceDisplayName, "display_name": req.DeviceDisplayName,
}).Info("PerformDeviceCreation") }).Info("PerformDeviceCreation")
dev, err := a.DeviceDB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName) dev, err := a.DeviceDB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
if err != nil { if err != nil {
return err return err
} }

View file

@ -31,10 +31,11 @@ type Database interface {
// an error will be returned. // an error will be returned.
// If no device ID is given one is generated. // If no device ID is given one is generated.
// Returns the device on success. // Returns the device on success.
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string) (dev *api.Device, returnErr error) CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
RemoveDevice(ctx context.Context, deviceID, localpart string) error RemoveDevice(ctx context.Context, deviceID, localpart string) error
RemoveDevices(ctx context.Context, localpart string, devices []string) error RemoveDevices(ctx context.Context, localpart string, devices []string) error
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted. // RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error) RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
UpdateDeviceLastSeen(ctx context.Context, deviceID, ipAddr string) error
} }

View file

@ -0,0 +1,13 @@
-- +goose Up
-- +goose StatementBegin
ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS last_seen_ts BIGINT NOT NULL;
ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS ip TEXT;
ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS user_agent TEXT;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
ALTER TABLE device_devices DROP COLUMN last_seen_ts;
ALTER TABLE device_devices DROP COLUMN ip;
ALTER TABLE device_devices DROP COLUMN user_agent;
-- +goose StatementEnd

View file

@ -51,8 +51,15 @@ CREATE TABLE IF NOT EXISTS device_devices (
-- When this devices was first recognised on the network, as a unix timestamp (ms resolution). -- When this devices was first recognised on the network, as a unix timestamp (ms resolution).
created_ts BIGINT NOT NULL, created_ts BIGINT NOT NULL,
-- The display name, human friendlier than device_id and updatable -- The display name, human friendlier than device_id and updatable
display_name TEXT display_name TEXT,
-- TODO: device keys, device display names, last used ts and IP address?, token restrictions (if 3rd-party OAuth app) -- The time the device was last used, as a unix timestamp (ms resolution).
last_seen_ts BIGINT NOT NULL,
-- The last seen IP address of this device
ip TEXT,
-- User agent of this device
user_agent TEXT
-- TODO: device keys, device display names, token restrictions (if 3rd-party OAuth app)
); );
-- Device IDs must be unique for a given user. -- Device IDs must be unique for a given user.
@ -60,7 +67,7 @@ CREATE UNIQUE INDEX IF NOT EXISTS device_localpart_id_idx ON device_devices(loca
` `
const insertDeviceSQL = "" + const insertDeviceSQL = "" +
"INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name) VALUES ($1, $2, $3, $4, $5)" + "INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" +
" RETURNING session_id" " RETURNING session_id"
const selectDeviceByTokenSQL = "" + const selectDeviceByTokenSQL = "" +
@ -87,6 +94,9 @@ const deleteDevicesSQL = "" +
const selectDevicesByIDSQL = "" + const selectDevicesByIDSQL = "" +
"SELECT device_id, localpart, display_name FROM device_devices WHERE device_id = ANY($1)" "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id = ANY($1)"
const updateDeviceLastSeen = "" +
"UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE device_id = $3"
type devicesStatements struct { type devicesStatements struct {
insertDeviceStmt *sql.Stmt insertDeviceStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt
@ -94,6 +104,7 @@ type devicesStatements struct {
selectDevicesByLocalpartStmt *sql.Stmt selectDevicesByLocalpartStmt *sql.Stmt
selectDevicesByIDStmt *sql.Stmt selectDevicesByIDStmt *sql.Stmt
updateDeviceNameStmt *sql.Stmt updateDeviceNameStmt *sql.Stmt
updateDeviceLastSeenStmt *sql.Stmt
deleteDeviceStmt *sql.Stmt deleteDeviceStmt *sql.Stmt
deleteDevicesByLocalpartStmt *sql.Stmt deleteDevicesByLocalpartStmt *sql.Stmt
deleteDevicesStmt *sql.Stmt deleteDevicesStmt *sql.Stmt
@ -132,6 +143,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil { if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
return return
} }
if s.updateDeviceLastSeenStmt, err = db.Prepare(updateDeviceLastSeen); err != nil {
return
}
s.serverName = server s.serverName = server
return return
} }
@ -141,12 +155,12 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
// Returns the device on success. // Returns the device on success.
func (s *devicesStatements) insertDevice( func (s *devicesStatements) insertDevice(
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
displayName *string, displayName *string, ipAddr, userAgent string,
) (*api.Device, error) { ) (*api.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64 var sessionID int64
stmt := sqlutil.TxStmt(txn, s.insertDeviceStmt) stmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName).Scan(&sessionID); err != nil { if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil {
return nil, err return nil, err
} }
return &api.Device{ return &api.Device{
@ -154,6 +168,9 @@ func (s *devicesStatements) insertDevice(
UserID: userutil.MakeUserID(localpart, s.serverName), UserID: userutil.MakeUserID(localpart, s.serverName),
AccessToken: accessToken, AccessToken: accessToken,
SessionID: sessionID, SessionID: sessionID,
LastSeenTS: createdTimeMS,
LastSeenIP: ipAddr,
UserAgent: userAgent,
}, nil }, nil
} }
@ -280,3 +297,10 @@ func (s *devicesStatements) selectDevicesByLocalpart(
return devices, rows.Err() return devices, rows.Err()
} }
func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, deviceID, ipAddr string) error {
lastSeenTs := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, deviceID)
return err
}

View file

@ -83,7 +83,7 @@ func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]ap
// Returns the device on success. // Returns the device on success.
func (d *Database) CreateDevice( func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string, ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string, displayName *string, ipAddr, userAgent string,
) (dev *api.Device, returnErr error) { ) (dev *api.Device, returnErr error) {
if deviceID != nil { if deviceID != nil {
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
@ -93,7 +93,7 @@ func (d *Database) CreateDevice(
return err return err
} }
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName) dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err return err
}) })
} else { } else {
@ -108,7 +108,7 @@ func (d *Database) CreateDevice(
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName) dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err return err
}) })
if returnErr == nil { if returnErr == nil {
@ -189,3 +189,10 @@ func (d *Database) RemoveAllDevices(
}) })
return return
} }
// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, deviceID, ipAddr string) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.devices.updateDeviceLastSeen(ctx, txn, deviceID, ipAddr)
})
}

View file

@ -0,0 +1,44 @@
-- +goose Up
-- +goose StatementBegin
ALTER TABLE device_devices RENAME TO device_devices_tmp;
CREATE TABLE device_devices (
access_token TEXT PRIMARY KEY,
session_id INTEGER,
device_id TEXT ,
localpart TEXT ,
created_ts BIGINT,
display_name TEXT,
last_seen_ts BIGINT,
ip TEXT,
user_agent TEXT,
UNIQUE (localpart, device_id)
);
INSERT
INTO device_devices (
access_token, session_id, device_id, localpart, created_ts, display_name, last_seen_ts, ip, user_agent
) SELECT
access_token, session_id, device_id, localpart, created_ts, display_name, created_ts, '', ''
FROM device_devices_tmp;
DROP TABLE device_devices_tmp;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
ALTER TABLE device_devices RENAME TO device_devices_tmp;
CREATE TABLE IF NOT EXISTS device_devices (
access_token TEXT PRIMARY KEY,
session_id INTEGER,
device_id TEXT ,
localpart TEXT ,
created_ts BIGINT,
display_name TEXT,
UNIQUE (localpart, device_id)
);
INSERT
INTO device_devices (
access_token, session_id, device_id, localpart, created_ts, display_name
) SELECT
access_token, session_id, device_id, localpart, created_ts, display_name
FROM device_devices_tmp;
DROP TABLE device_devices_tmp;
-- +goose StatementEnd

View file

@ -40,14 +40,17 @@ CREATE TABLE IF NOT EXISTS device_devices (
localpart TEXT , localpart TEXT ,
created_ts BIGINT, created_ts BIGINT,
display_name TEXT, display_name TEXT,
last_seen_ts BIGINT,
ip TEXT,
user_agent TEXT,
UNIQUE (localpart, device_id) UNIQUE (localpart, device_id)
); );
` `
const insertDeviceSQL = "" + const insertDeviceSQL = "" +
"INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id)" + "INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
" VALUES ($1, $2, $3, $4, $5, $6)" " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
const selectDevicesCountSQL = "" + const selectDevicesCountSQL = "" +
"SELECT COUNT(access_token) FROM device_devices" "SELECT COUNT(access_token) FROM device_devices"
@ -76,6 +79,9 @@ const deleteDevicesSQL = "" +
const selectDevicesByIDSQL = "" + const selectDevicesByIDSQL = "" +
"SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)" "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)"
const updateDeviceLastSeen = "" +
"UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE device_id = $3"
type devicesStatements struct { type devicesStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.Writer writer sqlutil.Writer
@ -86,6 +92,7 @@ type devicesStatements struct {
selectDevicesByIDStmt *sql.Stmt selectDevicesByIDStmt *sql.Stmt
selectDevicesByLocalpartStmt *sql.Stmt selectDevicesByLocalpartStmt *sql.Stmt
updateDeviceNameStmt *sql.Stmt updateDeviceNameStmt *sql.Stmt
updateDeviceLastSeenStmt *sql.Stmt
deleteDeviceStmt *sql.Stmt deleteDeviceStmt *sql.Stmt
deleteDevicesByLocalpartStmt *sql.Stmt deleteDevicesByLocalpartStmt *sql.Stmt
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
@ -125,6 +132,9 @@ func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server go
if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil { if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
return return
} }
if s.updateDeviceLastSeenStmt, err = db.Prepare(updateDeviceLastSeen); err != nil {
return
}
s.serverName = server s.serverName = server
return return
} }
@ -134,7 +144,7 @@ func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server go
// Returns the device on success. // Returns the device on success.
func (s *devicesStatements) insertDevice( func (s *devicesStatements) insertDevice(
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
displayName *string, displayName *string, ipAddr, userAgent string,
) (*api.Device, error) { ) (*api.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64 var sessionID int64
@ -144,7 +154,7 @@ func (s *devicesStatements) insertDevice(
return nil, err return nil, err
} }
sessionID++ sessionID++
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil { if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
return nil, err return nil, err
} }
return &api.Device{ return &api.Device{
@ -152,6 +162,9 @@ func (s *devicesStatements) insertDevice(
UserID: userutil.MakeUserID(localpart, s.serverName), UserID: userutil.MakeUserID(localpart, s.serverName),
AccessToken: accessToken, AccessToken: accessToken,
SessionID: sessionID, SessionID: sessionID,
LastSeenTS: createdTimeMS,
LastSeenIP: ipAddr,
UserAgent: userAgent,
}, nil }, nil
} }
@ -288,3 +301,10 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s
} }
return devices, rows.Err() return devices, rows.Err()
} }
func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, deviceID, ipAddr string) error {
lastSeenTs := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, deviceID)
return err
}

View file

@ -87,7 +87,7 @@ func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]ap
// Returns the device on success. // Returns the device on success.
func (d *Database) CreateDevice( func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string, ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string, displayName *string, ipAddr, userAgent string,
) (dev *api.Device, returnErr error) { ) (dev *api.Device, returnErr error) {
if deviceID != nil { if deviceID != nil {
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
@ -97,7 +97,7 @@ func (d *Database) CreateDevice(
return err return err
} }
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName) dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err return err
}) })
} else { } else {
@ -112,7 +112,7 @@ func (d *Database) CreateDevice(
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
var err error var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName) dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err return err
}) })
if returnErr == nil { if returnErr == nil {
@ -193,3 +193,10 @@ func (d *Database) RemoveAllDevices(
}) })
return return
} }
// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, deviceID, ipAddr string) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.devices.updateDeviceLastSeen(ctx, txn, deviceID, ipAddr)
})
}