diff --git a/setup/base/base.go b/setup/base/base.go index 70021bb8a..dbb236956 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -381,21 +381,6 @@ func (b *BaseDendrite) SetupAndServeHTTP( externalRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() internalRouter := externalRouter - notFoundHandler := func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusNotFound) - _, _ = w.Write([]byte("endpoint not found")) - } - notAllowedHandler := func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusMethodNotAllowed) - _, _ = w.Write([]byte(fmt.Sprintf("%s not allowed on this endpoint", r.Method))) - } - - notFoundCORSHandler := httputil.WrapHandlerInCORS(http.HandlerFunc(notFoundHandler)) - notAllowedCORSHandler := httputil.WrapHandlerInCORS(http.HandlerFunc(notAllowedHandler)) - - externalRouter.NotFoundHandler = notFoundCORSHandler - externalRouter.MethodNotAllowedHandler = notAllowedCORSHandler - externalServ := &http.Server{ Addr: string(externalAddr), WriteTimeout: HTTPServerTimeout, @@ -415,8 +400,6 @@ func (b *BaseDendrite) SetupAndServeHTTP( // without enabling TLS. internalH2S := &http2.Server{} internalRouter = mux.NewRouter().SkipClean(true).UseEncodedPath() - internalRouter.NotFoundHandler = notFoundCORSHandler - internalRouter.MethodNotAllowedHandler = notAllowedCORSHandler internalServ = &http.Server{ Addr: string(internalAddr), Handler: h2c.NewHandler(internalRouter, internalH2S), @@ -426,6 +409,29 @@ func (b *BaseDendrite) SetupAndServeHTTP( } } + notFoundHandler := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte("endpoint not found")) + } + notAllowedHandler := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusMethodNotAllowed) + _, _ = w.Write([]byte(fmt.Sprintf("%s not allowed on this endpoint", r.Method))) + } + + notFoundCORSHandler := httputil.WrapHandlerInCORS(http.HandlerFunc(notFoundHandler)) + notAllowedCORSHandler := httputil.WrapHandlerInCORS(http.HandlerFunc(notAllowedHandler)) + + for _, router := range []*mux.Router{ + internalRouter, externalRouter, + b.DendriteAdminMux, b.SynapseAdminMux, + b.InternalAPIMux, b.PublicWellKnownAPIMux, + b.PublicClientAPIMux, b.PublicFederationAPIMux, + b.PublicKeyAPIMux, b.PublicMediaAPIMux, + } { + router.NotFoundHandler = notFoundCORSHandler + router.MethodNotAllowedHandler = notAllowedCORSHandler + } + internalRouter.PathPrefix(httputil.InternalPathPrefix).Handler(b.InternalAPIMux) if b.Cfg.Global.Metrics.Enabled { internalRouter.Handle("/metrics", httputil.WrapHandlerInBasicAuth(promhttp.Handler(), b.Cfg.Global.Metrics.BasicAuth))