dendrite/clientapi/routing/sso_test.go
Tommie Gannert 210ab1eef6 Add SSO tests.
Renames cookie oidc_nonce to sso_nonce, since it's defined in a file
that doesn't know about OIDC specifically.
2022-06-08 09:14:11 +02:00

532 lines
16 KiB
Go

package routing
import (
"context"
"encoding/base64"
"errors"
"net/http"
"net/url"
"regexp"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/matrix-org/dendrite/clientapi/auth/sso"
"github.com/matrix-org/dendrite/setup/config"
uapi "github.com/matrix-org/dendrite/userapi/api"
)
func TestSSORedirect(t *testing.T) {
tsts := []struct {
Name string
Req http.Request
IDPID string
Auth fakeSSOAuthenticator
Config config.SSO
WantLocationRE string
WantSetCookieRE string
}{
{
Name: "redirectDefault",
Req: http.Request{
Host: "matrix.example.com",
URL: &url.URL{
Path: "/_matrix/v4/login/sso/redirect",
RawQuery: url.Values{
"redirectUrl": []string{"http://example.com/continue"},
}.Encode(),
},
},
WantLocationRE: `http://auth.example.com/authorize\?callbackURL=http%3A%2F%2Fmatrix.example.com%2F_matrix%2Fv4%2Flogin%2Fsso%2Fcallback%3Fprovider%3D&nonce=.+&providerID=`,
WantSetCookieRE: "sso_nonce=[^;].*Path=/_matrix/v4/login/sso",
},
{
Name: "redirectExplicitProvider",
Req: http.Request{
Host: "matrix.example.com",
URL: &url.URL{
Path: "/_matrix/v4/login/sso/redirect",
RawQuery: url.Values{
"redirectUrl": []string{"http://example.com/continue"},
}.Encode(),
},
},
IDPID: "someprovider",
WantLocationRE: `http://auth.example.com/authorize\?callbackURL=http.*%3Fprovider%3Dsomeprovider&nonce=.+&providerID=someprovider`,
WantSetCookieRE: "sso_nonce=[^;].*Path=/_matrix/v4/login/sso",
},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
got := SSORedirect(&tst.Req, tst.IDPID, &tst.Auth, &tst.Config)
if want := http.StatusFound; got.Code != want {
t.Errorf("SSORedirect Code: got %v, want %v", got.Code, want)
}
if m, err := regexp.MatchString(tst.WantLocationRE, got.Headers["Location"]); err != nil {
t.Fatalf("WantSetCookieRE failed: %v", err)
} else if !m {
t.Errorf("SSORedirect Location: got %q, want match %v", got.Headers["Location"], tst.WantLocationRE)
}
if m, err := regexp.MatchString(tst.WantSetCookieRE, got.Headers["Set-Cookie"]); err != nil {
t.Fatalf("WantSetCookieRE failed: %v", err)
} else if !m {
t.Errorf("SSORedirect Set-Cookie: got %q, want match %v", got.Headers["Set-Cookie"], tst.WantSetCookieRE)
}
})
}
}
func TestSSORedirectError(t *testing.T) {
tsts := []struct {
Name string
Req http.Request
IDPID string
Auth fakeSSOAuthenticator
Config config.SSO
WantCode int
}{
{
Name: "missingRedirectUrl",
Req: http.Request{
Host: "matrix.example.com",
URL: &url.URL{
Path: "/_matrix/v4/login/sso/redirect",
RawQuery: url.Values{}.Encode(),
},
},
WantCode: http.StatusBadRequest,
},
{
Name: "invalidRedirectUrl",
Req: http.Request{
Host: "matrix.example.com",
URL: &url.URL{
Path: "/_matrix/v4/login/sso/redirect",
RawQuery: url.Values{
"redirectUrl": []string{"/continue"},
}.Encode(),
},
},
WantCode: http.StatusBadRequest,
},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
got := SSORedirect(&tst.Req, tst.IDPID, &tst.Auth, &tst.Config)
if got.Code != tst.WantCode {
t.Errorf("SSORedirect Code: got %v, want %v", got.Code, tst.WantCode)
}
})
}
}
func TestSSOCallback(t *testing.T) {
nonce := "1234." + base64.RawURLEncoding.EncodeToString([]byte("http://matrix.example.com/continue"))
tsts := []struct {
Name string
Req http.Request
UserAPI fakeUserAPIForSSO
Auth fakeSSOAuthenticator
Config config.SSO
WantLocationRE string
WantSetCookieRE string
WantAccountCreation []*uapi.PerformAccountCreationRequest
WantLoginTokenCreation []*uapi.PerformLoginTokenCreationRequest
WantSaveSSOAssociation []*uapi.PerformSaveSSOAssociationRequest
WantQueryLocalpart []*uapi.QueryLocalpartForSSORequest
}{
{
Name: "logIn",
Req: http.Request{
Host: "matrix.example.com",
URL: &url.URL{
Path: "/_matrix/v4/login/sso/callback",
RawQuery: url.Values{
"provider": []string{"aprovider"},
}.Encode(),
},
Header: http.Header{
"Cookie": []string{(&http.Cookie{
Name: "sso_nonce",
Value: nonce,
}).String()},
},
},
UserAPI: fakeUserAPIForSSO{
localpart: "alocalpart",
},
Auth: fakeSSOAuthenticator{
callbackResult: sso.CallbackResult{
Identifier: &sso.UserIdentifier{
Namespace: "anamespace",
Issuer: "anissuer",
Subject: "asubject",
},
},
},
WantLocationRE: `http://matrix.example.com/continue\?loginToken=atoken`,
WantSetCookieRE: "sso_nonce=;",
WantLoginTokenCreation: []*uapi.PerformLoginTokenCreationRequest{{Data: uapi.LoginTokenData{UserID: "@alocalpart:aservername"}}},
WantQueryLocalpart: []*uapi.QueryLocalpartForSSORequest{{Namespace: "anamespace", Issuer: "anissuer", Subject: "asubject"}},
},
{
Name: "registerSuggested",
Req: http.Request{
Host: "matrix.example.com",
URL: &url.URL{
Path: "/_matrix/v4/login/sso/callback",
RawQuery: url.Values{
"provider": []string{"aprovider"},
}.Encode(),
},
Header: http.Header{
"Cookie": []string{(&http.Cookie{
Name: "sso_nonce",
Value: nonce,
}).String()},
},
},
Auth: fakeSSOAuthenticator{
callbackResult: sso.CallbackResult{
Identifier: &sso.UserIdentifier{
Namespace: "anamespace",
Issuer: "anissuer",
Subject: "asubject",
},
SuggestedUserID: "asuggestedid",
},
},
WantLocationRE: `http://matrix.example.com/continue\?loginToken=atoken`,
WantSetCookieRE: "sso_nonce=;",
WantAccountCreation: []*uapi.PerformAccountCreationRequest{{Localpart: "asuggestedid", AccountType: uapi.AccountTypeUser, OnConflict: uapi.ConflictAbort}},
WantLoginTokenCreation: []*uapi.PerformLoginTokenCreationRequest{{Data: uapi.LoginTokenData{UserID: "@asuggestedid:aservername"}}},
WantSaveSSOAssociation: []*uapi.PerformSaveSSOAssociationRequest{{Namespace: "anamespace", Issuer: "anissuer", Subject: "asubject", Localpart: "asuggestedid"}},
WantQueryLocalpart: []*uapi.QueryLocalpartForSSORequest{{Namespace: "anamespace", Issuer: "anissuer", Subject: "asubject"}},
},
{
Name: "registerNumeric",
Req: http.Request{
Host: "matrix.example.com",
URL: &url.URL{
Path: "/_matrix/v4/login/sso/callback",
RawQuery: url.Values{
"provider": []string{"aprovider"},
}.Encode(),
},
Header: http.Header{
"Cookie": []string{(&http.Cookie{
Name: "sso_nonce",
Value: nonce,
}).String()},
},
},
Auth: fakeSSOAuthenticator{
callbackResult: sso.CallbackResult{
Identifier: &sso.UserIdentifier{
Namespace: "anamespace",
Issuer: "anissuer",
Subject: "asubject",
},
},
},
WantLocationRE: `http://matrix.example.com/continue\?loginToken=atoken`,
WantSetCookieRE: "sso_nonce=;",
WantAccountCreation: []*uapi.PerformAccountCreationRequest{{Localpart: "12345", AccountType: uapi.AccountTypeUser, OnConflict: uapi.ConflictAbort}},
WantLoginTokenCreation: []*uapi.PerformLoginTokenCreationRequest{{Data: uapi.LoginTokenData{UserID: "@12345:aservername"}}},
WantSaveSSOAssociation: []*uapi.PerformSaveSSOAssociationRequest{{Namespace: "anamespace", Issuer: "anissuer", Subject: "asubject", Localpart: "12345"}},
WantQueryLocalpart: []*uapi.QueryLocalpartForSSORequest{{Namespace: "anamespace", Issuer: "anissuer", Subject: "asubject"}},
},
{
Name: "noIdentifierRedirectURL",
Req: http.Request{
Host: "matrix.example.com",
URL: &url.URL{
Path: "/_matrix/v4/login/sso/callback",
RawQuery: url.Values{
"provider": []string{"aprovider"},
}.Encode(),
},
Header: http.Header{
"Cookie": []string{(&http.Cookie{
Name: "sso_nonce",
Value: nonce,
}).String()},
},
},
Auth: fakeSSOAuthenticator{
callbackResult: sso.CallbackResult{
RedirectURL: "http://auth.example.com/notdone",
},
},
WantLocationRE: `http://auth.example.com/notdone`,
WantSetCookieRE: "^$",
},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
got := SSOCallback(&tst.Req, &tst.UserAPI, &tst.Auth, &tst.Config, "aservername")
if want := http.StatusFound; got.Code != want {
t.Log(got)
t.Errorf("SSOCallback Code: got %v, want %v", got.Code, want)
}
if m, err := regexp.MatchString(tst.WantLocationRE, got.Headers["Location"]); err != nil {
t.Fatalf("WantSetCookieRE failed: %v", err)
} else if !m {
t.Errorf("SSOCallback Location: got %q, want match %v", got.Headers["Location"], tst.WantLocationRE)
}
if m, err := regexp.MatchString(tst.WantSetCookieRE, got.Headers["Set-Cookie"]); err != nil {
t.Fatalf("WantSetCookieRE failed: %v", err)
} else if !m {
t.Errorf("SSOCallback Set-Cookie: got %q, want match %v", got.Headers["Set-Cookie"], tst.WantSetCookieRE)
}
if diff := cmp.Diff(tst.WantAccountCreation, tst.UserAPI.gotAccountCreation); diff != "" {
t.Errorf("PerformAccountCreation: +got -want:\n%s", diff)
}
if diff := cmp.Diff(tst.WantLoginTokenCreation, tst.UserAPI.gotLoginTokenCreation); diff != "" {
t.Errorf("PerformLoginTokenCreation: +got -want:\n%s", diff)
}
if diff := cmp.Diff(tst.WantSaveSSOAssociation, tst.UserAPI.gotSaveSSOAssociation); diff != "" {
t.Errorf("PerformSaveSSOAssociation: +got -want:\n%s", diff)
}
if diff := cmp.Diff(tst.WantQueryLocalpart, tst.UserAPI.gotQueryLocalpart); diff != "" {
t.Errorf("QueryLocalpartForSSO: +got -want:\n%s", diff)
}
})
}
}
func TestSSOCallbackError(t *testing.T) {
nonce := "1234." + base64.RawURLEncoding.EncodeToString([]byte("http://matrix.example.com/continue"))
goodReq := http.Request{
Host: "matrix.example.com",
URL: &url.URL{
Path: "/_matrix/v4/login/sso/callback",
RawQuery: url.Values{
"provider": []string{"aprovider"},
}.Encode(),
},
Header: http.Header{
"Cookie": []string{(&http.Cookie{
Name: "sso_nonce",
Value: nonce,
}).String()},
},
}
goodAuth := fakeSSOAuthenticator{
callbackResult: sso.CallbackResult{
Identifier: &sso.UserIdentifier{
Namespace: "anamespace",
Issuer: "anissuer",
Subject: "asubject",
},
},
}
errMocked := errors.New("mocked error")
tsts := []struct {
Name string
Req http.Request
UserAPI fakeUserAPIForSSO
Auth fakeSSOAuthenticator
Config config.SSO
WantCode int
}{
{
Name: "missingProvider",
Req: http.Request{
Host: "matrix.example.com",
URL: &url.URL{
Path: "/_matrix/v4/login/sso/callback",
},
Header: http.Header{
"Cookie": []string{(&http.Cookie{
Name: "sso_nonce",
Value: nonce,
}).String()},
},
},
WantCode: http.StatusBadRequest,
},
{
Name: "missingCookie",
Req: http.Request{
Host: "matrix.example.com",
URL: &url.URL{
Path: "/_matrix/v4/login/sso/callback",
RawQuery: url.Values{
"provider": []string{"aprovider"},
}.Encode(),
},
},
WantCode: http.StatusBadRequest,
},
{
Name: "malformedCookie",
Req: http.Request{
Host: "matrix.example.com",
URL: &url.URL{
Path: "/_matrix/v4/login/sso/callback",
RawQuery: url.Values{
"provider": []string{"aprovider"},
}.Encode(),
},
Header: http.Header{
"Cookie": []string{(&http.Cookie{
Name: "sso_nonce",
Value: "badvalue",
}).String()},
},
},
WantCode: http.StatusBadRequest,
},
{
Name: "failedProcessCallback",
Req: goodReq,
Auth: fakeSSOAuthenticator{
callbackErr: errMocked,
},
WantCode: http.StatusInternalServerError,
},
{
Name: "failedQueryLocalpartForSSO",
Req: goodReq,
UserAPI: fakeUserAPIForSSO{
localpartErr: errMocked,
},
Auth: goodAuth,
WantCode: http.StatusUnauthorized,
},
{
Name: "failedQueryNumericLocalpart",
Req: goodReq,
UserAPI: fakeUserAPIForSSO{
numericLocalpartErr: errMocked,
},
Auth: goodAuth,
WantCode: http.StatusInternalServerError,
},
{
Name: "failedAccountCreation",
Req: goodReq,
UserAPI: fakeUserAPIForSSO{
accountCreationErr: errMocked,
},
Auth: goodAuth,
WantCode: http.StatusInternalServerError,
},
{
Name: "failedSaveSSOAssociation",
Req: goodReq,
UserAPI: fakeUserAPIForSSO{
saveSSOAssociationErr: errMocked,
},
Auth: goodAuth,
WantCode: http.StatusInternalServerError,
},
{
Name: "failedPerformLoginTokenCreation",
Req: goodReq,
UserAPI: fakeUserAPIForSSO{
localpart: "alocalpart",
tokenTokenCreationErr: errMocked,
},
Auth: goodAuth,
WantCode: http.StatusInternalServerError,
},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
got := SSOCallback(&tst.Req, &tst.UserAPI, &tst.Auth, &tst.Config, "aservername")
if got.Code != tst.WantCode {
t.Log(got)
t.Errorf("SSOCallback Code: got %v, want %v", got.Code, tst.WantCode)
}
})
}
}
type fakeSSOAuthenticator struct {
callbackResult sso.CallbackResult
callbackErr error
}
func (auth *fakeSSOAuthenticator) AuthorizationURL(ctx context.Context, providerID, callbackURL, nonce string) (string, error) {
return (&url.URL{
Scheme: "http",
Host: "auth.example.com",
Path: "/authorize",
}).ResolveReference(&url.URL{
RawQuery: url.Values{
"callbackURL": []string{callbackURL},
"nonce": []string{nonce},
"providerID": []string{providerID},
}.Encode(),
}).String(), nil
}
func (auth *fakeSSOAuthenticator) ProcessCallback(ctx context.Context, providerID, callbackURL, nonce string, query url.Values) (*sso.CallbackResult, error) {
return &auth.callbackResult, auth.callbackErr
}
type fakeUserAPIForSSO struct {
userAPIForSSO
accountCreationErr error
tokenTokenCreationErr error
saveSSOAssociationErr error
localpart string
localpartErr error
numericLocalpartErr error
gotAccountCreation []*uapi.PerformAccountCreationRequest
gotLoginTokenCreation []*uapi.PerformLoginTokenCreationRequest
gotSaveSSOAssociation []*uapi.PerformSaveSSOAssociationRequest
gotQueryLocalpart []*uapi.QueryLocalpartForSSORequest
}
func (userAPI *fakeUserAPIForSSO) PerformAccountCreation(ctx context.Context, req *uapi.PerformAccountCreationRequest, res *uapi.PerformAccountCreationResponse) error {
userAPI.gotAccountCreation = append(userAPI.gotAccountCreation, req)
return userAPI.accountCreationErr
}
func (userAPI *fakeUserAPIForSSO) PerformLoginTokenCreation(ctx context.Context, req *uapi.PerformLoginTokenCreationRequest, res *uapi.PerformLoginTokenCreationResponse) error {
userAPI.gotLoginTokenCreation = append(userAPI.gotLoginTokenCreation, req)
res.Metadata = uapi.LoginTokenMetadata{
Token: "atoken",
}
return userAPI.tokenTokenCreationErr
}
func (userAPI *fakeUserAPIForSSO) PerformSaveSSOAssociation(ctx context.Context, req *uapi.PerformSaveSSOAssociationRequest, res *struct{}) error {
userAPI.gotSaveSSOAssociation = append(userAPI.gotSaveSSOAssociation, req)
return userAPI.saveSSOAssociationErr
}
func (userAPI *fakeUserAPIForSSO) QueryLocalpartForSSO(ctx context.Context, req *uapi.QueryLocalpartForSSORequest, res *uapi.QueryLocalpartForSSOResponse) error {
userAPI.gotQueryLocalpart = append(userAPI.gotQueryLocalpart, req)
res.Localpart = userAPI.localpart
return userAPI.localpartErr
}
func (userAPI *fakeUserAPIForSSO) QueryNumericLocalpart(ctx context.Context, res *uapi.QueryNumericLocalpartResponse) error {
res.ID = 12345
return userAPI.numericLocalpartErr
}