diff --git a/clientapi/routing/sso.go b/clientapi/routing/sso.go index 34accda1f..3b809092e 100644 --- a/clientapi/routing/sso.go +++ b/clientapi/routing/sso.go @@ -70,6 +70,13 @@ func SSORedirect( } } + if idpID == "" { + idpID = cfg.DefaultProviderID + if idpID == "" && len(cfg.Providers) > 0 { + idpID = cfg.Providers[0].ID + } + } + callbackURL, err := buildCallbackURLFromOther(cfg, req, "/login/sso/redirect") if err != nil { util.GetLogger(ctx).WithError(err).Error("Failed to build callback URL") diff --git a/clientapi/routing/sso_test.go b/clientapi/routing/sso_test.go index a26604b13..c6ac3637e 100644 --- a/clientapi/routing/sso_test.go +++ b/clientapi/routing/sso_test.go @@ -37,7 +37,30 @@ func TestSSORedirect(t *testing.T) { }.Encode(), }, }, - WantLocationRE: `http://auth.example.com/authorize\?callbackURL=http%3A%2F%2Fmatrix.example.com%2F_matrix%2Fv4%2Flogin%2Fsso%2Fcallback%3Fprovider%3D&nonce=.+&providerID=`, + Config: config.SSO{ + DefaultProviderID: "adefault", + }, + WantLocationRE: `http://auth.example.com/authorize\?callbackURL=http%3A%2F%2Fmatrix.example.com%2F_matrix%2Fv4%2Flogin%2Fsso%2Fcallback%3Fprovider%3Dadefault&nonce=.+&providerID=adefault`, + WantSetCookieRE: "sso_nonce=[^;].*Path=/_matrix/v4/login/sso", + }, + { + Name: "redirectFirstProvider", + Req: http.Request{ + Host: "matrix.example.com", + URL: &url.URL{ + Path: "/_matrix/v4/login/sso/redirect", + RawQuery: url.Values{ + "redirectUrl": []string{"http://example.com/continue"}, + }.Encode(), + }, + }, + Config: config.SSO{ + Providers: []config.IdentityProvider{ + {ID: "firstprovider"}, + {ID: "secondprovider"}, + }, + }, + WantLocationRE: `http://auth.example.com/authorize\?callbackURL=http%3A%2F%2Fmatrix.example.com%2F_matrix%2Fv4%2Flogin%2Fsso%2Fcallback%3Fprovider%3Dfirstprovider&nonce=.+&providerID=firstprovider`, WantSetCookieRE: "sso_nonce=[^;].*Path=/_matrix/v4/login/sso", }, { @@ -468,6 +491,10 @@ type fakeSSOAuthenticator struct { } func (auth *fakeSSOAuthenticator) AuthorizationURL(ctx context.Context, providerID, callbackURL, nonce string) (string, error) { + if providerID == "" { + return "", errors.New("empty providerID") + } + return (&url.URL{ Scheme: "http", Host: "auth.example.com",