From 58233c790a6fbd3fd8eedf9f6d3881a98d4b69d9 Mon Sep 17 00:00:00 2001 From: Tommie Gannert Date: Wed, 25 May 2022 01:17:08 +0200 Subject: [PATCH] Split SSO into OAuth2 and OIDC. Add OIDC discovery. GitHub implements OAuth2, but not OpenID Connect. This means it needs more magic constants than those that can do OIDC discovery (and where Userinfo is in OIDC-compatible.) Fixes the HTTP client to have a timeout. --- clientapi/auth/sso/github.go | 32 ++-- clientapi/auth/sso/oauth2.go | 221 +++++++++++++++++++++++++ clientapi/auth/sso/oidc.go | 159 ++++++++++++++++++ clientapi/auth/sso/oidc_base.go | 272 ------------------------------- clientapi/auth/sso/sso.go | 80 ++++++--- clientapi/routing/login.go | 20 ++- clientapi/routing/routing.go | 16 +- clientapi/routing/sso.go | 92 +++-------- setup/config/config_clientapi.go | 39 +++-- userapi/api/api_sso.go | 4 + 10 files changed, 534 insertions(+), 401 deletions(-) create mode 100644 clientapi/auth/sso/oauth2.go create mode 100644 clientapi/auth/sso/oidc.go delete mode 100644 clientapi/auth/sso/oidc_base.go diff --git a/clientapi/auth/sso/github.go b/clientapi/auth/sso/github.go index 55f5417b6..70f1a95e0 100644 --- a/clientapi/auth/sso/github.go +++ b/clientapi/auth/sso/github.go @@ -15,23 +15,25 @@ package sso import ( + "net/http" + "github.com/matrix-org/dendrite/setup/config" ) -// GitHubIdentityProvider is a GitHub-flavored identity provider. -var GitHubIdentityProvider IdentityProvider = githubIdentityProvider{ - baseOIDCIdentityProvider: &baseOIDCIdentityProvider{ - AuthURL: mustParseURLTemplate("https://github.com/login/oauth/authorize?scope=user:email"), - AccessTokenURL: mustParseURLTemplate("https://github.com/login/oauth/access_token"), - UserInfoURL: mustParseURLTemplate("https://api.github.com/user"), - UserInfoAccept: "application/vnd.github.v3+json", - UserInfoEmailPath: "email", - UserInfoSuggestedUserIDPath: "login", - }, -} +func newGitHubIdentityProvider(cfg *config.IdentityProvider, hc *http.Client) identityProvider { + return &oauth2IdentityProvider{ + cfg: cfg, + hc: hc, -type githubIdentityProvider struct { - *baseOIDCIdentityProvider -} + authorizationURL: "https://github.com/login/oauth/authorize", + accessTokenURL: "https://github.com/login/oauth/access_token", + userInfoURL: "https://api.github.com/user", -func (githubIdentityProvider) DefaultBrand() string { return config.SSOBrandGitHub } + scopes: []string{"user:email"}, + responseMimeType: "application/vnd.github.v3+json", + subPath: "id", + emailPath: "email", + displayNamePath: "name", + suggestedUserIDPath: "login", + } +} diff --git a/clientapi/auth/sso/oauth2.go b/clientapi/auth/sso/oauth2.go new file mode 100644 index 000000000..287e187bc --- /dev/null +++ b/clientapi/auth/sso/oauth2.go @@ -0,0 +1,221 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sso + +import ( + "context" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "strings" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/setup/config" + uapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/tidwall/gjson" +) + +type oauth2IdentityProvider struct { + cfg *config.IdentityProvider + hc *http.Client + + authorizationURL string + accessTokenURL string + userInfoURL string + + scopes []string + responseMimeType string + issPath string + subPath string + emailPath string + displayNamePath string + suggestedUserIDPath string +} + +func (p *oauth2IdentityProvider) AuthorizationURL(ctx context.Context, callbackURL, nonce string) (string, error) { + u, err := resolveURL(p.authorizationURL, url.Values{ + "client_id": []string{p.cfg.OIDC.ClientID}, + "response_type": []string{"code"}, + "redirect_uri": []string{callbackURL}, + "scope": []string{strings.Join(p.scopes, " ")}, + "state": []string{nonce}, + }) + if err != nil { + return "", err + } + return u.String(), nil +} + +func (p *oauth2IdentityProvider) ProcessCallback(ctx context.Context, callbackURL, nonce string, query url.Values) (*CallbackResult, error) { + state := query.Get("state") + if state == "" { + return nil, jsonerror.MissingArgument("state parameter missing") + } + if state != nonce { + return nil, jsonerror.InvalidArgumentValue("state parameter not matching nonce") + } + + if error := query.Get("error"); error != "" { + if euri := query.Get("error_uri"); euri != "" { + return &CallbackResult{RedirectURL: euri}, nil + } + + desc := query.Get("error_description") + if desc == "" { + desc = error + } + switch error { + case "unauthorized_client", "access_denied": + return nil, jsonerror.Forbidden("SSO said no: " + desc) + default: + return nil, fmt.Errorf("SSO failed: %v", error) + } + } + + code := query.Get("code") + if code == "" { + return nil, jsonerror.MissingArgument("code parameter missing") + } + + at, err := p.getAccessToken(ctx, callbackURL, code) + if err != nil { + return nil, err + } + + subject, displayName, suggestedLocalpart, err := p.getUserInfo(ctx, at) + if err != nil { + return nil, err + } + + return &CallbackResult{ + Identifier: &UserIdentifier{ + Namespace: uapi.SSOIDNamespace, + Issuer: p.cfg.ID, + Subject: subject, + }, + DisplayName: displayName, + SuggestedUserID: suggestedLocalpart, + }, nil +} + +func (p *oauth2IdentityProvider) getAccessToken(ctx context.Context, callbackURL, code string) (string, error) { + body := url.Values{ + "grant_type": []string{"authorization_code"}, + "code": []string{code}, + "redirect_uri": []string{callbackURL}, + "client_id": []string{p.cfg.OIDC.ClientID}, + "client_secret": []string{p.cfg.OIDC.ClientSecret}, + } + hreq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.accessTokenURL, strings.NewReader(body.Encode())) + if err != nil { + return "", err + } + hreq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + hreq.Header.Set("Accept", p.responseMimeType) + + hresp, err := p.hc.Do(hreq) + if err != nil { + return "", err + } + defer hresp.Body.Close() + + var resp oauth2TokenResponse + if err := json.NewDecoder(hresp.Body).Decode(&resp); err != nil { + return "", err + } + + if resp.Error != "" { + desc := resp.ErrorDescription + if desc == "" { + desc = resp.Error + } + return "", fmt.Errorf("failed to retrieve OIDC access token: %s", desc) + } + + if strings.ToLower(resp.TokenType) != "bearer" { + return "", fmt.Errorf("expected bearer token, got type %q", resp.TokenType) + } + + return resp.AccessToken, nil +} + +type oauth2TokenResponse struct { + TokenType string `json:"token_type"` + AccessToken string `json:"access_token"` + + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + ErrorURI string `json:"error_uri"` +} + +func (p *oauth2IdentityProvider) getUserInfo(ctx context.Context, accessToken string) (subject, displayName, suggestedLocalpart string, _ error) { + hreq, err := http.NewRequestWithContext(ctx, http.MethodGet, p.userInfoURL, nil) + if err != nil { + return "", "", "", err + } + hreq.Header.Set("Authorization", "token "+accessToken) + hreq.Header.Set("Accept", p.responseMimeType) + + hresp, err := p.hc.Do(hreq) + if err != nil { + return "", "", "", err + } + defer hresp.Body.Close() + + body, err := ioutil.ReadAll(hresp.Body) + if err != nil { + return "", "", "", err + } + + if res := gjson.GetBytes(body, p.subPath); !res.Exists() { + return "", "", "", fmt.Errorf("no %q in user info response body", p.subPath) + } else { + subject = res.String() + } + if subject == "" { + return "", "", "", fmt.Errorf("empty subject in user info") + } + + if p.suggestedUserIDPath != "" { + suggestedLocalpart = gjson.GetBytes(body, p.suggestedUserIDPath).String() + } + + if p.displayNamePath != "" { + displayName = gjson.GetBytes(body, p.displayNamePath).String() + } + + return +} + +func resolveURL(urlString string, defaultQuery url.Values) (*url.URL, error) { + u, err := url.Parse(urlString) + if err != nil { + return nil, err + } + + if defaultQuery != nil { + q := u.Query() + for k, vs := range defaultQuery { + if q.Get(k) == "" { + q[k] = vs + } + } + u.RawQuery = q.Encode() + } + + return u, nil +} diff --git a/clientapi/auth/sso/oidc.go b/clientapi/auth/sso/oidc.go new file mode 100644 index 000000000..ac9c56a57 --- /dev/null +++ b/clientapi/auth/sso/oidc.go @@ -0,0 +1,159 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sso + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "sync" + "time" + + "github.com/matrix-org/dendrite/setup/config" + uapi "github.com/matrix-org/dendrite/userapi/api" +) + +type oidcIdentityProvider struct { + *oauth2IdentityProvider + + disc *oidcDiscovery + exp time.Time + mu sync.Mutex +} + +func newOIDCIdentityProvider(cfg *config.IdentityProvider, hc *http.Client) *oidcIdentityProvider { + return &oidcIdentityProvider{ + oauth2IdentityProvider: &oauth2IdentityProvider{ + cfg: cfg, + hc: hc, + + scopes: []string{"openid", "profile", "email"}, + responseMimeType: "application/json", + subPath: "sub", + emailPath: "email", + displayNamePath: "name", + suggestedUserIDPath: "preferred_username", + }, + } +} + +func (p *oidcIdentityProvider) AuthorizationURL(ctx context.Context, callbackURL, nonce string) (string, error) { + oauth2p, _, err := p.get(ctx) + if err != nil { + return "", err + } + return oauth2p.AuthorizationURL(ctx, callbackURL, nonce) +} + +func (p *oidcIdentityProvider) ProcessCallback(ctx context.Context, callbackURL, nonce string, query url.Values) (*CallbackResult, error) { + oauth2p, disc, err := p.get(ctx) + if err != nil { + return nil, err + } + res, err := oauth2p.ProcessCallback(ctx, callbackURL, nonce, query) + if err != nil { + return nil, err + } + + // OIDC has the notion of issuer URL, which will be more + // stable than our configuration ID. + res.Identifier.Namespace = uapi.OIDCNamespace + res.Identifier.Issuer = disc.Issuer + + return res, nil +} + +func (p *oidcIdentityProvider) get(ctx context.Context) (*oauth2IdentityProvider, *oidcDiscovery, error) { + p.mu.Lock() + defer p.mu.Unlock() + + now := time.Now() + if p.exp.Before(now) || p.disc == nil { + disc, err := oidcDiscover(ctx, p.cfg.OIDC.DiscoveryURL) + if err != nil { + if p.disc != nil { + // Prefers returning a stale entry. + return p.oauth2IdentityProvider, p.disc, nil + } + return nil, nil, err + } + + p.exp = now.Add(24 * time.Hour) + newProvider := *p.oauth2IdentityProvider + newProvider.authorizationURL = disc.AuthorizationEndpoint + newProvider.accessTokenURL = disc.TokenEndpoint + newProvider.userInfoURL = disc.UserinfoEndpoint + + p.oauth2IdentityProvider = &newProvider + p.disc = disc + } + + return p.oauth2IdentityProvider, p.disc, nil +} + +type oidcDiscovery struct { + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + UserinfoEndpoint string `json:"userinfo_endpoint"` + ScopesSupported []string `json:"scopes_supported"` + ClaimsSupported []string `json:"claims_supported"` +} + +func oidcDiscover(ctx context.Context, url string) (*oidcDiscovery, error) { + hreq, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, err + } + hreq.Header.Set("Accept", "application/jrd+json,application/json;q=0.9") + + hresp, err := http.DefaultClient.Do(hreq) + if err != nil { + return nil, err + } + defer hresp.Body.Close() + + var disc oidcDiscovery + if err := json.NewDecoder(hresp.Body).Decode(&disc); err != nil { + return nil, err + } + + if disc.ScopesSupported != nil { + if !stringSliceContains(disc.ScopesSupported, "openid") { + return nil, fmt.Errorf("scope 'openid' is missing in %q", url) + } + } + + if disc.ClaimsSupported != nil { + for _, claim := range []string{"iss", "sub"} { + if !stringSliceContains(disc.ClaimsSupported, claim) { + return nil, fmt.Errorf("claim %q is not supported in %q", claim, url) + } + } + } + + return &disc, nil +} + +func stringSliceContains(ss []string, s string) bool { + for _, s2 := range ss { + if s2 == s { + return true + } + } + return false +} diff --git a/clientapi/auth/sso/oidc_base.go b/clientapi/auth/sso/oidc_base.go deleted file mode 100644 index 4275e5733..000000000 --- a/clientapi/auth/sso/oidc_base.go +++ /dev/null @@ -1,272 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sso - -import ( - "context" - "encoding/json" - "fmt" - "io/ioutil" - "mime" - "net/http" - "net/url" - "strings" - "text/template" - - "github.com/matrix-org/dendrite/clientapi/jsonerror" - uapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/tidwall/gjson" -) - -type baseOIDCIdentityProvider struct { - AuthURL *urlTemplate - AccessTokenURL *urlTemplate - UserInfoURL *urlTemplate - UserInfoAccept string - UserInfoEmailPath string - UserInfoSuggestedUserIDPath string -} - -func (p *baseOIDCIdentityProvider) AuthorizationURL(ctx context.Context, req *IdentityProviderRequest) (string, error) { - u, err := p.AuthURL.Execute(map[string]interface{}{ - "Config": req.System, - "State": req.DendriteNonce, - "RedirectURI": req.CallbackURL, - }, url.Values{ - "client_id": []string{req.System.OIDC.ClientID}, - "response_type": []string{"code"}, - "redirect_uri": []string{req.CallbackURL}, - "state": []string{req.DendriteNonce}, - }) - if err != nil { - return "", err - } - return u.String(), nil -} - -func (p *baseOIDCIdentityProvider) ProcessCallback(ctx context.Context, req *IdentityProviderRequest, values url.Values) (*CallbackResult, error) { - state := values.Get("state") - if state == "" { - return nil, jsonerror.MissingArgument("state parameter missing") - } - if state != req.DendriteNonce { - return nil, jsonerror.InvalidArgumentValue("state parameter not matching nonce") - } - - if error := values.Get("error"); error != "" { - if euri := values.Get("error_uri"); euri != "" { - return &CallbackResult{RedirectURL: euri}, nil - } - - desc := values.Get("error_description") - if desc == "" { - desc = error - } - switch error { - case "unauthorized_client", "access_denied": - return nil, jsonerror.Forbidden("SSO said no: " + desc) - default: - return nil, fmt.Errorf("SSO failed: %v", error) - } - } - - code := values.Get("code") - if code == "" { - return nil, jsonerror.MissingArgument("code parameter missing") - } - - oidcAccessToken, err := p.getOIDCAccessToken(ctx, req, code) - if err != nil { - return nil, err - } - - id, userID, err := p.getUserInfo(ctx, req, oidcAccessToken) - if err != nil { - return nil, err - } - - return &CallbackResult{Identifier: id, SuggestedUserID: userID}, nil -} - -func (p *baseOIDCIdentityProvider) getOIDCAccessToken(ctx context.Context, req *IdentityProviderRequest, code string) (string, error) { - u, err := p.AccessTokenURL.Execute(nil, nil) - if err != nil { - return "", err - } - - body := url.Values{ - "grant_type": []string{"authorization_code"}, - "code": []string{code}, - "redirect_uri": []string{req.CallbackURL}, - "client_id": []string{req.System.OIDC.ClientID}, - } - - hreq, err := http.NewRequestWithContext(ctx, http.MethodPost, u.String(), strings.NewReader(body.Encode())) - if err != nil { - return "", err - } - hreq.Header.Set("Content-Type", "application/x-www-form-urlencoded") - hreq.Header.Set("Accept", "application/x-www-form-urlencoded") - - hresp, err := http.DefaultClient.Do(hreq) - if err != nil { - return "", err - } - defer hresp.Body.Close() - - ctype, _, err := mime.ParseMediaType(hresp.Header.Get("Content-Type")) - if err != nil { - return "", err - } - if ctype != "application/json" { - return "", fmt.Errorf("expected URL encoded response, got content type %q", ctype) - } - - var resp struct { - TokenType string `json:"token_type"` - AccessToken string `json:"access_token"` - - Error string `json:"error"` - ErrorDescription string `json:"error_description"` - ErrorURI string `json:"error_uri"` - } - if err := json.NewDecoder(hresp.Body).Decode(&resp); err != nil { - return "", err - } - - if resp.Error != "" { - desc := resp.ErrorDescription - if desc == "" { - desc = resp.Error - } - return "", fmt.Errorf("failed to retrieve OIDC access token: %s", desc) - } - - if strings.ToLower(resp.TokenType) != "bearer" { - return "", fmt.Errorf("expected bearer token, got type %q", resp.TokenType) - } - - return resp.AccessToken, nil -} - -func (p *baseOIDCIdentityProvider) getUserInfo(ctx context.Context, req *IdentityProviderRequest, oidcAccessToken string) (ssoUser *UserIdentifier, suggestedUserID string, _ error) { - u, err := p.UserInfoURL.Execute(map[string]interface{}{ - "Config": req.System, - }, nil) - if err != nil { - return nil, "", err - } - - hreq, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) - if err != nil { - return nil, "", err - } - hreq.Header.Set("Authorization", "token "+oidcAccessToken) - hreq.Header.Set("Accept", p.UserInfoAccept) - - hresp, err := http.DefaultClient.Do(hreq) - if err != nil { - return nil, "", err - } - defer hresp.Body.Close() - - ctype, _, err := mime.ParseMediaType(hresp.Header.Get("Content-Type")) - if err != nil { - return nil, "", err - } - - if ctype != "application/json" { - return nil, "", fmt.Errorf("got unknown content type %q for user info", ctype) - } - - body, err := ioutil.ReadAll(hresp.Body) - if err != nil { - return nil, "", err - } - - issRes := gjson.GetBytes(body, "iss") - if !issRes.Exists() { - return nil, "", fmt.Errorf("no iss in user info response body") - } - iss := issRes.String() - - subRes := gjson.GetBytes(body, "sub") - if !subRes.Exists() { - return nil, "", fmt.Errorf("no sub in user info response body") - } - sub := subRes.String() - - if iss == "" { - return nil, "", fmt.Errorf("no iss in user info") - } - - if sub == "" { - return nil, "", fmt.Errorf("no sub in user info") - } - - // This is optional. - userIDRes := gjson.GetBytes(body, p.UserInfoSuggestedUserIDPath) - suggestedUserID = userIDRes.String() - - return &UserIdentifier{ - Namespace: uapi.OIDCNamespace, - Issuer: iss, - Subject: sub, - }, suggestedUserID, nil -} - -type urlTemplate struct { - base *template.Template -} - -func parseURLTemplate(s string) (*urlTemplate, error) { - t, err := template.New("").Parse(s) - if err != nil { - return nil, err - } - return &urlTemplate{base: t}, nil -} - -func mustParseURLTemplate(s string) *urlTemplate { - t, err := parseURLTemplate(s) - if err != nil { - panic(err) - } - return t -} - -func (t *urlTemplate) Execute(params interface{}, defaultQuery url.Values) (*url.URL, error) { - var sb strings.Builder - err := t.base.Execute(&sb, params) - if err != nil { - return nil, err - } - - u, err := url.Parse(sb.String()) - if err != nil { - return nil, err - } - - if defaultQuery != nil { - q := u.Query() - for k, vs := range defaultQuery { - if q.Get(k) == "" { - q[k] = vs - } - } - u.RawQuery = q.Encode() - } - return u, nil -} diff --git a/clientapi/auth/sso/sso.go b/clientapi/auth/sso/sso.go index 37e3842a8..5923d8491 100644 --- a/clientapi/auth/sso/sso.go +++ b/clientapi/auth/sso/sso.go @@ -16,46 +16,78 @@ package sso import ( "context" + "fmt" + "net/http" "net/url" + "time" "github.com/matrix-org/dendrite/setup/config" uapi "github.com/matrix-org/dendrite/userapi/api" ) -type IdentityProvider interface { - DefaultBrand() string - - AuthorizationURL(context.Context, *IdentityProviderRequest) (string, error) - ProcessCallback(context.Context, *IdentityProviderRequest, url.Values) (*CallbackResult, error) +type Authenticator struct { + providers map[string]identityProvider } -type IdentityProviderRequest struct { - System *config.IdentityProvider - CallbackURL string - DendriteNonce string +func NewAuthenticator(cfg *config.SSO) (*Authenticator, error) { + hc := &http.Client{ + Timeout: 10 * time.Second, + Transport: &http.Transport{ + DisableKeepAlives: true, + Proxy: http.ProxyFromEnvironment, + }, + } + + a := &Authenticator{ + providers: make(map[string]identityProvider, len(cfg.Providers)), + } + for _, pcfg := range cfg.Providers { + typ := pcfg.Type + if typ == "" { + typ = config.IdentityProviderType(pcfg.ID) + } + + switch typ { + case config.SSOTypeOIDC: + a.providers[pcfg.ID] = newOIDCIdentityProvider(&pcfg, hc) + case config.SSOTypeGitHub: + a.providers[pcfg.ID] = newGitHubIdentityProvider(&pcfg, hc) + default: + return nil, fmt.Errorf("unknown SSO provider type: %s", typ) + } + } + + return a, nil +} + +func (auth *Authenticator) AuthorizationURL(ctx context.Context, providerID, callbackURL, nonce string) (string, error) { + p := auth.providers[providerID] + if p == nil { + return "", fmt.Errorf("unknown identity provider: %s", providerID) + } + return p.AuthorizationURL(ctx, callbackURL, nonce) +} + +func (auth *Authenticator) ProcessCallback(ctx context.Context, providerID, callbackURL, nonce string, query url.Values) (*CallbackResult, error) { + p := auth.providers[providerID] + if p == nil { + return nil, fmt.Errorf("unknown identity provider: %s", providerID) + } + return p.ProcessCallback(ctx, callbackURL, nonce, query) +} + +type identityProvider interface { + AuthorizationURL(ctx context.Context, callbackURL, nonce string) (string, error) + ProcessCallback(ctx context.Context, callbackURL, nonce string, query url.Values) (*CallbackResult, error) } type CallbackResult struct { RedirectURL string Identifier *UserIdentifier + DisplayName string SuggestedUserID string } -type IdentityProviderType string - -const ( - TypeGitHub IdentityProviderType = config.SSOBrandGitHub -) - -func GetIdentityProvider(t IdentityProviderType) IdentityProvider { - switch t { - case TypeGitHub: - return GitHubIdentityProvider - default: - return nil - } -} - type UserIdentifier struct { Namespace uapi.SSOIssuerNamespace Issuer, Subject string diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index ad4aca29c..f0600ceb1 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -20,7 +20,6 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/sso" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/setup/config" @@ -46,10 +45,10 @@ type stage struct { } type identityProvider struct { - ID string `json:"id"` - Name string `json:"name"` - Brand string `json:"brand,omitempty"` - Icon string `json:"icon,omitempty"` + ID string `json:"id"` + Name string `json:"name"` + Brand config.SSOBrand `json:"brand,omitempty"` + Icon string `json:"icon,omitempty"` } func passwordLogin() []stage { @@ -69,11 +68,14 @@ func ssoLogin(cfg *config.ClientAPI) []stage { if brand == "" { typ := idp.Type if typ == "" { - typ = idp.ID + typ = config.IdentityProviderType(idp.ID) } - idpType := sso.GetIdentityProvider(sso.IdentityProviderType(typ)) - if idpType != nil { - brand = idpType.DefaultBrand() + switch typ { + case config.SSOTypeGitHub: + brand = config.SSOBrandGitHub + + default: + brand = config.SSOBrand(idp.ID) } } idps = append(idps, identityProvider{ diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index a833e5217..ebc7003d7 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -23,6 +23,7 @@ import ( appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/auth" + "github.com/matrix-org/dendrite/clientapi/auth/sso" clientutil "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" @@ -67,6 +68,15 @@ func Setup( rateLimits := httputil.NewRateLimits(&cfg.RateLimiting) userInteractiveAuth := auth.NewUserInteractive(userAPI, cfg) + var ssoAuthenticator *sso.Authenticator + if cfg.Login.SSO.Enabled { + var err error + ssoAuthenticator, err = sso.NewAuthenticator(&cfg.Login.SSO) + if err != nil { + logrus.WithError(err).Fatal("failed to create SSO authenticator") + } + } + unstableFeatures := map[string]bool{ "org.matrix.e2e_cross_signing": true, } @@ -565,20 +575,20 @@ func Setup( v3mux.Handle("/login/sso/callback", httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { - return SSOCallback(req, userAPI, cfg) + return SSOCallback(req, userAPI, ssoAuthenticator, cfg.Matrix.ServerName) }), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/login/sso/redirect", httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { - return SSORedirect(req, "", cfg) + return SSORedirect(req, "", ssoAuthenticator) }), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/login/sso/redirect/{idpID}", httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { vars := mux.Vars(req) - return SSORedirect(req, vars["idpID"], cfg) + return SSORedirect(req, vars["idpID"], ssoAuthenticator) }), ).Methods(http.MethodGet, http.MethodOptions) diff --git a/clientapi/routing/sso.go b/clientapi/routing/sso.go index d1b224fbe..6e2e1d967 100644 --- a/clientapi/routing/sso.go +++ b/clientapi/routing/sso.go @@ -25,7 +25,6 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/sso" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/userutil" - "github.com/matrix-org/dendrite/setup/config" uapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -36,11 +35,11 @@ import ( func SSORedirect( req *http.Request, idpID string, - cfg *config.ClientAPI, + auth *sso.Authenticator, ) util.JSONResponse { - if !cfg.Login.SSO.Enabled { + if auth == nil { return util.JSONResponse{ - Code: http.StatusNotImplemented, + Code: http.StatusNotFound, JSON: jsonerror.NotFound("authentication method disabled"), } } @@ -60,28 +59,9 @@ func SSORedirect( } } - if idpID == "" { - // Check configuration if the client didn't provide an ID. - idpID = cfg.Login.SSO.DefaultProviderID - } - if idpID == "" && len(cfg.Login.SSO.Providers) > 0 { - // Fall back to the first provider. If there are no providers, getProvider("") will fail. - idpID = cfg.Login.SSO.Providers[0].ID - } - idpCfg, idpType := getProvider(cfg, idpID) - if idpType == nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue("unknown identity provider"), - } - } - - idpReq := &sso.IdentityProviderRequest{ - System: idpCfg, - CallbackURL: req.URL.ResolveReference(&url.URL{Path: "../callback", RawQuery: url.Values{"provider": []string{idpID}}.Encode()}).String(), - DendriteNonce: formatNonce(redirectURL), - } - u, err := idpType.AuthorizationURL(req.Context(), idpReq) + callbackURL := req.URL.ResolveReference(&url.URL{Path: "../callback", RawQuery: url.Values{"provider": []string{idpID}}.Encode()}) + nonce := formatNonce(redirectURL) + u, err := auth.AuthorizationURL(req.Context(), idpID, callbackURL.String(), nonce) if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, @@ -92,7 +72,7 @@ func SSORedirect( resp := util.RedirectResponse(u) resp.Headers["Set-Cookie"] = (&http.Cookie{ Name: "oidc_nonce", - Value: idpReq.DendriteNonce, + Value: nonce, Expires: time.Now().Add(10 * time.Minute), Secure: true, SameSite: http.SameSiteStrictMode, @@ -105,8 +85,16 @@ func SSORedirect( func SSOCallback( req *http.Request, userAPI userAPIForSSO, - cfg *config.ClientAPI, + auth *sso.Authenticator, + serverName gomatrixserverlib.ServerName, ) util.JSONResponse { + if auth == nil { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound("authentication method disabled"), + } + } + ctx := req.Context() query := req.URL.Query() @@ -117,13 +105,6 @@ func SSOCallback( JSON: jsonerror.MissingArgument("provider parameter missing"), } } - idpCfg, idpType := getProvider(cfg, idpID) - if idpType == nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue("unknown identity provider"), - } - } nonce, err := req.Cookie("oidc_nonce") if err != nil { @@ -140,19 +121,15 @@ func SSOCallback( } } - idpReq := &sso.IdentityProviderRequest{ - System: idpCfg, - CallbackURL: (&url.URL{ - Scheme: req.URL.Scheme, - Host: req.URL.Host, - Path: req.URL.Path, - RawQuery: url.Values{ - "provider": []string{idpID}, - }.Encode(), - }).String(), - DendriteNonce: nonce.Value, + callbackURL := &url.URL{ + Scheme: req.URL.Scheme, + Host: req.URL.Host, + Path: req.URL.Path, + RawQuery: url.Values{ + "provider": []string{idpID}, + }.Encode(), } - result, err := idpType.ProcessCallback(ctx, idpReq, query) + result, err := auth.ProcessCallback(ctx, idpID, callbackURL.String(), nonce.Value, query) if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, @@ -165,7 +142,7 @@ func SSOCallback( return util.RedirectResponse(result.RedirectURL) } - localpart, err := verifySSOUserIdentifier(ctx, userAPI, result.Identifier, cfg.Matrix.ServerName) + localpart, err := verifySSOUserIdentifier(ctx, userAPI, result.Identifier, serverName) if err != nil { util.GetLogger(ctx).WithError(err).WithField("identifier", result.Identifier).Error("failed to find user") return util.JSONResponse{ @@ -184,7 +161,7 @@ func SSOCallback( } } - token, err := createLoginToken(ctx, userAPI, userutil.MakeUserID(localpart, cfg.Matrix.ServerName)) + token, err := createLoginToken(ctx, userAPI, userutil.MakeUserID(localpart, serverName)) if err != nil { util.GetLogger(ctx).WithError(err).Errorf("PerformLoginTokenCreation failed") return jsonerror.InternalServerError() @@ -210,23 +187,6 @@ type userAPIForSSO interface { QueryLocalpartForSSO(ctx context.Context, req *uapi.QueryLocalpartForSSORequest, res *uapi.QueryLocalpartForSSOResponse) error } -// getProvider looks up the given provider in the -// configuration. Returns nil if it wasn't found or was of unknown -// type. -func getProvider(cfg *config.ClientAPI, id string) (*config.IdentityProvider, sso.IdentityProvider) { - for _, idp := range cfg.Login.SSO.Providers { - if idp.ID == id { - switch sso.IdentityProviderType(id) { - case sso.TypeGitHub: - return &idp, sso.GitHubIdentityProvider - default: - return nil, nil - } - } - } - return nil, nil -} - // formatNonce creates a random nonce that also contains the URL. func formatNonce(redirectURL string) string { return util.RandomString(16) + "." + base64.RawURLEncoding.EncodeToString([]byte(redirectURL)) diff --git a/setup/config/config_clientapi.go b/setup/config/config_clientapi.go index 4f529dd8e..2106c341c 100644 --- a/setup/config/config_clientapi.go +++ b/setup/config/config_clientapi.go @@ -161,7 +161,7 @@ type IdentityProvider struct { // Brand is a hint on how to display the IdP to the user. If this is empty, a default // based on the type is used. - Brand string `yaml:"brand"` + Brand SSOBrand `yaml:"brand"` // Icon is an MXC URI describing how to display the IdP to the user. Prefer using `brand`. Icon string `yaml:"icon"` @@ -169,18 +169,19 @@ type IdentityProvider struct { // Type describes how this provider is implemented. It must match "github". If this is // empty, the ID is used, which means there is a weak expectation that ID is also a // valid type, unless you have a complicated setup. - Type string `yaml:"type"` + Type IdentityProviderType `yaml:"type"` // OIDC contains settings for providers based on OpenID Connect (OAuth 2). OIDC struct { ClientID string `yaml:"client_id"` ClientSecret string `yaml:"client_secret"` + DiscoveryURL string `yaml:"discovery_url"` } `yaml:"oidc"` } func (idp *IdentityProvider) Verify(configErrs *ConfigErrors) { checkNotEmpty(configErrs, "client_api.sso.providers.id", idp.ID) - if !checkIdentityProviderBrand(idp.ID) { + if !checkIdentityProviderBrand(SSOBrand(idp.ID)) { configErrs.Add(fmt.Sprintf("unrecognized ID config key %q: %s", "client_api.sso.providers", idp.ID)) } checkNotEmpty(configErrs, "client_api.sso.providers.name", idp.Name) @@ -192,11 +193,16 @@ func (idp *IdentityProvider) Verify(configErrs *ConfigErrors) { } typ := idp.Type if idp.Type == "" { - typ = idp.ID + typ = IdentityProviderType(idp.ID) } switch typ { - case "github": + case SSOTypeOIDC: + checkNotEmpty(configErrs, "client_api.sso.providers.oidc.client_id", idp.OIDC.ClientID) + checkNotEmpty(configErrs, "client_api.sso.providers.oidc.client_secret", idp.OIDC.ClientSecret) + checkNotEmpty(configErrs, "client_api.sso.providers.oidc.discovery_url", idp.OIDC.DiscoveryURL) + + case SSOTypeGitHub: checkNotEmpty(configErrs, "client_api.sso.providers.oidc.client_id", idp.OIDC.ClientID) checkNotEmpty(configErrs, "client_api.sso.providers.oidc.client_secret", idp.OIDC.ClientSecret) @@ -206,7 +212,7 @@ func (idp *IdentityProvider) Verify(configErrs *ConfigErrors) { } // See https://github.com/matrix-org/matrix-doc/blob/old_master/informal/idp-brands.md. -func checkIdentityProviderBrand(s string) bool { +func checkIdentityProviderBrand(s SSOBrand) bool { switch s { case SSOBrandApple, SSOBrandFacebook, SSOBrandGitHub, SSOBrandGitLab, SSOBrandGoogle, SSOBrandTwitter: return true @@ -215,13 +221,22 @@ func checkIdentityProviderBrand(s string) bool { } } +type SSOBrand string + const ( - SSOBrandApple = "apple" - SSOBrandFacebook = "facebook" - SSOBrandGitHub = "github" - SSOBrandGitLab = "gitlab" - SSOBrandGoogle = "google" - SSOBrandTwitter = "twitter" + SSOBrandApple SSOBrand = "apple" + SSOBrandFacebook SSOBrand = "facebook" + SSOBrandGitHub SSOBrand = "github" + SSOBrandGitLab SSOBrand = "gitlab" + SSOBrandGoogle SSOBrand = "google" + SSOBrandTwitter SSOBrand = "twitter" +) + +type IdentityProviderType string + +const ( + SSOTypeOIDC IdentityProviderType = "oidc" + SSOTypeGitHub IdentityProviderType = "github" ) type TURN struct { diff --git a/userapi/api/api_sso.go b/userapi/api/api_sso.go index 56a3686b9..58204d8cb 100644 --- a/userapi/api/api_sso.go +++ b/userapi/api/api_sso.go @@ -47,6 +47,10 @@ type SSOIssuerNamespace string const ( UnknownNamespace SSOIssuerNamespace = "" + // SSOIDNamespace indicates the issuer is an ID key matching a + // Dendrite SSO provider configuration. + SSOIDNamespace SSOIssuerNamespace = "sso" + // OIDCNamespace indicates the issuer is a full URL, as defined in // https://openid.net/specs/openid-connect-core-1_0.html#Terminology. OIDCNamespace SSOIssuerNamespace = "oidc"