diff --git a/clientapi/auth/sso/github.go b/clientapi/auth/sso/github.go index 55f5417b6..70f1a95e0 100644 --- a/clientapi/auth/sso/github.go +++ b/clientapi/auth/sso/github.go @@ -15,23 +15,25 @@ package sso import ( + "net/http" + "github.com/matrix-org/dendrite/setup/config" ) -// GitHubIdentityProvider is a GitHub-flavored identity provider. -var GitHubIdentityProvider IdentityProvider = githubIdentityProvider{ - baseOIDCIdentityProvider: &baseOIDCIdentityProvider{ - AuthURL: mustParseURLTemplate("https://github.com/login/oauth/authorize?scope=user:email"), - AccessTokenURL: mustParseURLTemplate("https://github.com/login/oauth/access_token"), - UserInfoURL: mustParseURLTemplate("https://api.github.com/user"), - UserInfoAccept: "application/vnd.github.v3+json", - UserInfoEmailPath: "email", - UserInfoSuggestedUserIDPath: "login", - }, -} +func newGitHubIdentityProvider(cfg *config.IdentityProvider, hc *http.Client) identityProvider { + return &oauth2IdentityProvider{ + cfg: cfg, + hc: hc, -type githubIdentityProvider struct { - *baseOIDCIdentityProvider -} + authorizationURL: "https://github.com/login/oauth/authorize", + accessTokenURL: "https://github.com/login/oauth/access_token", + userInfoURL: "https://api.github.com/user", -func (githubIdentityProvider) DefaultBrand() string { return config.SSOBrandGitHub } + scopes: []string{"user:email"}, + responseMimeType: "application/vnd.github.v3+json", + subPath: "id", + emailPath: "email", + displayNamePath: "name", + suggestedUserIDPath: "login", + } +} diff --git a/clientapi/auth/sso/oauth2.go b/clientapi/auth/sso/oauth2.go new file mode 100644 index 000000000..287e187bc --- /dev/null +++ b/clientapi/auth/sso/oauth2.go @@ -0,0 +1,221 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sso + +import ( + "context" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "strings" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/setup/config" + uapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/tidwall/gjson" +) + +type oauth2IdentityProvider struct { + cfg *config.IdentityProvider + hc *http.Client + + authorizationURL string + accessTokenURL string + userInfoURL string + + scopes []string + responseMimeType string + issPath string + subPath string + emailPath string + displayNamePath string + suggestedUserIDPath string +} + +func (p *oauth2IdentityProvider) AuthorizationURL(ctx context.Context, callbackURL, nonce string) (string, error) { + u, err := resolveURL(p.authorizationURL, url.Values{ + "client_id": []string{p.cfg.OIDC.ClientID}, + "response_type": []string{"code"}, + "redirect_uri": []string{callbackURL}, + "scope": []string{strings.Join(p.scopes, " ")}, + "state": []string{nonce}, + }) + if err != nil { + return "", err + } + return u.String(), nil +} + +func (p *oauth2IdentityProvider) ProcessCallback(ctx context.Context, callbackURL, nonce string, query url.Values) (*CallbackResult, error) { + state := query.Get("state") + if state == "" { + return nil, jsonerror.MissingArgument("state parameter missing") + } + if state != nonce { + return nil, jsonerror.InvalidArgumentValue("state parameter not matching nonce") + } + + if error := query.Get("error"); error != "" { + if euri := query.Get("error_uri"); euri != "" { + return &CallbackResult{RedirectURL: euri}, nil + } + + desc := query.Get("error_description") + if desc == "" { + desc = error + } + switch error { + case "unauthorized_client", "access_denied": + return nil, jsonerror.Forbidden("SSO said no: " + desc) + default: + return nil, fmt.Errorf("SSO failed: %v", error) + } + } + + code := query.Get("code") + if code == "" { + return nil, jsonerror.MissingArgument("code parameter missing") + } + + at, err := p.getAccessToken(ctx, callbackURL, code) + if err != nil { + return nil, err + } + + subject, displayName, suggestedLocalpart, err := p.getUserInfo(ctx, at) + if err != nil { + return nil, err + } + + return &CallbackResult{ + Identifier: &UserIdentifier{ + Namespace: uapi.SSOIDNamespace, + Issuer: p.cfg.ID, + Subject: subject, + }, + DisplayName: displayName, + SuggestedUserID: suggestedLocalpart, + }, nil +} + +func (p *oauth2IdentityProvider) getAccessToken(ctx context.Context, callbackURL, code string) (string, error) { + body := url.Values{ + "grant_type": []string{"authorization_code"}, + "code": []string{code}, + "redirect_uri": []string{callbackURL}, + "client_id": []string{p.cfg.OIDC.ClientID}, + "client_secret": []string{p.cfg.OIDC.ClientSecret}, + } + hreq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.accessTokenURL, strings.NewReader(body.Encode())) + if err != nil { + return "", err + } + hreq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + hreq.Header.Set("Accept", p.responseMimeType) + + hresp, err := p.hc.Do(hreq) + if err != nil { + return "", err + } + defer hresp.Body.Close() + + var resp oauth2TokenResponse + if err := json.NewDecoder(hresp.Body).Decode(&resp); err != nil { + return "", err + } + + if resp.Error != "" { + desc := resp.ErrorDescription + if desc == "" { + desc = resp.Error + } + return "", fmt.Errorf("failed to retrieve OIDC access token: %s", desc) + } + + if strings.ToLower(resp.TokenType) != "bearer" { + return "", fmt.Errorf("expected bearer token, got type %q", resp.TokenType) + } + + return resp.AccessToken, nil +} + +type oauth2TokenResponse struct { + TokenType string `json:"token_type"` + AccessToken string `json:"access_token"` + + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + ErrorURI string `json:"error_uri"` +} + +func (p *oauth2IdentityProvider) getUserInfo(ctx context.Context, accessToken string) (subject, displayName, suggestedLocalpart string, _ error) { + hreq, err := http.NewRequestWithContext(ctx, http.MethodGet, p.userInfoURL, nil) + if err != nil { + return "", "", "", err + } + hreq.Header.Set("Authorization", "token "+accessToken) + hreq.Header.Set("Accept", p.responseMimeType) + + hresp, err := p.hc.Do(hreq) + if err != nil { + return "", "", "", err + } + defer hresp.Body.Close() + + body, err := ioutil.ReadAll(hresp.Body) + if err != nil { + return "", "", "", err + } + + if res := gjson.GetBytes(body, p.subPath); !res.Exists() { + return "", "", "", fmt.Errorf("no %q in user info response body", p.subPath) + } else { + subject = res.String() + } + if subject == "" { + return "", "", "", fmt.Errorf("empty subject in user info") + } + + if p.suggestedUserIDPath != "" { + suggestedLocalpart = gjson.GetBytes(body, p.suggestedUserIDPath).String() + } + + if p.displayNamePath != "" { + displayName = gjson.GetBytes(body, p.displayNamePath).String() + } + + return +} + +func resolveURL(urlString string, defaultQuery url.Values) (*url.URL, error) { + u, err := url.Parse(urlString) + if err != nil { + return nil, err + } + + if defaultQuery != nil { + q := u.Query() + for k, vs := range defaultQuery { + if q.Get(k) == "" { + q[k] = vs + } + } + u.RawQuery = q.Encode() + } + + return u, nil +} diff --git a/clientapi/auth/sso/oidc.go b/clientapi/auth/sso/oidc.go new file mode 100644 index 000000000..ac9c56a57 --- /dev/null +++ b/clientapi/auth/sso/oidc.go @@ -0,0 +1,159 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sso + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "sync" + "time" + + "github.com/matrix-org/dendrite/setup/config" + uapi "github.com/matrix-org/dendrite/userapi/api" +) + +type oidcIdentityProvider struct { + *oauth2IdentityProvider + + disc *oidcDiscovery + exp time.Time + mu sync.Mutex +} + +func newOIDCIdentityProvider(cfg *config.IdentityProvider, hc *http.Client) *oidcIdentityProvider { + return &oidcIdentityProvider{ + oauth2IdentityProvider: &oauth2IdentityProvider{ + cfg: cfg, + hc: hc, + + scopes: []string{"openid", "profile", "email"}, + responseMimeType: "application/json", + subPath: "sub", + emailPath: "email", + displayNamePath: "name", + suggestedUserIDPath: "preferred_username", + }, + } +} + +func (p *oidcIdentityProvider) AuthorizationURL(ctx context.Context, callbackURL, nonce string) (string, error) { + oauth2p, _, err := p.get(ctx) + if err != nil { + return "", err + } + return oauth2p.AuthorizationURL(ctx, callbackURL, nonce) +} + +func (p *oidcIdentityProvider) ProcessCallback(ctx context.Context, callbackURL, nonce string, query url.Values) (*CallbackResult, error) { + oauth2p, disc, err := p.get(ctx) + if err != nil { + return nil, err + } + res, err := oauth2p.ProcessCallback(ctx, callbackURL, nonce, query) + if err != nil { + return nil, err + } + + // OIDC has the notion of issuer URL, which will be more + // stable than our configuration ID. + res.Identifier.Namespace = uapi.OIDCNamespace + res.Identifier.Issuer = disc.Issuer + + return res, nil +} + +func (p *oidcIdentityProvider) get(ctx context.Context) (*oauth2IdentityProvider, *oidcDiscovery, error) { + p.mu.Lock() + defer p.mu.Unlock() + + now := time.Now() + if p.exp.Before(now) || p.disc == nil { + disc, err := oidcDiscover(ctx, p.cfg.OIDC.DiscoveryURL) + if err != nil { + if p.disc != nil { + // Prefers returning a stale entry. + return p.oauth2IdentityProvider, p.disc, nil + } + return nil, nil, err + } + + p.exp = now.Add(24 * time.Hour) + newProvider := *p.oauth2IdentityProvider + newProvider.authorizationURL = disc.AuthorizationEndpoint + newProvider.accessTokenURL = disc.TokenEndpoint + newProvider.userInfoURL = disc.UserinfoEndpoint + + p.oauth2IdentityProvider = &newProvider + p.disc = disc + } + + return p.oauth2IdentityProvider, p.disc, nil +} + +type oidcDiscovery struct { + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + UserinfoEndpoint string `json:"userinfo_endpoint"` + ScopesSupported []string `json:"scopes_supported"` + ClaimsSupported []string `json:"claims_supported"` +} + +func oidcDiscover(ctx context.Context, url string) (*oidcDiscovery, error) { + hreq, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, err + } + hreq.Header.Set("Accept", "application/jrd+json,application/json;q=0.9") + + hresp, err := http.DefaultClient.Do(hreq) + if err != nil { + return nil, err + } + defer hresp.Body.Close() + + var disc oidcDiscovery + if err := json.NewDecoder(hresp.Body).Decode(&disc); err != nil { + return nil, err + } + + if disc.ScopesSupported != nil { + if !stringSliceContains(disc.ScopesSupported, "openid") { + return nil, fmt.Errorf("scope 'openid' is missing in %q", url) + } + } + + if disc.ClaimsSupported != nil { + for _, claim := range []string{"iss", "sub"} { + if !stringSliceContains(disc.ClaimsSupported, claim) { + return nil, fmt.Errorf("claim %q is not supported in %q", claim, url) + } + } + } + + return &disc, nil +} + +func stringSliceContains(ss []string, s string) bool { + for _, s2 := range ss { + if s2 == s { + return true + } + } + return false +} diff --git a/clientapi/auth/sso/oidc_base.go b/clientapi/auth/sso/oidc_base.go deleted file mode 100644 index 4275e5733..000000000 --- a/clientapi/auth/sso/oidc_base.go +++ /dev/null @@ -1,272 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sso - -import ( - "context" - "encoding/json" - "fmt" - "io/ioutil" - "mime" - "net/http" - "net/url" - "strings" - "text/template" - - "github.com/matrix-org/dendrite/clientapi/jsonerror" - uapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/tidwall/gjson" -) - -type baseOIDCIdentityProvider struct { - AuthURL *urlTemplate - AccessTokenURL *urlTemplate - UserInfoURL *urlTemplate - UserInfoAccept string - UserInfoEmailPath string - UserInfoSuggestedUserIDPath string -} - -func (p *baseOIDCIdentityProvider) AuthorizationURL(ctx context.Context, req *IdentityProviderRequest) (string, error) { - u, err := p.AuthURL.Execute(map[string]interface{}{ - "Config": req.System, - "State": req.DendriteNonce, - "RedirectURI": req.CallbackURL, - }, url.Values{ - "client_id": []string{req.System.OIDC.ClientID}, - "response_type": []string{"code"}, - "redirect_uri": []string{req.CallbackURL}, - "state": []string{req.DendriteNonce}, - }) - if err != nil { - return "", err - } - return u.String(), nil -} - -func (p *baseOIDCIdentityProvider) ProcessCallback(ctx context.Context, req *IdentityProviderRequest, values url.Values) (*CallbackResult, error) { - state := values.Get("state") - if state == "" { - return nil, jsonerror.MissingArgument("state parameter missing") - } - if state != req.DendriteNonce { - return nil, jsonerror.InvalidArgumentValue("state parameter not matching nonce") - } - - if error := values.Get("error"); error != "" { - if euri := values.Get("error_uri"); euri != "" { - return &CallbackResult{RedirectURL: euri}, nil - } - - desc := values.Get("error_description") - if desc == "" { - desc = error - } - switch error { - case "unauthorized_client", "access_denied": - return nil, jsonerror.Forbidden("SSO said no: " + desc) - default: - return nil, fmt.Errorf("SSO failed: %v", error) - } - } - - code := values.Get("code") - if code == "" { - return nil, jsonerror.MissingArgument("code parameter missing") - } - - oidcAccessToken, err := p.getOIDCAccessToken(ctx, req, code) - if err != nil { - return nil, err - } - - id, userID, err := p.getUserInfo(ctx, req, oidcAccessToken) - if err != nil { - return nil, err - } - - return &CallbackResult{Identifier: id, SuggestedUserID: userID}, nil -} - -func (p *baseOIDCIdentityProvider) getOIDCAccessToken(ctx context.Context, req *IdentityProviderRequest, code string) (string, error) { - u, err := p.AccessTokenURL.Execute(nil, nil) - if err != nil { - return "", err - } - - body := url.Values{ - "grant_type": []string{"authorization_code"}, - "code": []string{code}, - "redirect_uri": []string{req.CallbackURL}, - "client_id": []string{req.System.OIDC.ClientID}, - } - - hreq, err := http.NewRequestWithContext(ctx, http.MethodPost, u.String(), strings.NewReader(body.Encode())) - if err != nil { - return "", err - } - hreq.Header.Set("Content-Type", "application/x-www-form-urlencoded") - hreq.Header.Set("Accept", "application/x-www-form-urlencoded") - - hresp, err := http.DefaultClient.Do(hreq) - if err != nil { - return "", err - } - defer hresp.Body.Close() - - ctype, _, err := mime.ParseMediaType(hresp.Header.Get("Content-Type")) - if err != nil { - return "", err - } - if ctype != "application/json" { - return "", fmt.Errorf("expected URL encoded response, got content type %q", ctype) - } - - var resp struct { - TokenType string `json:"token_type"` - AccessToken string `json:"access_token"` - - Error string `json:"error"` - ErrorDescription string `json:"error_description"` - ErrorURI string `json:"error_uri"` - } - if err := json.NewDecoder(hresp.Body).Decode(&resp); err != nil { - return "", err - } - - if resp.Error != "" { - desc := resp.ErrorDescription - if desc == "" { - desc = resp.Error - } - return "", fmt.Errorf("failed to retrieve OIDC access token: %s", desc) - } - - if strings.ToLower(resp.TokenType) != "bearer" { - return "", fmt.Errorf("expected bearer token, got type %q", resp.TokenType) - } - - return resp.AccessToken, nil -} - -func (p *baseOIDCIdentityProvider) getUserInfo(ctx context.Context, req *IdentityProviderRequest, oidcAccessToken string) (ssoUser *UserIdentifier, suggestedUserID string, _ error) { - u, err := p.UserInfoURL.Execute(map[string]interface{}{ - "Config": req.System, - }, nil) - if err != nil { - return nil, "", err - } - - hreq, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) - if err != nil { - return nil, "", err - } - hreq.Header.Set("Authorization", "token "+oidcAccessToken) - hreq.Header.Set("Accept", p.UserInfoAccept) - - hresp, err := http.DefaultClient.Do(hreq) - if err != nil { - return nil, "", err - } - defer hresp.Body.Close() - - ctype, _, err := mime.ParseMediaType(hresp.Header.Get("Content-Type")) - if err != nil { - return nil, "", err - } - - if ctype != "application/json" { - return nil, "", fmt.Errorf("got unknown content type %q for user info", ctype) - } - - body, err := ioutil.ReadAll(hresp.Body) - if err != nil { - return nil, "", err - } - - issRes := gjson.GetBytes(body, "iss") - if !issRes.Exists() { - return nil, "", fmt.Errorf("no iss in user info response body") - } - iss := issRes.String() - - subRes := gjson.GetBytes(body, "sub") - if !subRes.Exists() { - return nil, "", fmt.Errorf("no sub in user info response body") - } - sub := subRes.String() - - if iss == "" { - return nil, "", fmt.Errorf("no iss in user info") - } - - if sub == "" { - return nil, "", fmt.Errorf("no sub in user info") - } - - // This is optional. - userIDRes := gjson.GetBytes(body, p.UserInfoSuggestedUserIDPath) - suggestedUserID = userIDRes.String() - - return &UserIdentifier{ - Namespace: uapi.OIDCNamespace, - Issuer: iss, - Subject: sub, - }, suggestedUserID, nil -} - -type urlTemplate struct { - base *template.Template -} - -func parseURLTemplate(s string) (*urlTemplate, error) { - t, err := template.New("").Parse(s) - if err != nil { - return nil, err - } - return &urlTemplate{base: t}, nil -} - -func mustParseURLTemplate(s string) *urlTemplate { - t, err := parseURLTemplate(s) - if err != nil { - panic(err) - } - return t -} - -func (t *urlTemplate) Execute(params interface{}, defaultQuery url.Values) (*url.URL, error) { - var sb strings.Builder - err := t.base.Execute(&sb, params) - if err != nil { - return nil, err - } - - u, err := url.Parse(sb.String()) - if err != nil { - return nil, err - } - - if defaultQuery != nil { - q := u.Query() - for k, vs := range defaultQuery { - if q.Get(k) == "" { - q[k] = vs - } - } - u.RawQuery = q.Encode() - } - return u, nil -} diff --git a/clientapi/auth/sso/sso.go b/clientapi/auth/sso/sso.go index 37e3842a8..5923d8491 100644 --- a/clientapi/auth/sso/sso.go +++ b/clientapi/auth/sso/sso.go @@ -16,46 +16,78 @@ package sso import ( "context" + "fmt" + "net/http" "net/url" + "time" "github.com/matrix-org/dendrite/setup/config" uapi "github.com/matrix-org/dendrite/userapi/api" ) -type IdentityProvider interface { - DefaultBrand() string - - AuthorizationURL(context.Context, *IdentityProviderRequest) (string, error) - ProcessCallback(context.Context, *IdentityProviderRequest, url.Values) (*CallbackResult, error) +type Authenticator struct { + providers map[string]identityProvider } -type IdentityProviderRequest struct { - System *config.IdentityProvider - CallbackURL string - DendriteNonce string +func NewAuthenticator(cfg *config.SSO) (*Authenticator, error) { + hc := &http.Client{ + Timeout: 10 * time.Second, + Transport: &http.Transport{ + DisableKeepAlives: true, + Proxy: http.ProxyFromEnvironment, + }, + } + + a := &Authenticator{ + providers: make(map[string]identityProvider, len(cfg.Providers)), + } + for _, pcfg := range cfg.Providers { + typ := pcfg.Type + if typ == "" { + typ = config.IdentityProviderType(pcfg.ID) + } + + switch typ { + case config.SSOTypeOIDC: + a.providers[pcfg.ID] = newOIDCIdentityProvider(&pcfg, hc) + case config.SSOTypeGitHub: + a.providers[pcfg.ID] = newGitHubIdentityProvider(&pcfg, hc) + default: + return nil, fmt.Errorf("unknown SSO provider type: %s", typ) + } + } + + return a, nil +} + +func (auth *Authenticator) AuthorizationURL(ctx context.Context, providerID, callbackURL, nonce string) (string, error) { + p := auth.providers[providerID] + if p == nil { + return "", fmt.Errorf("unknown identity provider: %s", providerID) + } + return p.AuthorizationURL(ctx, callbackURL, nonce) +} + +func (auth *Authenticator) ProcessCallback(ctx context.Context, providerID, callbackURL, nonce string, query url.Values) (*CallbackResult, error) { + p := auth.providers[providerID] + if p == nil { + return nil, fmt.Errorf("unknown identity provider: %s", providerID) + } + return p.ProcessCallback(ctx, callbackURL, nonce, query) +} + +type identityProvider interface { + AuthorizationURL(ctx context.Context, callbackURL, nonce string) (string, error) + ProcessCallback(ctx context.Context, callbackURL, nonce string, query url.Values) (*CallbackResult, error) } type CallbackResult struct { RedirectURL string Identifier *UserIdentifier + DisplayName string SuggestedUserID string } -type IdentityProviderType string - -const ( - TypeGitHub IdentityProviderType = config.SSOBrandGitHub -) - -func GetIdentityProvider(t IdentityProviderType) IdentityProvider { - switch t { - case TypeGitHub: - return GitHubIdentityProvider - default: - return nil - } -} - type UserIdentifier struct { Namespace uapi.SSOIssuerNamespace Issuer, Subject string diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index ad4aca29c..f0600ceb1 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -20,7 +20,6 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/sso" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/setup/config" @@ -46,10 +45,10 @@ type stage struct { } type identityProvider struct { - ID string `json:"id"` - Name string `json:"name"` - Brand string `json:"brand,omitempty"` - Icon string `json:"icon,omitempty"` + ID string `json:"id"` + Name string `json:"name"` + Brand config.SSOBrand `json:"brand,omitempty"` + Icon string `json:"icon,omitempty"` } func passwordLogin() []stage { @@ -69,11 +68,14 @@ func ssoLogin(cfg *config.ClientAPI) []stage { if brand == "" { typ := idp.Type if typ == "" { - typ = idp.ID + typ = config.IdentityProviderType(idp.ID) } - idpType := sso.GetIdentityProvider(sso.IdentityProviderType(typ)) - if idpType != nil { - brand = idpType.DefaultBrand() + switch typ { + case config.SSOTypeGitHub: + brand = config.SSOBrandGitHub + + default: + brand = config.SSOBrand(idp.ID) } } idps = append(idps, identityProvider{ diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index a833e5217..ebc7003d7 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -23,6 +23,7 @@ import ( appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/auth" + "github.com/matrix-org/dendrite/clientapi/auth/sso" clientutil "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" @@ -67,6 +68,15 @@ func Setup( rateLimits := httputil.NewRateLimits(&cfg.RateLimiting) userInteractiveAuth := auth.NewUserInteractive(userAPI, cfg) + var ssoAuthenticator *sso.Authenticator + if cfg.Login.SSO.Enabled { + var err error + ssoAuthenticator, err = sso.NewAuthenticator(&cfg.Login.SSO) + if err != nil { + logrus.WithError(err).Fatal("failed to create SSO authenticator") + } + } + unstableFeatures := map[string]bool{ "org.matrix.e2e_cross_signing": true, } @@ -565,20 +575,20 @@ func Setup( v3mux.Handle("/login/sso/callback", httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { - return SSOCallback(req, userAPI, cfg) + return SSOCallback(req, userAPI, ssoAuthenticator, cfg.Matrix.ServerName) }), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/login/sso/redirect", httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { - return SSORedirect(req, "", cfg) + return SSORedirect(req, "", ssoAuthenticator) }), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/login/sso/redirect/{idpID}", httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { vars := mux.Vars(req) - return SSORedirect(req, vars["idpID"], cfg) + return SSORedirect(req, vars["idpID"], ssoAuthenticator) }), ).Methods(http.MethodGet, http.MethodOptions) diff --git a/clientapi/routing/sso.go b/clientapi/routing/sso.go index d1b224fbe..6e2e1d967 100644 --- a/clientapi/routing/sso.go +++ b/clientapi/routing/sso.go @@ -25,7 +25,6 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/sso" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/userutil" - "github.com/matrix-org/dendrite/setup/config" uapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -36,11 +35,11 @@ import ( func SSORedirect( req *http.Request, idpID string, - cfg *config.ClientAPI, + auth *sso.Authenticator, ) util.JSONResponse { - if !cfg.Login.SSO.Enabled { + if auth == nil { return util.JSONResponse{ - Code: http.StatusNotImplemented, + Code: http.StatusNotFound, JSON: jsonerror.NotFound("authentication method disabled"), } } @@ -60,28 +59,9 @@ func SSORedirect( } } - if idpID == "" { - // Check configuration if the client didn't provide an ID. - idpID = cfg.Login.SSO.DefaultProviderID - } - if idpID == "" && len(cfg.Login.SSO.Providers) > 0 { - // Fall back to the first provider. If there are no providers, getProvider("") will fail. - idpID = cfg.Login.SSO.Providers[0].ID - } - idpCfg, idpType := getProvider(cfg, idpID) - if idpType == nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue("unknown identity provider"), - } - } - - idpReq := &sso.IdentityProviderRequest{ - System: idpCfg, - CallbackURL: req.URL.ResolveReference(&url.URL{Path: "../callback", RawQuery: url.Values{"provider": []string{idpID}}.Encode()}).String(), - DendriteNonce: formatNonce(redirectURL), - } - u, err := idpType.AuthorizationURL(req.Context(), idpReq) + callbackURL := req.URL.ResolveReference(&url.URL{Path: "../callback", RawQuery: url.Values{"provider": []string{idpID}}.Encode()}) + nonce := formatNonce(redirectURL) + u, err := auth.AuthorizationURL(req.Context(), idpID, callbackURL.String(), nonce) if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, @@ -92,7 +72,7 @@ func SSORedirect( resp := util.RedirectResponse(u) resp.Headers["Set-Cookie"] = (&http.Cookie{ Name: "oidc_nonce", - Value: idpReq.DendriteNonce, + Value: nonce, Expires: time.Now().Add(10 * time.Minute), Secure: true, SameSite: http.SameSiteStrictMode, @@ -105,8 +85,16 @@ func SSORedirect( func SSOCallback( req *http.Request, userAPI userAPIForSSO, - cfg *config.ClientAPI, + auth *sso.Authenticator, + serverName gomatrixserverlib.ServerName, ) util.JSONResponse { + if auth == nil { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound("authentication method disabled"), + } + } + ctx := req.Context() query := req.URL.Query() @@ -117,13 +105,6 @@ func SSOCallback( JSON: jsonerror.MissingArgument("provider parameter missing"), } } - idpCfg, idpType := getProvider(cfg, idpID) - if idpType == nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue("unknown identity provider"), - } - } nonce, err := req.Cookie("oidc_nonce") if err != nil { @@ -140,19 +121,15 @@ func SSOCallback( } } - idpReq := &sso.IdentityProviderRequest{ - System: idpCfg, - CallbackURL: (&url.URL{ - Scheme: req.URL.Scheme, - Host: req.URL.Host, - Path: req.URL.Path, - RawQuery: url.Values{ - "provider": []string{idpID}, - }.Encode(), - }).String(), - DendriteNonce: nonce.Value, + callbackURL := &url.URL{ + Scheme: req.URL.Scheme, + Host: req.URL.Host, + Path: req.URL.Path, + RawQuery: url.Values{ + "provider": []string{idpID}, + }.Encode(), } - result, err := idpType.ProcessCallback(ctx, idpReq, query) + result, err := auth.ProcessCallback(ctx, idpID, callbackURL.String(), nonce.Value, query) if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, @@ -165,7 +142,7 @@ func SSOCallback( return util.RedirectResponse(result.RedirectURL) } - localpart, err := verifySSOUserIdentifier(ctx, userAPI, result.Identifier, cfg.Matrix.ServerName) + localpart, err := verifySSOUserIdentifier(ctx, userAPI, result.Identifier, serverName) if err != nil { util.GetLogger(ctx).WithError(err).WithField("identifier", result.Identifier).Error("failed to find user") return util.JSONResponse{ @@ -184,7 +161,7 @@ func SSOCallback( } } - token, err := createLoginToken(ctx, userAPI, userutil.MakeUserID(localpart, cfg.Matrix.ServerName)) + token, err := createLoginToken(ctx, userAPI, userutil.MakeUserID(localpart, serverName)) if err != nil { util.GetLogger(ctx).WithError(err).Errorf("PerformLoginTokenCreation failed") return jsonerror.InternalServerError() @@ -210,23 +187,6 @@ type userAPIForSSO interface { QueryLocalpartForSSO(ctx context.Context, req *uapi.QueryLocalpartForSSORequest, res *uapi.QueryLocalpartForSSOResponse) error } -// getProvider looks up the given provider in the -// configuration. Returns nil if it wasn't found or was of unknown -// type. -func getProvider(cfg *config.ClientAPI, id string) (*config.IdentityProvider, sso.IdentityProvider) { - for _, idp := range cfg.Login.SSO.Providers { - if idp.ID == id { - switch sso.IdentityProviderType(id) { - case sso.TypeGitHub: - return &idp, sso.GitHubIdentityProvider - default: - return nil, nil - } - } - } - return nil, nil -} - // formatNonce creates a random nonce that also contains the URL. func formatNonce(redirectURL string) string { return util.RandomString(16) + "." + base64.RawURLEncoding.EncodeToString([]byte(redirectURL)) diff --git a/setup/config/config_clientapi.go b/setup/config/config_clientapi.go index 4f529dd8e..2106c341c 100644 --- a/setup/config/config_clientapi.go +++ b/setup/config/config_clientapi.go @@ -161,7 +161,7 @@ type IdentityProvider struct { // Brand is a hint on how to display the IdP to the user. If this is empty, a default // based on the type is used. - Brand string `yaml:"brand"` + Brand SSOBrand `yaml:"brand"` // Icon is an MXC URI describing how to display the IdP to the user. Prefer using `brand`. Icon string `yaml:"icon"` @@ -169,18 +169,19 @@ type IdentityProvider struct { // Type describes how this provider is implemented. It must match "github". If this is // empty, the ID is used, which means there is a weak expectation that ID is also a // valid type, unless you have a complicated setup. - Type string `yaml:"type"` + Type IdentityProviderType `yaml:"type"` // OIDC contains settings for providers based on OpenID Connect (OAuth 2). OIDC struct { ClientID string `yaml:"client_id"` ClientSecret string `yaml:"client_secret"` + DiscoveryURL string `yaml:"discovery_url"` } `yaml:"oidc"` } func (idp *IdentityProvider) Verify(configErrs *ConfigErrors) { checkNotEmpty(configErrs, "client_api.sso.providers.id", idp.ID) - if !checkIdentityProviderBrand(idp.ID) { + if !checkIdentityProviderBrand(SSOBrand(idp.ID)) { configErrs.Add(fmt.Sprintf("unrecognized ID config key %q: %s", "client_api.sso.providers", idp.ID)) } checkNotEmpty(configErrs, "client_api.sso.providers.name", idp.Name) @@ -192,11 +193,16 @@ func (idp *IdentityProvider) Verify(configErrs *ConfigErrors) { } typ := idp.Type if idp.Type == "" { - typ = idp.ID + typ = IdentityProviderType(idp.ID) } switch typ { - case "github": + case SSOTypeOIDC: + checkNotEmpty(configErrs, "client_api.sso.providers.oidc.client_id", idp.OIDC.ClientID) + checkNotEmpty(configErrs, "client_api.sso.providers.oidc.client_secret", idp.OIDC.ClientSecret) + checkNotEmpty(configErrs, "client_api.sso.providers.oidc.discovery_url", idp.OIDC.DiscoveryURL) + + case SSOTypeGitHub: checkNotEmpty(configErrs, "client_api.sso.providers.oidc.client_id", idp.OIDC.ClientID) checkNotEmpty(configErrs, "client_api.sso.providers.oidc.client_secret", idp.OIDC.ClientSecret) @@ -206,7 +212,7 @@ func (idp *IdentityProvider) Verify(configErrs *ConfigErrors) { } // See https://github.com/matrix-org/matrix-doc/blob/old_master/informal/idp-brands.md. -func checkIdentityProviderBrand(s string) bool { +func checkIdentityProviderBrand(s SSOBrand) bool { switch s { case SSOBrandApple, SSOBrandFacebook, SSOBrandGitHub, SSOBrandGitLab, SSOBrandGoogle, SSOBrandTwitter: return true @@ -215,13 +221,22 @@ func checkIdentityProviderBrand(s string) bool { } } +type SSOBrand string + const ( - SSOBrandApple = "apple" - SSOBrandFacebook = "facebook" - SSOBrandGitHub = "github" - SSOBrandGitLab = "gitlab" - SSOBrandGoogle = "google" - SSOBrandTwitter = "twitter" + SSOBrandApple SSOBrand = "apple" + SSOBrandFacebook SSOBrand = "facebook" + SSOBrandGitHub SSOBrand = "github" + SSOBrandGitLab SSOBrand = "gitlab" + SSOBrandGoogle SSOBrand = "google" + SSOBrandTwitter SSOBrand = "twitter" +) + +type IdentityProviderType string + +const ( + SSOTypeOIDC IdentityProviderType = "oidc" + SSOTypeGitHub IdentityProviderType = "github" ) type TURN struct { diff --git a/userapi/api/api_sso.go b/userapi/api/api_sso.go index 56a3686b9..58204d8cb 100644 --- a/userapi/api/api_sso.go +++ b/userapi/api/api_sso.go @@ -47,6 +47,10 @@ type SSOIssuerNamespace string const ( UnknownNamespace SSOIssuerNamespace = "" + // SSOIDNamespace indicates the issuer is an ID key matching a + // Dendrite SSO provider configuration. + SSOIDNamespace SSOIssuerNamespace = "sso" + // OIDCNamespace indicates the issuer is a full URL, as defined in // https://openid.net/specs/openid-connect-core-1_0.html#Terminology. OIDCNamespace SSOIssuerNamespace = "oidc"