mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-06 13:43:09 -06:00
263 lines
6.8 KiB
Go
263 lines
6.8 KiB
Go
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package sso
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"mime"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"text/template"
|
|
|
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
|
"github.com/matrix-org/dendrite/clientapi/userutil"
|
|
"github.com/tidwall/gjson"
|
|
)
|
|
|
|
type baseOIDCIdentityProvider struct {
|
|
AuthURL *urlTemplate
|
|
AccessTokenURL *urlTemplate
|
|
UserInfoURL *urlTemplate
|
|
UserInfoAccept string
|
|
UserInfoEmailPath string
|
|
UserInfoSuggestedUserIDPath string
|
|
}
|
|
|
|
func (p *baseOIDCIdentityProvider) AuthorizationURL(ctx context.Context, req *IdentityProviderRequest) (string, error) {
|
|
u, err := p.AuthURL.Execute(map[string]interface{}{
|
|
"Config": req.System,
|
|
"State": req.DendriteNonce,
|
|
"RedirectURI": req.CallbackURL,
|
|
}, url.Values{
|
|
"client_id": []string{req.System.OIDC.ClientID},
|
|
"response_type": []string{"code"},
|
|
"redirect_uri": []string{req.CallbackURL},
|
|
"state": []string{req.DendriteNonce},
|
|
})
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return u.String(), nil
|
|
}
|
|
|
|
func (p *baseOIDCIdentityProvider) ProcessCallback(ctx context.Context, req *IdentityProviderRequest, values url.Values) (*CallbackResult, error) {
|
|
state := values.Get("state")
|
|
if state == "" {
|
|
return nil, jsonerror.MissingArgument("state parameter missing")
|
|
}
|
|
if state != req.DendriteNonce {
|
|
return nil, jsonerror.InvalidArgumentValue("state parameter not matching nonce")
|
|
}
|
|
|
|
if error := values.Get("error"); error != "" {
|
|
if euri := values.Get("error_uri"); euri != "" {
|
|
return &CallbackResult{RedirectURL: euri}, nil
|
|
}
|
|
|
|
desc := values.Get("error_description")
|
|
if desc == "" {
|
|
desc = error
|
|
}
|
|
switch error {
|
|
case "unauthorized_client", "access_denied":
|
|
return nil, jsonerror.Forbidden("SSO said no: " + desc)
|
|
default:
|
|
return nil, fmt.Errorf("SSO failed: %v", error)
|
|
}
|
|
}
|
|
|
|
code := values.Get("code")
|
|
if code == "" {
|
|
return nil, jsonerror.MissingArgument("code parameter missing")
|
|
}
|
|
|
|
oidcAccessToken, err := p.getOIDCAccessToken(ctx, req, code)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
id, userID, err := p.getUserInfo(ctx, req, oidcAccessToken)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &CallbackResult{Identifier: id, SuggestedUserID: userID}, nil
|
|
}
|
|
|
|
func (p *baseOIDCIdentityProvider) getOIDCAccessToken(ctx context.Context, req *IdentityProviderRequest, code string) (string, error) {
|
|
u, err := p.AccessTokenURL.Execute(nil, nil)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
body := url.Values{
|
|
"grant_type": []string{"authorization_code"},
|
|
"code": []string{code},
|
|
"redirect_uri": []string{req.CallbackURL},
|
|
"client_id": []string{req.System.OIDC.ClientID},
|
|
}
|
|
|
|
hreq, err := http.NewRequestWithContext(ctx, http.MethodPost, u.String(), strings.NewReader(body.Encode()))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
hreq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
hreq.Header.Set("Accept", "application/x-www-form-urlencoded")
|
|
|
|
hresp, err := http.DefaultClient.Do(hreq)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer hresp.Body.Close()
|
|
|
|
ctype, _, err := mime.ParseMediaType(hresp.Header.Get("Content-Type"))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if ctype != "application/json" {
|
|
return "", fmt.Errorf("expected URL encoded response, got content type %q", ctype)
|
|
}
|
|
|
|
var resp struct {
|
|
TokenType string `json:"token_type"`
|
|
AccessToken string `json:"access_token"`
|
|
|
|
Error string `json:"error"`
|
|
ErrorDescription string `json:"error_description"`
|
|
ErrorURI string `json:"error_uri"`
|
|
}
|
|
if err := json.NewDecoder(hresp.Body).Decode(&resp); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if resp.Error != "" {
|
|
desc := resp.ErrorDescription
|
|
if desc == "" {
|
|
desc = resp.Error
|
|
}
|
|
return "", fmt.Errorf("failed to retrieve OIDC access token: %s", desc)
|
|
}
|
|
|
|
if strings.ToLower(resp.TokenType) != "bearer" {
|
|
return "", fmt.Errorf("expected bearer token, got type %q", resp.TokenType)
|
|
}
|
|
|
|
return resp.AccessToken, nil
|
|
}
|
|
|
|
func (p *baseOIDCIdentityProvider) getUserInfo(ctx context.Context, req *IdentityProviderRequest, oidcAccessToken string) (*userutil.ThirdPartyIdentifier, string, error) {
|
|
u, err := p.UserInfoURL.Execute(map[string]interface{}{
|
|
"Config": req.System,
|
|
}, nil)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
hreq, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
hreq.Header.Set("Authorization", "token "+oidcAccessToken)
|
|
hreq.Header.Set("Accept", p.UserInfoAccept)
|
|
|
|
hresp, err := http.DefaultClient.Do(hreq)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
defer hresp.Body.Close()
|
|
|
|
ctype, _, err := mime.ParseMediaType(hresp.Header.Get("Content-Type"))
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
var email string
|
|
var suggestedUserID string
|
|
switch ctype {
|
|
case "application/json":
|
|
body, err := ioutil.ReadAll(hresp.Body)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
emailRes := gjson.GetBytes(body, p.UserInfoEmailPath)
|
|
if !emailRes.Exists() {
|
|
return nil, "", fmt.Errorf("no email in user info response body")
|
|
}
|
|
email = emailRes.String()
|
|
|
|
// This is optional.
|
|
userIDRes := gjson.GetBytes(body, p.UserInfoSuggestedUserIDPath)
|
|
suggestedUserID = userIDRes.String()
|
|
|
|
default:
|
|
return nil, "", fmt.Errorf("got unknown content type %q for user info", ctype)
|
|
}
|
|
|
|
if email == "" {
|
|
return nil, "", fmt.Errorf("no email address in user info")
|
|
}
|
|
|
|
return &userutil.ThirdPartyIdentifier{Medium: "email", Address: email}, suggestedUserID, nil
|
|
}
|
|
|
|
type urlTemplate struct {
|
|
base *template.Template
|
|
}
|
|
|
|
func parseURLTemplate(s string) (*urlTemplate, error) {
|
|
t, err := template.New("").Parse(s)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &urlTemplate{base: t}, nil
|
|
}
|
|
|
|
func mustParseURLTemplate(s string) *urlTemplate {
|
|
t, err := parseURLTemplate(s)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return t
|
|
}
|
|
|
|
func (t *urlTemplate) Execute(params interface{}, defaultQuery url.Values) (*url.URL, error) {
|
|
var sb strings.Builder
|
|
err := t.base.Execute(&sb, params)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
u, err := url.Parse(sb.String())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if defaultQuery != nil {
|
|
q := u.Query()
|
|
for k, vs := range defaultQuery {
|
|
if q.Get(k) == "" {
|
|
q[k] = vs
|
|
}
|
|
}
|
|
u.RawQuery = q.Encode()
|
|
}
|
|
return u, nil
|
|
}
|