Fix device_id on registration as well

This commit is contained in:
Andrew Morgan 2019-07-19 15:58:02 +01:00
parent e7bd091e9b
commit 0da8d365bf
2 changed files with 19 additions and 11 deletions

View file

@ -44,7 +44,7 @@ type passwordRequest struct {
User string `json:"user"` User string `json:"user"`
Password string `json:"password"` Password string `json:"password"`
// Both DeviceID and InitialDisplayName can be omitted, or empty strings ("") // Both DeviceID and InitialDisplayName can be omitted, or empty strings ("")
// Thus a pointer is needed to differentiate between them two // Thus a pointer is needed to differentiate between the two
InitialDisplayName *string `json:"initial_device_display_name"` InitialDisplayName *string `json:"initial_device_display_name"`
DeviceID *string `json:"device_id"` DeviceID *string `json:"device_id"`
} }

View file

@ -121,7 +121,10 @@ type registerRequest struct {
// user-interactive auth params // user-interactive auth params
Auth authDict `json:"auth"` Auth authDict `json:"auth"`
// Both DeviceID and InitialDisplayName can be omitted, or empty strings ("")
// Thus a pointer is needed to differentiate between the two
InitialDisplayName *string `json:"initial_device_display_name"` InitialDisplayName *string `json:"initial_device_display_name"`
DeviceID *string `json:"device_id"`
// Prevent this user from logging in // Prevent this user from logging in
InhibitLogin common.WeakBoolean `json:"inhibit_login"` InhibitLogin common.WeakBoolean `json:"inhibit_login"`
@ -626,7 +629,7 @@ func handleApplicationServiceRegistration(
// application service registration is entirely separate. // application service registration is entirely separate.
return completeRegistration( return completeRegistration(
req.Context(), accountDB, deviceDB, r.Username, "", appserviceID, req.Context(), accountDB, deviceDB, r.Username, "", appserviceID,
r.InhibitLogin, r.InitialDisplayName, r.InhibitLogin, r.InitialDisplayName, r.DeviceID,
) )
} }
@ -646,7 +649,7 @@ func checkAndCompleteFlow(
// This flow was completed, registration can continue // This flow was completed, registration can continue
return completeRegistration( return completeRegistration(
req.Context(), accountDB, deviceDB, r.Username, r.Password, "", req.Context(), accountDB, deviceDB, r.Username, r.Password, "",
r.InhibitLogin, r.InitialDisplayName, r.InhibitLogin, r.InitialDisplayName, r.DeviceID,
) )
} }
@ -697,10 +700,10 @@ func LegacyRegister(
return util.MessageResponse(http.StatusForbidden, "HMAC incorrect") return util.MessageResponse(http.StatusForbidden, "HMAC incorrect")
} }
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil) return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil, nil)
case authtypes.LoginTypeDummy: case authtypes.LoginTypeDummy:
// there is nothing to do // there is nothing to do
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil) return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil, nil)
default: default:
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusNotImplemented, Code: http.StatusNotImplemented,
@ -738,13 +741,19 @@ func parseAndValidateLegacyLogin(req *http.Request, r *legacyRegisterRequest) *u
return nil return nil
} }
// completeRegistration runs some rudimentary checks against the submitted
// input, then if successful creates an account and a newly associated device
// We pass in each individual part of the request here instead of just passing a
// registerRequest, as this function serves requests encoded as both
// registerRequests and legacyRegisterRequests, which share some attributes but
// not all
func completeRegistration( func completeRegistration(
ctx context.Context, ctx context.Context,
accountDB *accounts.Database, accountDB *accounts.Database,
deviceDB *devices.Database, deviceDB *devices.Database,
username, password, appserviceID string, username, password, appserviceID string,
inhibitLogin common.WeakBoolean, inhibitLogin common.WeakBoolean,
displayName *string, displayName, deviceID *string,
) util.JSONResponse { ) util.JSONResponse {
if username == "" { if username == "" {
return util.JSONResponse{ return util.JSONResponse{
@ -773,6 +782,9 @@ func completeRegistration(
} }
} }
// Increment prometheus counter for created users
amtRegUsers.Inc()
// Check whether inhibit_login option is set. If so, don't create an access // Check whether inhibit_login option is set. If so, don't create an access
// token or a device for this user // token or a device for this user
if inhibitLogin { if inhibitLogin {
@ -793,8 +805,7 @@ func completeRegistration(
} }
} }
// TODO: Use the device ID in the request. dev, err := deviceDB.CreateDevice(ctx, username, deviceID, token, displayName)
dev, err := deviceDB.CreateDevice(ctx, username, nil, token, displayName)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
@ -802,9 +813,6 @@ func completeRegistration(
} }
} }
// Increment prometheus counter for created users
amtRegUsers.Inc()
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: registerResponse{ JSON: registerResponse{