diff --git a/clientapi/routing/auth_fallback.go b/clientapi/routing/auth_fallback.go index ae97b242b..f8d3684fe 100644 --- a/clientapi/routing/auth_fallback.go +++ b/clientapi/routing/auth_fallback.go @@ -20,7 +20,6 @@ import ( "net/http" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/util" ) @@ -102,30 +101,28 @@ func serveTemplate(w http.ResponseWriter, templateHTML string, data map[string]s func AuthFallback( w http.ResponseWriter, req *http.Request, authType string, cfg *config.ClientAPI, -) *util.JSONResponse { - // We currently only support reCaptcha, so fail early if that's not requested +) { + // We currently only support "m.login.recaptcha", so fail early if that's not requested if authType == authtypes.LoginTypeRecaptcha { if !cfg.RecaptchaEnabled { - return writeHTTPMessage(w, req, + writeHTTPMessage(w, req, "Recaptcha login is disabled on this Homeserver", http.StatusBadRequest, ) + return } } else { - _ = writeHTTPMessage(w, req, fmt.Sprintf("Unknown authtype %q", authType), http.StatusNotImplemented) - return &util.JSONResponse{ - Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Unknown auth stage type"), - } + writeHTTPMessage(w, req, fmt.Sprintf("Unknown authtype %q", authType), http.StatusNotImplemented) + return } sessionID := req.URL.Query().Get("session") - if sessionID == "" { - return writeHTTPMessage(w, req, + writeHTTPMessage(w, req, "Session ID not provided", http.StatusBadRequest, ) + return } serveRecaptcha := func() { @@ -148,36 +145,43 @@ func AuthFallback( if req.Method == http.MethodGet { // Handle Recaptcha serveRecaptcha() - return nil + return } else if req.Method == http.MethodPost { // Handle Recaptcha clientIP := req.RemoteAddr err := req.ParseForm() if err != nil { util.GetLogger(req.Context()).WithError(err).Error("req.ParseForm failed") - res := jsonerror.InternalServerError() - return &res + w.WriteHeader(http.StatusBadRequest) + serveRecaptcha() + return } response := req.Form.Get(cfg.RecaptchaFormField) - if err := validateRecaptcha(cfg, response, clientIP); err != nil { - util.GetLogger(req.Context()).Error(err) - w.WriteHeader(http.StatusUnauthorized) + err = validateRecaptcha(cfg, response, clientIP) + switch err { + case ErrMissingResponse: + w.WriteHeader(http.StatusBadRequest) serveRecaptcha() // serve the initial page again, instead of nothing - return err + return + case ErrInvalidCaptcha: + w.WriteHeader(http.StatusUnauthorized) + serveRecaptcha() + return + case nil: + default: // something else failed + util.GetLogger(req.Context()).WithError(err).Error("failed to validate recaptcha") + serveRecaptcha() + return } // Success. Add recaptcha as a completed login flow sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha) serveSuccess() - return nil - } - _ = writeHTTPMessage(w, req, "Bad method", http.StatusMethodNotAllowed) - return &util.JSONResponse{ - Code: http.StatusMethodNotAllowed, - JSON: jsonerror.NotFound("Bad method"), + return } + writeHTTPMessage(w, req, "Bad method", http.StatusMethodNotAllowed) } // writeHTTPMessage writes the given header and message to the HTTP response writer. @@ -185,13 +189,10 @@ func AuthFallback( func writeHTTPMessage( w http.ResponseWriter, req *http.Request, message string, header int, -) *util.JSONResponse { +) { w.WriteHeader(header) _, err := w.Write([]byte(message)) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("w.Write failed") - res := jsonerror.InternalServerError() - return &res } - return nil } diff --git a/clientapi/routing/auth_fallback_test.go b/clientapi/routing/auth_fallback_test.go index 637119dc6..0d77f9a01 100644 --- a/clientapi/routing/auth_fallback_test.go +++ b/clientapi/routing/auth_fallback_test.go @@ -128,4 +128,22 @@ func Test_AuthFallback(t *testing.T) { t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest) } }) + + t.Run("missing session parameter is handled correctly", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI) + if rec.Code != http.StatusBadRequest { + t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest) + } + }) + + t.Run("missing 'response' is handled correctly", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/?session=1337", nil) + rec := httptest.NewRecorder() + AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI) + if rec.Code != http.StatusBadRequest { + t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest) + } + }) } diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 7821f2c49..50debebab 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -18,6 +18,7 @@ package routing import ( "context" "encoding/json" + "errors" "fmt" "io" "net" @@ -331,25 +332,25 @@ func validatePassword(password string) *util.JSONResponse { return nil } +var ( + ErrInvalidCaptcha = errors.New("invalid captcha response") + ErrMissingResponse = errors.New("captcha response is required") + ErrCaptchaDisabled = errors.New("captcha registration is disabled") +) + // validateRecaptcha returns an error response if the captcha response is invalid func validateRecaptcha( cfg *config.ClientAPI, response string, clientip string, -) *util.JSONResponse { +) error { ip, _, _ := net.SplitHostPort(clientip) if !cfg.RecaptchaEnabled { - return &util.JSONResponse{ - Code: http.StatusConflict, - JSON: jsonerror.Unknown("Captcha registration is disabled"), - } + return ErrCaptchaDisabled } if response == "" { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Captcha response is required"), - } + return ErrMissingResponse } // Make a POST request to Google's API to check the captcha response @@ -362,10 +363,7 @@ func validateRecaptcha( ) if err != nil { - return &util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: jsonerror.BadJSON("Error in requesting validation of captcha response"), - } + return err } // Close the request once we're finishing reading from it @@ -375,25 +373,16 @@ func validateRecaptcha( var r recaptchaResponse body, err := io.ReadAll(resp.Body) if err != nil { - return &util.JSONResponse{ - Code: http.StatusGatewayTimeout, - JSON: jsonerror.Unknown("Error in contacting captcha server" + err.Error()), - } + return err } err = json.Unmarshal(body, &r) if err != nil { - return &util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: jsonerror.BadJSON("Error in unmarshaling captcha server's response: " + err.Error()), - } + return err } // Check that we received a "success" if !r.Success { - return &util.JSONResponse{ - Code: http.StatusUnauthorized, - JSON: jsonerror.BadJSON("Invalid captcha response. Please try again."), - } + return ErrInvalidCaptcha } return nil } @@ -777,9 +766,18 @@ func handleRegistrationFlow( switch r.Auth.Type { case authtypes.LoginTypeRecaptcha: // Check given captcha response - resErr := validateRecaptcha(cfg, r.Auth.Response, req.RemoteAddr) - if resErr != nil { - return *resErr + err := validateRecaptcha(cfg, r.Auth.Response, req.RemoteAddr) + switch err { + case ErrCaptchaDisabled: + return util.JSONResponse{Code: http.StatusForbidden, JSON: jsonerror.Unknown(err.Error())} + case ErrMissingResponse: + return util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(err.Error())} + case ErrInvalidCaptcha: + return util.JSONResponse{Code: http.StatusUnauthorized, JSON: jsonerror.BadJSON(err.Error())} + case nil: + default: + util.GetLogger(req.Context()).WithError(err).Error("failed to validate recaptcha") + return util.JSONResponse{Code: http.StatusInternalServerError, JSON: jsonerror.InternalServerError()} } // Add Recaptcha to the list of completed registration stages diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index a510761eb..10c8a75a3 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -631,9 +631,9 @@ func Setup( ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) v3mux.Handle("/auth/{authType}/fallback/web", - httputil.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) *util.JSONResponse { + httputil.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) { vars := mux.Vars(req) - return AuthFallback(w, req, vars["authType"], cfg) + AuthFallback(w, req, vars["authType"], cfg) }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) diff --git a/internal/httputil/httpapi.go b/internal/httputil/httpapi.go index 127d1fac7..da5abff71 100644 --- a/internal/httputil/httpapi.go +++ b/internal/httputil/httpapi.go @@ -198,17 +198,12 @@ func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse // MakeHTMLAPI adds Span metrics to the HTML Handler function // This is used to serve HTML alongside JSON error messages -func MakeHTMLAPI(metricsName string, f func(http.ResponseWriter, *http.Request) *util.JSONResponse) http.Handler { +func MakeHTMLAPI(metricsName string, f func(http.ResponseWriter, *http.Request)) http.Handler { withSpan := func(w http.ResponseWriter, req *http.Request) { span := opentracing.StartSpan(metricsName) defer span.Finish() req = req.WithContext(opentracing.ContextWithSpan(req.Context(), span)) - if err := f(w, req); err != nil { - h := util.MakeJSONAPI(util.NewJSONRequestHandler(func(req *http.Request) util.JSONResponse { - return *err - })) - h.ServeHTTP(w, req) - } + f(w, req) } return promhttp.InstrumentHandlerCounter(