From 2f100c7993f7a0434fe8a770d24590bf5d3d5b62 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 1 Aug 2022 11:03:05 +0100 Subject: [PATCH] Tweak setup --- setup/base/base.go | 35 +++++++++++++++++------------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/setup/base/base.go b/setup/base/base.go index 4f433c8d9..70021bb8a 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -381,6 +381,21 @@ func (b *BaseDendrite) SetupAndServeHTTP( externalRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() internalRouter := externalRouter + notFoundHandler := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte("endpoint not found")) + } + notAllowedHandler := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusMethodNotAllowed) + _, _ = w.Write([]byte(fmt.Sprintf("%s not allowed on this endpoint", r.Method))) + } + + notFoundCORSHandler := httputil.WrapHandlerInCORS(http.HandlerFunc(notFoundHandler)) + notAllowedCORSHandler := httputil.WrapHandlerInCORS(http.HandlerFunc(notAllowedHandler)) + + externalRouter.NotFoundHandler = notFoundCORSHandler + externalRouter.MethodNotAllowedHandler = notAllowedCORSHandler + externalServ := &http.Server{ Addr: string(externalAddr), WriteTimeout: HTTPServerTimeout, @@ -400,6 +415,8 @@ func (b *BaseDendrite) SetupAndServeHTTP( // without enabling TLS. internalH2S := &http2.Server{} internalRouter = mux.NewRouter().SkipClean(true).UseEncodedPath() + internalRouter.NotFoundHandler = notFoundCORSHandler + internalRouter.MethodNotAllowedHandler = notAllowedCORSHandler internalServ = &http.Server{ Addr: string(internalAddr), Handler: h2c.NewHandler(internalRouter, internalH2S), @@ -451,24 +468,6 @@ func (b *BaseDendrite) SetupAndServeHTTP( externalRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(b.PublicMediaAPIMux) externalRouter.PathPrefix(httputil.PublicWellKnownPrefix).Handler(b.PublicWellKnownAPIMux) - notFoundHandler := func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusNotFound) - _, _ = w.Write([]byte(fmt.Sprintf("not found: %q", r.RequestURI))) - } - notAllowedHandler := func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusMethodNotAllowed) - _, _ = w.Write([]byte(fmt.Sprintf("%s not allowed: %q", r.Method, r.RequestURI))) - } - - notFoundCORSHandler := httputil.WrapHandlerInCORS(http.HandlerFunc(notFoundHandler)) - notAllowedCORSHandler := httputil.WrapHandlerInCORS(http.HandlerFunc(notAllowedHandler)) - - internalRouter.NotFoundHandler = notFoundCORSHandler - internalRouter.MethodNotAllowedHandler = notAllowedCORSHandler - - externalRouter.NotFoundHandler = notFoundCORSHandler - externalRouter.MethodNotAllowedHandler = notAllowedCORSHandler - if internalAddr != NoListener && internalAddr != externalAddr { go func() { var internalShutdown atomic.Bool // RegisterOnShutdown can be called more than once