mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-06 13:43:09 -06:00
Support for m.login.sso.
This is forked from @anandv96's #1374. Closes #1297.
This commit is contained in:
parent
1d6501ae30
commit
43989aa017
|
|
@ -43,6 +43,7 @@ type AccountDatabase interface {
|
||||||
// Look up the account matching the given localpart.
|
// Look up the account matching the given localpart.
|
||||||
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
|
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
|
||||||
GetAccountByPassword(ctx context.Context, localpart, password string) (*api.Account, error)
|
GetAccountByPassword(ctx context.Context, localpart, password string) (*api.Account, error)
|
||||||
|
GetLocalpartForThreePID(ctx context.Context, address, medium string) (string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerifyUserFromRequest authenticates the HTTP request,
|
// VerifyUserFromRequest authenticates the HTTP request,
|
||||||
|
|
|
||||||
|
|
@ -10,5 +10,6 @@ const (
|
||||||
LoginTypeSharedSecret = "org.matrix.login.shared_secret"
|
LoginTypeSharedSecret = "org.matrix.login.shared_secret"
|
||||||
LoginTypeRecaptcha = "m.login.recaptcha"
|
LoginTypeRecaptcha = "m.login.recaptcha"
|
||||||
LoginTypeApplicationService = "m.login.application_service"
|
LoginTypeApplicationService = "m.login.application_service"
|
||||||
|
LoginTypeSSO = "m.login.sso"
|
||||||
LoginTypeToken = "m.login.token"
|
LoginTypeToken = "m.login.token"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
37
clientapi/auth/sso/github.go
Normal file
37
clientapi/auth/sso/github.go
Normal file
|
|
@ -0,0 +1,37 @@
|
||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package sso
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GitHubIdentityProvider is a GitHub-flavored identity provider.
|
||||||
|
var GitHubIdentityProvider IdentityProvider = githubIdentityProvider{
|
||||||
|
baseOIDCIdentityProvider: &baseOIDCIdentityProvider{
|
||||||
|
AuthURL: mustParseURLTemplate("https://github.com/login/oauth/authorize?scope=user:email"),
|
||||||
|
AccessTokenURL: mustParseURLTemplate("https://github.com/login/oauth/access_token"),
|
||||||
|
UserInfoURL: mustParseURLTemplate("https://api.github.com/user"),
|
||||||
|
UserInfoAccept: "application/vnd.github.v3+json",
|
||||||
|
UserInfoEmailPath: "email",
|
||||||
|
UserInfoSuggestedUserIDPath: "login",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
type githubIdentityProvider struct {
|
||||||
|
*baseOIDCIdentityProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
func (githubIdentityProvider) DefaultBrand() string { return config.SSOBrandGitHub }
|
||||||
262
clientapi/auth/sso/oidc_base.go
Normal file
262
clientapi/auth/sso/oidc_base.go
Normal file
|
|
@ -0,0 +1,262 @@
|
||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package sso
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"mime"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"text/template"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
type baseOIDCIdentityProvider struct {
|
||||||
|
AuthURL *urlTemplate
|
||||||
|
AccessTokenURL *urlTemplate
|
||||||
|
UserInfoURL *urlTemplate
|
||||||
|
UserInfoAccept string
|
||||||
|
UserInfoEmailPath string
|
||||||
|
UserInfoSuggestedUserIDPath string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *baseOIDCIdentityProvider) AuthorizationURL(ctx context.Context, req *IdentityProviderRequest) (string, error) {
|
||||||
|
u, err := p.AuthURL.Execute(map[string]interface{}{
|
||||||
|
"Config": req.System,
|
||||||
|
"State": req.DendriteNonce,
|
||||||
|
"RedirectURI": req.CallbackURL,
|
||||||
|
}, url.Values{
|
||||||
|
"client_id": []string{req.System.OIDC.ClientID},
|
||||||
|
"response_type": []string{"code"},
|
||||||
|
"redirect_uri": []string{req.CallbackURL},
|
||||||
|
"state": []string{req.DendriteNonce},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return u.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *baseOIDCIdentityProvider) ProcessCallback(ctx context.Context, req *IdentityProviderRequest, values url.Values) (*CallbackResult, error) {
|
||||||
|
state := values.Get("state")
|
||||||
|
if state == "" {
|
||||||
|
return nil, jsonerror.MissingArgument("state parameter missing")
|
||||||
|
}
|
||||||
|
if state != req.DendriteNonce {
|
||||||
|
return nil, jsonerror.InvalidArgumentValue("state parameter not matching nonce")
|
||||||
|
}
|
||||||
|
|
||||||
|
if error := values.Get("error"); error != "" {
|
||||||
|
if euri := values.Get("error_uri"); euri != "" {
|
||||||
|
return &CallbackResult{RedirectURL: euri}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
desc := values.Get("error_description")
|
||||||
|
if desc == "" {
|
||||||
|
desc = error
|
||||||
|
}
|
||||||
|
switch error {
|
||||||
|
case "unauthorized_client", "access_denied":
|
||||||
|
return nil, jsonerror.Forbidden("SSO said no: " + desc)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("SSO failed: %v", error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
code := values.Get("code")
|
||||||
|
if code == "" {
|
||||||
|
return nil, jsonerror.MissingArgument("code parameter missing")
|
||||||
|
}
|
||||||
|
|
||||||
|
oidcAccessToken, err := p.getOIDCAccessToken(ctx, req, code)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
id, userID, err := p.getUserInfo(ctx, req, oidcAccessToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &CallbackResult{Identifier: id, SuggestedUserID: userID}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *baseOIDCIdentityProvider) getOIDCAccessToken(ctx context.Context, req *IdentityProviderRequest, code string) (string, error) {
|
||||||
|
u, err := p.AccessTokenURL.Execute(nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
body := url.Values{
|
||||||
|
"grant_type": []string{"authorization_code"},
|
||||||
|
"code": []string{code},
|
||||||
|
"redirect_uri": []string{req.CallbackURL},
|
||||||
|
"client_id": []string{req.System.OIDC.ClientID},
|
||||||
|
}
|
||||||
|
|
||||||
|
hreq, err := http.NewRequestWithContext(ctx, http.MethodPost, u.String(), strings.NewReader(body.Encode()))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
hreq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
hreq.Header.Set("Accept", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
hresp, err := http.DefaultClient.Do(hreq)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer hresp.Body.Close()
|
||||||
|
|
||||||
|
ctype, _, err := mime.ParseMediaType(hresp.Header.Get("Content-Type"))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if ctype != "application/json" {
|
||||||
|
return "", fmt.Errorf("expected URL encoded response, got content type %q", ctype)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp struct {
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
|
||||||
|
Error string `json:"error"`
|
||||||
|
ErrorDescription string `json:"error_description"`
|
||||||
|
ErrorURI string `json:"error_uri"`
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(hresp.Body).Decode(&resp); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Error != "" {
|
||||||
|
desc := resp.ErrorDescription
|
||||||
|
if desc == "" {
|
||||||
|
desc = resp.Error
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("failed to retrieve OIDC access token: %s", desc)
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.ToLower(resp.TokenType) != "bearer" {
|
||||||
|
return "", fmt.Errorf("expected bearer token, got type %q", resp.TokenType)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp.AccessToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *baseOIDCIdentityProvider) getUserInfo(ctx context.Context, req *IdentityProviderRequest, oidcAccessToken string) (*userutil.ThirdPartyIdentifier, string, error) {
|
||||||
|
u, err := p.UserInfoURL.Execute(map[string]interface{}{
|
||||||
|
"Config": req.System,
|
||||||
|
}, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
hreq, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
hreq.Header.Set("Authorization", "token "+oidcAccessToken)
|
||||||
|
hreq.Header.Set("Accept", p.UserInfoAccept)
|
||||||
|
|
||||||
|
hresp, err := http.DefaultClient.Do(hreq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
defer hresp.Body.Close()
|
||||||
|
|
||||||
|
ctype, _, err := mime.ParseMediaType(hresp.Header.Get("Content-Type"))
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
var email string
|
||||||
|
var suggestedUserID string
|
||||||
|
switch ctype {
|
||||||
|
case "application/json":
|
||||||
|
body, err := ioutil.ReadAll(hresp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
emailRes := gjson.GetBytes(body, p.UserInfoEmailPath)
|
||||||
|
if !emailRes.Exists() {
|
||||||
|
return nil, "", fmt.Errorf("no email in user info response body")
|
||||||
|
}
|
||||||
|
email = emailRes.String()
|
||||||
|
|
||||||
|
// This is optional.
|
||||||
|
userIDRes := gjson.GetBytes(body, p.UserInfoSuggestedUserIDPath)
|
||||||
|
suggestedUserID = userIDRes.String()
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, "", fmt.Errorf("got unknown content type %q for user info", ctype)
|
||||||
|
}
|
||||||
|
|
||||||
|
if email == "" {
|
||||||
|
return nil, "", fmt.Errorf("no email address in user info")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &userutil.ThirdPartyIdentifier{Medium: "email", Address: email}, suggestedUserID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type urlTemplate struct {
|
||||||
|
base *template.Template
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseURLTemplate(s string) (*urlTemplate, error) {
|
||||||
|
t, err := template.New("").Parse(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &urlTemplate{base: t}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustParseURLTemplate(s string) *urlTemplate {
|
||||||
|
t, err := parseURLTemplate(s)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *urlTemplate) Execute(params interface{}, defaultQuery url.Values) (*url.URL, error) {
|
||||||
|
var sb strings.Builder
|
||||||
|
err := t.base.Execute(&sb, params)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := url.Parse(sb.String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if defaultQuery != nil {
|
||||||
|
q := u.Query()
|
||||||
|
for k, vs := range defaultQuery {
|
||||||
|
if q.Get(k) == "" {
|
||||||
|
q[k] = vs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
}
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
57
clientapi/auth/sso/sso.go
Normal file
57
clientapi/auth/sso/sso.go
Normal file
|
|
@ -0,0 +1,57 @@
|
||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package sso
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
type IdentityProvider interface {
|
||||||
|
DefaultBrand() string
|
||||||
|
|
||||||
|
AuthorizationURL(context.Context, *IdentityProviderRequest) (string, error)
|
||||||
|
ProcessCallback(context.Context, *IdentityProviderRequest, url.Values) (*CallbackResult, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type IdentityProviderRequest struct {
|
||||||
|
System *config.IdentityProvider
|
||||||
|
CallbackURL string
|
||||||
|
DendriteNonce string
|
||||||
|
}
|
||||||
|
|
||||||
|
type CallbackResult struct {
|
||||||
|
RedirectURL string
|
||||||
|
Identifier *userutil.ThirdPartyIdentifier
|
||||||
|
SuggestedUserID string
|
||||||
|
}
|
||||||
|
|
||||||
|
type IdentityProviderType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
TypeGitHub IdentityProviderType = config.SSOBrandGitHub
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetIdentityProvider(t IdentityProviderType) IdentityProvider {
|
||||||
|
switch t {
|
||||||
|
case TypeGitHub:
|
||||||
|
return GitHubIdentityProvider
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -19,6 +19,8 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth"
|
"github.com/matrix-org/dendrite/clientapi/auth"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/sso"
|
||||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
|
@ -35,20 +37,54 @@ type loginResponse struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type flows struct {
|
type flows struct {
|
||||||
Flows []flow `json:"flows"`
|
Flows []stage `json:"flows"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type flow struct {
|
type stage struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
|
IdentityProviders []identityProvider `json:"identity_providers,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func passwordLogin() flows {
|
type identityProvider struct {
|
||||||
f := flows{}
|
ID string `json:"id"`
|
||||||
s := flow{
|
Name string `json:"name"`
|
||||||
Type: "m.login.password",
|
Brand string `json:"brand,omitempty"`
|
||||||
|
Icon string `json:"icon,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func passwordLogin() []stage {
|
||||||
|
return []stage{
|
||||||
|
{Type: authtypes.LoginTypePassword},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ssoLogin(cfg *config.ClientAPI) []stage {
|
||||||
|
var idps []identityProvider
|
||||||
|
for _, idp := range cfg.Login.SSO.Providers {
|
||||||
|
brand := idp.Brand
|
||||||
|
if brand == "" {
|
||||||
|
typ := idp.Type
|
||||||
|
if typ == "" {
|
||||||
|
typ = idp.ID
|
||||||
|
}
|
||||||
|
idpType := sso.GetIdentityProvider(sso.IdentityProviderType(typ))
|
||||||
|
if idpType != nil {
|
||||||
|
brand = idpType.DefaultBrand()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
idps = append(idps, identityProvider{
|
||||||
|
ID: idp.ID,
|
||||||
|
Name: idp.Name,
|
||||||
|
Brand: brand,
|
||||||
|
Icon: idp.Icon,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return []stage{
|
||||||
|
{
|
||||||
|
Type: authtypes.LoginTypeSSO,
|
||||||
|
IdentityProviders: idps,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
f.Flows = append(f.Flows, s)
|
|
||||||
return f
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Login implements GET and POST /login
|
// Login implements GET and POST /login
|
||||||
|
|
@ -57,10 +93,13 @@ func Login(
|
||||||
cfg *config.ClientAPI,
|
cfg *config.ClientAPI,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
if req.Method == http.MethodGet {
|
if req.Method == http.MethodGet {
|
||||||
// TODO: support other forms of login other than password, depending on config options
|
allFlows := passwordLogin()
|
||||||
|
if cfg.Login.SSO.Enabled {
|
||||||
|
allFlows = append(allFlows, ssoLogin(cfg)...)
|
||||||
|
}
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
JSON: passwordLogin(),
|
JSON: flows{Flows: allFlows},
|
||||||
}
|
}
|
||||||
} else if req.Method == http.MethodPost {
|
} else if req.Method == http.MethodPost {
|
||||||
login, cleanup, authErr := auth.LoginFromJSONReader(req.Context(), req.Body, userAPI, userAPI, cfg)
|
login, cleanup, authErr := auth.LoginFromJSONReader(req.Context(), req.Body, userAPI, userAPI, cfg)
|
||||||
|
|
@ -72,6 +111,7 @@ func Login(
|
||||||
cleanup(req.Context(), &authErr2)
|
cleanup(req.Context(), &authErr2)
|
||||||
return authErr2
|
return authErr2
|
||||||
}
|
}
|
||||||
|
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusMethodNotAllowed,
|
Code: http.StatusMethodNotAllowed,
|
||||||
JSON: jsonerror.NotFound("Bad method"),
|
JSON: jsonerror.NotFound("Bad method"),
|
||||||
|
|
|
||||||
|
|
@ -563,6 +563,25 @@ func Setup(
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
|
v3mux.Handle("/login/sso/callback",
|
||||||
|
httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse {
|
||||||
|
return SSOCallback(req, userAPI, cfg)
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
|
v3mux.Handle("/login/sso/redirect",
|
||||||
|
httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse {
|
||||||
|
return SSORedirect(req, "", cfg)
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
|
v3mux.Handle("/login/sso/redirect/{idpID}",
|
||||||
|
httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse {
|
||||||
|
vars := mux.Vars(req)
|
||||||
|
return SSORedirect(req, vars["idpID"], cfg)
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
v3mux.Handle("/auth/{authType}/fallback/web",
|
v3mux.Handle("/auth/{authType}/fallback/web",
|
||||||
httputil.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) *util.JSONResponse {
|
httputil.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) *util.JSONResponse {
|
||||||
vars := mux.Vars(req)
|
vars := mux.Vars(req)
|
||||||
|
|
|
||||||
283
clientapi/routing/sso.go
Normal file
283
clientapi/routing/sso.go
Normal file
|
|
@ -0,0 +1,283 @@
|
||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/sso"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
uapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SSORedirect implements /login/sso/redirect
|
||||||
|
// https://spec.matrix.org/v1.2/client-server-api/#redirecting-to-the-authentication-server
|
||||||
|
func SSORedirect(
|
||||||
|
req *http.Request,
|
||||||
|
idpID string,
|
||||||
|
cfg *config.ClientAPI,
|
||||||
|
) util.JSONResponse {
|
||||||
|
if !cfg.Login.SSO.Enabled {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusNotImplemented,
|
||||||
|
JSON: jsonerror.NotFound("authentication method disabled"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectURL := req.URL.Query().Get("redirectUrl")
|
||||||
|
if redirectURL == "" {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.MissingArgument("redirectUrl parameter missing"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_, err := url.Parse(redirectURL)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.InvalidArgumentValue("Invalid redirectURL: " + err.Error()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if idpID == "" {
|
||||||
|
// Check configuration if the client didn't provide an ID.
|
||||||
|
idpID = cfg.Login.SSO.DefaultProviderID
|
||||||
|
}
|
||||||
|
if idpID == "" && len(cfg.Login.SSO.Providers) > 0 {
|
||||||
|
// Fall back to the first provider. If there are no providers, getProvider("") will fail.
|
||||||
|
idpID = cfg.Login.SSO.Providers[0].ID
|
||||||
|
}
|
||||||
|
idpCfg, idpType := getProvider(cfg, idpID)
|
||||||
|
if idpType == nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.InvalidArgumentValue("unknown identity provider"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
idpReq := &sso.IdentityProviderRequest{
|
||||||
|
System: idpCfg,
|
||||||
|
CallbackURL: req.URL.ResolveReference(&url.URL{Path: "../callback", RawQuery: url.Values{"provider": []string{idpID}}.Encode()}).String(),
|
||||||
|
DendriteNonce: formatNonce(redirectURL),
|
||||||
|
}
|
||||||
|
u, err := idpType.AuthorizationURL(req.Context(), idpReq)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
JSON: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := util.RedirectResponse(u)
|
||||||
|
resp.Headers["Set-Cookie"] = (&http.Cookie{
|
||||||
|
Name: "oidc_nonce",
|
||||||
|
Value: idpReq.DendriteNonce,
|
||||||
|
Expires: time.Now().Add(10 * time.Minute),
|
||||||
|
Secure: true,
|
||||||
|
SameSite: http.SameSiteStrictMode,
|
||||||
|
}).String()
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSOCallback implements /login/sso/callback.
|
||||||
|
// https://spec.matrix.org/v1.2/client-server-api/#handling-the-callback-from-the-authentication-server
|
||||||
|
func SSOCallback(
|
||||||
|
req *http.Request,
|
||||||
|
userAPI userAPIForSSO,
|
||||||
|
cfg *config.ClientAPI,
|
||||||
|
) util.JSONResponse {
|
||||||
|
ctx := req.Context()
|
||||||
|
|
||||||
|
query := req.URL.Query()
|
||||||
|
idpID := query.Get("provider")
|
||||||
|
if idpID == "" {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.MissingArgument("provider parameter missing"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
idpCfg, idpType := getProvider(cfg, idpID)
|
||||||
|
if idpType == nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.InvalidArgumentValue("unknown identity provider"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
nonce, err := req.Cookie("oidc_nonce")
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.MissingArgument("no nonce cookie: " + err.Error()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
finalRedirectURL, err := parseNonce(nonce.Value)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
idpReq := &sso.IdentityProviderRequest{
|
||||||
|
System: idpCfg,
|
||||||
|
CallbackURL: (&url.URL{
|
||||||
|
Scheme: req.URL.Scheme,
|
||||||
|
Host: req.URL.Host,
|
||||||
|
Path: req.URL.Path,
|
||||||
|
RawQuery: url.Values{
|
||||||
|
"provider": []string{idpID},
|
||||||
|
}.Encode(),
|
||||||
|
}).String(),
|
||||||
|
DendriteNonce: nonce.Value,
|
||||||
|
}
|
||||||
|
result, err := idpType.ProcessCallback(ctx, idpReq, query)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
JSON: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Identifier == nil {
|
||||||
|
// Not authenticated yet.
|
||||||
|
return util.RedirectResponse(result.RedirectURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := verifyThirdPartyUserIdentifier(ctx, userAPI, result.Identifier, cfg.Matrix.ServerName)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).WithField("identifier", result.Identifier.String()).Error("failed to find user")
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusUnauthorized,
|
||||||
|
JSON: jsonerror.Forbidden("ID not associated with a local account"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if id == nil {
|
||||||
|
// The user doesn't exist.
|
||||||
|
// TODO: let the user select a localpart and register an account.
|
||||||
|
util.GetLogger(ctx).WithError(err).WithField("identifier", result.Identifier.String()).Error("failed to find user")
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusNotImplemented,
|
||||||
|
JSON: jsonerror.Forbidden("SSO registration not implemented"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := createLoginToken(ctx, userAPI, id)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).Errorf("PerformLoginTokenCreation failed")
|
||||||
|
return jsonerror.InternalServerError()
|
||||||
|
}
|
||||||
|
|
||||||
|
rquery := finalRedirectURL.Query()
|
||||||
|
rquery.Set("loginToken", token.Token)
|
||||||
|
resp := util.RedirectResponse(finalRedirectURL.ResolveReference(&url.URL{RawQuery: rquery.Encode()}).String())
|
||||||
|
resp.Headers["Set-Cookie"] = (&http.Cookie{
|
||||||
|
Name: "oidc_nonce",
|
||||||
|
Value: "",
|
||||||
|
MaxAge: -1,
|
||||||
|
Secure: true,
|
||||||
|
}).String()
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
type userAPIForSSO interface {
|
||||||
|
uapi.LoginTokenInternalAPI
|
||||||
|
|
||||||
|
QueryLocalpartForThreePID(ctx context.Context, req *uapi.QueryLocalpartForThreePIDRequest, res *uapi.QueryLocalpartForThreePIDResponse) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// getProvider looks up the given provider in the
|
||||||
|
// configuration. Returns nil if it wasn't found or was of unknown
|
||||||
|
// type.
|
||||||
|
func getProvider(cfg *config.ClientAPI, id string) (*config.IdentityProvider, sso.IdentityProvider) {
|
||||||
|
for _, idp := range cfg.Login.SSO.Providers {
|
||||||
|
if idp.ID == id {
|
||||||
|
switch sso.IdentityProviderType(id) {
|
||||||
|
case sso.TypeGitHub:
|
||||||
|
return &idp, sso.GitHubIdentityProvider
|
||||||
|
default:
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// formatNonce creates a random nonce that also contains the URL.
|
||||||
|
func formatNonce(redirectURL string) string {
|
||||||
|
return util.RandomString(16) + "." + base64.RawURLEncoding.EncodeToString([]byte(redirectURL))
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseNonce extracts the embedded URL from the nonce. The nonce
|
||||||
|
// should have been validated to be the original before calling this
|
||||||
|
// function. The URL is not integrity protected.
|
||||||
|
func parseNonce(s string) (redirectURL *url.URL, _ error) {
|
||||||
|
if s == "" {
|
||||||
|
return nil, jsonerror.MissingArgument("empty OIDC nonce cookie")
|
||||||
|
}
|
||||||
|
|
||||||
|
ss := strings.Split(s, ".")
|
||||||
|
if len(ss) < 2 {
|
||||||
|
return nil, jsonerror.InvalidArgumentValue("malformed OIDC nonce cookie")
|
||||||
|
}
|
||||||
|
|
||||||
|
urlbs, err := base64.RawURLEncoding.DecodeString(ss[1])
|
||||||
|
if err != nil {
|
||||||
|
return nil, jsonerror.InvalidArgumentValue("invalid redirect URL in OIDC nonce cookie")
|
||||||
|
}
|
||||||
|
u, err := url.Parse(string(urlbs))
|
||||||
|
if err != nil {
|
||||||
|
return nil, jsonerror.InvalidArgumentValue("invalid redirect URL in OIDC nonce cookie: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// verifyThirdPartyUserIdentifier resolves a ThirdPartyIdentifier to a
|
||||||
|
// UserIdentifier using the User API. Returns nil if there is no
|
||||||
|
// associated user.
|
||||||
|
func verifyThirdPartyUserIdentifier(ctx context.Context, userAPI userAPIForSSO, id *userutil.ThirdPartyIdentifier, serverName gomatrixserverlib.ServerName) (*userutil.UserIdentifier, error) {
|
||||||
|
req := &uapi.QueryLocalpartForThreePIDRequest{
|
||||||
|
ThreePID: id.Address,
|
||||||
|
Medium: string(id.Medium),
|
||||||
|
}
|
||||||
|
var res uapi.QueryLocalpartForThreePIDResponse
|
||||||
|
if err := userAPI.QueryLocalpartForThreePID(ctx, req, &res); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if res.Localpart == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &userutil.UserIdentifier{UserID: userutil.MakeUserID(res.Localpart, serverName)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createLoginToken(ctx context.Context, userAPI userAPIForSSO, id *userutil.UserIdentifier) (*uapi.LoginTokenMetadata, error) {
|
||||||
|
req := uapi.PerformLoginTokenCreationRequest{Data: uapi.LoginTokenData{UserID: id.UserID}}
|
||||||
|
var resp uapi.PerformLoginTokenCreationResponse
|
||||||
|
if err := userAPI.PerformLoginTokenCreation(ctx, &req, &resp); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &resp.Metadata, nil
|
||||||
|
}
|
||||||
153
clientapi/userutil/identifier.go
Normal file
153
clientapi/userutil/identifier.go
Normal file
|
|
@ -0,0 +1,153 @@
|
||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package userutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// An Identifier identifies a user. There are many kinds, and this is
|
||||||
|
// the common interface for them.
|
||||||
|
//
|
||||||
|
// If you need to handle an identifier as JSON, use the AnyIdentifier wrapper.
|
||||||
|
// Passing around identifiers in code, the raw Identifier is enough.
|
||||||
|
//
|
||||||
|
// See https://matrix.org/docs/spec/client_server/r0.6.1#identifier-types
|
||||||
|
type Identifier interface {
|
||||||
|
// IdentifierType returns the identifier type, like "m.id.user".
|
||||||
|
IdentifierType() IdentifierType
|
||||||
|
|
||||||
|
// String returns a debug-output string representation. The format
|
||||||
|
// is not specified.
|
||||||
|
String() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// A UserIdentifier contains an MXID. It may be only the local part.
|
||||||
|
type UserIdentifier struct {
|
||||||
|
UserID string `json:"user"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *UserIdentifier) IdentifierType() IdentifierType { return IdentifierUser }
|
||||||
|
func (i *UserIdentifier) String() string { return i.UserID }
|
||||||
|
|
||||||
|
// A ThirdPartyIdentifier references an identifier in another system.
|
||||||
|
type ThirdPartyIdentifier struct {
|
||||||
|
// Medium is normally MediumEmail.
|
||||||
|
Medium Medium `json:"medium"`
|
||||||
|
|
||||||
|
// Address is the medium-specific identifier.
|
||||||
|
Address string `json:"address"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *ThirdPartyIdentifier) IdentifierType() IdentifierType { return IdentifierThirdParty }
|
||||||
|
func (i *ThirdPartyIdentifier) String() string { return string(i.Medium) + ":" + i.Address }
|
||||||
|
|
||||||
|
// A PhoneIdentifier references a phone number.
|
||||||
|
type PhoneIdentifier struct {
|
||||||
|
// Country is a ISO-3166-1 alpha-2 country code.
|
||||||
|
Country string `json:"country"`
|
||||||
|
|
||||||
|
// PhoneNumber is a country-specific phone number, as it would be dialled from.
|
||||||
|
PhoneNumber string `json:"phone"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *PhoneIdentifier) IdentifierType() IdentifierType { return IdentifierPhone }
|
||||||
|
func (i *PhoneIdentifier) String() string { return i.Country + ":" + i.PhoneNumber }
|
||||||
|
|
||||||
|
// UnknownIdentifier is the catch-all for identifiers this code doesn't know about.
|
||||||
|
// It simply stores raw JSON.
|
||||||
|
type UnknownIdentifier struct {
|
||||||
|
json.RawMessage
|
||||||
|
Type IdentifierType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *UnknownIdentifier) IdentifierType() IdentifierType { return i.Type }
|
||||||
|
func (i *UnknownIdentifier) String() string { return "unknown/" + string(i.Type) }
|
||||||
|
|
||||||
|
// AnyIdentifier is a wrapper that allows marshalling and unmarshalling the various
|
||||||
|
// types of identifiers to/from JSON. Always use this in data types that will be
|
||||||
|
// used in JSON manipulation.
|
||||||
|
type AnyIdentifier struct {
|
||||||
|
Identifier
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i AnyIdentifier) MarshalJSON() ([]byte, error) {
|
||||||
|
v := struct {
|
||||||
|
*UserIdentifier
|
||||||
|
*ThirdPartyIdentifier
|
||||||
|
*PhoneIdentifier
|
||||||
|
Type IdentifierType `json:"type"`
|
||||||
|
}{
|
||||||
|
Type: i.Identifier.IdentifierType(),
|
||||||
|
}
|
||||||
|
switch iid := i.Identifier.(type) {
|
||||||
|
case *UserIdentifier:
|
||||||
|
v.UserIdentifier = iid
|
||||||
|
case *ThirdPartyIdentifier:
|
||||||
|
v.ThirdPartyIdentifier = iid
|
||||||
|
case *PhoneIdentifier:
|
||||||
|
v.PhoneIdentifier = iid
|
||||||
|
case *UnknownIdentifier:
|
||||||
|
return iid.RawMessage, nil
|
||||||
|
}
|
||||||
|
return json.Marshal(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *AnyIdentifier) UnmarshalJSON(bs []byte) error {
|
||||||
|
var hdr struct {
|
||||||
|
Type IdentifierType `json:"type"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(bs, &hdr); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
switch hdr.Type {
|
||||||
|
case IdentifierUser:
|
||||||
|
var ui UserIdentifier
|
||||||
|
i.Identifier = &ui
|
||||||
|
return json.Unmarshal(bs, &ui)
|
||||||
|
case IdentifierThirdParty:
|
||||||
|
var tpi ThirdPartyIdentifier
|
||||||
|
i.Identifier = &tpi
|
||||||
|
return json.Unmarshal(bs, &tpi)
|
||||||
|
case IdentifierPhone:
|
||||||
|
var pi PhoneIdentifier
|
||||||
|
i.Identifier = &pi
|
||||||
|
return json.Unmarshal(bs, &pi)
|
||||||
|
case "":
|
||||||
|
return errors.New("missing identifier type")
|
||||||
|
default:
|
||||||
|
i.Identifier = &UnknownIdentifier{RawMessage: json.RawMessage(bytes.TrimSpace(bs)), Type: hdr.Type}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IdentifierType describes the type of identifier.
|
||||||
|
type IdentifierType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
IdentifierUser IdentifierType = "m.id.user"
|
||||||
|
IdentifierThirdParty IdentifierType = "m.id.thirdparty"
|
||||||
|
IdentifierPhone IdentifierType = "m.id.phone"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Medium describes the interpretation of a third-party identifier.
|
||||||
|
type Medium string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// MediumEmail signifies that the address is an email address.
|
||||||
|
MediumEmail Medium = "email"
|
||||||
|
)
|
||||||
75
clientapi/userutil/identifier_test.go
Normal file
75
clientapi/userutil/identifier_test.go
Normal file
|
|
@ -0,0 +1,75 @@
|
||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package userutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAnyIdentifierJSON(t *testing.T) {
|
||||||
|
tsts := []struct {
|
||||||
|
Name string
|
||||||
|
JSON string
|
||||||
|
Want Identifier
|
||||||
|
}{
|
||||||
|
{Name: "empty", JSON: `{}`},
|
||||||
|
{Name: "user", JSON: `{"type":"m.id.user","user":"auser"}`, Want: &UserIdentifier{UserID: "auser"}},
|
||||||
|
{Name: "thirdparty", JSON: `{"type":"m.id.thirdparty","medium":"email","address":"auser@example.com"}`, Want: &ThirdPartyIdentifier{Medium: "email", Address: "auser@example.com"}},
|
||||||
|
{Name: "phone", JSON: `{"type":"m.id.phone","country":"GB","phone":"123456789"}`, Want: &PhoneIdentifier{Country: "GB", PhoneNumber: "123456789"}},
|
||||||
|
// This test is a little fragile since it compares the output of json.Marshal.
|
||||||
|
{Name: "unknown", JSON: `{"type":"other"}`, Want: &UnknownIdentifier{Type: "other", RawMessage: json.RawMessage(`{"type":"other"}`)}},
|
||||||
|
}
|
||||||
|
for _, tst := range tsts {
|
||||||
|
t.Run("Unmarshal/"+tst.Name, func(t *testing.T) {
|
||||||
|
var got AnyIdentifier
|
||||||
|
if err := json.Unmarshal([]byte(tst.JSON), &got); err != nil {
|
||||||
|
if tst.Want == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Fatalf("Unmarshal failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(got.Identifier, tst.Want) {
|
||||||
|
t.Errorf("got %+v, want %+v", got.Identifier, tst.Want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if tst.Want == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
t.Run("Marshal/"+tst.Name, func(t *testing.T) {
|
||||||
|
id := AnyIdentifier{Identifier: tst.Want}
|
||||||
|
bs, err := json.Marshal(id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Marshal failed: %v", err)
|
||||||
|
}
|
||||||
|
t.Logf("Marshalled JSON: %q", string(bs))
|
||||||
|
|
||||||
|
var got AnyIdentifier
|
||||||
|
if err := json.Unmarshal(bs, &got); err != nil {
|
||||||
|
if tst.Want == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Fatalf("Unmarshal failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(got.Identifier, tst.Want) {
|
||||||
|
t.Errorf("got %+v, want %+v", got.Identifier, tst.Want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -42,6 +42,8 @@ type ClientAPI struct {
|
||||||
// was successful
|
// was successful
|
||||||
RecaptchaSiteVerifyAPI string `yaml:"recaptcha_siteverify_api"`
|
RecaptchaSiteVerifyAPI string `yaml:"recaptcha_siteverify_api"`
|
||||||
|
|
||||||
|
Login Login `yaml:"login"`
|
||||||
|
|
||||||
// TURN options
|
// TURN options
|
||||||
TURN TURN `yaml:"turn"`
|
TURN TURN `yaml:"turn"`
|
||||||
|
|
||||||
|
|
@ -64,9 +66,11 @@ func (c *ClientAPI) Defaults(generate bool) {
|
||||||
c.RegistrationDisabled = true
|
c.RegistrationDisabled = true
|
||||||
c.OpenRegistrationWithoutVerificationEnabled = false
|
c.OpenRegistrationWithoutVerificationEnabled = false
|
||||||
c.RateLimiting.Defaults()
|
c.RateLimiting.Defaults()
|
||||||
|
c.Login.SSO.Enabled = false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
|
func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
|
||||||
|
c.Login.Verify(configErrs)
|
||||||
c.TURN.Verify(configErrs)
|
c.TURN.Verify(configErrs)
|
||||||
c.RateLimiting.Verify(configErrs)
|
c.RateLimiting.Verify(configErrs)
|
||||||
if c.RecaptchaEnabled {
|
if c.RecaptchaEnabled {
|
||||||
|
|
@ -95,6 +99,125 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
|
||||||
checkURL(configErrs, "client_api.external_api.listen", string(c.ExternalAPI.Listen))
|
checkURL(configErrs, "client_api.external_api.listen", string(c.ExternalAPI.Listen))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Login struct {
|
||||||
|
SSO SSO `yaml:"sso"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Login) Verify(configErrs *ConfigErrors) {
|
||||||
|
l.SSO.Verify(configErrs)
|
||||||
|
}
|
||||||
|
|
||||||
|
type SSO struct {
|
||||||
|
// Enabled determines whether SSO should be allowed.
|
||||||
|
Enabled bool `yaml:"enabled"`
|
||||||
|
|
||||||
|
// Providers list the identity providers this server is capable of confirming an
|
||||||
|
// identity with.
|
||||||
|
Providers []IdentityProvider `yaml:"providers"`
|
||||||
|
|
||||||
|
// DefaultProviderID is the provider to use when the client doesn't indicate one.
|
||||||
|
// This is legacy support. If empty, the first provider listed is used.
|
||||||
|
DefaultProviderID string `yaml:"default_provider"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sso *SSO) Verify(configErrs *ConfigErrors) {
|
||||||
|
var foundDefaultProvider bool
|
||||||
|
seenPIDs := make(map[string]bool, len(sso.Providers))
|
||||||
|
for _, p := range sso.Providers {
|
||||||
|
p.Verify(configErrs)
|
||||||
|
if p.ID == sso.DefaultProviderID {
|
||||||
|
foundDefaultProvider = true
|
||||||
|
}
|
||||||
|
if seenPIDs[p.ID] {
|
||||||
|
configErrs.Add(fmt.Sprintf("duplicate identity provider for config key %q: %s", "client_api.sso.providers", p.ID))
|
||||||
|
}
|
||||||
|
seenPIDs[p.ID] = true
|
||||||
|
}
|
||||||
|
if sso.DefaultProviderID != "" && !foundDefaultProvider {
|
||||||
|
configErrs.Add(fmt.Sprintf("identity provider ID not found for config key %q: %s", "client_api.sso.default_provider", sso.DefaultProviderID))
|
||||||
|
}
|
||||||
|
|
||||||
|
if sso.Enabled {
|
||||||
|
if len(sso.Providers) == 0 {
|
||||||
|
configErrs.Add(fmt.Sprintf("empty list for config key %q", "client_api.sso.providers"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// See https://github.com/matrix-org/matrix-doc/blob/old_master/informal/idp-brands.md.
|
||||||
|
type IdentityProvider struct {
|
||||||
|
// ID is the unique identifier of this IdP. We use the brand identifiers as provider
|
||||||
|
// identifiers for simplicity.
|
||||||
|
ID string `yaml:"id"`
|
||||||
|
|
||||||
|
// Name is a human-friendly name of the provider.
|
||||||
|
Name string `yaml:"name"`
|
||||||
|
|
||||||
|
// Brand is a hint on how to display the IdP to the user. If this is empty, a default
|
||||||
|
// based on the type is used.
|
||||||
|
Brand string `yaml:"brand"`
|
||||||
|
|
||||||
|
// Icon is an MXC URI describing how to display the IdP to the user. Prefer using `brand`.
|
||||||
|
Icon string `yaml:"icon"`
|
||||||
|
|
||||||
|
// Type describes how this provider is implemented. It must match "github". If this is
|
||||||
|
// empty, the ID is used, which means there is a weak expectation that ID is also a
|
||||||
|
// valid type, unless you have a complicated setup.
|
||||||
|
Type string `yaml:"type"`
|
||||||
|
|
||||||
|
// OIDC contains settings for providers based on OpenID Connect (OAuth 2).
|
||||||
|
OIDC struct {
|
||||||
|
ClientID string `yaml:"client_id"`
|
||||||
|
ClientSecret string `yaml:"client_secret"`
|
||||||
|
} `yaml:"oidc"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (idp *IdentityProvider) Verify(configErrs *ConfigErrors) {
|
||||||
|
checkNotEmpty(configErrs, "client_api.sso.providers.id", idp.ID)
|
||||||
|
if !checkIdentityProviderBrand(idp.ID) {
|
||||||
|
configErrs.Add(fmt.Sprintf("unrecognized ID config key %q: %s", "client_api.sso.providers", idp.ID))
|
||||||
|
}
|
||||||
|
checkNotEmpty(configErrs, "client_api.sso.providers.name", idp.Name)
|
||||||
|
if idp.Brand != "" && !checkIdentityProviderBrand(idp.Brand) {
|
||||||
|
configErrs.Add(fmt.Sprintf("unrecognized brand in identity provider %q for config key %q: %s", idp.ID, "client_api.sso.providers", idp.Brand))
|
||||||
|
}
|
||||||
|
if idp.Icon != "" {
|
||||||
|
checkURL(configErrs, "client_api.sso.providers.icon", idp.Icon)
|
||||||
|
}
|
||||||
|
typ := idp.Type
|
||||||
|
if idp.Type == "" {
|
||||||
|
typ = idp.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
switch typ {
|
||||||
|
case "github":
|
||||||
|
checkNotEmpty(configErrs, "client_api.sso.providers.oidc.client_id", idp.OIDC.ClientID)
|
||||||
|
checkNotEmpty(configErrs, "client_api.sso.providers.oidc.client_secret", idp.OIDC.ClientSecret)
|
||||||
|
|
||||||
|
default:
|
||||||
|
configErrs.Add(fmt.Sprintf("unrecognized type in identity provider %q for config key %q: %s", idp.ID, "client_api.sso.providers", typ))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// See https://github.com/matrix-org/matrix-doc/blob/old_master/informal/idp-brands.md.
|
||||||
|
func checkIdentityProviderBrand(s string) bool {
|
||||||
|
switch s {
|
||||||
|
case SSOBrandApple, SSOBrandFacebook, SSOBrandGitHub, SSOBrandGitLab, SSOBrandGoogle, SSOBrandTwitter:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
SSOBrandApple = "apple"
|
||||||
|
SSOBrandFacebook = "facebook"
|
||||||
|
SSOBrandGitHub = "github"
|
||||||
|
SSOBrandGitLab = "gitlab"
|
||||||
|
SSOBrandGoogle = "google"
|
||||||
|
SSOBrandTwitter = "twitter"
|
||||||
|
)
|
||||||
|
|
||||||
type TURN struct {
|
type TURN struct {
|
||||||
// TODO Guest Support
|
// TODO Guest Support
|
||||||
// Whether or not guests can request TURN credentials
|
// Whether or not guests can request TURN credentials
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue