diff --git a/clientapi/auth/authtypes/logintypes.go b/clientapi/auth/authtypes/logintypes.go index f01e48f80..6e08d9735 100644 --- a/clientapi/auth/authtypes/logintypes.go +++ b/clientapi/auth/authtypes/logintypes.go @@ -11,4 +11,5 @@ const ( LoginTypeRecaptcha = "m.login.recaptcha" LoginTypeApplicationService = "m.login.application_service" LoginTypeToken = "m.login.token" + LoginTypeRegistrationToken = "m.login.registration_token" ) diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 5235e9092..2a0db6aed 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -233,6 +233,9 @@ type authDict struct { // Recaptcha Response string `json:"response"` + + // Registration token + Token string `json:"token"` // TODO: Lots of custom keys depending on the type } @@ -272,9 +275,12 @@ type recaptchaResponse struct { } var ( - ErrInvalidCaptcha = errors.New("invalid captcha response") - ErrMissingResponse = errors.New("captcha response is required") - ErrCaptchaDisabled = errors.New("captcha registration is disabled") + ErrInvalidCaptcha = errors.New("invalid captcha response") + ErrMissingResponse = errors.New("captcha response is required") + ErrCaptchaDisabled = errors.New("captcha registration is disabled") + ErrRegistrationTokenDisabled = errors.New("token registration is disabled") + ErrMissingToken = errors.New("registration token is required") + ErrInvalidToken = errors.New("invalid registration token") ) // validateRecaptcha returns an error response if the captcha response is invalid @@ -326,6 +332,34 @@ func validateRecaptcha( return nil } +// validateToken returns an error response if the token is invalid +func validateToken( + req *http.Request, + userAPI userapi.ClientUserAPI, + cfg *config.ClientAPI, + token string, +) error { + if !cfg.RegistrationRequiresToken { + return ErrRegistrationTokenDisabled + } + + if token == "" { + return ErrMissingToken + } + + exists, err := userAPI.ValidateRegistrationToken(req.Context(), token) + + if err != nil { + return err + } + + if !exists { + return ErrInvalidToken + } + + return nil +} + // UserIDIsWithinApplicationServiceNamespace checks to see if a given userID // falls within any of the namespaces of a given Application Service. If no // Application Service is given, it will check to see if it matches any @@ -733,6 +767,25 @@ func handleRegistrationFlow( // Add Recaptcha to the list of completed registration stages sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha) + case authtypes.LoginTypeRegistrationToken: + // Check given token response + err := validateToken(req, userAPI, cfg, r.Auth.Token) + switch err { + case ErrRegistrationTokenDisabled: + return util.JSONResponse{Code: http.StatusForbidden, JSON: spec.Unknown(err.Error())} + case ErrMissingToken: + return util.JSONResponse{Code: http.StatusBadRequest, JSON: spec.BadJSON(err.Error())} + case ErrInvalidToken: + return util.JSONResponse{Code: http.StatusUnauthorized, JSON: spec.BadJSON(err.Error())} + case nil: + default: + util.GetLogger(req.Context()).WithError(err).Error("failed to validate token") + return util.JSONResponse{Code: http.StatusInternalServerError, JSON: spec.InternalServerError{}} + } + + // Add RegistrationToken to the list of completed registration stages + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRegistrationToken) + case authtypes.LoginTypeDummy: // there is nothing to do // Add Dummy to the list of completed registration stages diff --git a/setup/config/config.go b/setup/config/config.go index 41396ae36..9daa6b119 100644 --- a/setup/config/config.go +++ b/setup/config/config.go @@ -291,6 +291,10 @@ func (config *Dendrite) Derive() error { config.Derived.Registration.Flows = []authtypes.Flow{ {Stages: []authtypes.LoginType{authtypes.LoginTypeRecaptcha}}, } + } else if config.ClientAPI.RegistrationRequiresToken { + config.Derived.Registration.Flows = []authtypes.Flow{ + {Stages: []authtypes.LoginType{authtypes.LoginTypeRegistrationToken}}, + } } else { config.Derived.Registration.Flows = []authtypes.Flow{ {Stages: []authtypes.LoginType{authtypes.LoginTypeDummy}}, diff --git a/userapi/api/api.go b/userapi/api/api.go index d4daec820..02e4f0d22 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -117,6 +117,7 @@ type ClientUserAPI interface { QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error PerformForgetThreePID(ctx context.Context, req *PerformForgetThreePIDRequest, res *struct{}) error PerformSaveThreePIDAssociation(ctx context.Context, req *PerformSaveThreePIDAssociationRequest, res *struct{}) error + ValidateRegistrationToken(ctx context.Context, registrationToken string) (bool, error) } type KeyBackupAPI interface { diff --git a/userapi/internal/user_api.go b/userapi/internal/user_api.go index a126dc871..f8556c109 100644 --- a/userapi/internal/user_api.go +++ b/userapi/internal/user_api.go @@ -978,3 +978,11 @@ func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, re } const pushRulesAccountDataType = "m.push_rules" + +func (a *UserInternalAPI) ValidateRegistrationToken(ctx context.Context, registrationToken string) (bool, error) { + exists, err := a.DB.RegistrationTokenExists(ctx, registrationToken) + if err != nil { + return false, err + } + return exists, nil +}