Remove util.JSONResponse from HTML API

This commit is contained in:
Till Faelligen 2022-12-20 13:50:50 +01:00
parent b801c13566
commit b852f70638
No known key found for this signature in database
GPG key ID: ACCDC9606D472758
5 changed files with 77 additions and 65 deletions

View file

@ -20,7 +20,6 @@ import (
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -102,30 +101,28 @@ func serveTemplate(w http.ResponseWriter, templateHTML string, data map[string]s
func AuthFallback( func AuthFallback(
w http.ResponseWriter, req *http.Request, authType string, w http.ResponseWriter, req *http.Request, authType string,
cfg *config.ClientAPI, cfg *config.ClientAPI,
) *util.JSONResponse { ) {
// We currently only support reCaptcha, so fail early if that's not requested // We currently only support "m.login.recaptcha", so fail early if that's not requested
if authType == authtypes.LoginTypeRecaptcha { if authType == authtypes.LoginTypeRecaptcha {
if !cfg.RecaptchaEnabled { if !cfg.RecaptchaEnabled {
return writeHTTPMessage(w, req, writeHTTPMessage(w, req,
"Recaptcha login is disabled on this Homeserver", "Recaptcha login is disabled on this Homeserver",
http.StatusBadRequest, http.StatusBadRequest,
) )
return
} }
} else { } else {
_ = writeHTTPMessage(w, req, fmt.Sprintf("Unknown authtype %q", authType), http.StatusNotImplemented) writeHTTPMessage(w, req, fmt.Sprintf("Unknown authtype %q", authType), http.StatusNotImplemented)
return &util.JSONResponse{ return
Code: http.StatusNotFound,
JSON: jsonerror.NotFound("Unknown auth stage type"),
}
} }
sessionID := req.URL.Query().Get("session") sessionID := req.URL.Query().Get("session")
if sessionID == "" { if sessionID == "" {
return writeHTTPMessage(w, req, writeHTTPMessage(w, req,
"Session ID not provided", "Session ID not provided",
http.StatusBadRequest, http.StatusBadRequest,
) )
return
} }
serveRecaptcha := func() { serveRecaptcha := func() {
@ -148,36 +145,43 @@ func AuthFallback(
if req.Method == http.MethodGet { if req.Method == http.MethodGet {
// Handle Recaptcha // Handle Recaptcha
serveRecaptcha() serveRecaptcha()
return nil return
} else if req.Method == http.MethodPost { } else if req.Method == http.MethodPost {
// Handle Recaptcha // Handle Recaptcha
clientIP := req.RemoteAddr clientIP := req.RemoteAddr
err := req.ParseForm() err := req.ParseForm()
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("req.ParseForm failed") util.GetLogger(req.Context()).WithError(err).Error("req.ParseForm failed")
res := jsonerror.InternalServerError() w.WriteHeader(http.StatusBadRequest)
return &res serveRecaptcha()
return
} }
response := req.Form.Get(cfg.RecaptchaFormField) response := req.Form.Get(cfg.RecaptchaFormField)
if err := validateRecaptcha(cfg, response, clientIP); err != nil { err = validateRecaptcha(cfg, response, clientIP)
util.GetLogger(req.Context()).Error(err) switch err {
w.WriteHeader(http.StatusUnauthorized) case ErrMissingResponse:
w.WriteHeader(http.StatusBadRequest)
serveRecaptcha() // serve the initial page again, instead of nothing serveRecaptcha() // serve the initial page again, instead of nothing
return err return
case ErrInvalidCaptcha:
w.WriteHeader(http.StatusUnauthorized)
serveRecaptcha()
return
case nil:
default: // something else failed
util.GetLogger(req.Context()).WithError(err).Error("failed to validate recaptcha")
serveRecaptcha()
return
} }
// Success. Add recaptcha as a completed login flow // Success. Add recaptcha as a completed login flow
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha) sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha)
serveSuccess() serveSuccess()
return nil return
}
_ = writeHTTPMessage(w, req, "Bad method", http.StatusMethodNotAllowed)
return &util.JSONResponse{
Code: http.StatusMethodNotAllowed,
JSON: jsonerror.NotFound("Bad method"),
} }
writeHTTPMessage(w, req, "Bad method", http.StatusMethodNotAllowed)
} }
// writeHTTPMessage writes the given header and message to the HTTP response writer. // writeHTTPMessage writes the given header and message to the HTTP response writer.
@ -185,13 +189,10 @@ func AuthFallback(
func writeHTTPMessage( func writeHTTPMessage(
w http.ResponseWriter, req *http.Request, w http.ResponseWriter, req *http.Request,
message string, header int, message string, header int,
) *util.JSONResponse { ) {
w.WriteHeader(header) w.WriteHeader(header)
_, err := w.Write([]byte(message)) _, err := w.Write([]byte(message))
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("w.Write failed") util.GetLogger(req.Context()).WithError(err).Error("w.Write failed")
res := jsonerror.InternalServerError()
return &res
} }
return nil
} }

View file

@ -128,4 +128,22 @@ func Test_AuthFallback(t *testing.T) {
t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest) t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest)
} }
}) })
t.Run("missing session parameter is handled correctly", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI)
if rec.Code != http.StatusBadRequest {
t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest)
}
})
t.Run("missing 'response' is handled correctly", func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/?session=1337", nil)
rec := httptest.NewRecorder()
AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI)
if rec.Code != http.StatusBadRequest {
t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest)
}
})
} }

View file

@ -18,6 +18,7 @@ package routing
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -331,25 +332,25 @@ func validatePassword(password string) *util.JSONResponse {
return nil return nil
} }
var (
ErrInvalidCaptcha = errors.New("invalid captcha response")
ErrMissingResponse = errors.New("captcha response is required")
ErrCaptchaDisabled = errors.New("captcha registration is disabled")
)
// validateRecaptcha returns an error response if the captcha response is invalid // validateRecaptcha returns an error response if the captcha response is invalid
func validateRecaptcha( func validateRecaptcha(
cfg *config.ClientAPI, cfg *config.ClientAPI,
response string, response string,
clientip string, clientip string,
) *util.JSONResponse { ) error {
ip, _, _ := net.SplitHostPort(clientip) ip, _, _ := net.SplitHostPort(clientip)
if !cfg.RecaptchaEnabled { if !cfg.RecaptchaEnabled {
return &util.JSONResponse{ return ErrCaptchaDisabled
Code: http.StatusConflict,
JSON: jsonerror.Unknown("Captcha registration is disabled"),
}
} }
if response == "" { if response == "" {
return &util.JSONResponse{ return ErrMissingResponse
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("Captcha response is required"),
}
} }
// Make a POST request to Google's API to check the captcha response // Make a POST request to Google's API to check the captcha response
@ -362,10 +363,7 @@ func validateRecaptcha(
) )
if err != nil { if err != nil {
return &util.JSONResponse{ return err
Code: http.StatusInternalServerError,
JSON: jsonerror.BadJSON("Error in requesting validation of captcha response"),
}
} }
// Close the request once we're finishing reading from it // Close the request once we're finishing reading from it
@ -375,25 +373,16 @@ func validateRecaptcha(
var r recaptchaResponse var r recaptchaResponse
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return &util.JSONResponse{ return err
Code: http.StatusGatewayTimeout,
JSON: jsonerror.Unknown("Error in contacting captcha server" + err.Error()),
}
} }
err = json.Unmarshal(body, &r) err = json.Unmarshal(body, &r)
if err != nil { if err != nil {
return &util.JSONResponse{ return err
Code: http.StatusInternalServerError,
JSON: jsonerror.BadJSON("Error in unmarshaling captcha server's response: " + err.Error()),
}
} }
// Check that we received a "success" // Check that we received a "success"
if !r.Success { if !r.Success {
return &util.JSONResponse{ return ErrInvalidCaptcha
Code: http.StatusUnauthorized,
JSON: jsonerror.BadJSON("Invalid captcha response. Please try again."),
}
} }
return nil return nil
} }
@ -777,9 +766,18 @@ func handleRegistrationFlow(
switch r.Auth.Type { switch r.Auth.Type {
case authtypes.LoginTypeRecaptcha: case authtypes.LoginTypeRecaptcha:
// Check given captcha response // Check given captcha response
resErr := validateRecaptcha(cfg, r.Auth.Response, req.RemoteAddr) err := validateRecaptcha(cfg, r.Auth.Response, req.RemoteAddr)
if resErr != nil { switch err {
return *resErr case ErrCaptchaDisabled:
return util.JSONResponse{Code: http.StatusForbidden, JSON: jsonerror.Unknown(err.Error())}
case ErrMissingResponse:
return util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(err.Error())}
case ErrInvalidCaptcha:
return util.JSONResponse{Code: http.StatusUnauthorized, JSON: jsonerror.BadJSON(err.Error())}
case nil:
default:
util.GetLogger(req.Context()).WithError(err).Error("failed to validate recaptcha")
return util.JSONResponse{Code: http.StatusInternalServerError, JSON: jsonerror.InternalServerError()}
} }
// Add Recaptcha to the list of completed registration stages // Add Recaptcha to the list of completed registration stages

View file

@ -631,9 +631,9 @@ func Setup(
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
v3mux.Handle("/auth/{authType}/fallback/web", v3mux.Handle("/auth/{authType}/fallback/web",
httputil.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) *util.JSONResponse { httputil.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) {
vars := mux.Vars(req) vars := mux.Vars(req)
return AuthFallback(w, req, vars["authType"], cfg) AuthFallback(w, req, vars["authType"], cfg)
}), }),
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)

View file

@ -198,17 +198,12 @@ func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse
// MakeHTMLAPI adds Span metrics to the HTML Handler function // MakeHTMLAPI adds Span metrics to the HTML Handler function
// This is used to serve HTML alongside JSON error messages // This is used to serve HTML alongside JSON error messages
func MakeHTMLAPI(metricsName string, f func(http.ResponseWriter, *http.Request) *util.JSONResponse) http.Handler { func MakeHTMLAPI(metricsName string, f func(http.ResponseWriter, *http.Request)) http.Handler {
withSpan := func(w http.ResponseWriter, req *http.Request) { withSpan := func(w http.ResponseWriter, req *http.Request) {
span := opentracing.StartSpan(metricsName) span := opentracing.StartSpan(metricsName)
defer span.Finish() defer span.Finish()
req = req.WithContext(opentracing.ContextWithSpan(req.Context(), span)) req = req.WithContext(opentracing.ContextWithSpan(req.Context(), span))
if err := f(w, req); err != nil { f(w, req)
h := util.MakeJSONAPI(util.NewJSONRequestHandler(func(req *http.Request) util.JSONResponse {
return *err
}))
h.ServeHTTP(w, req)
}
} }
return promhttp.InstrumentHandlerCounter( return promhttp.InstrumentHandlerCounter(