From 3c23beb8040fc201835ba6267482dfbd16387fd6 Mon Sep 17 00:00:00 2001 From: Till Faelligen Date: Sat, 21 Dec 2019 22:46:49 +0100 Subject: [PATCH] Implement missing device management features Signed-off-by: Till Faelligen --- .../auth/storage/devices/devices_table.go | 24 ++++++++- clientapi/auth/storage/devices/storage.go | 15 ++++++ clientapi/routing/device.go | 51 +++++++++++++++++++ clientapi/routing/login.go | 9 +++- clientapi/routing/routing.go | 16 ++++++ 5 files changed, 111 insertions(+), 4 deletions(-) diff --git a/clientapi/auth/storage/devices/devices_table.go b/clientapi/auth/storage/devices/devices_table.go index d011d25c9..c5773ce39 100644 --- a/clientapi/auth/storage/devices/devices_table.go +++ b/clientapi/auth/storage/devices/devices_table.go @@ -19,10 +19,10 @@ import ( "database/sql" "time" - "github.com/matrix-org/dendrite/common" - + "github.com/lib/pq" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/gomatrixserverlib" ) @@ -80,6 +80,9 @@ const deleteDeviceSQL = "" + const deleteDevicesByLocalpartSQL = "" + "DELETE FROM device_devices WHERE localpart = $1" +const deleteDevicesSQL = "" + + "DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)" + type devicesStatements struct { insertDeviceStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt @@ -88,6 +91,7 @@ type devicesStatements struct { updateDeviceNameStmt *sql.Stmt deleteDeviceStmt *sql.Stmt deleteDevicesByLocalpartStmt *sql.Stmt + deleteDevicesStmt *sql.Stmt serverName gomatrixserverlib.ServerName } @@ -117,6 +121,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil { return } + if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil { + return + } s.serverName = server return } @@ -142,6 +149,7 @@ func (s *devicesStatements) insertDevice( }, nil } +// deleteDevice removes a single device by id and user localpart. func (s *devicesStatements) deleteDevice( ctx context.Context, txn *sql.Tx, id, localpart string, ) error { @@ -150,6 +158,18 @@ func (s *devicesStatements) deleteDevice( return err } +// deleteDevices removes a single or multiple devices by ids and user localpart. +// Returns an error if the execution failed. +func (s *devicesStatements) deleteDevices( + ctx context.Context, txn *sql.Tx, localpart string, devices []string, +) error { + stmt := common.TxStmt(txn, s.deleteDevicesStmt) + _, err := stmt.ExecContext(ctx, localpart, pq.Array(devices)) + return err +} + +// deleteDevicesByLocalpart removes all devices for the +// given user localpart. func (s *devicesStatements) deleteDevicesByLocalpart( ctx context.Context, txn *sql.Tx, localpart string, ) error { diff --git a/clientapi/auth/storage/devices/storage.go b/clientapi/auth/storage/devices/storage.go index 82c8e97a2..150180c1e 100644 --- a/clientapi/auth/storage/devices/storage.go +++ b/clientapi/auth/storage/devices/storage.go @@ -152,6 +152,21 @@ func (d *Database) RemoveDevice( }) } +// RemoveDevices revokes one or more devices by deleting the entry in the database +// matching with the given device IDs and user ID localpart. +// If the devices don't exist, it will not return an error +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveDevices( + ctx context.Context, localpart string, devices []string, +) error { + return common.WithTransaction(d.db, func(txn *sql.Tx) error { + if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { + return err + } + return nil + }) +} + // RemoveAllDevices revokes devices by deleting the entry in the // database matching the given user ID localpart. // If something went wrong during the deletion, it will return the SQL error. diff --git a/clientapi/routing/device.go b/clientapi/routing/device.go index c858e88aa..23f802529 100644 --- a/clientapi/routing/device.go +++ b/clientapi/routing/device.go @@ -146,3 +146,54 @@ func UpdateDeviceByID( JSON: struct{}{}, } } + +func DeleteDeviceById( + req *http.Request, deviceDB *devices.Database, device *authtypes.Device, + deviceID string, +) util.JSONResponse { + localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + if err != nil { + return httputil.LogThenError(req, err) + } + ctx := req.Context() + + defer req.Body.Close() // nolint: errcheck + + if err := deviceDB.RemoveDevice(ctx, deviceID, localpart); err != nil { + return httputil.LogThenError(req, err) + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } +} + +func DeleteDevices( + req *http.Request, deviceDB *devices.Database, device *authtypes.Device, +) util.JSONResponse { + localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + if err != nil { + return httputil.LogThenError(req, err) + } + + ctx := req.Context() + d := struct { + Devices []string `json:"devices"` + }{} + + if err := json.NewDecoder(req.Body).Decode(&d); err != nil { + return httputil.LogThenError(req, err) + } + + defer req.Body.Close() // nolint: errcheck + + if err := deviceDB.RemoveDevices(ctx, localpart, d.Devices); err != nil { + return httputil.LogThenError(req, err) + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } +} diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index 02d958152..25b003032 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -47,6 +47,9 @@ type passwordRequest struct { // Thus a pointer is needed to differentiate between the two InitialDisplayName *string `json:"initial_device_display_name"` DeviceID *string `json:"device_id"` + Identifier struct { + User string `json:"user"` + } `json:"identifier"` } type loginResponse struct { @@ -79,13 +82,15 @@ func Login( if resErr != nil { return *resErr } - if r.User == "" { + if r.User == "" && r.Identifier.User == "" { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.BadJSON("'user' must be supplied."), } } - + if r.User == "" { + r.User = r.Identifier.User + } util.GetLogger(req.Context()).WithField("user", r.User).Info("Processing login request") localpart, err := userutil.ParseUsernameParam(r.User, &cfg.Matrix.ServerName) diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 4a36661db..b42b2eecf 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -483,6 +483,22 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) + r0mux.Handle("/devices/{deviceID}", + common.MakeAuthAPI("delete_device", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + vars, err := common.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return DeleteDeviceById(req, deviceDB, device, vars["deviceID"]) + }), + ).Methods(http.MethodDelete, http.MethodOptions) + + r0mux.Handle("/delete_devices", + common.MakeAuthAPI("delete_devices", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + return DeleteDevices(req, deviceDB, device) + }), + ).Methods(http.MethodPost, http.MethodOptions) + // Stub implementations for sytest r0mux.Handle("/events", common.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse {