Correctly create new device when device_id is passed to /login (#753)
Fixes https://github.com/matrix-org/dendrite/issues/401 Currently when passing a `device_id` parameter to `/login`, which is [supposed](https://matrix.org/docs/spec/client_server/unstable#post-matrix-client-r0-login) to return a device with that ID set, it instead just generates a random `device_id` and hands that back to you. The code was already there to do this correctly, it looks like it had just been broken during some change. Hopefully sytest will prevent this from becoming broken again.
This commit is contained in:
parent
bdd1a87d4d
commit
78032b3f4c
|
@ -169,6 +169,8 @@ func (s *devicesStatements) selectDeviceByToken(
|
||||||
return &dev, err
|
return &dev, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// selectDeviceByID retrieves a device from the database with the given user
|
||||||
|
// localpart and deviceID
|
||||||
func (s *devicesStatements) selectDeviceByID(
|
func (s *devicesStatements) selectDeviceByID(
|
||||||
ctx context.Context, localpart, deviceID string,
|
ctx context.Context, localpart, deviceID string,
|
||||||
) (*authtypes.Device, error) {
|
) (*authtypes.Device, error) {
|
||||||
|
|
|
@ -84,7 +84,7 @@ func (d *Database) CreateDevice(
|
||||||
if deviceID != nil {
|
if deviceID != nil {
|
||||||
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
var err error
|
var err error
|
||||||
// Revoke existing token for this device
|
// Revoke existing tokens for this device
|
||||||
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
|
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,7 +18,6 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth"
|
"github.com/matrix-org/dendrite/clientapi/auth"
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
@ -42,10 +41,12 @@ type flow struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type passwordRequest struct {
|
type passwordRequest struct {
|
||||||
User string `json:"user"`
|
User string `json:"user"`
|
||||||
Password string `json:"password"`
|
Password string `json:"password"`
|
||||||
|
// Both DeviceID and InitialDisplayName can be omitted, or empty strings ("")
|
||||||
|
// Thus a pointer is needed to differentiate between the two
|
||||||
InitialDisplayName *string `json:"initial_device_display_name"`
|
InitialDisplayName *string `json:"initial_device_display_name"`
|
||||||
DeviceID string `json:"device_id"`
|
DeviceID *string `json:"device_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type loginResponse struct {
|
type loginResponse struct {
|
||||||
|
@ -110,7 +111,7 @@ func Login(
|
||||||
return httputil.LogThenError(req, err)
|
return httputil.LogThenError(req, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dev, err := getDevice(req.Context(), r, deviceDB, acc, localpart, token)
|
dev, err := getDevice(req.Context(), r, deviceDB, acc, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusInternalServerError,
|
Code: http.StatusInternalServerError,
|
||||||
|
@ -134,20 +135,16 @@ func Login(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if device exists else create one
|
// getDevice returns a new or existing device
|
||||||
func getDevice(
|
func getDevice(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
r passwordRequest,
|
r passwordRequest,
|
||||||
deviceDB *devices.Database,
|
deviceDB *devices.Database,
|
||||||
acc *authtypes.Account,
|
acc *authtypes.Account,
|
||||||
localpart, token string,
|
token string,
|
||||||
) (dev *authtypes.Device, err error) {
|
) (dev *authtypes.Device, err error) {
|
||||||
dev, err = deviceDB.GetDeviceByID(ctx, localpart, r.DeviceID)
|
dev, err = deviceDB.CreateDevice(
|
||||||
if err == sql.ErrNoRows {
|
ctx, acc.Localpart, r.DeviceID, token, r.InitialDisplayName,
|
||||||
// device doesn't exist, create one
|
)
|
||||||
dev, err = deviceDB.CreateDevice(
|
|
||||||
ctx, acc.Localpart, nil, token, r.InitialDisplayName,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -121,7 +121,10 @@ type registerRequest struct {
|
||||||
// user-interactive auth params
|
// user-interactive auth params
|
||||||
Auth authDict `json:"auth"`
|
Auth authDict `json:"auth"`
|
||||||
|
|
||||||
|
// Both DeviceID and InitialDisplayName can be omitted, or empty strings ("")
|
||||||
|
// Thus a pointer is needed to differentiate between the two
|
||||||
InitialDisplayName *string `json:"initial_device_display_name"`
|
InitialDisplayName *string `json:"initial_device_display_name"`
|
||||||
|
DeviceID *string `json:"device_id"`
|
||||||
|
|
||||||
// Prevent this user from logging in
|
// Prevent this user from logging in
|
||||||
InhibitLogin common.WeakBoolean `json:"inhibit_login"`
|
InhibitLogin common.WeakBoolean `json:"inhibit_login"`
|
||||||
|
@ -626,7 +629,7 @@ func handleApplicationServiceRegistration(
|
||||||
// application service registration is entirely separate.
|
// application service registration is entirely separate.
|
||||||
return completeRegistration(
|
return completeRegistration(
|
||||||
req.Context(), accountDB, deviceDB, r.Username, "", appserviceID,
|
req.Context(), accountDB, deviceDB, r.Username, "", appserviceID,
|
||||||
r.InhibitLogin, r.InitialDisplayName,
|
r.InhibitLogin, r.InitialDisplayName, r.DeviceID,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -646,7 +649,7 @@ func checkAndCompleteFlow(
|
||||||
// This flow was completed, registration can continue
|
// This flow was completed, registration can continue
|
||||||
return completeRegistration(
|
return completeRegistration(
|
||||||
req.Context(), accountDB, deviceDB, r.Username, r.Password, "",
|
req.Context(), accountDB, deviceDB, r.Username, r.Password, "",
|
||||||
r.InhibitLogin, r.InitialDisplayName,
|
r.InhibitLogin, r.InitialDisplayName, r.DeviceID,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -697,10 +700,10 @@ func LegacyRegister(
|
||||||
return util.MessageResponse(http.StatusForbidden, "HMAC incorrect")
|
return util.MessageResponse(http.StatusForbidden, "HMAC incorrect")
|
||||||
}
|
}
|
||||||
|
|
||||||
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil)
|
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil, nil)
|
||||||
case authtypes.LoginTypeDummy:
|
case authtypes.LoginTypeDummy:
|
||||||
// there is nothing to do
|
// there is nothing to do
|
||||||
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil)
|
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil, nil)
|
||||||
default:
|
default:
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusNotImplemented,
|
Code: http.StatusNotImplemented,
|
||||||
|
@ -738,13 +741,19 @@ func parseAndValidateLegacyLogin(req *http.Request, r *legacyRegisterRequest) *u
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// completeRegistration runs some rudimentary checks against the submitted
|
||||||
|
// input, then if successful creates an account and a newly associated device
|
||||||
|
// We pass in each individual part of the request here instead of just passing a
|
||||||
|
// registerRequest, as this function serves requests encoded as both
|
||||||
|
// registerRequests and legacyRegisterRequests, which share some attributes but
|
||||||
|
// not all
|
||||||
func completeRegistration(
|
func completeRegistration(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
accountDB *accounts.Database,
|
accountDB *accounts.Database,
|
||||||
deviceDB *devices.Database,
|
deviceDB *devices.Database,
|
||||||
username, password, appserviceID string,
|
username, password, appserviceID string,
|
||||||
inhibitLogin common.WeakBoolean,
|
inhibitLogin common.WeakBoolean,
|
||||||
displayName *string,
|
displayName, deviceID *string,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
if username == "" {
|
if username == "" {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
|
@ -773,6 +782,9 @@ func completeRegistration(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Increment prometheus counter for created users
|
||||||
|
amtRegUsers.Inc()
|
||||||
|
|
||||||
// Check whether inhibit_login option is set. If so, don't create an access
|
// Check whether inhibit_login option is set. If so, don't create an access
|
||||||
// token or a device for this user
|
// token or a device for this user
|
||||||
if inhibitLogin {
|
if inhibitLogin {
|
||||||
|
@ -793,8 +805,7 @@ func completeRegistration(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Use the device ID in the request.
|
dev, err := deviceDB.CreateDevice(ctx, username, deviceID, token, displayName)
|
||||||
dev, err := deviceDB.CreateDevice(ctx, username, nil, token, displayName)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusInternalServerError,
|
Code: http.StatusInternalServerError,
|
||||||
|
@ -802,9 +813,6 @@ func completeRegistration(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Increment prometheus counter for created users
|
|
||||||
amtRegUsers.Inc()
|
|
||||||
|
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
JSON: registerResponse{
|
JSON: registerResponse{
|
||||||
|
|
2
testfile
2
testfile
|
@ -149,3 +149,5 @@ Typing events appear in incremental sync
|
||||||
Typing events appear in gapped sync
|
Typing events appear in gapped sync
|
||||||
Inbound federation of state requires event_id as a mandatory paramater
|
Inbound federation of state requires event_id as a mandatory paramater
|
||||||
Inbound federation of state_ids requires event_id as a mandatory paramater
|
Inbound federation of state_ids requires event_id as a mandatory paramater
|
||||||
|
POST /register returns the same device_id as that in the request
|
||||||
|
POST /login returns the same device_id as that in the request
|
||||||
|
|
Loading…
Reference in a new issue