mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-12 09:23:09 -06:00
Respond to review feedback
This commit is contained in:
parent
cb011212df
commit
33f754b80c
|
|
@ -63,7 +63,7 @@ const selectDevicesByLocalpartSQL = "" +
|
||||||
"SELECT device_id, display_name FROM device_devices WHERE localpart = $1"
|
"SELECT device_id, display_name FROM device_devices WHERE localpart = $1"
|
||||||
|
|
||||||
const updateDeviceNameSQL = "" +
|
const updateDeviceNameSQL = "" +
|
||||||
"UPDATE device_devices SET display_name = $1 WHERE device_id = $2"
|
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
||||||
|
|
||||||
const deleteDeviceSQL = "" +
|
const deleteDeviceSQL = "" +
|
||||||
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
|
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
|
||||||
|
|
@ -116,7 +116,8 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
|
||||||
// Returns an error if the user already has a device with the given device ID.
|
// Returns an error if the user already has a device with the given device ID.
|
||||||
// Returns the device on success.
|
// Returns the device on success.
|
||||||
func (s *devicesStatements) insertDevice(
|
func (s *devicesStatements) insertDevice(
|
||||||
ctx context.Context, txn *sql.Tx, id, localpart, accessToken, displayName string,
|
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
|
||||||
|
displayName *string,
|
||||||
) (*authtypes.Device, error) {
|
) (*authtypes.Device, error) {
|
||||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||||
stmt := common.TxStmt(txn, s.insertDeviceStmt)
|
stmt := common.TxStmt(txn, s.insertDeviceStmt)
|
||||||
|
|
@ -147,10 +148,10 @@ func (s *devicesStatements) deleteDevicesByLocalpart(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) updateDeviceName(
|
func (s *devicesStatements) updateDeviceName(
|
||||||
ctx context.Context, txn *sql.Tx, deviceID, displayName string,
|
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
|
||||||
) error {
|
) error {
|
||||||
stmt := common.TxStmt(txn, s.updateDeviceNameStmt)
|
stmt := common.TxStmt(txn, s.updateDeviceNameStmt)
|
||||||
_, err := stmt.ExecContext(ctx, displayName, deviceID)
|
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -74,7 +74,8 @@ func (d *Database) GetDevicesByLocalpart(
|
||||||
// If no device ID is given one is generated.
|
// If no device ID is given one is generated.
|
||||||
// Returns the device on success.
|
// Returns the device on success.
|
||||||
func (d *Database) CreateDevice(
|
func (d *Database) CreateDevice(
|
||||||
ctx context.Context, localpart string, deviceID *string, accessToken, displayName string,
|
ctx context.Context, localpart string, deviceID *string, accessToken string,
|
||||||
|
displayName *string,
|
||||||
) (dev *authtypes.Device, returnErr error) {
|
) (dev *authtypes.Device, returnErr error) {
|
||||||
if deviceID != nil {
|
if deviceID != nil {
|
||||||
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
|
@ -113,10 +114,10 @@ func (d *Database) CreateDevice(
|
||||||
// UpdateDevice updates the given device with the display name.
|
// UpdateDevice updates the given device with the display name.
|
||||||
// Returns SQL error if there are problems and nil on success.
|
// Returns SQL error if there are problems and nil on success.
|
||||||
func (d *Database) UpdateDevice(
|
func (d *Database) UpdateDevice(
|
||||||
ctx context.Context, deviceID, displayName string,
|
ctx context.Context, localpart, deviceID string, displayName *string,
|
||||||
) error {
|
) error {
|
||||||
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
return d.devices.updateDeviceName(ctx, txn, deviceID, displayName)
|
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,7 @@ type devicesJSON struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type deviceUpdateJSON struct {
|
type deviceUpdateJSON struct {
|
||||||
DisplayName string `json:"display_name"`
|
DisplayName *string `json:"display_name"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDeviceByID handles /device/{deviceID}
|
// GetDeviceByID handles /device/{deviceID}
|
||||||
|
|
@ -113,7 +113,28 @@ func UpdateDeviceByID(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: who should be able to update device displayName?
|
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return httputil.LogThenError(req, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := req.Context()
|
||||||
|
dev, err := deviceDB.GetDeviceByID(ctx, localpart, deviceID)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 404,
|
||||||
|
JSON: jsonerror.NotFound("Unknown device"),
|
||||||
|
}
|
||||||
|
} else if err != nil {
|
||||||
|
return httputil.LogThenError(req, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if dev.UserID != device.UserID {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 403,
|
||||||
|
JSON: jsonerror.Forbidden("device not owned by current user"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
defer req.Body.Close() // nolint: errcheck
|
defer req.Body.Close() // nolint: errcheck
|
||||||
|
|
||||||
|
|
@ -123,9 +144,7 @@ func UpdateDeviceByID(
|
||||||
return httputil.LogThenError(req, err)
|
return httputil.LogThenError(req, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := req.Context()
|
if err := deviceDB.UpdateDevice(ctx, localpart, deviceID, payload.DisplayName); err != nil {
|
||||||
|
|
||||||
if err := deviceDB.UpdateDevice(ctx, deviceID, payload.DisplayName); err != nil {
|
|
||||||
return httputil.LogThenError(req, err)
|
return httputil.LogThenError(req, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -38,9 +38,9 @@ type flow struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type passwordRequest struct {
|
type passwordRequest struct {
|
||||||
User string `json:"user"`
|
User string `json:"user"`
|
||||||
Password string `json:"password"`
|
Password string `json:"password"`
|
||||||
InitialDisplayName string `json:"initial_device_display_name"`
|
InitialDisplayName *string `json:"initial_device_display_name"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type loginResponse struct {
|
type loginResponse struct {
|
||||||
|
|
|
||||||
|
|
@ -61,7 +61,7 @@ type registerRequest struct {
|
||||||
// user-interactive auth params
|
// user-interactive auth params
|
||||||
Auth authDict `json:"auth"`
|
Auth authDict `json:"auth"`
|
||||||
|
|
||||||
InitialDisplayName string `json:"initial_device_display_name"`
|
InitialDisplayName *string `json:"initial_device_display_name"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type authDict struct {
|
type authDict struct {
|
||||||
|
|
@ -272,12 +272,10 @@ func LegacyRegister(
|
||||||
return util.MessageResponse(403, "HMAC incorrect")
|
return util.MessageResponse(403, "HMAC incorrect")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: does the legacy registration request support initial
|
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, nil)
|
||||||
// display name?
|
|
||||||
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "legacy registration")
|
|
||||||
case authtypes.LoginTypeDummy:
|
case authtypes.LoginTypeDummy:
|
||||||
// there is nothing to do
|
// there is nothing to do
|
||||||
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "legacy registration")
|
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, nil)
|
||||||
default:
|
default:
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: 501,
|
Code: 501,
|
||||||
|
|
@ -290,7 +288,8 @@ func completeRegistration(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
accountDB *accounts.Database,
|
accountDB *accounts.Database,
|
||||||
deviceDB *devices.Database,
|
deviceDB *devices.Database,
|
||||||
username, password, displayName string,
|
username, password string,
|
||||||
|
displayName *string,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
if username == "" {
|
if username == "" {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
|
|
|
||||||
|
|
@ -87,7 +87,7 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
device, err := deviceDB.CreateDevice(
|
device, err := deviceDB.CreateDevice(
|
||||||
context.Background(), *username, nil, *accessToken, "by-create-account",
|
context.Background(), *username, nil, *accessToken, nil,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err.Error())
|
fmt.Println(err.Error())
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue