diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index 772775aa0..6b4754530 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -79,7 +79,7 @@ func Login( return *authErr } // make a device/access token - return completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login) + return completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login, req.RemoteAddr) } return util.JSONResponse{ Code: http.StatusMethodNotAllowed, @@ -88,7 +88,7 @@ func Login( } func completeAuth( - ctx context.Context, serverName gomatrixserverlib.ServerName, userAPI userapi.UserInternalAPI, login *auth.Login, + ctx context.Context, serverName gomatrixserverlib.ServerName, userAPI userapi.UserInternalAPI, login *auth.Login, ipAddr string, ) util.JSONResponse { token, err := auth.GenerateAccessToken() if err != nil { @@ -108,6 +108,7 @@ func completeAuth( DeviceID: login.DeviceID, AccessToken: token, Localpart: localpart, + IPAddr: ipAddr, }, &performRes) if err != nil { return util.JSONResponse{ diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 937abc83d..dff970923 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -543,6 +543,7 @@ func handleGuestRegistration( Localpart: res.Account.Localpart, DeviceDisplayName: r.InitialDisplayName, AccessToken: token, + IPAddr: req.RemoteAddr, }, &devRes) if err != nil { return util.JSONResponse{ @@ -691,7 +692,7 @@ func handleApplicationServiceRegistration( // Don't need to worry about appending to registration stages as // application service registration is entirely separate. return completeRegistration( - req.Context(), userAPI, r.Username, "", appserviceID, + req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, ) } @@ -710,7 +711,7 @@ func checkAndCompleteFlow( if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) { // This flow was completed, registration can continue return completeRegistration( - req.Context(), userAPI, r.Username, r.Password, "", + req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, ) } @@ -762,10 +763,10 @@ func LegacyRegister( return util.MessageResponse(http.StatusForbidden, "HMAC incorrect") } - return completeRegistration(req.Context(), userAPI, r.Username, r.Password, "", false, nil, nil) + return completeRegistration(req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, false, nil, nil) case authtypes.LoginTypeDummy: // there is nothing to do - return completeRegistration(req.Context(), userAPI, r.Username, r.Password, "", false, nil, nil) + return completeRegistration(req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, false, nil, nil) default: return util.JSONResponse{ Code: http.StatusNotImplemented, @@ -812,7 +813,7 @@ func parseAndValidateLegacyLogin(req *http.Request, r *legacyRegisterRequest) *u func completeRegistration( ctx context.Context, userAPI userapi.UserInternalAPI, - username, password, appserviceID string, + username, password, appserviceID, ipAddr string, inhibitLogin eventutil.WeakBoolean, displayName, deviceID *string, ) util.JSONResponse { @@ -880,6 +881,7 @@ func completeRegistration( AccessToken: token, DeviceDisplayName: displayName, DeviceID: deviceID, + IPAddr: ipAddr, }, &devRes) if err != nil { return util.JSONResponse{ diff --git a/cmd/create-account/main.go b/cmd/create-account/main.go index 73e223d61..d3e7f3fb3 100644 --- a/cmd/create-account/main.go +++ b/cmd/create-account/main.go @@ -92,7 +92,7 @@ func main() { } device, err := deviceDB.CreateDevice( - context.Background(), *username, nil, *accessToken, nil, + context.Background(), *username, nil, *accessToken, nil, "127.0.0.1", ) if err != nil { fmt.Println(err.Error()) diff --git a/userapi/api/api.go b/userapi/api/api.go index 3baaa1002..365081cef 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -191,6 +191,8 @@ type PerformDeviceCreationRequest struct { DeviceID *string // optional: if nil no display name will be associated with this device. DeviceDisplayName *string + // IP address of this device + IPAddr string } // PerformDeviceCreationResponse is the response for PerformDeviceCreation @@ -211,6 +213,8 @@ type Device struct { // associated with access tokens. SessionID int64 DisplayName string + LastSeen int64 + IPAddr string } // Account represents a Matrix account on this home server. diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 461c548cc..3dd294c40 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -113,7 +113,7 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe "device_id": req.DeviceID, "display_name": req.DeviceDisplayName, }).Info("PerformDeviceCreation") - dev, err := a.DeviceDB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName) + dev, err := a.DeviceDB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr) if err != nil { return err } diff --git a/userapi/storage/devices/interface.go b/userapi/storage/devices/interface.go index 168c84c5c..12607e2e8 100644 --- a/userapi/storage/devices/interface.go +++ b/userapi/storage/devices/interface.go @@ -31,7 +31,7 @@ type Database interface { // an error will be returned. // If no device ID is given one is generated. // Returns the device on success. - CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string) (dev *api.Device, returnErr error) + CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr string) (dev *api.Device, returnErr error) UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error RemoveDevice(ctx context.Context, deviceID, localpart string) error RemoveDevices(ctx context.Context, localpart string, devices []string) error diff --git a/userapi/storage/devices/postgres/devices_table.go b/userapi/storage/devices/postgres/devices_table.go index c06af7549..ce17b27cc 100644 --- a/userapi/storage/devices/postgres/devices_table.go +++ b/userapi/storage/devices/postgres/devices_table.go @@ -51,8 +51,13 @@ CREATE TABLE IF NOT EXISTS device_devices ( -- When this devices was first recognised on the network, as a unix timestamp (ms resolution). created_ts BIGINT NOT NULL, -- The display name, human friendlier than device_id and updatable - display_name TEXT - -- TODO: device keys, device display names, last used ts and IP address?, token restrictions (if 3rd-party OAuth app) + display_name TEXT, + -- The time the device was last used, as a unix timestamp (ms resolution). + last_used_ts BIGINT NOT NULL, + -- The last seen IP address of this device + ip TEXT + + -- TODO: device keys, device display names, token restrictions (if 3rd-party OAuth app) ); -- Device IDs must be unique for a given user. @@ -60,7 +65,7 @@ CREATE UNIQUE INDEX IF NOT EXISTS device_localpart_id_idx ON device_devices(loca ` const insertDeviceSQL = "" + - "INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name) VALUES ($1, $2, $3, $4, $5)" + + "INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name, last_used_ts, ip) VALUES ($1, $2, $3, $4, $5, $6, $7)" + " RETURNING session_id" const selectDeviceByTokenSQL = "" + @@ -141,12 +146,12 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN // Returns the device on success. func (s *devicesStatements) insertDevice( ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, - displayName *string, + displayName *string, ipAddr string, ) (*api.Device, error) { createdTimeMS := time.Now().UnixNano() / 1000000 var sessionID int64 stmt := sqlutil.TxStmt(txn, s.insertDeviceStmt) - if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName).Scan(&sessionID); err != nil { + if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr).Scan(&sessionID); err != nil { return nil, err } return &api.Device{ @@ -154,6 +159,8 @@ func (s *devicesStatements) insertDevice( UserID: userutil.MakeUserID(localpart, s.serverName), AccessToken: accessToken, SessionID: sessionID, + LastSeen: createdTimeMS, + IPAddr: ipAddr, }, nil } diff --git a/userapi/storage/devices/postgres/storage.go b/userapi/storage/devices/postgres/storage.go index c5bd5b6cf..9d2beba92 100644 --- a/userapi/storage/devices/postgres/storage.go +++ b/userapi/storage/devices/postgres/storage.go @@ -83,7 +83,7 @@ func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]ap // Returns the device on success. func (d *Database) CreateDevice( ctx context.Context, localpart string, deviceID *string, accessToken string, - displayName *string, + displayName *string, ipAddr string, ) (dev *api.Device, returnErr error) { if deviceID != nil { returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { @@ -93,7 +93,7 @@ func (d *Database) CreateDevice( return err } - dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName) + dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr) return err }) } else { @@ -108,7 +108,7 @@ func (d *Database) CreateDevice( returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { var err error - dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName) + dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr) return err }) if returnErr == nil { diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go index c75e19825..215dec0ea 100644 --- a/userapi/storage/devices/sqlite3/devices_table.go +++ b/userapi/storage/devices/sqlite3/devices_table.go @@ -40,14 +40,16 @@ CREATE TABLE IF NOT EXISTS device_devices ( localpart TEXT , created_ts BIGINT, display_name TEXT, + last_used_ts BIGINT, + ip TEXT, UNIQUE (localpart, device_id) ); ` const insertDeviceSQL = "" + - "INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id)" + - " VALUES ($1, $2, $3, $4, $5, $6)" + "INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_used_ts, ip)" + + " VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" const selectDevicesCountSQL = "" + "SELECT COUNT(access_token) FROM device_devices" @@ -134,7 +136,7 @@ func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server go // Returns the device on success. func (s *devicesStatements) insertDevice( ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, - displayName *string, + displayName *string, ipAddr string, ) (*api.Device, error) { createdTimeMS := time.Now().UnixNano() / 1000000 var sessionID int64 @@ -144,7 +146,7 @@ func (s *devicesStatements) insertDevice( return nil, err } sessionID++ - if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil { + if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr); err != nil { return nil, err } return &api.Device{ @@ -152,6 +154,8 @@ func (s *devicesStatements) insertDevice( UserID: userutil.MakeUserID(localpart, s.serverName), AccessToken: accessToken, SessionID: sessionID, + LastSeen: createdTimeMS, + IPAddr: ipAddr, }, nil } diff --git a/userapi/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go index 7c6645dd6..d18d9b3af 100644 --- a/userapi/storage/devices/sqlite3/storage.go +++ b/userapi/storage/devices/sqlite3/storage.go @@ -87,7 +87,7 @@ func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]ap // Returns the device on success. func (d *Database) CreateDevice( ctx context.Context, localpart string, deviceID *string, accessToken string, - displayName *string, + displayName *string, ipAddr string, ) (dev *api.Device, returnErr error) { if deviceID != nil { returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { @@ -97,7 +97,7 @@ func (d *Database) CreateDevice( return err } - dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName) + dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr) return err }) } else { @@ -112,7 +112,7 @@ func (d *Database) CreateDevice( returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { var err error - dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName) + dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr) return err }) if returnErr == nil {