check available uses and expiry time, and update available uses

This commit is contained in:
Givralix 2024-07-14 12:23:14 +02:00
parent 618f3eaff6
commit 3c5943bef4
3 changed files with 24 additions and 10 deletions

View file

@ -332,8 +332,8 @@ func validateRecaptcha(
return nil return nil
} }
// validateToken returns an error response if the token is invalid // authenticateToken returns an error response if the token is invalid
func validateToken( func authenticateToken(
req *http.Request, req *http.Request,
userAPI userapi.ClientUserAPI, userAPI userapi.ClientUserAPI,
cfg *config.ClientAPI, cfg *config.ClientAPI,
@ -347,16 +347,25 @@ func validateToken(
return ErrMissingToken return ErrMissingToken
} }
exists, err := userAPI.ValidateRegistrationToken(req.Context(), token) registrationToken, err := userAPI.ValidateRegistrationToken(req.Context(), token)
if err != nil { if err != nil {
return err return err
} }
if !exists { if registrationToken == nil {
return ErrInvalidToken return ErrInvalidToken
} }
// Decrease available uses
newAttributes := make(map[string]interface{})
newAttributes["usesAllowed"] = *registrationToken.UsesAllowed - 1
_, updateErr := userAPI.PerformAdminUpdateRegistrationToken(req.Context(), token, newAttributes)
if updateErr != nil {
return updateErr
}
return nil return nil
} }
@ -769,7 +778,7 @@ func handleRegistrationFlow(
case authtypes.LoginTypeRegistrationToken: case authtypes.LoginTypeRegistrationToken:
// Check given token response // Check given token response
err := validateToken(req, userAPI, cfg, r.Auth.Token) err := authenticateToken(req, userAPI, cfg, r.Auth.Token)
switch err { switch err {
case ErrRegistrationTokenDisabled: case ErrRegistrationTokenDisabled:
return util.JSONResponse{Code: http.StatusForbidden, JSON: spec.Unknown(err.Error())} return util.JSONResponse{Code: http.StatusForbidden, JSON: spec.Unknown(err.Error())}

View file

@ -117,7 +117,7 @@ type ClientUserAPI interface {
QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error
PerformForgetThreePID(ctx context.Context, req *PerformForgetThreePIDRequest, res *struct{}) error PerformForgetThreePID(ctx context.Context, req *PerformForgetThreePIDRequest, res *struct{}) error
PerformSaveThreePIDAssociation(ctx context.Context, req *PerformSaveThreePIDAssociationRequest, res *struct{}) error PerformSaveThreePIDAssociation(ctx context.Context, req *PerformSaveThreePIDAssociationRequest, res *struct{}) error
ValidateRegistrationToken(ctx context.Context, registrationToken string) (bool, error) ValidateRegistrationToken(ctx context.Context, registrationToken string) (*clientapi.RegistrationToken, error)
} }
type KeyBackupAPI interface { type KeyBackupAPI interface {

View file

@ -979,10 +979,15 @@ func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, re
const pushRulesAccountDataType = "m.push_rules" const pushRulesAccountDataType = "m.push_rules"
func (a *UserInternalAPI) ValidateRegistrationToken(ctx context.Context, registrationToken string) (bool, error) { func (a *UserInternalAPI) ValidateRegistrationToken(ctx context.Context, token string) (*clientapi.RegistrationToken, error) {
exists, err := a.DB.RegistrationTokenExists(ctx, registrationToken) registrationToken, err := a.DB.GetRegistrationToken(ctx, token)
if err != nil { if err != nil {
return false, err return nil, err
} }
return exists, nil if registrationToken == nil || *registrationToken.UsesAllowed == 0 || *registrationToken.ExpiryTime > int64(spec.AsTimestamp(time.Now())) {
return nil, nil
}
return registrationToken, nil
} }