mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-22 13:33:09 -06:00
Remove util.JSONResponse from HTML API
This commit is contained in:
parent
b801c13566
commit
b852f70638
|
|
@ -20,7 +20,6 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
)
|
)
|
||||||
|
|
@ -102,30 +101,28 @@ func serveTemplate(w http.ResponseWriter, templateHTML string, data map[string]s
|
||||||
func AuthFallback(
|
func AuthFallback(
|
||||||
w http.ResponseWriter, req *http.Request, authType string,
|
w http.ResponseWriter, req *http.Request, authType string,
|
||||||
cfg *config.ClientAPI,
|
cfg *config.ClientAPI,
|
||||||
) *util.JSONResponse {
|
) {
|
||||||
// We currently only support reCaptcha, so fail early if that's not requested
|
// We currently only support "m.login.recaptcha", so fail early if that's not requested
|
||||||
if authType == authtypes.LoginTypeRecaptcha {
|
if authType == authtypes.LoginTypeRecaptcha {
|
||||||
if !cfg.RecaptchaEnabled {
|
if !cfg.RecaptchaEnabled {
|
||||||
return writeHTTPMessage(w, req,
|
writeHTTPMessage(w, req,
|
||||||
"Recaptcha login is disabled on this Homeserver",
|
"Recaptcha login is disabled on this Homeserver",
|
||||||
http.StatusBadRequest,
|
http.StatusBadRequest,
|
||||||
)
|
)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
_ = writeHTTPMessage(w, req, fmt.Sprintf("Unknown authtype %q", authType), http.StatusNotImplemented)
|
writeHTTPMessage(w, req, fmt.Sprintf("Unknown authtype %q", authType), http.StatusNotImplemented)
|
||||||
return &util.JSONResponse{
|
return
|
||||||
Code: http.StatusNotFound,
|
|
||||||
JSON: jsonerror.NotFound("Unknown auth stage type"),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionID := req.URL.Query().Get("session")
|
sessionID := req.URL.Query().Get("session")
|
||||||
|
|
||||||
if sessionID == "" {
|
if sessionID == "" {
|
||||||
return writeHTTPMessage(w, req,
|
writeHTTPMessage(w, req,
|
||||||
"Session ID not provided",
|
"Session ID not provided",
|
||||||
http.StatusBadRequest,
|
http.StatusBadRequest,
|
||||||
)
|
)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
serveRecaptcha := func() {
|
serveRecaptcha := func() {
|
||||||
|
|
@ -148,36 +145,43 @@ func AuthFallback(
|
||||||
if req.Method == http.MethodGet {
|
if req.Method == http.MethodGet {
|
||||||
// Handle Recaptcha
|
// Handle Recaptcha
|
||||||
serveRecaptcha()
|
serveRecaptcha()
|
||||||
return nil
|
return
|
||||||
} else if req.Method == http.MethodPost {
|
} else if req.Method == http.MethodPost {
|
||||||
// Handle Recaptcha
|
// Handle Recaptcha
|
||||||
clientIP := req.RemoteAddr
|
clientIP := req.RemoteAddr
|
||||||
err := req.ParseForm()
|
err := req.ParseForm()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("req.ParseForm failed")
|
util.GetLogger(req.Context()).WithError(err).Error("req.ParseForm failed")
|
||||||
res := jsonerror.InternalServerError()
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
return &res
|
serveRecaptcha()
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response := req.Form.Get(cfg.RecaptchaFormField)
|
response := req.Form.Get(cfg.RecaptchaFormField)
|
||||||
if err := validateRecaptcha(cfg, response, clientIP); err != nil {
|
err = validateRecaptcha(cfg, response, clientIP)
|
||||||
util.GetLogger(req.Context()).Error(err)
|
switch err {
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
case ErrMissingResponse:
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
serveRecaptcha() // serve the initial page again, instead of nothing
|
serveRecaptcha() // serve the initial page again, instead of nothing
|
||||||
return err
|
return
|
||||||
|
case ErrInvalidCaptcha:
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
serveRecaptcha()
|
||||||
|
return
|
||||||
|
case nil:
|
||||||
|
default: // something else failed
|
||||||
|
util.GetLogger(req.Context()).WithError(err).Error("failed to validate recaptcha")
|
||||||
|
serveRecaptcha()
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Success. Add recaptcha as a completed login flow
|
// Success. Add recaptcha as a completed login flow
|
||||||
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha)
|
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha)
|
||||||
|
|
||||||
serveSuccess()
|
serveSuccess()
|
||||||
return nil
|
return
|
||||||
}
|
|
||||||
_ = writeHTTPMessage(w, req, "Bad method", http.StatusMethodNotAllowed)
|
|
||||||
return &util.JSONResponse{
|
|
||||||
Code: http.StatusMethodNotAllowed,
|
|
||||||
JSON: jsonerror.NotFound("Bad method"),
|
|
||||||
}
|
}
|
||||||
|
writeHTTPMessage(w, req, "Bad method", http.StatusMethodNotAllowed)
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeHTTPMessage writes the given header and message to the HTTP response writer.
|
// writeHTTPMessage writes the given header and message to the HTTP response writer.
|
||||||
|
|
@ -185,13 +189,10 @@ func AuthFallback(
|
||||||
func writeHTTPMessage(
|
func writeHTTPMessage(
|
||||||
w http.ResponseWriter, req *http.Request,
|
w http.ResponseWriter, req *http.Request,
|
||||||
message string, header int,
|
message string, header int,
|
||||||
) *util.JSONResponse {
|
) {
|
||||||
w.WriteHeader(header)
|
w.WriteHeader(header)
|
||||||
_, err := w.Write([]byte(message))
|
_, err := w.Write([]byte(message))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("w.Write failed")
|
util.GetLogger(req.Context()).WithError(err).Error("w.Write failed")
|
||||||
res := jsonerror.InternalServerError()
|
|
||||||
return &res
|
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -128,4 +128,22 @@ func Test_AuthFallback(t *testing.T) {
|
||||||
t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest)
|
t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("missing session parameter is handled correctly", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI)
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing 'response' is handled correctly", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/?session=1337", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI)
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ package routing
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
|
@ -331,25 +332,25 @@ func validatePassword(password string) *util.JSONResponse {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrInvalidCaptcha = errors.New("invalid captcha response")
|
||||||
|
ErrMissingResponse = errors.New("captcha response is required")
|
||||||
|
ErrCaptchaDisabled = errors.New("captcha registration is disabled")
|
||||||
|
)
|
||||||
|
|
||||||
// validateRecaptcha returns an error response if the captcha response is invalid
|
// validateRecaptcha returns an error response if the captcha response is invalid
|
||||||
func validateRecaptcha(
|
func validateRecaptcha(
|
||||||
cfg *config.ClientAPI,
|
cfg *config.ClientAPI,
|
||||||
response string,
|
response string,
|
||||||
clientip string,
|
clientip string,
|
||||||
) *util.JSONResponse {
|
) error {
|
||||||
ip, _, _ := net.SplitHostPort(clientip)
|
ip, _, _ := net.SplitHostPort(clientip)
|
||||||
if !cfg.RecaptchaEnabled {
|
if !cfg.RecaptchaEnabled {
|
||||||
return &util.JSONResponse{
|
return ErrCaptchaDisabled
|
||||||
Code: http.StatusConflict,
|
|
||||||
JSON: jsonerror.Unknown("Captcha registration is disabled"),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if response == "" {
|
if response == "" {
|
||||||
return &util.JSONResponse{
|
return ErrMissingResponse
|
||||||
Code: http.StatusBadRequest,
|
|
||||||
JSON: jsonerror.BadJSON("Captcha response is required"),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make a POST request to Google's API to check the captcha response
|
// Make a POST request to Google's API to check the captcha response
|
||||||
|
|
@ -362,10 +363,7 @@ func validateRecaptcha(
|
||||||
)
|
)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &util.JSONResponse{
|
return err
|
||||||
Code: http.StatusInternalServerError,
|
|
||||||
JSON: jsonerror.BadJSON("Error in requesting validation of captcha response"),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close the request once we're finishing reading from it
|
// Close the request once we're finishing reading from it
|
||||||
|
|
@ -375,25 +373,16 @@ func validateRecaptcha(
|
||||||
var r recaptchaResponse
|
var r recaptchaResponse
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &util.JSONResponse{
|
return err
|
||||||
Code: http.StatusGatewayTimeout,
|
|
||||||
JSON: jsonerror.Unknown("Error in contacting captcha server" + err.Error()),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
err = json.Unmarshal(body, &r)
|
err = json.Unmarshal(body, &r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &util.JSONResponse{
|
return err
|
||||||
Code: http.StatusInternalServerError,
|
|
||||||
JSON: jsonerror.BadJSON("Error in unmarshaling captcha server's response: " + err.Error()),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check that we received a "success"
|
// Check that we received a "success"
|
||||||
if !r.Success {
|
if !r.Success {
|
||||||
return &util.JSONResponse{
|
return ErrInvalidCaptcha
|
||||||
Code: http.StatusUnauthorized,
|
|
||||||
JSON: jsonerror.BadJSON("Invalid captcha response. Please try again."),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
@ -777,9 +766,18 @@ func handleRegistrationFlow(
|
||||||
switch r.Auth.Type {
|
switch r.Auth.Type {
|
||||||
case authtypes.LoginTypeRecaptcha:
|
case authtypes.LoginTypeRecaptcha:
|
||||||
// Check given captcha response
|
// Check given captcha response
|
||||||
resErr := validateRecaptcha(cfg, r.Auth.Response, req.RemoteAddr)
|
err := validateRecaptcha(cfg, r.Auth.Response, req.RemoteAddr)
|
||||||
if resErr != nil {
|
switch err {
|
||||||
return *resErr
|
case ErrCaptchaDisabled:
|
||||||
|
return util.JSONResponse{Code: http.StatusForbidden, JSON: jsonerror.Unknown(err.Error())}
|
||||||
|
case ErrMissingResponse:
|
||||||
|
return util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(err.Error())}
|
||||||
|
case ErrInvalidCaptcha:
|
||||||
|
return util.JSONResponse{Code: http.StatusUnauthorized, JSON: jsonerror.BadJSON(err.Error())}
|
||||||
|
case nil:
|
||||||
|
default:
|
||||||
|
util.GetLogger(req.Context()).WithError(err).Error("failed to validate recaptcha")
|
||||||
|
return util.JSONResponse{Code: http.StatusInternalServerError, JSON: jsonerror.InternalServerError()}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add Recaptcha to the list of completed registration stages
|
// Add Recaptcha to the list of completed registration stages
|
||||||
|
|
|
||||||
|
|
@ -631,9 +631,9 @@ func Setup(
|
||||||
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
v3mux.Handle("/auth/{authType}/fallback/web",
|
v3mux.Handle("/auth/{authType}/fallback/web",
|
||||||
httputil.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) *util.JSONResponse {
|
httputil.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) {
|
||||||
vars := mux.Vars(req)
|
vars := mux.Vars(req)
|
||||||
return AuthFallback(w, req, vars["authType"], cfg)
|
AuthFallback(w, req, vars["authType"], cfg)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -198,17 +198,12 @@ func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse
|
||||||
|
|
||||||
// MakeHTMLAPI adds Span metrics to the HTML Handler function
|
// MakeHTMLAPI adds Span metrics to the HTML Handler function
|
||||||
// This is used to serve HTML alongside JSON error messages
|
// This is used to serve HTML alongside JSON error messages
|
||||||
func MakeHTMLAPI(metricsName string, f func(http.ResponseWriter, *http.Request) *util.JSONResponse) http.Handler {
|
func MakeHTMLAPI(metricsName string, f func(http.ResponseWriter, *http.Request)) http.Handler {
|
||||||
withSpan := func(w http.ResponseWriter, req *http.Request) {
|
withSpan := func(w http.ResponseWriter, req *http.Request) {
|
||||||
span := opentracing.StartSpan(metricsName)
|
span := opentracing.StartSpan(metricsName)
|
||||||
defer span.Finish()
|
defer span.Finish()
|
||||||
req = req.WithContext(opentracing.ContextWithSpan(req.Context(), span))
|
req = req.WithContext(opentracing.ContextWithSpan(req.Context(), span))
|
||||||
if err := f(w, req); err != nil {
|
f(w, req)
|
||||||
h := util.MakeJSONAPI(util.NewJSONRequestHandler(func(req *http.Request) util.JSONResponse {
|
|
||||||
return *err
|
|
||||||
}))
|
|
||||||
h.ServeHTTP(w, req)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return promhttp.InstrumentHandlerCounter(
|
return promhttp.InstrumentHandlerCounter(
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue