From 33f754b80c2b777447618c925963f88c34fe0068 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul=20T=C3=B6tterman?= Date: Mon, 13 Nov 2017 19:47:25 +0200 Subject: [PATCH] Respond to review feedback --- .../auth/storage/devices/devices_table.go | 9 +++--- .../clientapi/auth/storage/devices/storage.go | 7 +++-- .../dendrite/clientapi/routing/device.go | 29 +++++++++++++++---- .../dendrite/clientapi/routing/login.go | 6 ++-- .../dendrite/clientapi/routing/register.go | 11 ++++--- .../dendrite/cmd/create-account/main.go | 2 +- 6 files changed, 42 insertions(+), 22 deletions(-) diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/devices_table.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/devices_table.go index 80d12abe7..903471afe 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/devices_table.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/devices_table.go @@ -63,7 +63,7 @@ const selectDevicesByLocalpartSQL = "" + "SELECT device_id, display_name FROM device_devices WHERE localpart = $1" const updateDeviceNameSQL = "" + - "UPDATE device_devices SET display_name = $1 WHERE device_id = $2" + "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" const deleteDeviceSQL = "" + "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2" @@ -116,7 +116,8 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN // Returns an error if the user already has a device with the given device ID. // Returns the device on success. func (s *devicesStatements) insertDevice( - ctx context.Context, txn *sql.Tx, id, localpart, accessToken, displayName string, + ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, + displayName *string, ) (*authtypes.Device, error) { createdTimeMS := time.Now().UnixNano() / 1000000 stmt := common.TxStmt(txn, s.insertDeviceStmt) @@ -147,10 +148,10 @@ func (s *devicesStatements) deleteDevicesByLocalpart( } func (s *devicesStatements) updateDeviceName( - ctx context.Context, txn *sql.Tx, deviceID, displayName string, + ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, ) error { stmt := common.TxStmt(txn, s.updateDeviceNameStmt) - _, err := stmt.ExecContext(ctx, displayName, deviceID) + _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID) return err } diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/storage.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/storage.go index be3d34458..6ac475a66 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/storage.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/storage.go @@ -74,7 +74,8 @@ func (d *Database) GetDevicesByLocalpart( // If no device ID is given one is generated. // Returns the device on success. func (d *Database) CreateDevice( - ctx context.Context, localpart string, deviceID *string, accessToken, displayName string, + ctx context.Context, localpart string, deviceID *string, accessToken string, + displayName *string, ) (dev *authtypes.Device, returnErr error) { if deviceID != nil { returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { @@ -113,10 +114,10 @@ func (d *Database) CreateDevice( // UpdateDevice updates the given device with the display name. // Returns SQL error if there are problems and nil on success. func (d *Database) UpdateDevice( - ctx context.Context, deviceID, displayName string, + ctx context.Context, localpart, deviceID string, displayName *string, ) error { return common.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.devices.updateDeviceName(ctx, txn, deviceID, displayName) + return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName) }) } diff --git a/src/github.com/matrix-org/dendrite/clientapi/routing/device.go b/src/github.com/matrix-org/dendrite/clientapi/routing/device.go index 6531da61e..86e393be1 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/routing/device.go +++ b/src/github.com/matrix-org/dendrite/clientapi/routing/device.go @@ -37,7 +37,7 @@ type devicesJSON struct { } type deviceUpdateJSON struct { - DisplayName string `json:"display_name"` + DisplayName *string `json:"display_name"` } // GetDeviceByID handles /device/{deviceID} @@ -113,7 +113,28 @@ func UpdateDeviceByID( } } - // TODO: who should be able to update device displayName? + localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + if err != nil { + return httputil.LogThenError(req, err) + } + + ctx := req.Context() + dev, err := deviceDB.GetDeviceByID(ctx, localpart, deviceID) + if err == sql.ErrNoRows { + return util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound("Unknown device"), + } + } else if err != nil { + return httputil.LogThenError(req, err) + } + + if dev.UserID != device.UserID { + return util.JSONResponse{ + Code: 403, + JSON: jsonerror.Forbidden("device not owned by current user"), + } + } defer req.Body.Close() // nolint: errcheck @@ -123,9 +144,7 @@ func UpdateDeviceByID( return httputil.LogThenError(req, err) } - ctx := req.Context() - - if err := deviceDB.UpdateDevice(ctx, deviceID, payload.DisplayName); err != nil { + if err := deviceDB.UpdateDevice(ctx, localpart, deviceID, payload.DisplayName); err != nil { return httputil.LogThenError(req, err) } diff --git a/src/github.com/matrix-org/dendrite/clientapi/routing/login.go b/src/github.com/matrix-org/dendrite/clientapi/routing/login.go index 3b50ed122..56c67b77d 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/routing/login.go +++ b/src/github.com/matrix-org/dendrite/clientapi/routing/login.go @@ -38,9 +38,9 @@ type flow struct { } type passwordRequest struct { - User string `json:"user"` - Password string `json:"password"` - InitialDisplayName string `json:"initial_device_display_name"` + User string `json:"user"` + Password string `json:"password"` + InitialDisplayName *string `json:"initial_device_display_name"` } type loginResponse struct { diff --git a/src/github.com/matrix-org/dendrite/clientapi/routing/register.go b/src/github.com/matrix-org/dendrite/clientapi/routing/register.go index efb9bc6f3..875ceb049 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/routing/register.go +++ b/src/github.com/matrix-org/dendrite/clientapi/routing/register.go @@ -61,7 +61,7 @@ type registerRequest struct { // user-interactive auth params Auth authDict `json:"auth"` - InitialDisplayName string `json:"initial_device_display_name"` + InitialDisplayName *string `json:"initial_device_display_name"` } type authDict struct { @@ -272,12 +272,10 @@ func LegacyRegister( return util.MessageResponse(403, "HMAC incorrect") } - // TODO: does the legacy registration request support initial - // display name? - return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "legacy registration") + return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, nil) case authtypes.LoginTypeDummy: // there is nothing to do - return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "legacy registration") + return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, nil) default: return util.JSONResponse{ Code: 501, @@ -290,7 +288,8 @@ func completeRegistration( ctx context.Context, accountDB *accounts.Database, deviceDB *devices.Database, - username, password, displayName string, + username, password string, + displayName *string, ) util.JSONResponse { if username == "" { return util.JSONResponse{ diff --git a/src/github.com/matrix-org/dendrite/cmd/create-account/main.go b/src/github.com/matrix-org/dendrite/cmd/create-account/main.go index 1e163a5b0..7914a6266 100644 --- a/src/github.com/matrix-org/dendrite/cmd/create-account/main.go +++ b/src/github.com/matrix-org/dendrite/cmd/create-account/main.go @@ -87,7 +87,7 @@ func main() { } device, err := deviceDB.CreateDevice( - context.Background(), *username, nil, *accessToken, "by-create-account", + context.Background(), *username, nil, *accessToken, nil, ) if err != nil { fmt.Println(err.Error())