From 210ab1eef61693cf9e0e7fece4f5e132acb5cd5f Mon Sep 17 00:00:00 2001 From: Tommie Gannert Date: Wed, 8 Jun 2022 09:14:11 +0200 Subject: [PATCH] Add SSO tests. Renames cookie oidc_nonce to sso_nonce, since it's defined in a file that doesn't know about OIDC specifically. --- clientapi/auth/login_test.go | 28 ++ clientapi/auth/sso/oauth2_test.go | 226 +++++++++++++ clientapi/auth/sso/oidc.go | 16 +- clientapi/auth/sso/oidc_test.go | 118 +++++++ clientapi/auth/sso/sso.go | 10 +- clientapi/auth/sso/sso_test.go | 76 +++++ clientapi/routing/routing.go | 2 +- clientapi/routing/sso.go | 37 ++- clientapi/routing/sso_test.go | 531 ++++++++++++++++++++++++++++++ setup/config/config_test.go | 18 +- userapi/storage/storage_test.go | 35 ++ 11 files changed, 1079 insertions(+), 18 deletions(-) create mode 100644 clientapi/auth/sso/oauth2_test.go create mode 100644 clientapi/auth/sso/oidc_test.go create mode 100644 clientapi/auth/sso/sso_test.go create mode 100644 clientapi/routing/sso_test.go diff --git a/clientapi/auth/login_test.go b/clientapi/auth/login_test.go index cb57e9552..ef4038a0e 100644 --- a/clientapi/auth/login_test.go +++ b/clientapi/auth/login_test.go @@ -101,6 +101,34 @@ func TestLoginFromJSONReader(t *testing.T) { } } +func TestLoginFromJSONReaderTokenDisabled(t *testing.T) { + ctx := context.Background() + + var userAPI fakeUserInternalAPI + cfg := &config.ClientAPI{ + Matrix: &config.Global{ + ServerName: serverName, + }, + Login: config.Login{ + SSO: config.SSO{ + Enabled: false, + }, + }, + } + _, cleanup, err := LoginFromJSONReader(ctx, strings.NewReader(`{ + "type": "m.login.token", + "token": "atoken", + "device_id": "adevice" + }`), &userAPI, &userAPI, cfg) + wantCode := "M_INVALID_ARGUMENT_VALUE" + if err == nil { + cleanup(ctx, nil) + t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", err, wantCode) + } else if merr, ok := err.JSON.(*jsonerror.MatrixError); ok && merr.ErrCode != wantCode { + t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", err, wantCode) + } +} + func TestBadLoginFromJSONReader(t *testing.T) { ctx := context.Background() diff --git a/clientapi/auth/sso/oauth2_test.go b/clientapi/auth/sso/oauth2_test.go new file mode 100644 index 000000000..f541a5236 --- /dev/null +++ b/clientapi/auth/sso/oauth2_test.go @@ -0,0 +1,226 @@ +package sso + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "reflect" + "testing" + + "github.com/matrix-org/dendrite/setup/config" + uapi "github.com/matrix-org/dendrite/userapi/api" +) + +func TestOAuth2IdentityProviderAuthorizationURL(t *testing.T) { + ctx := context.Background() + + idp := &oauth2IdentityProvider{ + cfg: &config.IdentityProvider{ + OAuth2: config.OAuth2{ + ClientID: "aclientid", + }, + }, + hc: http.DefaultClient, + + authorizationURL: "https://oauth2.example.com/authorize", + } + + got, err := idp.AuthorizationURL(ctx, "https://matrix.example.com/continue", "anonce") + if err != nil { + t.Fatalf("AuthorizationURL failed: %v", err) + } + + if want := "https://oauth2.example.com/authorize?client_id=aclientid&redirect_uri=https%3A%2F%2Fmatrix.example.com%2Fcontinue&response_type=code&scope=&state=anonce"; got != want { + t.Errorf("AuthorizationURL: got %q, want %q", got, want) + } +} + +func TestOAuth2IdentityProviderProcessCallback(t *testing.T) { + ctx := context.Background() + + const callbackURL = "https://matrix.example.com/continue" + + tsts := []struct { + Name string + Query url.Values + + Want *CallbackResult + WantTokenReq url.Values + }{ + { + Name: "gotEverything", + Query: url.Values{ + "code": []string{"acode"}, + "state": []string{"anonce"}, + }, + + Want: &CallbackResult{ + Identifier: &UserIdentifier{ + Namespace: uapi.SSOIDNamespace, + Issuer: "anid", + Subject: "asub", + }, + DisplayName: "aname", + SuggestedUserID: "auser", + }, + }, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"access_token":"atoken", "token_type":"Bearer"}`)) + }) + mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"sub":"asub", "name":"aname", "preferred_user":"auser"}`)) + }) + + s := httptest.NewServer(mux) + defer s.Close() + + idp := &oauth2IdentityProvider{ + cfg: &config.IdentityProvider{ + ID: "anid", + OAuth2: config.OAuth2{ + ClientID: "aclientid", + ClientSecret: "aclientsecret", + }, + }, + hc: s.Client(), + + accessTokenURL: s.URL + "/token", + userInfoURL: s.URL + "/userinfo", + + subPath: "sub", + displayNamePath: "name", + suggestedUserIDPath: "preferred_user", + } + + got, err := idp.ProcessCallback(ctx, callbackURL, "anonce", tst.Query) + if err != nil { + t.Fatalf("ProcessCallback failed: %v", err) + } + + if !reflect.DeepEqual(got, tst.Want) { + t.Errorf("ProcessCallback: got %+v, want %+v", got, tst.Want) + } + }) + } +} + +func TestOAuth2IdentityProviderGetAccessToken(t *testing.T) { + ctx := context.Background() + + const callbackURL = "https://matrix.example.com/continue" + + mux := http.NewServeMux() + var gotReq url.Values + mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + gotReq = r.Form + + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"access_token":"atoken", "token_type":"Bearer"}`)) + }) + + s := httptest.NewServer(mux) + defer s.Close() + + idp := &oauth2IdentityProvider{ + cfg: &config.IdentityProvider{ + ID: "anid", + OAuth2: config.OAuth2{ + ClientID: "aclientid", + ClientSecret: "aclientsecret", + }, + }, + hc: s.Client(), + + accessTokenURL: s.URL + "/token", + } + + got, err := idp.getAccessToken(ctx, callbackURL, "acode") + if err != nil { + t.Fatalf("getAccessToken failed: %v", err) + } + + if want := "atoken"; got != want { + t.Errorf("getAccessToken: got %q, want %q", got, want) + } + + wantReq := url.Values{ + "client_id": []string{"aclientid"}, + "client_secret": []string{"aclientsecret"}, + "code": []string{"acode"}, + "grant_type": []string{"authorization_code"}, + "redirect_uri": []string{callbackURL}, + } + if !reflect.DeepEqual(gotReq, wantReq) { + t.Errorf("getAccessToken request: got %+v, want %+v", gotReq, wantReq) + } +} + +func TestOAuth2IdentityProviderGetUserInfo(t *testing.T) { + ctx := context.Background() + + mux := http.NewServeMux() + var gotHeader http.Header + mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) { + gotHeader = r.Header + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"sub":"asub", "name":"aname", "preferred_user":"auser"}`)) + }) + + s := httptest.NewServer(mux) + defer s.Close() + + idp := &oauth2IdentityProvider{ + cfg: &config.IdentityProvider{ + ID: "anid", + OAuth2: config.OAuth2{ + ClientID: "aclientid", + ClientSecret: "aclientsecret", + }, + }, + hc: s.Client(), + + userInfoURL: s.URL + "/userinfo", + + responseMimeType: "application/json", + subPath: "sub", + displayNamePath: "name", + suggestedUserIDPath: "preferred_user", + } + + gotSub, gotName, gotSuggestedUser, err := idp.getUserInfo(ctx, "atoken") + if err != nil { + t.Fatalf("getUserInfo failed: %v", err) + } + + if want := "asub"; gotSub != want { + t.Errorf("getUserInfo subject: got %q, want %q", gotSub, want) + } + if want := "aname"; gotName != want { + t.Errorf("getUserInfo displayName: got %q, want %q", gotName, want) + } + if want := "auser"; gotSuggestedUser != want { + t.Errorf("getUserInfo suggestedUser: got %q, want %q", gotSuggestedUser, want) + } + + gotHeader.Del("Accept-Encoding") + gotHeader.Del("User-Agent") + wantHeader := http.Header{ + "Accept": []string{"application/json"}, + "Authorization": []string{"Bearer atoken"}, + } + if !reflect.DeepEqual(gotHeader, wantHeader) { + t.Errorf("getUserInfo header: got %+v, want %+v", gotHeader, wantHeader) + } +} diff --git a/clientapi/auth/sso/oidc.go b/clientapi/auth/sso/oidc.go index d1e28a736..7d00e457f 100644 --- a/clientapi/auth/sso/oidc.go +++ b/clientapi/auth/sso/oidc.go @@ -27,6 +27,18 @@ import ( uapi "github.com/matrix-org/dendrite/userapi/api" ) +// oidcDiscoveryMaxStaleness indicates how stale the Discovery +// information is allowed to be. This will very rarely change, so +// we're just making sure even a Dendrite that isn't restarting often +// is picking this up eventually. +const oidcDiscoveryMaxStaleness = 24 * time.Hour + +// An oidcIdentityProvider wraps OAuth2 with OpenID Connect Discovery. +// +// The SSO identifier is the "sub." A suggested UserID is grabbed from +// "preferred_username", though this isn't commonly provided. +// +// See https://openid.net/specs/openid-connect-core-1_0.html and https://openid.net/specs/openid-connect-discovery-1_0.html. type oidcIdentityProvider struct { *oauth2IdentityProvider @@ -44,7 +56,7 @@ func newOIDCIdentityProvider(cfg *config.IdentityProvider, hc *http.Client) *oid scopes: []string{"openid", "profile", "email"}, responseMimeType: "application/json", subPath: "sub", - emailPath: "email", + emailPath: "email", // TODO: should this require email_verified? displayNamePath: "name", suggestedUserIDPath: "preferred_username", }, @@ -92,7 +104,7 @@ func (p *oidcIdentityProvider) get(ctx context.Context) (*oauth2IdentityProvider return nil, nil, err } - p.exp = now.Add(24 * time.Hour) + p.exp = now.Add(oidcDiscoveryMaxStaleness) newProvider := *p.oauth2IdentityProvider newProvider.authorizationURL = disc.AuthorizationEndpoint newProvider.accessTokenURL = disc.TokenEndpoint diff --git a/clientapi/auth/sso/oidc_test.go b/clientapi/auth/sso/oidc_test.go new file mode 100644 index 000000000..21205e80c --- /dev/null +++ b/clientapi/auth/sso/oidc_test.go @@ -0,0 +1,118 @@ +package sso + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "reflect" + "testing" + + "github.com/matrix-org/dendrite/setup/config" + uapi "github.com/matrix-org/dendrite/userapi/api" +) + +func TestOIDCIdentityProviderAuthorizationURL(t *testing.T) { + ctx := context.Background() + + mux := http.NewServeMux() + mux.HandleFunc("/discovery", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"authorization_endpoint":"http://oidc.example.com/authorize","token_endpoint":"http://oidc.example.com/token","userinfo_endpoint":"http://oidc.example.com/userinfo","issuer":"http://oidc.example.com/"}`)) + }) + + s := httptest.NewServer(mux) + defer s.Close() + + idp := newOIDCIdentityProvider(&config.IdentityProvider{ + OAuth2: config.OAuth2{ + ClientID: "aclientid", + }, + OIDC: config.OIDC{ + DiscoveryURL: s.URL + "/discovery", + }, + }, s.Client()) + + got, err := idp.AuthorizationURL(ctx, "https://matrix.example.com/continue", "anonce") + if err != nil { + t.Fatalf("AuthorizationURL failed: %v", err) + } + + if want := "http://oidc.example.com/authorize?client_id=aclientid&redirect_uri=https%3A%2F%2Fmatrix.example.com%2Fcontinue&response_type=code&scope=openid+profile+email&state=anonce"; got != want { + t.Errorf("AuthorizationURL: got %q, want %q", got, want) + } +} + +func TestOIDCIdentityProviderProcessCallback(t *testing.T) { + ctx := context.Background() + + const callbackURL = "https://matrix.example.com/continue" + + tsts := []struct { + Name string + Query url.Values + + Want *CallbackResult + WantTokenReq url.Values + }{ + { + Name: "gotEverything", + Query: url.Values{ + "code": []string{"acode"}, + "state": []string{"anonce"}, + }, + + Want: &CallbackResult{ + Identifier: &UserIdentifier{ + Namespace: uapi.OIDCNamespace, + Issuer: "http://oidc.example.com/", + Subject: "asub", + }, + DisplayName: "aname", + SuggestedUserID: "auser", + }, + }, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + mux := http.NewServeMux() + var sURL string + mux.HandleFunc("/discovery", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(fmt.Sprintf(`{"authorization_endpoint":"%s/authorize","token_endpoint":"%s/token","userinfo_endpoint":"%s/userinfo","issuer":"http://oidc.example.com/"}`, + sURL, sURL, sURL))) + }) + mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"access_token":"atoken", "token_type":"Bearer"}`)) + }) + mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"sub":"asub", "name":"aname", "preferred_username":"auser"}`)) + }) + + s := httptest.NewServer(mux) + defer s.Close() + + sURL = s.URL + idp := newOIDCIdentityProvider(&config.IdentityProvider{ + OAuth2: config.OAuth2{ + ClientID: "aclientid", + }, + OIDC: config.OIDC{ + DiscoveryURL: sURL + "/discovery", + }, + }, s.Client()) + + got, err := idp.ProcessCallback(ctx, callbackURL, "anonce", tst.Query) + if err != nil { + t.Fatalf("ProcessCallback failed: %v", err) + } + + if !reflect.DeepEqual(got, tst.Want) { + t.Errorf("ProcessCallback: got %+v, want %+v", got, tst.Want) + } + }) + } +} diff --git a/clientapi/auth/sso/sso.go b/clientapi/auth/sso/sso.go index 862da6dc9..90c613b37 100644 --- a/clientapi/auth/sso/sso.go +++ b/clientapi/auth/sso/sso.go @@ -25,13 +25,19 @@ import ( uapi "github.com/matrix-org/dendrite/userapi/api" ) +// maxHTTPTimeout is an upper bound on an HTTP request to an SSO +// backend. The individual request context deadlines are also honored. +const maxHTTPTimeout = 10 * time.Second + +// An Authenticator keeps a set of identity providers and dispatches +// calls to one of them, based on configured ID. type Authenticator struct { providers map[string]identityProvider } -func NewAuthenticator(ctx context.Context, cfg *config.SSO) (*Authenticator, error) { +func NewAuthenticator(cfg *config.SSO) (*Authenticator, error) { hc := &http.Client{ - Timeout: 10 * time.Second, + Timeout: maxHTTPTimeout, Transport: &http.Transport{ DisableKeepAlives: true, Proxy: http.ProxyFromEnvironment, diff --git a/clientapi/auth/sso/sso_test.go b/clientapi/auth/sso/sso_test.go new file mode 100644 index 000000000..663e07721 --- /dev/null +++ b/clientapi/auth/sso/sso_test.go @@ -0,0 +1,76 @@ +package sso + +import ( + "context" + "net/url" + "reflect" + "testing" + + "github.com/matrix-org/dendrite/setup/config" +) + +func TestNewAuthenticator(t *testing.T) { + _, err := NewAuthenticator(&config.SSO{ + Providers: []config.IdentityProvider{ + { + Type: config.SSOTypeGitHub, + OAuth2: config.OAuth2{ + ClientID: "aclientid", + }, + }, + { + Type: config.SSOTypeOIDC, + OAuth2: config.OAuth2{ + ClientID: "aclientid", + }, + OIDC: config.OIDC{ + DiscoveryURL: "http://oidc.example.com/discovery", + }, + }, + }, + }) + if err != nil { + t.Fatalf("NewAuthenticator failed: %v", err) + } +} + +func TestAuthenticator(t *testing.T) { + ctx := context.Background() + + var idp fakeIdentityProvider + a := Authenticator{ + providers: map[string]identityProvider{ + "fake": &idp, + }, + } + + t.Run("authorizationURL", func(t *testing.T) { + got, err := a.AuthorizationURL(ctx, "fake", "http://matrix.example.com/continue", "anonce") + if err != nil { + t.Fatalf("AuthorizationURL failed: %v", err) + } + if want := "aurl"; got != want { + t.Errorf("AuthorizationURL: got %q, want %q", got, want) + } + }) + + t.Run("processCallback", func(t *testing.T) { + got, err := a.ProcessCallback(ctx, "fake", "http://matrix.example.com/continue", "anonce", url.Values{}) + if err != nil { + t.Fatalf("ProcessCallback failed: %v", err) + } + if want := (&CallbackResult{DisplayName: "aname"}); !reflect.DeepEqual(got, want) { + t.Errorf("ProcessCallback: got %+v, want %+v", got, want) + } + }) +} + +type fakeIdentityProvider struct{} + +func (idp *fakeIdentityProvider) AuthorizationURL(ctx context.Context, callbackURL, nonce string) (string, error) { + return "aurl", nil +} + +func (idp *fakeIdentityProvider) ProcessCallback(ctx context.Context, callbackURL, nonce string, query url.Values) (*CallbackResult, error) { + return &CallbackResult{DisplayName: "aname"}, nil +} diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 21bec3307..82b6ef976 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -73,7 +73,7 @@ func Setup( var ssoAuthenticator *sso.Authenticator if cfg.Login.SSO.Enabled { var err error - ssoAuthenticator, err = sso.NewAuthenticator(ctx, &cfg.Login.SSO) + ssoAuthenticator, err = sso.NewAuthenticator(&cfg.Login.SSO) if err != nil { logrus.WithError(err).Fatal("failed to create SSO authenticator") } diff --git a/clientapi/routing/sso.go b/clientapi/routing/sso.go index f2114733e..154b7e93d 100644 --- a/clientapi/routing/sso.go +++ b/clientapi/routing/sso.go @@ -39,7 +39,7 @@ import ( func SSORedirect( req *http.Request, idpID string, - auth *sso.Authenticator, + auth ssoAuthenticator, cfg *config.SSO, ) util.JSONResponse { ctx := req.Context() @@ -58,12 +58,16 @@ func SSORedirect( JSON: jsonerror.MissingArgument("redirectUrl parameter missing"), } } - _, err := url.Parse(redirectURL) - if err != nil { + if ru, err := url.Parse(redirectURL); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.InvalidArgumentValue("Invalid redirectURL: " + err.Error()), } + } else if ru.Scheme == "" || ru.Host == "" || ru.Path == "" { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("Invalid redirectURL: " + redirectURL), + } } callbackURL, err := buildCallbackURLFromOther(cfg, req, "/login/sso/redirect") @@ -92,7 +96,7 @@ func SSORedirect( resp := util.RedirectResponse(u) cookie := &http.Cookie{ - Name: "oidc_nonce", + Name: "sso_nonce", Value: nonce, Path: path.Dir(callbackURL.Path), Expires: time.Now().Add(10 * time.Minute), @@ -113,7 +117,6 @@ func SSORedirect( func buildCallbackURLFromOther(cfg *config.SSO, req *http.Request, expectedPath string) (*url.URL, error) { u := &url.URL{ Scheme: "https", - User: req.URL.User, Host: req.Host, Path: req.URL.Path, } @@ -141,7 +144,7 @@ func buildCallbackURLFromOther(cfg *config.SSO, req *http.Request, expectedPath func SSOCallback( req *http.Request, userAPI userAPIForSSO, - auth *sso.Authenticator, + auth ssoAuthenticator, cfg *config.SSO, serverName gomatrixserverlib.ServerName, ) util.JSONResponse { @@ -163,7 +166,7 @@ func SSOCallback( } } - nonce, err := req.Cookie("oidc_nonce") + nonce, err := req.Cookie("sso_nonce") if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, @@ -246,7 +249,7 @@ func SSOCallback( rquery.Set("loginToken", token.Token) resp := util.RedirectResponse(finalRedirectURL.ResolveReference(&url.URL{RawQuery: rquery.Encode()}).String()) resp.Headers["Set-Cookie"] = (&http.Cookie{ - Name: "oidc_nonce", + Name: "sso_nonce", Value: "", MaxAge: -1, Secure: true, @@ -254,6 +257,11 @@ func SSOCallback( return resp } +type ssoAuthenticator interface { + AuthorizationURL(ctx context.Context, providerID, callbackURL, nonce string) (string, error) + ProcessCallback(ctx context.Context, providerID, callbackURL, nonce string, query url.Values) (*sso.CallbackResult, error) +} + type userAPIForSSO interface { uapi.LoginTokenInternalAPI @@ -273,21 +281,21 @@ func formatNonce(redirectURL string) string { // function. The URL is not integrity protected. func parseNonce(s string) (redirectURL *url.URL, _ error) { if s == "" { - return nil, jsonerror.MissingArgument("empty OIDC nonce cookie") + return nil, jsonerror.MissingArgument("empty SSO nonce cookie") } ss := strings.Split(s, ".") if len(ss) < 2 { - return nil, jsonerror.InvalidArgumentValue("malformed OIDC nonce cookie") + return nil, jsonerror.InvalidArgumentValue("malformed SSO nonce cookie") } urlbs, err := base64.RawURLEncoding.DecodeString(ss[1]) if err != nil { - return nil, jsonerror.InvalidArgumentValue("invalid redirect URL in OIDC nonce cookie") + return nil, jsonerror.InvalidArgumentValue("invalid redirect URL in SSO nonce cookie") } u, err := url.Parse(string(urlbs)) if err != nil { - return nil, jsonerror.InvalidArgumentValue("invalid redirect URL in OIDC nonce cookie: " + err.Error()) + return nil, jsonerror.InvalidArgumentValue("invalid redirect URL in SSO nonce cookie: " + err.Error()) } return u, nil @@ -309,6 +317,9 @@ func verifySSOUserIdentifier(ctx context.Context, userAPI userAPIForSSO, id *sso return res.Localpart, nil } +// registerSSOAccount creates an account and associates the SSO +// identifier with it. Note that SSO login account creation doesn't +// use the standard registration API, but happens ad-hoc. func registerSSOAccount(ctx context.Context, userAPI userAPIForSSO, ssoID *sso.UserIdentifier, localpart string) (bool, util.JSONResponse) { var accRes uapi.PerformAccountCreationResponse err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ @@ -347,6 +358,8 @@ func registerSSOAccount(ctx context.Context, userAPI userAPIForSSO, ssoID *sso.U return true, util.JSONResponse{} } +// createLoginToken produces a new login token, valid for the given +// user. func createLoginToken(ctx context.Context, userAPI userAPIForSSO, userID string) (*uapi.LoginTokenMetadata, error) { req := uapi.PerformLoginTokenCreationRequest{Data: uapi.LoginTokenData{UserID: userID}} var resp uapi.PerformLoginTokenCreationResponse diff --git a/clientapi/routing/sso_test.go b/clientapi/routing/sso_test.go new file mode 100644 index 000000000..a26604b13 --- /dev/null +++ b/clientapi/routing/sso_test.go @@ -0,0 +1,531 @@ +package routing + +import ( + "context" + "encoding/base64" + "errors" + "net/http" + "net/url" + "regexp" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/matrix-org/dendrite/clientapi/auth/sso" + "github.com/matrix-org/dendrite/setup/config" + uapi "github.com/matrix-org/dendrite/userapi/api" +) + +func TestSSORedirect(t *testing.T) { + tsts := []struct { + Name string + Req http.Request + IDPID string + Auth fakeSSOAuthenticator + Config config.SSO + + WantLocationRE string + WantSetCookieRE string + }{ + { + Name: "redirectDefault", + Req: http.Request{ + Host: "matrix.example.com", + URL: &url.URL{ + Path: "/_matrix/v4/login/sso/redirect", + RawQuery: url.Values{ + "redirectUrl": []string{"http://example.com/continue"}, + }.Encode(), + }, + }, + WantLocationRE: `http://auth.example.com/authorize\?callbackURL=http%3A%2F%2Fmatrix.example.com%2F_matrix%2Fv4%2Flogin%2Fsso%2Fcallback%3Fprovider%3D&nonce=.+&providerID=`, + WantSetCookieRE: "sso_nonce=[^;].*Path=/_matrix/v4/login/sso", + }, + { + Name: "redirectExplicitProvider", + Req: http.Request{ + Host: "matrix.example.com", + URL: &url.URL{ + Path: "/_matrix/v4/login/sso/redirect", + RawQuery: url.Values{ + "redirectUrl": []string{"http://example.com/continue"}, + }.Encode(), + }, + }, + IDPID: "someprovider", + WantLocationRE: `http://auth.example.com/authorize\?callbackURL=http.*%3Fprovider%3Dsomeprovider&nonce=.+&providerID=someprovider`, + WantSetCookieRE: "sso_nonce=[^;].*Path=/_matrix/v4/login/sso", + }, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + got := SSORedirect(&tst.Req, tst.IDPID, &tst.Auth, &tst.Config) + + if want := http.StatusFound; got.Code != want { + t.Errorf("SSORedirect Code: got %v, want %v", got.Code, want) + } + + if m, err := regexp.MatchString(tst.WantLocationRE, got.Headers["Location"]); err != nil { + t.Fatalf("WantSetCookieRE failed: %v", err) + } else if !m { + t.Errorf("SSORedirect Location: got %q, want match %v", got.Headers["Location"], tst.WantLocationRE) + } + + if m, err := regexp.MatchString(tst.WantSetCookieRE, got.Headers["Set-Cookie"]); err != nil { + t.Fatalf("WantSetCookieRE failed: %v", err) + } else if !m { + t.Errorf("SSORedirect Set-Cookie: got %q, want match %v", got.Headers["Set-Cookie"], tst.WantSetCookieRE) + } + }) + } +} + +func TestSSORedirectError(t *testing.T) { + tsts := []struct { + Name string + Req http.Request + IDPID string + Auth fakeSSOAuthenticator + Config config.SSO + + WantCode int + }{ + { + Name: "missingRedirectUrl", + Req: http.Request{ + Host: "matrix.example.com", + URL: &url.URL{ + Path: "/_matrix/v4/login/sso/redirect", + RawQuery: url.Values{}.Encode(), + }, + }, + WantCode: http.StatusBadRequest, + }, + { + Name: "invalidRedirectUrl", + Req: http.Request{ + Host: "matrix.example.com", + URL: &url.URL{ + Path: "/_matrix/v4/login/sso/redirect", + RawQuery: url.Values{ + "redirectUrl": []string{"/continue"}, + }.Encode(), + }, + }, + WantCode: http.StatusBadRequest, + }, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + got := SSORedirect(&tst.Req, tst.IDPID, &tst.Auth, &tst.Config) + + if got.Code != tst.WantCode { + t.Errorf("SSORedirect Code: got %v, want %v", got.Code, tst.WantCode) + } + }) + } +} + +func TestSSOCallback(t *testing.T) { + nonce := "1234." + base64.RawURLEncoding.EncodeToString([]byte("http://matrix.example.com/continue")) + + tsts := []struct { + Name string + Req http.Request + UserAPI fakeUserAPIForSSO + Auth fakeSSOAuthenticator + Config config.SSO + + WantLocationRE string + WantSetCookieRE string + + WantAccountCreation []*uapi.PerformAccountCreationRequest + WantLoginTokenCreation []*uapi.PerformLoginTokenCreationRequest + WantSaveSSOAssociation []*uapi.PerformSaveSSOAssociationRequest + WantQueryLocalpart []*uapi.QueryLocalpartForSSORequest + }{ + { + Name: "logIn", + Req: http.Request{ + Host: "matrix.example.com", + URL: &url.URL{ + Path: "/_matrix/v4/login/sso/callback", + RawQuery: url.Values{ + "provider": []string{"aprovider"}, + }.Encode(), + }, + Header: http.Header{ + "Cookie": []string{(&http.Cookie{ + Name: "sso_nonce", + Value: nonce, + }).String()}, + }, + }, + UserAPI: fakeUserAPIForSSO{ + localpart: "alocalpart", + }, + Auth: fakeSSOAuthenticator{ + callbackResult: sso.CallbackResult{ + Identifier: &sso.UserIdentifier{ + Namespace: "anamespace", + Issuer: "anissuer", + Subject: "asubject", + }, + }, + }, + WantLocationRE: `http://matrix.example.com/continue\?loginToken=atoken`, + WantSetCookieRE: "sso_nonce=;", + + WantLoginTokenCreation: []*uapi.PerformLoginTokenCreationRequest{{Data: uapi.LoginTokenData{UserID: "@alocalpart:aservername"}}}, + WantQueryLocalpart: []*uapi.QueryLocalpartForSSORequest{{Namespace: "anamespace", Issuer: "anissuer", Subject: "asubject"}}, + }, + { + Name: "registerSuggested", + Req: http.Request{ + Host: "matrix.example.com", + URL: &url.URL{ + Path: "/_matrix/v4/login/sso/callback", + RawQuery: url.Values{ + "provider": []string{"aprovider"}, + }.Encode(), + }, + Header: http.Header{ + "Cookie": []string{(&http.Cookie{ + Name: "sso_nonce", + Value: nonce, + }).String()}, + }, + }, + Auth: fakeSSOAuthenticator{ + callbackResult: sso.CallbackResult{ + Identifier: &sso.UserIdentifier{ + Namespace: "anamespace", + Issuer: "anissuer", + Subject: "asubject", + }, + SuggestedUserID: "asuggestedid", + }, + }, + WantLocationRE: `http://matrix.example.com/continue\?loginToken=atoken`, + WantSetCookieRE: "sso_nonce=;", + + WantAccountCreation: []*uapi.PerformAccountCreationRequest{{Localpart: "asuggestedid", AccountType: uapi.AccountTypeUser, OnConflict: uapi.ConflictAbort}}, + WantLoginTokenCreation: []*uapi.PerformLoginTokenCreationRequest{{Data: uapi.LoginTokenData{UserID: "@asuggestedid:aservername"}}}, + WantSaveSSOAssociation: []*uapi.PerformSaveSSOAssociationRequest{{Namespace: "anamespace", Issuer: "anissuer", Subject: "asubject", Localpart: "asuggestedid"}}, + WantQueryLocalpart: []*uapi.QueryLocalpartForSSORequest{{Namespace: "anamespace", Issuer: "anissuer", Subject: "asubject"}}, + }, + { + Name: "registerNumeric", + Req: http.Request{ + Host: "matrix.example.com", + URL: &url.URL{ + Path: "/_matrix/v4/login/sso/callback", + RawQuery: url.Values{ + "provider": []string{"aprovider"}, + }.Encode(), + }, + Header: http.Header{ + "Cookie": []string{(&http.Cookie{ + Name: "sso_nonce", + Value: nonce, + }).String()}, + }, + }, + Auth: fakeSSOAuthenticator{ + callbackResult: sso.CallbackResult{ + Identifier: &sso.UserIdentifier{ + Namespace: "anamespace", + Issuer: "anissuer", + Subject: "asubject", + }, + }, + }, + WantLocationRE: `http://matrix.example.com/continue\?loginToken=atoken`, + WantSetCookieRE: "sso_nonce=;", + + WantAccountCreation: []*uapi.PerformAccountCreationRequest{{Localpart: "12345", AccountType: uapi.AccountTypeUser, OnConflict: uapi.ConflictAbort}}, + WantLoginTokenCreation: []*uapi.PerformLoginTokenCreationRequest{{Data: uapi.LoginTokenData{UserID: "@12345:aservername"}}}, + WantSaveSSOAssociation: []*uapi.PerformSaveSSOAssociationRequest{{Namespace: "anamespace", Issuer: "anissuer", Subject: "asubject", Localpart: "12345"}}, + WantQueryLocalpart: []*uapi.QueryLocalpartForSSORequest{{Namespace: "anamespace", Issuer: "anissuer", Subject: "asubject"}}, + }, + { + Name: "noIdentifierRedirectURL", + Req: http.Request{ + Host: "matrix.example.com", + URL: &url.URL{ + Path: "/_matrix/v4/login/sso/callback", + RawQuery: url.Values{ + "provider": []string{"aprovider"}, + }.Encode(), + }, + Header: http.Header{ + "Cookie": []string{(&http.Cookie{ + Name: "sso_nonce", + Value: nonce, + }).String()}, + }, + }, + Auth: fakeSSOAuthenticator{ + callbackResult: sso.CallbackResult{ + RedirectURL: "http://auth.example.com/notdone", + }, + }, + WantLocationRE: `http://auth.example.com/notdone`, + WantSetCookieRE: "^$", + }, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + got := SSOCallback(&tst.Req, &tst.UserAPI, &tst.Auth, &tst.Config, "aservername") + + if want := http.StatusFound; got.Code != want { + t.Log(got) + t.Errorf("SSOCallback Code: got %v, want %v", got.Code, want) + } + + if m, err := regexp.MatchString(tst.WantLocationRE, got.Headers["Location"]); err != nil { + t.Fatalf("WantSetCookieRE failed: %v", err) + } else if !m { + t.Errorf("SSOCallback Location: got %q, want match %v", got.Headers["Location"], tst.WantLocationRE) + } + + if m, err := regexp.MatchString(tst.WantSetCookieRE, got.Headers["Set-Cookie"]); err != nil { + t.Fatalf("WantSetCookieRE failed: %v", err) + } else if !m { + t.Errorf("SSOCallback Set-Cookie: got %q, want match %v", got.Headers["Set-Cookie"], tst.WantSetCookieRE) + } + + if diff := cmp.Diff(tst.WantAccountCreation, tst.UserAPI.gotAccountCreation); diff != "" { + t.Errorf("PerformAccountCreation: +got -want:\n%s", diff) + } + if diff := cmp.Diff(tst.WantLoginTokenCreation, tst.UserAPI.gotLoginTokenCreation); diff != "" { + t.Errorf("PerformLoginTokenCreation: +got -want:\n%s", diff) + } + if diff := cmp.Diff(tst.WantSaveSSOAssociation, tst.UserAPI.gotSaveSSOAssociation); diff != "" { + t.Errorf("PerformSaveSSOAssociation: +got -want:\n%s", diff) + } + if diff := cmp.Diff(tst.WantQueryLocalpart, tst.UserAPI.gotQueryLocalpart); diff != "" { + t.Errorf("QueryLocalpartForSSO: +got -want:\n%s", diff) + } + }) + } +} + +func TestSSOCallbackError(t *testing.T) { + nonce := "1234." + base64.RawURLEncoding.EncodeToString([]byte("http://matrix.example.com/continue")) + goodReq := http.Request{ + Host: "matrix.example.com", + URL: &url.URL{ + Path: "/_matrix/v4/login/sso/callback", + RawQuery: url.Values{ + "provider": []string{"aprovider"}, + }.Encode(), + }, + Header: http.Header{ + "Cookie": []string{(&http.Cookie{ + Name: "sso_nonce", + Value: nonce, + }).String()}, + }, + } + goodAuth := fakeSSOAuthenticator{ + callbackResult: sso.CallbackResult{ + Identifier: &sso.UserIdentifier{ + Namespace: "anamespace", + Issuer: "anissuer", + Subject: "asubject", + }, + }, + } + errMocked := errors.New("mocked error") + + tsts := []struct { + Name string + Req http.Request + UserAPI fakeUserAPIForSSO + Auth fakeSSOAuthenticator + Config config.SSO + + WantCode int + }{ + { + Name: "missingProvider", + Req: http.Request{ + Host: "matrix.example.com", + URL: &url.URL{ + Path: "/_matrix/v4/login/sso/callback", + }, + Header: http.Header{ + "Cookie": []string{(&http.Cookie{ + Name: "sso_nonce", + Value: nonce, + }).String()}, + }, + }, + WantCode: http.StatusBadRequest, + }, + { + Name: "missingCookie", + Req: http.Request{ + Host: "matrix.example.com", + URL: &url.URL{ + Path: "/_matrix/v4/login/sso/callback", + RawQuery: url.Values{ + "provider": []string{"aprovider"}, + }.Encode(), + }, + }, + WantCode: http.StatusBadRequest, + }, + { + Name: "malformedCookie", + Req: http.Request{ + Host: "matrix.example.com", + URL: &url.URL{ + Path: "/_matrix/v4/login/sso/callback", + RawQuery: url.Values{ + "provider": []string{"aprovider"}, + }.Encode(), + }, + Header: http.Header{ + "Cookie": []string{(&http.Cookie{ + Name: "sso_nonce", + Value: "badvalue", + }).String()}, + }, + }, + WantCode: http.StatusBadRequest, + }, + { + Name: "failedProcessCallback", + Req: goodReq, + Auth: fakeSSOAuthenticator{ + callbackErr: errMocked, + }, + WantCode: http.StatusInternalServerError, + }, + { + Name: "failedQueryLocalpartForSSO", + Req: goodReq, + UserAPI: fakeUserAPIForSSO{ + localpartErr: errMocked, + }, + Auth: goodAuth, + WantCode: http.StatusUnauthorized, + }, + { + Name: "failedQueryNumericLocalpart", + Req: goodReq, + UserAPI: fakeUserAPIForSSO{ + numericLocalpartErr: errMocked, + }, + Auth: goodAuth, + WantCode: http.StatusInternalServerError, + }, + { + Name: "failedAccountCreation", + Req: goodReq, + UserAPI: fakeUserAPIForSSO{ + accountCreationErr: errMocked, + }, + Auth: goodAuth, + WantCode: http.StatusInternalServerError, + }, + { + Name: "failedSaveSSOAssociation", + Req: goodReq, + UserAPI: fakeUserAPIForSSO{ + saveSSOAssociationErr: errMocked, + }, + Auth: goodAuth, + WantCode: http.StatusInternalServerError, + }, + { + Name: "failedPerformLoginTokenCreation", + Req: goodReq, + UserAPI: fakeUserAPIForSSO{ + localpart: "alocalpart", + tokenTokenCreationErr: errMocked, + }, + Auth: goodAuth, + WantCode: http.StatusInternalServerError, + }, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + got := SSOCallback(&tst.Req, &tst.UserAPI, &tst.Auth, &tst.Config, "aservername") + + if got.Code != tst.WantCode { + t.Log(got) + t.Errorf("SSOCallback Code: got %v, want %v", got.Code, tst.WantCode) + } + }) + } +} + +type fakeSSOAuthenticator struct { + callbackResult sso.CallbackResult + callbackErr error +} + +func (auth *fakeSSOAuthenticator) AuthorizationURL(ctx context.Context, providerID, callbackURL, nonce string) (string, error) { + return (&url.URL{ + Scheme: "http", + Host: "auth.example.com", + Path: "/authorize", + }).ResolveReference(&url.URL{ + RawQuery: url.Values{ + "callbackURL": []string{callbackURL}, + "nonce": []string{nonce}, + "providerID": []string{providerID}, + }.Encode(), + }).String(), nil +} + +func (auth *fakeSSOAuthenticator) ProcessCallback(ctx context.Context, providerID, callbackURL, nonce string, query url.Values) (*sso.CallbackResult, error) { + return &auth.callbackResult, auth.callbackErr +} + +type fakeUserAPIForSSO struct { + userAPIForSSO + + accountCreationErr error + tokenTokenCreationErr error + saveSSOAssociationErr error + localpart string + localpartErr error + numericLocalpartErr error + + gotAccountCreation []*uapi.PerformAccountCreationRequest + gotLoginTokenCreation []*uapi.PerformLoginTokenCreationRequest + gotSaveSSOAssociation []*uapi.PerformSaveSSOAssociationRequest + gotQueryLocalpart []*uapi.QueryLocalpartForSSORequest +} + +func (userAPI *fakeUserAPIForSSO) PerformAccountCreation(ctx context.Context, req *uapi.PerformAccountCreationRequest, res *uapi.PerformAccountCreationResponse) error { + userAPI.gotAccountCreation = append(userAPI.gotAccountCreation, req) + return userAPI.accountCreationErr +} + +func (userAPI *fakeUserAPIForSSO) PerformLoginTokenCreation(ctx context.Context, req *uapi.PerformLoginTokenCreationRequest, res *uapi.PerformLoginTokenCreationResponse) error { + userAPI.gotLoginTokenCreation = append(userAPI.gotLoginTokenCreation, req) + res.Metadata = uapi.LoginTokenMetadata{ + Token: "atoken", + } + return userAPI.tokenTokenCreationErr +} + +func (userAPI *fakeUserAPIForSSO) PerformSaveSSOAssociation(ctx context.Context, req *uapi.PerformSaveSSOAssociationRequest, res *struct{}) error { + userAPI.gotSaveSSOAssociation = append(userAPI.gotSaveSSOAssociation, req) + return userAPI.saveSSOAssociationErr +} + +func (userAPI *fakeUserAPIForSSO) QueryLocalpartForSSO(ctx context.Context, req *uapi.QueryLocalpartForSSORequest, res *uapi.QueryLocalpartForSSOResponse) error { + userAPI.gotQueryLocalpart = append(userAPI.gotQueryLocalpart, req) + res.Localpart = userAPI.localpart + return userAPI.localpartErr +} + +func (userAPI *fakeUserAPIForSSO) QueryNumericLocalpart(ctx context.Context, res *uapi.QueryNumericLocalpartResponse) error { + res.ID = 12345 + return userAPI.numericLocalpartErr +} diff --git a/setup/config/config_test.go b/setup/config/config_test.go index cbc57ad18..565e3c729 100644 --- a/setup/config/config_test.go +++ b/setup/config/config_test.go @@ -62,7 +62,7 @@ global: local_part: "_server" display_name: "Server alerts" avatar: "" - room_name: "Server Alerts" + room_name: "Server Alerts" app_service_api: internal_api: listen: http://localhost:7777 @@ -86,6 +86,22 @@ client_api: recaptcha_private_key: "" recaptcha_bypass_secret: "" recaptcha_siteverify_api: "" + login: + sso: + enabled: true + callback_url: http://example.com:8071/_matrix/v3/login/sso/callback + default_provider: github + providers: + - brand: github + - id: custom + name: "Custom Provider" + icon: "mxc://example.com/abc123" + type: oidc + oauth2: + client_id: aclientid + client_secret: aclientsecret + oidc: + discovery_url: http://auth.example.com/.well-known/openid-configuration turn: turn_user_lifetime: "" turn_uris: [] diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index 5bee880d3..1c381ffee 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -429,6 +429,41 @@ func Test_Pusher(t *testing.T) { }) } +func Test_SSO(t *testing.T) { + alice := test.NewUser(t) + aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) + assert.NoError(t, err) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateDatabase(t, dbType) + defer close() + + t.Log("Create SSO association") + + ns := util.RandomString(8) + issuer := util.RandomString(8) + subject := util.RandomString(8) + err := db.SaveSSOAssociation(ctx, ns, issuer, subject, aliceLocalpart) + assert.NoError(t, err, "unable to save SSO association") + + t.Log("Retrieve localpart for association") + + gotLocalpart, err := db.GetLocalpartForSSO(ctx, ns, issuer, subject) + assert.Equal(t, aliceLocalpart, gotLocalpart) + + t.Log("Remove SSO association") + + err = db.RemoveSSOAssociation(ctx, ns, issuer, subject) + assert.NoError(t, err, "unexpected error") + + t.Log("Verify the SSO association was removed") + + gotLocalpart, err = db.GetLocalpartForSSO(ctx, ns, issuer, subject) + assert.NoError(t, err, "unable to get localpart for SSO subject") + assert.Equal(t, "", gotLocalpart) + }) +} + func Test_ThreePID(t *testing.T) { alice := test.NewUser(t) aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)