mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-09 07:03:10 -06:00
Renames cookie oidc_nonce to sso_nonce, since it's defined in a file that doesn't know about OIDC specifically.
532 lines
16 KiB
Go
532 lines
16 KiB
Go
package routing
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"errors"
|
|
"net/http"
|
|
"net/url"
|
|
"regexp"
|
|
"testing"
|
|
|
|
"github.com/google/go-cmp/cmp"
|
|
"github.com/matrix-org/dendrite/clientapi/auth/sso"
|
|
"github.com/matrix-org/dendrite/setup/config"
|
|
uapi "github.com/matrix-org/dendrite/userapi/api"
|
|
)
|
|
|
|
func TestSSORedirect(t *testing.T) {
|
|
tsts := []struct {
|
|
Name string
|
|
Req http.Request
|
|
IDPID string
|
|
Auth fakeSSOAuthenticator
|
|
Config config.SSO
|
|
|
|
WantLocationRE string
|
|
WantSetCookieRE string
|
|
}{
|
|
{
|
|
Name: "redirectDefault",
|
|
Req: http.Request{
|
|
Host: "matrix.example.com",
|
|
URL: &url.URL{
|
|
Path: "/_matrix/v4/login/sso/redirect",
|
|
RawQuery: url.Values{
|
|
"redirectUrl": []string{"http://example.com/continue"},
|
|
}.Encode(),
|
|
},
|
|
},
|
|
WantLocationRE: `http://auth.example.com/authorize\?callbackURL=http%3A%2F%2Fmatrix.example.com%2F_matrix%2Fv4%2Flogin%2Fsso%2Fcallback%3Fprovider%3D&nonce=.+&providerID=`,
|
|
WantSetCookieRE: "sso_nonce=[^;].*Path=/_matrix/v4/login/sso",
|
|
},
|
|
{
|
|
Name: "redirectExplicitProvider",
|
|
Req: http.Request{
|
|
Host: "matrix.example.com",
|
|
URL: &url.URL{
|
|
Path: "/_matrix/v4/login/sso/redirect",
|
|
RawQuery: url.Values{
|
|
"redirectUrl": []string{"http://example.com/continue"},
|
|
}.Encode(),
|
|
},
|
|
},
|
|
IDPID: "someprovider",
|
|
WantLocationRE: `http://auth.example.com/authorize\?callbackURL=http.*%3Fprovider%3Dsomeprovider&nonce=.+&providerID=someprovider`,
|
|
WantSetCookieRE: "sso_nonce=[^;].*Path=/_matrix/v4/login/sso",
|
|
},
|
|
}
|
|
for _, tst := range tsts {
|
|
t.Run(tst.Name, func(t *testing.T) {
|
|
got := SSORedirect(&tst.Req, tst.IDPID, &tst.Auth, &tst.Config)
|
|
|
|
if want := http.StatusFound; got.Code != want {
|
|
t.Errorf("SSORedirect Code: got %v, want %v", got.Code, want)
|
|
}
|
|
|
|
if m, err := regexp.MatchString(tst.WantLocationRE, got.Headers["Location"]); err != nil {
|
|
t.Fatalf("WantSetCookieRE failed: %v", err)
|
|
} else if !m {
|
|
t.Errorf("SSORedirect Location: got %q, want match %v", got.Headers["Location"], tst.WantLocationRE)
|
|
}
|
|
|
|
if m, err := regexp.MatchString(tst.WantSetCookieRE, got.Headers["Set-Cookie"]); err != nil {
|
|
t.Fatalf("WantSetCookieRE failed: %v", err)
|
|
} else if !m {
|
|
t.Errorf("SSORedirect Set-Cookie: got %q, want match %v", got.Headers["Set-Cookie"], tst.WantSetCookieRE)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSSORedirectError(t *testing.T) {
|
|
tsts := []struct {
|
|
Name string
|
|
Req http.Request
|
|
IDPID string
|
|
Auth fakeSSOAuthenticator
|
|
Config config.SSO
|
|
|
|
WantCode int
|
|
}{
|
|
{
|
|
Name: "missingRedirectUrl",
|
|
Req: http.Request{
|
|
Host: "matrix.example.com",
|
|
URL: &url.URL{
|
|
Path: "/_matrix/v4/login/sso/redirect",
|
|
RawQuery: url.Values{}.Encode(),
|
|
},
|
|
},
|
|
WantCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
Name: "invalidRedirectUrl",
|
|
Req: http.Request{
|
|
Host: "matrix.example.com",
|
|
URL: &url.URL{
|
|
Path: "/_matrix/v4/login/sso/redirect",
|
|
RawQuery: url.Values{
|
|
"redirectUrl": []string{"/continue"},
|
|
}.Encode(),
|
|
},
|
|
},
|
|
WantCode: http.StatusBadRequest,
|
|
},
|
|
}
|
|
for _, tst := range tsts {
|
|
t.Run(tst.Name, func(t *testing.T) {
|
|
got := SSORedirect(&tst.Req, tst.IDPID, &tst.Auth, &tst.Config)
|
|
|
|
if got.Code != tst.WantCode {
|
|
t.Errorf("SSORedirect Code: got %v, want %v", got.Code, tst.WantCode)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSSOCallback(t *testing.T) {
|
|
nonce := "1234." + base64.RawURLEncoding.EncodeToString([]byte("http://matrix.example.com/continue"))
|
|
|
|
tsts := []struct {
|
|
Name string
|
|
Req http.Request
|
|
UserAPI fakeUserAPIForSSO
|
|
Auth fakeSSOAuthenticator
|
|
Config config.SSO
|
|
|
|
WantLocationRE string
|
|
WantSetCookieRE string
|
|
|
|
WantAccountCreation []*uapi.PerformAccountCreationRequest
|
|
WantLoginTokenCreation []*uapi.PerformLoginTokenCreationRequest
|
|
WantSaveSSOAssociation []*uapi.PerformSaveSSOAssociationRequest
|
|
WantQueryLocalpart []*uapi.QueryLocalpartForSSORequest
|
|
}{
|
|
{
|
|
Name: "logIn",
|
|
Req: http.Request{
|
|
Host: "matrix.example.com",
|
|
URL: &url.URL{
|
|
Path: "/_matrix/v4/login/sso/callback",
|
|
RawQuery: url.Values{
|
|
"provider": []string{"aprovider"},
|
|
}.Encode(),
|
|
},
|
|
Header: http.Header{
|
|
"Cookie": []string{(&http.Cookie{
|
|
Name: "sso_nonce",
|
|
Value: nonce,
|
|
}).String()},
|
|
},
|
|
},
|
|
UserAPI: fakeUserAPIForSSO{
|
|
localpart: "alocalpart",
|
|
},
|
|
Auth: fakeSSOAuthenticator{
|
|
callbackResult: sso.CallbackResult{
|
|
Identifier: &sso.UserIdentifier{
|
|
Namespace: "anamespace",
|
|
Issuer: "anissuer",
|
|
Subject: "asubject",
|
|
},
|
|
},
|
|
},
|
|
WantLocationRE: `http://matrix.example.com/continue\?loginToken=atoken`,
|
|
WantSetCookieRE: "sso_nonce=;",
|
|
|
|
WantLoginTokenCreation: []*uapi.PerformLoginTokenCreationRequest{{Data: uapi.LoginTokenData{UserID: "@alocalpart:aservername"}}},
|
|
WantQueryLocalpart: []*uapi.QueryLocalpartForSSORequest{{Namespace: "anamespace", Issuer: "anissuer", Subject: "asubject"}},
|
|
},
|
|
{
|
|
Name: "registerSuggested",
|
|
Req: http.Request{
|
|
Host: "matrix.example.com",
|
|
URL: &url.URL{
|
|
Path: "/_matrix/v4/login/sso/callback",
|
|
RawQuery: url.Values{
|
|
"provider": []string{"aprovider"},
|
|
}.Encode(),
|
|
},
|
|
Header: http.Header{
|
|
"Cookie": []string{(&http.Cookie{
|
|
Name: "sso_nonce",
|
|
Value: nonce,
|
|
}).String()},
|
|
},
|
|
},
|
|
Auth: fakeSSOAuthenticator{
|
|
callbackResult: sso.CallbackResult{
|
|
Identifier: &sso.UserIdentifier{
|
|
Namespace: "anamespace",
|
|
Issuer: "anissuer",
|
|
Subject: "asubject",
|
|
},
|
|
SuggestedUserID: "asuggestedid",
|
|
},
|
|
},
|
|
WantLocationRE: `http://matrix.example.com/continue\?loginToken=atoken`,
|
|
WantSetCookieRE: "sso_nonce=;",
|
|
|
|
WantAccountCreation: []*uapi.PerformAccountCreationRequest{{Localpart: "asuggestedid", AccountType: uapi.AccountTypeUser, OnConflict: uapi.ConflictAbort}},
|
|
WantLoginTokenCreation: []*uapi.PerformLoginTokenCreationRequest{{Data: uapi.LoginTokenData{UserID: "@asuggestedid:aservername"}}},
|
|
WantSaveSSOAssociation: []*uapi.PerformSaveSSOAssociationRequest{{Namespace: "anamespace", Issuer: "anissuer", Subject: "asubject", Localpart: "asuggestedid"}},
|
|
WantQueryLocalpart: []*uapi.QueryLocalpartForSSORequest{{Namespace: "anamespace", Issuer: "anissuer", Subject: "asubject"}},
|
|
},
|
|
{
|
|
Name: "registerNumeric",
|
|
Req: http.Request{
|
|
Host: "matrix.example.com",
|
|
URL: &url.URL{
|
|
Path: "/_matrix/v4/login/sso/callback",
|
|
RawQuery: url.Values{
|
|
"provider": []string{"aprovider"},
|
|
}.Encode(),
|
|
},
|
|
Header: http.Header{
|
|
"Cookie": []string{(&http.Cookie{
|
|
Name: "sso_nonce",
|
|
Value: nonce,
|
|
}).String()},
|
|
},
|
|
},
|
|
Auth: fakeSSOAuthenticator{
|
|
callbackResult: sso.CallbackResult{
|
|
Identifier: &sso.UserIdentifier{
|
|
Namespace: "anamespace",
|
|
Issuer: "anissuer",
|
|
Subject: "asubject",
|
|
},
|
|
},
|
|
},
|
|
WantLocationRE: `http://matrix.example.com/continue\?loginToken=atoken`,
|
|
WantSetCookieRE: "sso_nonce=;",
|
|
|
|
WantAccountCreation: []*uapi.PerformAccountCreationRequest{{Localpart: "12345", AccountType: uapi.AccountTypeUser, OnConflict: uapi.ConflictAbort}},
|
|
WantLoginTokenCreation: []*uapi.PerformLoginTokenCreationRequest{{Data: uapi.LoginTokenData{UserID: "@12345:aservername"}}},
|
|
WantSaveSSOAssociation: []*uapi.PerformSaveSSOAssociationRequest{{Namespace: "anamespace", Issuer: "anissuer", Subject: "asubject", Localpart: "12345"}},
|
|
WantQueryLocalpart: []*uapi.QueryLocalpartForSSORequest{{Namespace: "anamespace", Issuer: "anissuer", Subject: "asubject"}},
|
|
},
|
|
{
|
|
Name: "noIdentifierRedirectURL",
|
|
Req: http.Request{
|
|
Host: "matrix.example.com",
|
|
URL: &url.URL{
|
|
Path: "/_matrix/v4/login/sso/callback",
|
|
RawQuery: url.Values{
|
|
"provider": []string{"aprovider"},
|
|
}.Encode(),
|
|
},
|
|
Header: http.Header{
|
|
"Cookie": []string{(&http.Cookie{
|
|
Name: "sso_nonce",
|
|
Value: nonce,
|
|
}).String()},
|
|
},
|
|
},
|
|
Auth: fakeSSOAuthenticator{
|
|
callbackResult: sso.CallbackResult{
|
|
RedirectURL: "http://auth.example.com/notdone",
|
|
},
|
|
},
|
|
WantLocationRE: `http://auth.example.com/notdone`,
|
|
WantSetCookieRE: "^$",
|
|
},
|
|
}
|
|
for _, tst := range tsts {
|
|
t.Run(tst.Name, func(t *testing.T) {
|
|
got := SSOCallback(&tst.Req, &tst.UserAPI, &tst.Auth, &tst.Config, "aservername")
|
|
|
|
if want := http.StatusFound; got.Code != want {
|
|
t.Log(got)
|
|
t.Errorf("SSOCallback Code: got %v, want %v", got.Code, want)
|
|
}
|
|
|
|
if m, err := regexp.MatchString(tst.WantLocationRE, got.Headers["Location"]); err != nil {
|
|
t.Fatalf("WantSetCookieRE failed: %v", err)
|
|
} else if !m {
|
|
t.Errorf("SSOCallback Location: got %q, want match %v", got.Headers["Location"], tst.WantLocationRE)
|
|
}
|
|
|
|
if m, err := regexp.MatchString(tst.WantSetCookieRE, got.Headers["Set-Cookie"]); err != nil {
|
|
t.Fatalf("WantSetCookieRE failed: %v", err)
|
|
} else if !m {
|
|
t.Errorf("SSOCallback Set-Cookie: got %q, want match %v", got.Headers["Set-Cookie"], tst.WantSetCookieRE)
|
|
}
|
|
|
|
if diff := cmp.Diff(tst.WantAccountCreation, tst.UserAPI.gotAccountCreation); diff != "" {
|
|
t.Errorf("PerformAccountCreation: +got -want:\n%s", diff)
|
|
}
|
|
if diff := cmp.Diff(tst.WantLoginTokenCreation, tst.UserAPI.gotLoginTokenCreation); diff != "" {
|
|
t.Errorf("PerformLoginTokenCreation: +got -want:\n%s", diff)
|
|
}
|
|
if diff := cmp.Diff(tst.WantSaveSSOAssociation, tst.UserAPI.gotSaveSSOAssociation); diff != "" {
|
|
t.Errorf("PerformSaveSSOAssociation: +got -want:\n%s", diff)
|
|
}
|
|
if diff := cmp.Diff(tst.WantQueryLocalpart, tst.UserAPI.gotQueryLocalpart); diff != "" {
|
|
t.Errorf("QueryLocalpartForSSO: +got -want:\n%s", diff)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSSOCallbackError(t *testing.T) {
|
|
nonce := "1234." + base64.RawURLEncoding.EncodeToString([]byte("http://matrix.example.com/continue"))
|
|
goodReq := http.Request{
|
|
Host: "matrix.example.com",
|
|
URL: &url.URL{
|
|
Path: "/_matrix/v4/login/sso/callback",
|
|
RawQuery: url.Values{
|
|
"provider": []string{"aprovider"},
|
|
}.Encode(),
|
|
},
|
|
Header: http.Header{
|
|
"Cookie": []string{(&http.Cookie{
|
|
Name: "sso_nonce",
|
|
Value: nonce,
|
|
}).String()},
|
|
},
|
|
}
|
|
goodAuth := fakeSSOAuthenticator{
|
|
callbackResult: sso.CallbackResult{
|
|
Identifier: &sso.UserIdentifier{
|
|
Namespace: "anamespace",
|
|
Issuer: "anissuer",
|
|
Subject: "asubject",
|
|
},
|
|
},
|
|
}
|
|
errMocked := errors.New("mocked error")
|
|
|
|
tsts := []struct {
|
|
Name string
|
|
Req http.Request
|
|
UserAPI fakeUserAPIForSSO
|
|
Auth fakeSSOAuthenticator
|
|
Config config.SSO
|
|
|
|
WantCode int
|
|
}{
|
|
{
|
|
Name: "missingProvider",
|
|
Req: http.Request{
|
|
Host: "matrix.example.com",
|
|
URL: &url.URL{
|
|
Path: "/_matrix/v4/login/sso/callback",
|
|
},
|
|
Header: http.Header{
|
|
"Cookie": []string{(&http.Cookie{
|
|
Name: "sso_nonce",
|
|
Value: nonce,
|
|
}).String()},
|
|
},
|
|
},
|
|
WantCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
Name: "missingCookie",
|
|
Req: http.Request{
|
|
Host: "matrix.example.com",
|
|
URL: &url.URL{
|
|
Path: "/_matrix/v4/login/sso/callback",
|
|
RawQuery: url.Values{
|
|
"provider": []string{"aprovider"},
|
|
}.Encode(),
|
|
},
|
|
},
|
|
WantCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
Name: "malformedCookie",
|
|
Req: http.Request{
|
|
Host: "matrix.example.com",
|
|
URL: &url.URL{
|
|
Path: "/_matrix/v4/login/sso/callback",
|
|
RawQuery: url.Values{
|
|
"provider": []string{"aprovider"},
|
|
}.Encode(),
|
|
},
|
|
Header: http.Header{
|
|
"Cookie": []string{(&http.Cookie{
|
|
Name: "sso_nonce",
|
|
Value: "badvalue",
|
|
}).String()},
|
|
},
|
|
},
|
|
WantCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
Name: "failedProcessCallback",
|
|
Req: goodReq,
|
|
Auth: fakeSSOAuthenticator{
|
|
callbackErr: errMocked,
|
|
},
|
|
WantCode: http.StatusInternalServerError,
|
|
},
|
|
{
|
|
Name: "failedQueryLocalpartForSSO",
|
|
Req: goodReq,
|
|
UserAPI: fakeUserAPIForSSO{
|
|
localpartErr: errMocked,
|
|
},
|
|
Auth: goodAuth,
|
|
WantCode: http.StatusUnauthorized,
|
|
},
|
|
{
|
|
Name: "failedQueryNumericLocalpart",
|
|
Req: goodReq,
|
|
UserAPI: fakeUserAPIForSSO{
|
|
numericLocalpartErr: errMocked,
|
|
},
|
|
Auth: goodAuth,
|
|
WantCode: http.StatusInternalServerError,
|
|
},
|
|
{
|
|
Name: "failedAccountCreation",
|
|
Req: goodReq,
|
|
UserAPI: fakeUserAPIForSSO{
|
|
accountCreationErr: errMocked,
|
|
},
|
|
Auth: goodAuth,
|
|
WantCode: http.StatusInternalServerError,
|
|
},
|
|
{
|
|
Name: "failedSaveSSOAssociation",
|
|
Req: goodReq,
|
|
UserAPI: fakeUserAPIForSSO{
|
|
saveSSOAssociationErr: errMocked,
|
|
},
|
|
Auth: goodAuth,
|
|
WantCode: http.StatusInternalServerError,
|
|
},
|
|
{
|
|
Name: "failedPerformLoginTokenCreation",
|
|
Req: goodReq,
|
|
UserAPI: fakeUserAPIForSSO{
|
|
localpart: "alocalpart",
|
|
tokenTokenCreationErr: errMocked,
|
|
},
|
|
Auth: goodAuth,
|
|
WantCode: http.StatusInternalServerError,
|
|
},
|
|
}
|
|
for _, tst := range tsts {
|
|
t.Run(tst.Name, func(t *testing.T) {
|
|
got := SSOCallback(&tst.Req, &tst.UserAPI, &tst.Auth, &tst.Config, "aservername")
|
|
|
|
if got.Code != tst.WantCode {
|
|
t.Log(got)
|
|
t.Errorf("SSOCallback Code: got %v, want %v", got.Code, tst.WantCode)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
type fakeSSOAuthenticator struct {
|
|
callbackResult sso.CallbackResult
|
|
callbackErr error
|
|
}
|
|
|
|
func (auth *fakeSSOAuthenticator) AuthorizationURL(ctx context.Context, providerID, callbackURL, nonce string) (string, error) {
|
|
return (&url.URL{
|
|
Scheme: "http",
|
|
Host: "auth.example.com",
|
|
Path: "/authorize",
|
|
}).ResolveReference(&url.URL{
|
|
RawQuery: url.Values{
|
|
"callbackURL": []string{callbackURL},
|
|
"nonce": []string{nonce},
|
|
"providerID": []string{providerID},
|
|
}.Encode(),
|
|
}).String(), nil
|
|
}
|
|
|
|
func (auth *fakeSSOAuthenticator) ProcessCallback(ctx context.Context, providerID, callbackURL, nonce string, query url.Values) (*sso.CallbackResult, error) {
|
|
return &auth.callbackResult, auth.callbackErr
|
|
}
|
|
|
|
type fakeUserAPIForSSO struct {
|
|
userAPIForSSO
|
|
|
|
accountCreationErr error
|
|
tokenTokenCreationErr error
|
|
saveSSOAssociationErr error
|
|
localpart string
|
|
localpartErr error
|
|
numericLocalpartErr error
|
|
|
|
gotAccountCreation []*uapi.PerformAccountCreationRequest
|
|
gotLoginTokenCreation []*uapi.PerformLoginTokenCreationRequest
|
|
gotSaveSSOAssociation []*uapi.PerformSaveSSOAssociationRequest
|
|
gotQueryLocalpart []*uapi.QueryLocalpartForSSORequest
|
|
}
|
|
|
|
func (userAPI *fakeUserAPIForSSO) PerformAccountCreation(ctx context.Context, req *uapi.PerformAccountCreationRequest, res *uapi.PerformAccountCreationResponse) error {
|
|
userAPI.gotAccountCreation = append(userAPI.gotAccountCreation, req)
|
|
return userAPI.accountCreationErr
|
|
}
|
|
|
|
func (userAPI *fakeUserAPIForSSO) PerformLoginTokenCreation(ctx context.Context, req *uapi.PerformLoginTokenCreationRequest, res *uapi.PerformLoginTokenCreationResponse) error {
|
|
userAPI.gotLoginTokenCreation = append(userAPI.gotLoginTokenCreation, req)
|
|
res.Metadata = uapi.LoginTokenMetadata{
|
|
Token: "atoken",
|
|
}
|
|
return userAPI.tokenTokenCreationErr
|
|
}
|
|
|
|
func (userAPI *fakeUserAPIForSSO) PerformSaveSSOAssociation(ctx context.Context, req *uapi.PerformSaveSSOAssociationRequest, res *struct{}) error {
|
|
userAPI.gotSaveSSOAssociation = append(userAPI.gotSaveSSOAssociation, req)
|
|
return userAPI.saveSSOAssociationErr
|
|
}
|
|
|
|
func (userAPI *fakeUserAPIForSSO) QueryLocalpartForSSO(ctx context.Context, req *uapi.QueryLocalpartForSSORequest, res *uapi.QueryLocalpartForSSOResponse) error {
|
|
userAPI.gotQueryLocalpart = append(userAPI.gotQueryLocalpart, req)
|
|
res.Localpart = userAPI.localpart
|
|
return userAPI.localpartErr
|
|
}
|
|
|
|
func (userAPI *fakeUserAPIForSSO) QueryNumericLocalpart(ctx context.Context, res *uapi.QueryNumericLocalpartResponse) error {
|
|
res.ID = 12345
|
|
return userAPI.numericLocalpartErr
|
|
}
|