Respond to review feedback

This commit is contained in:
Paul Tötterman 2017-11-13 19:47:25 +02:00
parent cb011212df
commit 33f754b80c
6 changed files with 42 additions and 22 deletions

View file

@ -63,7 +63,7 @@ const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name FROM device_devices WHERE localpart = $1"
const updateDeviceNameSQL = "" +
"UPDATE device_devices SET display_name = $1 WHERE device_id = $2"
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
const deleteDeviceSQL = "" +
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
@ -116,7 +116,8 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
// Returns an error if the user already has a device with the given device ID.
// Returns the device on success.
func (s *devicesStatements) insertDevice(
ctx context.Context, txn *sql.Tx, id, localpart, accessToken, displayName string,
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
displayName *string,
) (*authtypes.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
stmt := common.TxStmt(txn, s.insertDeviceStmt)
@ -147,10 +148,10 @@ func (s *devicesStatements) deleteDevicesByLocalpart(
}
func (s *devicesStatements) updateDeviceName(
ctx context.Context, txn *sql.Tx, deviceID, displayName string,
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
) error {
stmt := common.TxStmt(txn, s.updateDeviceNameStmt)
_, err := stmt.ExecContext(ctx, displayName, deviceID)
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
return err
}

View file

@ -74,7 +74,8 @@ func (d *Database) GetDevicesByLocalpart(
// If no device ID is given one is generated.
// Returns the device on success.
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken, displayName string,
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string,
) (dev *authtypes.Device, returnErr error) {
if deviceID != nil {
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
@ -113,10 +114,10 @@ func (d *Database) CreateDevice(
// UpdateDevice updates the given device with the display name.
// Returns SQL error if there are problems and nil on success.
func (d *Database) UpdateDevice(
ctx context.Context, deviceID, displayName string,
ctx context.Context, localpart, deviceID string, displayName *string,
) error {
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.devices.updateDeviceName(ctx, txn, deviceID, displayName)
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
})
}

View file

@ -37,7 +37,7 @@ type devicesJSON struct {
}
type deviceUpdateJSON struct {
DisplayName string `json:"display_name"`
DisplayName *string `json:"display_name"`
}
// GetDeviceByID handles /device/{deviceID}
@ -113,7 +113,28 @@ func UpdateDeviceByID(
}
}
// TODO: who should be able to update device displayName?
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
return httputil.LogThenError(req, err)
}
ctx := req.Context()
dev, err := deviceDB.GetDeviceByID(ctx, localpart, deviceID)
if err == sql.ErrNoRows {
return util.JSONResponse{
Code: 404,
JSON: jsonerror.NotFound("Unknown device"),
}
} else if err != nil {
return httputil.LogThenError(req, err)
}
if dev.UserID != device.UserID {
return util.JSONResponse{
Code: 403,
JSON: jsonerror.Forbidden("device not owned by current user"),
}
}
defer req.Body.Close() // nolint: errcheck
@ -123,9 +144,7 @@ func UpdateDeviceByID(
return httputil.LogThenError(req, err)
}
ctx := req.Context()
if err := deviceDB.UpdateDevice(ctx, deviceID, payload.DisplayName); err != nil {
if err := deviceDB.UpdateDevice(ctx, localpart, deviceID, payload.DisplayName); err != nil {
return httputil.LogThenError(req, err)
}

View file

@ -40,7 +40,7 @@ type flow struct {
type passwordRequest struct {
User string `json:"user"`
Password string `json:"password"`
InitialDisplayName string `json:"initial_device_display_name"`
InitialDisplayName *string `json:"initial_device_display_name"`
}
type loginResponse struct {

View file

@ -61,7 +61,7 @@ type registerRequest struct {
// user-interactive auth params
Auth authDict `json:"auth"`
InitialDisplayName string `json:"initial_device_display_name"`
InitialDisplayName *string `json:"initial_device_display_name"`
}
type authDict struct {
@ -272,12 +272,10 @@ func LegacyRegister(
return util.MessageResponse(403, "HMAC incorrect")
}
// TODO: does the legacy registration request support initial
// display name?
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "legacy registration")
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, nil)
case authtypes.LoginTypeDummy:
// there is nothing to do
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "legacy registration")
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, nil)
default:
return util.JSONResponse{
Code: 501,
@ -290,7 +288,8 @@ func completeRegistration(
ctx context.Context,
accountDB *accounts.Database,
deviceDB *devices.Database,
username, password, displayName string,
username, password string,
displayName *string,
) util.JSONResponse {
if username == "" {
return util.JSONResponse{

View file

@ -87,7 +87,7 @@ func main() {
}
device, err := deviceDB.CreateDevice(
context.Background(), *username, nil, *accessToken, "by-create-account",
context.Background(), *username, nil, *accessToken, nil,
)
if err != nil {
fmt.Println(err.Error())