diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 657954adb..21bec3307 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -577,7 +577,7 @@ func Setup( v3mux.Handle("/login/sso/callback", httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { - return SSOCallback(req, userAPI, ssoAuthenticator, cfg.Matrix.ServerName) + return SSOCallback(req, userAPI, ssoAuthenticator, &cfg.Login.SSO, cfg.Matrix.ServerName) }), ).Methods(http.MethodGet, http.MethodOptions) diff --git a/clientapi/routing/sso.go b/clientapi/routing/sso.go index cf8271abb..e5031b345 100644 --- a/clientapi/routing/sso.go +++ b/clientapi/routing/sso.go @@ -65,7 +65,7 @@ func SSORedirect( } } - callbackURL, err := buildCallbackURLFromRedirect(cfg, req) + callbackURL, err := buildCallbackURLFromOther(cfg, req, "/login/sso/redirect") if err != nil { util.GetLogger(ctx).WithError(err).Error("Failed to build callback URL") return util.JSONResponse{ @@ -107,9 +107,9 @@ func SSORedirect( return resp } -// buildCallbackURLFromRedirect builds a callback URL from a redirect +// buildCallbackURLFromOther builds a callback URL from another SSO // request and configuration. -func buildCallbackURLFromRedirect(cfg *config.SSO, req *http.Request) (*url.URL, error) { +func buildCallbackURLFromOther(cfg *config.SSO, req *http.Request, expectedPath string) (*url.URL, error) { u := &url.URL{ Scheme: "https", User: req.URL.User, @@ -122,10 +122,9 @@ func buildCallbackURLFromRedirect(cfg *config.SSO, req *http.Request) (*url.URL, // Find the v3mux base, handling both `redirect` and // `redirect/{idp}` and not hard-coding the Matrix version. - const redirectPath = "/login/sso/redirect" - i := strings.Index(u.Path, redirectPath) + i := strings.Index(u.Path, expectedPath) if i < 0 { - return nil, fmt.Errorf("cannot find %q to replace in URL %q", redirectPath, u.Path) + return nil, fmt.Errorf("cannot find %q to replace in URL %q", expectedPath, u.Path) } u.Path = u.Path[:i] + "/login/sso/callback" @@ -142,6 +141,7 @@ func SSOCallback( req *http.Request, userAPI userAPIForSSO, auth *sso.Authenticator, + cfg *config.SSO, serverName gomatrixserverlib.ServerName, ) util.JSONResponse { if auth == nil { @@ -177,14 +177,18 @@ func SSOCallback( } } - callbackURL := &url.URL{ - Scheme: req.URL.Scheme, - Host: req.URL.Host, - Path: req.URL.Path, - RawQuery: url.Values{ - "provider": []string{idpID}, - }.Encode(), + callbackURL, err := buildCallbackURLFromOther(cfg, req, "/login/sso/callback") + if err != nil { + util.GetLogger(ctx).WithError(err).Error("Failed to build callback URL") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: err, + } } + + callbackURL = callbackURL.ResolveReference(&url.URL{ + RawQuery: url.Values{"provider": []string{idpID}}.Encode(), + }) result, err := auth.ProcessCallback(ctx, idpID, callbackURL.String(), nonce.Value, query) if err != nil { util.GetLogger(ctx).WithError(err).Error("Failed to process callback")