diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 2a0db6aed..1172b9d74 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -332,8 +332,8 @@ func validateRecaptcha( return nil } -// validateToken returns an error response if the token is invalid -func validateToken( +// authenticateToken returns an error response if the token is invalid +func authenticateToken( req *http.Request, userAPI userapi.ClientUserAPI, cfg *config.ClientAPI, @@ -347,16 +347,25 @@ func validateToken( return ErrMissingToken } - exists, err := userAPI.ValidateRegistrationToken(req.Context(), token) + registrationToken, err := userAPI.ValidateRegistrationToken(req.Context(), token) if err != nil { return err } - if !exists { + if registrationToken == nil { return ErrInvalidToken } + // Decrease available uses + newAttributes := make(map[string]interface{}) + newAttributes["usesAllowed"] = *registrationToken.UsesAllowed - 1 + _, updateErr := userAPI.PerformAdminUpdateRegistrationToken(req.Context(), token, newAttributes) + + if updateErr != nil { + return updateErr + } + return nil } @@ -769,7 +778,7 @@ func handleRegistrationFlow( case authtypes.LoginTypeRegistrationToken: // Check given token response - err := validateToken(req, userAPI, cfg, r.Auth.Token) + err := authenticateToken(req, userAPI, cfg, r.Auth.Token) switch err { case ErrRegistrationTokenDisabled: return util.JSONResponse{Code: http.StatusForbidden, JSON: spec.Unknown(err.Error())} diff --git a/userapi/api/api.go b/userapi/api/api.go index 02e4f0d22..d8227ddfd 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -117,7 +117,7 @@ type ClientUserAPI interface { QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error PerformForgetThreePID(ctx context.Context, req *PerformForgetThreePIDRequest, res *struct{}) error PerformSaveThreePIDAssociation(ctx context.Context, req *PerformSaveThreePIDAssociationRequest, res *struct{}) error - ValidateRegistrationToken(ctx context.Context, registrationToken string) (bool, error) + ValidateRegistrationToken(ctx context.Context, registrationToken string) (*clientapi.RegistrationToken, error) } type KeyBackupAPI interface { diff --git a/userapi/internal/user_api.go b/userapi/internal/user_api.go index f8556c109..7c4604386 100644 --- a/userapi/internal/user_api.go +++ b/userapi/internal/user_api.go @@ -979,10 +979,15 @@ func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, re const pushRulesAccountDataType = "m.push_rules" -func (a *UserInternalAPI) ValidateRegistrationToken(ctx context.Context, registrationToken string) (bool, error) { - exists, err := a.DB.RegistrationTokenExists(ctx, registrationToken) +func (a *UserInternalAPI) ValidateRegistrationToken(ctx context.Context, token string) (*clientapi.RegistrationToken, error) { + registrationToken, err := a.DB.GetRegistrationToken(ctx, token) + if err != nil { - return false, err + return nil, err } - return exists, nil + if registrationToken == nil || *registrationToken.UsesAllowed == 0 || *registrationToken.ExpiryTime > int64(spec.AsTimestamp(time.Now())) { + return nil, nil + } + + return registrationToken, nil }