Add SSO tests.

Renames cookie oidc_nonce to sso_nonce, since it's defined in a file
that doesn't know about OIDC specifically.
This commit is contained in:
Tommie Gannert 2022-06-08 09:14:11 +02:00
parent b8844fb1e2
commit 210ab1eef6
11 changed files with 1079 additions and 18 deletions

View file

@ -101,6 +101,34 @@ func TestLoginFromJSONReader(t *testing.T) {
}
}
func TestLoginFromJSONReaderTokenDisabled(t *testing.T) {
ctx := context.Background()
var userAPI fakeUserInternalAPI
cfg := &config.ClientAPI{
Matrix: &config.Global{
ServerName: serverName,
},
Login: config.Login{
SSO: config.SSO{
Enabled: false,
},
},
}
_, cleanup, err := LoginFromJSONReader(ctx, strings.NewReader(`{
"type": "m.login.token",
"token": "atoken",
"device_id": "adevice"
}`), &userAPI, &userAPI, cfg)
wantCode := "M_INVALID_ARGUMENT_VALUE"
if err == nil {
cleanup(ctx, nil)
t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", err, wantCode)
} else if merr, ok := err.JSON.(*jsonerror.MatrixError); ok && merr.ErrCode != wantCode {
t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", err, wantCode)
}
}
func TestBadLoginFromJSONReader(t *testing.T) {
ctx := context.Background()

View file

@ -0,0 +1,226 @@
package sso
import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"testing"
"github.com/matrix-org/dendrite/setup/config"
uapi "github.com/matrix-org/dendrite/userapi/api"
)
func TestOAuth2IdentityProviderAuthorizationURL(t *testing.T) {
ctx := context.Background()
idp := &oauth2IdentityProvider{
cfg: &config.IdentityProvider{
OAuth2: config.OAuth2{
ClientID: "aclientid",
},
},
hc: http.DefaultClient,
authorizationURL: "https://oauth2.example.com/authorize",
}
got, err := idp.AuthorizationURL(ctx, "https://matrix.example.com/continue", "anonce")
if err != nil {
t.Fatalf("AuthorizationURL failed: %v", err)
}
if want := "https://oauth2.example.com/authorize?client_id=aclientid&redirect_uri=https%3A%2F%2Fmatrix.example.com%2Fcontinue&response_type=code&scope=&state=anonce"; got != want {
t.Errorf("AuthorizationURL: got %q, want %q", got, want)
}
}
func TestOAuth2IdentityProviderProcessCallback(t *testing.T) {
ctx := context.Background()
const callbackURL = "https://matrix.example.com/continue"
tsts := []struct {
Name string
Query url.Values
Want *CallbackResult
WantTokenReq url.Values
}{
{
Name: "gotEverything",
Query: url.Values{
"code": []string{"acode"},
"state": []string{"anonce"},
},
Want: &CallbackResult{
Identifier: &UserIdentifier{
Namespace: uapi.SSOIDNamespace,
Issuer: "anid",
Subject: "asub",
},
DisplayName: "aname",
SuggestedUserID: "auser",
},
},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"access_token":"atoken", "token_type":"Bearer"}`))
})
mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"sub":"asub", "name":"aname", "preferred_user":"auser"}`))
})
s := httptest.NewServer(mux)
defer s.Close()
idp := &oauth2IdentityProvider{
cfg: &config.IdentityProvider{
ID: "anid",
OAuth2: config.OAuth2{
ClientID: "aclientid",
ClientSecret: "aclientsecret",
},
},
hc: s.Client(),
accessTokenURL: s.URL + "/token",
userInfoURL: s.URL + "/userinfo",
subPath: "sub",
displayNamePath: "name",
suggestedUserIDPath: "preferred_user",
}
got, err := idp.ProcessCallback(ctx, callbackURL, "anonce", tst.Query)
if err != nil {
t.Fatalf("ProcessCallback failed: %v", err)
}
if !reflect.DeepEqual(got, tst.Want) {
t.Errorf("ProcessCallback: got %+v, want %+v", got, tst.Want)
}
})
}
}
func TestOAuth2IdentityProviderGetAccessToken(t *testing.T) {
ctx := context.Background()
const callbackURL = "https://matrix.example.com/continue"
mux := http.NewServeMux()
var gotReq url.Values
mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
gotReq = r.Form
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"access_token":"atoken", "token_type":"Bearer"}`))
})
s := httptest.NewServer(mux)
defer s.Close()
idp := &oauth2IdentityProvider{
cfg: &config.IdentityProvider{
ID: "anid",
OAuth2: config.OAuth2{
ClientID: "aclientid",
ClientSecret: "aclientsecret",
},
},
hc: s.Client(),
accessTokenURL: s.URL + "/token",
}
got, err := idp.getAccessToken(ctx, callbackURL, "acode")
if err != nil {
t.Fatalf("getAccessToken failed: %v", err)
}
if want := "atoken"; got != want {
t.Errorf("getAccessToken: got %q, want %q", got, want)
}
wantReq := url.Values{
"client_id": []string{"aclientid"},
"client_secret": []string{"aclientsecret"},
"code": []string{"acode"},
"grant_type": []string{"authorization_code"},
"redirect_uri": []string{callbackURL},
}
if !reflect.DeepEqual(gotReq, wantReq) {
t.Errorf("getAccessToken request: got %+v, want %+v", gotReq, wantReq)
}
}
func TestOAuth2IdentityProviderGetUserInfo(t *testing.T) {
ctx := context.Background()
mux := http.NewServeMux()
var gotHeader http.Header
mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) {
gotHeader = r.Header
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"sub":"asub", "name":"aname", "preferred_user":"auser"}`))
})
s := httptest.NewServer(mux)
defer s.Close()
idp := &oauth2IdentityProvider{
cfg: &config.IdentityProvider{
ID: "anid",
OAuth2: config.OAuth2{
ClientID: "aclientid",
ClientSecret: "aclientsecret",
},
},
hc: s.Client(),
userInfoURL: s.URL + "/userinfo",
responseMimeType: "application/json",
subPath: "sub",
displayNamePath: "name",
suggestedUserIDPath: "preferred_user",
}
gotSub, gotName, gotSuggestedUser, err := idp.getUserInfo(ctx, "atoken")
if err != nil {
t.Fatalf("getUserInfo failed: %v", err)
}
if want := "asub"; gotSub != want {
t.Errorf("getUserInfo subject: got %q, want %q", gotSub, want)
}
if want := "aname"; gotName != want {
t.Errorf("getUserInfo displayName: got %q, want %q", gotName, want)
}
if want := "auser"; gotSuggestedUser != want {
t.Errorf("getUserInfo suggestedUser: got %q, want %q", gotSuggestedUser, want)
}
gotHeader.Del("Accept-Encoding")
gotHeader.Del("User-Agent")
wantHeader := http.Header{
"Accept": []string{"application/json"},
"Authorization": []string{"Bearer atoken"},
}
if !reflect.DeepEqual(gotHeader, wantHeader) {
t.Errorf("getUserInfo header: got %+v, want %+v", gotHeader, wantHeader)
}
}

View file

@ -27,6 +27,18 @@ import (
uapi "github.com/matrix-org/dendrite/userapi/api"
)
// oidcDiscoveryMaxStaleness indicates how stale the Discovery
// information is allowed to be. This will very rarely change, so
// we're just making sure even a Dendrite that isn't restarting often
// is picking this up eventually.
const oidcDiscoveryMaxStaleness = 24 * time.Hour
// An oidcIdentityProvider wraps OAuth2 with OpenID Connect Discovery.
//
// The SSO identifier is the "sub." A suggested UserID is grabbed from
// "preferred_username", though this isn't commonly provided.
//
// See https://openid.net/specs/openid-connect-core-1_0.html and https://openid.net/specs/openid-connect-discovery-1_0.html.
type oidcIdentityProvider struct {
*oauth2IdentityProvider
@ -44,7 +56,7 @@ func newOIDCIdentityProvider(cfg *config.IdentityProvider, hc *http.Client) *oid
scopes: []string{"openid", "profile", "email"},
responseMimeType: "application/json",
subPath: "sub",
emailPath: "email",
emailPath: "email", // TODO: should this require email_verified?
displayNamePath: "name",
suggestedUserIDPath: "preferred_username",
},
@ -92,7 +104,7 @@ func (p *oidcIdentityProvider) get(ctx context.Context) (*oauth2IdentityProvider
return nil, nil, err
}
p.exp = now.Add(24 * time.Hour)
p.exp = now.Add(oidcDiscoveryMaxStaleness)
newProvider := *p.oauth2IdentityProvider
newProvider.authorizationURL = disc.AuthorizationEndpoint
newProvider.accessTokenURL = disc.TokenEndpoint

View file

@ -0,0 +1,118 @@
package sso
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"testing"
"github.com/matrix-org/dendrite/setup/config"
uapi "github.com/matrix-org/dendrite/userapi/api"
)
func TestOIDCIdentityProviderAuthorizationURL(t *testing.T) {
ctx := context.Background()
mux := http.NewServeMux()
mux.HandleFunc("/discovery", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"authorization_endpoint":"http://oidc.example.com/authorize","token_endpoint":"http://oidc.example.com/token","userinfo_endpoint":"http://oidc.example.com/userinfo","issuer":"http://oidc.example.com/"}`))
})
s := httptest.NewServer(mux)
defer s.Close()
idp := newOIDCIdentityProvider(&config.IdentityProvider{
OAuth2: config.OAuth2{
ClientID: "aclientid",
},
OIDC: config.OIDC{
DiscoveryURL: s.URL + "/discovery",
},
}, s.Client())
got, err := idp.AuthorizationURL(ctx, "https://matrix.example.com/continue", "anonce")
if err != nil {
t.Fatalf("AuthorizationURL failed: %v", err)
}
if want := "http://oidc.example.com/authorize?client_id=aclientid&redirect_uri=https%3A%2F%2Fmatrix.example.com%2Fcontinue&response_type=code&scope=openid+profile+email&state=anonce"; got != want {
t.Errorf("AuthorizationURL: got %q, want %q", got, want)
}
}
func TestOIDCIdentityProviderProcessCallback(t *testing.T) {
ctx := context.Background()
const callbackURL = "https://matrix.example.com/continue"
tsts := []struct {
Name string
Query url.Values
Want *CallbackResult
WantTokenReq url.Values
}{
{
Name: "gotEverything",
Query: url.Values{
"code": []string{"acode"},
"state": []string{"anonce"},
},
Want: &CallbackResult{
Identifier: &UserIdentifier{
Namespace: uapi.OIDCNamespace,
Issuer: "http://oidc.example.com/",
Subject: "asub",
},
DisplayName: "aname",
SuggestedUserID: "auser",
},
},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
mux := http.NewServeMux()
var sURL string
mux.HandleFunc("/discovery", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(fmt.Sprintf(`{"authorization_endpoint":"%s/authorize","token_endpoint":"%s/token","userinfo_endpoint":"%s/userinfo","issuer":"http://oidc.example.com/"}`,
sURL, sURL, sURL)))
})
mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"access_token":"atoken", "token_type":"Bearer"}`))
})
mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"sub":"asub", "name":"aname", "preferred_username":"auser"}`))
})
s := httptest.NewServer(mux)
defer s.Close()
sURL = s.URL
idp := newOIDCIdentityProvider(&config.IdentityProvider{
OAuth2: config.OAuth2{
ClientID: "aclientid",
},
OIDC: config.OIDC{
DiscoveryURL: sURL + "/discovery",
},
}, s.Client())
got, err := idp.ProcessCallback(ctx, callbackURL, "anonce", tst.Query)
if err != nil {
t.Fatalf("ProcessCallback failed: %v", err)
}
if !reflect.DeepEqual(got, tst.Want) {
t.Errorf("ProcessCallback: got %+v, want %+v", got, tst.Want)
}
})
}
}

View file

@ -25,13 +25,19 @@ import (
uapi "github.com/matrix-org/dendrite/userapi/api"
)
// maxHTTPTimeout is an upper bound on an HTTP request to an SSO
// backend. The individual request context deadlines are also honored.
const maxHTTPTimeout = 10 * time.Second
// An Authenticator keeps a set of identity providers and dispatches
// calls to one of them, based on configured ID.
type Authenticator struct {
providers map[string]identityProvider
}
func NewAuthenticator(ctx context.Context, cfg *config.SSO) (*Authenticator, error) {
func NewAuthenticator(cfg *config.SSO) (*Authenticator, error) {
hc := &http.Client{
Timeout: 10 * time.Second,
Timeout: maxHTTPTimeout,
Transport: &http.Transport{
DisableKeepAlives: true,
Proxy: http.ProxyFromEnvironment,

View file

@ -0,0 +1,76 @@
package sso
import (
"context"
"net/url"
"reflect"
"testing"
"github.com/matrix-org/dendrite/setup/config"
)
func TestNewAuthenticator(t *testing.T) {
_, err := NewAuthenticator(&config.SSO{
Providers: []config.IdentityProvider{
{
Type: config.SSOTypeGitHub,
OAuth2: config.OAuth2{
ClientID: "aclientid",
},
},
{
Type: config.SSOTypeOIDC,
OAuth2: config.OAuth2{
ClientID: "aclientid",
},
OIDC: config.OIDC{
DiscoveryURL: "http://oidc.example.com/discovery",
},
},
},
})
if err != nil {
t.Fatalf("NewAuthenticator failed: %v", err)
}
}
func TestAuthenticator(t *testing.T) {
ctx := context.Background()
var idp fakeIdentityProvider
a := Authenticator{
providers: map[string]identityProvider{
"fake": &idp,
},
}
t.Run("authorizationURL", func(t *testing.T) {
got, err := a.AuthorizationURL(ctx, "fake", "http://matrix.example.com/continue", "anonce")
if err != nil {
t.Fatalf("AuthorizationURL failed: %v", err)
}
if want := "aurl"; got != want {
t.Errorf("AuthorizationURL: got %q, want %q", got, want)
}
})
t.Run("processCallback", func(t *testing.T) {
got, err := a.ProcessCallback(ctx, "fake", "http://matrix.example.com/continue", "anonce", url.Values{})
if err != nil {
t.Fatalf("ProcessCallback failed: %v", err)
}
if want := (&CallbackResult{DisplayName: "aname"}); !reflect.DeepEqual(got, want) {
t.Errorf("ProcessCallback: got %+v, want %+v", got, want)
}
})
}
type fakeIdentityProvider struct{}
func (idp *fakeIdentityProvider) AuthorizationURL(ctx context.Context, callbackURL, nonce string) (string, error) {
return "aurl", nil
}
func (idp *fakeIdentityProvider) ProcessCallback(ctx context.Context, callbackURL, nonce string, query url.Values) (*CallbackResult, error) {
return &CallbackResult{DisplayName: "aname"}, nil
}

View file

@ -73,7 +73,7 @@ func Setup(
var ssoAuthenticator *sso.Authenticator
if cfg.Login.SSO.Enabled {
var err error
ssoAuthenticator, err = sso.NewAuthenticator(ctx, &cfg.Login.SSO)
ssoAuthenticator, err = sso.NewAuthenticator(&cfg.Login.SSO)
if err != nil {
logrus.WithError(err).Fatal("failed to create SSO authenticator")
}

View file

@ -39,7 +39,7 @@ import (
func SSORedirect(
req *http.Request,
idpID string,
auth *sso.Authenticator,
auth ssoAuthenticator,
cfg *config.SSO,
) util.JSONResponse {
ctx := req.Context()
@ -58,12 +58,16 @@ func SSORedirect(
JSON: jsonerror.MissingArgument("redirectUrl parameter missing"),
}
}
_, err := url.Parse(redirectURL)
if err != nil {
if ru, err := url.Parse(redirectURL); err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue("Invalid redirectURL: " + err.Error()),
}
} else if ru.Scheme == "" || ru.Host == "" || ru.Path == "" {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue("Invalid redirectURL: " + redirectURL),
}
}
callbackURL, err := buildCallbackURLFromOther(cfg, req, "/login/sso/redirect")
@ -92,7 +96,7 @@ func SSORedirect(
resp := util.RedirectResponse(u)
cookie := &http.Cookie{
Name: "oidc_nonce",
Name: "sso_nonce",
Value: nonce,
Path: path.Dir(callbackURL.Path),
Expires: time.Now().Add(10 * time.Minute),
@ -113,7 +117,6 @@ func SSORedirect(
func buildCallbackURLFromOther(cfg *config.SSO, req *http.Request, expectedPath string) (*url.URL, error) {
u := &url.URL{
Scheme: "https",
User: req.URL.User,
Host: req.Host,
Path: req.URL.Path,
}
@ -141,7 +144,7 @@ func buildCallbackURLFromOther(cfg *config.SSO, req *http.Request, expectedPath
func SSOCallback(
req *http.Request,
userAPI userAPIForSSO,
auth *sso.Authenticator,
auth ssoAuthenticator,
cfg *config.SSO,
serverName gomatrixserverlib.ServerName,
) util.JSONResponse {
@ -163,7 +166,7 @@ func SSOCallback(
}
}
nonce, err := req.Cookie("oidc_nonce")
nonce, err := req.Cookie("sso_nonce")
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
@ -246,7 +249,7 @@ func SSOCallback(
rquery.Set("loginToken", token.Token)
resp := util.RedirectResponse(finalRedirectURL.ResolveReference(&url.URL{RawQuery: rquery.Encode()}).String())
resp.Headers["Set-Cookie"] = (&http.Cookie{
Name: "oidc_nonce",
Name: "sso_nonce",
Value: "",
MaxAge: -1,
Secure: true,
@ -254,6 +257,11 @@ func SSOCallback(
return resp
}
type ssoAuthenticator interface {
AuthorizationURL(ctx context.Context, providerID, callbackURL, nonce string) (string, error)
ProcessCallback(ctx context.Context, providerID, callbackURL, nonce string, query url.Values) (*sso.CallbackResult, error)
}
type userAPIForSSO interface {
uapi.LoginTokenInternalAPI
@ -273,21 +281,21 @@ func formatNonce(redirectURL string) string {
// function. The URL is not integrity protected.
func parseNonce(s string) (redirectURL *url.URL, _ error) {
if s == "" {
return nil, jsonerror.MissingArgument("empty OIDC nonce cookie")
return nil, jsonerror.MissingArgument("empty SSO nonce cookie")
}
ss := strings.Split(s, ".")
if len(ss) < 2 {
return nil, jsonerror.InvalidArgumentValue("malformed OIDC nonce cookie")
return nil, jsonerror.InvalidArgumentValue("malformed SSO nonce cookie")
}
urlbs, err := base64.RawURLEncoding.DecodeString(ss[1])
if err != nil {
return nil, jsonerror.InvalidArgumentValue("invalid redirect URL in OIDC nonce cookie")
return nil, jsonerror.InvalidArgumentValue("invalid redirect URL in SSO nonce cookie")
}
u, err := url.Parse(string(urlbs))
if err != nil {
return nil, jsonerror.InvalidArgumentValue("invalid redirect URL in OIDC nonce cookie: " + err.Error())
return nil, jsonerror.InvalidArgumentValue("invalid redirect URL in SSO nonce cookie: " + err.Error())
}
return u, nil
@ -309,6 +317,9 @@ func verifySSOUserIdentifier(ctx context.Context, userAPI userAPIForSSO, id *sso
return res.Localpart, nil
}
// registerSSOAccount creates an account and associates the SSO
// identifier with it. Note that SSO login account creation doesn't
// use the standard registration API, but happens ad-hoc.
func registerSSOAccount(ctx context.Context, userAPI userAPIForSSO, ssoID *sso.UserIdentifier, localpart string) (bool, util.JSONResponse) {
var accRes uapi.PerformAccountCreationResponse
err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
@ -347,6 +358,8 @@ func registerSSOAccount(ctx context.Context, userAPI userAPIForSSO, ssoID *sso.U
return true, util.JSONResponse{}
}
// createLoginToken produces a new login token, valid for the given
// user.
func createLoginToken(ctx context.Context, userAPI userAPIForSSO, userID string) (*uapi.LoginTokenMetadata, error) {
req := uapi.PerformLoginTokenCreationRequest{Data: uapi.LoginTokenData{UserID: userID}}
var resp uapi.PerformLoginTokenCreationResponse

View file

@ -0,0 +1,531 @@
package routing
import (
"context"
"encoding/base64"
"errors"
"net/http"
"net/url"
"regexp"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/matrix-org/dendrite/clientapi/auth/sso"
"github.com/matrix-org/dendrite/setup/config"
uapi "github.com/matrix-org/dendrite/userapi/api"
)
func TestSSORedirect(t *testing.T) {
tsts := []struct {
Name string
Req http.Request
IDPID string
Auth fakeSSOAuthenticator
Config config.SSO
WantLocationRE string
WantSetCookieRE string
}{
{
Name: "redirectDefault",
Req: http.Request{
Host: "matrix.example.com",
URL: &url.URL{
Path: "/_matrix/v4/login/sso/redirect",
RawQuery: url.Values{
"redirectUrl": []string{"http://example.com/continue"},
}.Encode(),
},
},
WantLocationRE: `http://auth.example.com/authorize\?callbackURL=http%3A%2F%2Fmatrix.example.com%2F_matrix%2Fv4%2Flogin%2Fsso%2Fcallback%3Fprovider%3D&nonce=.+&providerID=`,
WantSetCookieRE: "sso_nonce=[^;].*Path=/_matrix/v4/login/sso",
},
{
Name: "redirectExplicitProvider",
Req: http.Request{
Host: "matrix.example.com",
URL: &url.URL{
Path: "/_matrix/v4/login/sso/redirect",
RawQuery: url.Values{
"redirectUrl": []string{"http://example.com/continue"},
}.Encode(),
},
},
IDPID: "someprovider",
WantLocationRE: `http://auth.example.com/authorize\?callbackURL=http.*%3Fprovider%3Dsomeprovider&nonce=.+&providerID=someprovider`,
WantSetCookieRE: "sso_nonce=[^;].*Path=/_matrix/v4/login/sso",
},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
got := SSORedirect(&tst.Req, tst.IDPID, &tst.Auth, &tst.Config)
if want := http.StatusFound; got.Code != want {
t.Errorf("SSORedirect Code: got %v, want %v", got.Code, want)
}
if m, err := regexp.MatchString(tst.WantLocationRE, got.Headers["Location"]); err != nil {
t.Fatalf("WantSetCookieRE failed: %v", err)
} else if !m {
t.Errorf("SSORedirect Location: got %q, want match %v", got.Headers["Location"], tst.WantLocationRE)
}
if m, err := regexp.MatchString(tst.WantSetCookieRE, got.Headers["Set-Cookie"]); err != nil {
t.Fatalf("WantSetCookieRE failed: %v", err)
} else if !m {
t.Errorf("SSORedirect Set-Cookie: got %q, want match %v", got.Headers["Set-Cookie"], tst.WantSetCookieRE)
}
})
}
}
func TestSSORedirectError(t *testing.T) {
tsts := []struct {
Name string
Req http.Request
IDPID string
Auth fakeSSOAuthenticator
Config config.SSO
WantCode int
}{
{
Name: "missingRedirectUrl",
Req: http.Request{
Host: "matrix.example.com",
URL: &url.URL{
Path: "/_matrix/v4/login/sso/redirect",
RawQuery: url.Values{}.Encode(),
},
},
WantCode: http.StatusBadRequest,
},
{
Name: "invalidRedirectUrl",
Req: http.Request{
Host: "matrix.example.com",
URL: &url.URL{
Path: "/_matrix/v4/login/sso/redirect",
RawQuery: url.Values{
"redirectUrl": []string{"/continue"},
}.Encode(),
},
},
WantCode: http.StatusBadRequest,
},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
got := SSORedirect(&tst.Req, tst.IDPID, &tst.Auth, &tst.Config)
if got.Code != tst.WantCode {
t.Errorf("SSORedirect Code: got %v, want %v", got.Code, tst.WantCode)
}
})
}
}
func TestSSOCallback(t *testing.T) {
nonce := "1234." + base64.RawURLEncoding.EncodeToString([]byte("http://matrix.example.com/continue"))
tsts := []struct {
Name string
Req http.Request
UserAPI fakeUserAPIForSSO
Auth fakeSSOAuthenticator
Config config.SSO
WantLocationRE string
WantSetCookieRE string
WantAccountCreation []*uapi.PerformAccountCreationRequest
WantLoginTokenCreation []*uapi.PerformLoginTokenCreationRequest
WantSaveSSOAssociation []*uapi.PerformSaveSSOAssociationRequest
WantQueryLocalpart []*uapi.QueryLocalpartForSSORequest
}{
{
Name: "logIn",
Req: http.Request{
Host: "matrix.example.com",
URL: &url.URL{
Path: "/_matrix/v4/login/sso/callback",
RawQuery: url.Values{
"provider": []string{"aprovider"},
}.Encode(),
},
Header: http.Header{
"Cookie": []string{(&http.Cookie{
Name: "sso_nonce",
Value: nonce,
}).String()},
},
},
UserAPI: fakeUserAPIForSSO{
localpart: "alocalpart",
},
Auth: fakeSSOAuthenticator{
callbackResult: sso.CallbackResult{
Identifier: &sso.UserIdentifier{
Namespace: "anamespace",
Issuer: "anissuer",
Subject: "asubject",
},
},
},
WantLocationRE: `http://matrix.example.com/continue\?loginToken=atoken`,
WantSetCookieRE: "sso_nonce=;",
WantLoginTokenCreation: []*uapi.PerformLoginTokenCreationRequest{{Data: uapi.LoginTokenData{UserID: "@alocalpart:aservername"}}},
WantQueryLocalpart: []*uapi.QueryLocalpartForSSORequest{{Namespace: "anamespace", Issuer: "anissuer", Subject: "asubject"}},
},
{
Name: "registerSuggested",
Req: http.Request{
Host: "matrix.example.com",
URL: &url.URL{
Path: "/_matrix/v4/login/sso/callback",
RawQuery: url.Values{
"provider": []string{"aprovider"},
}.Encode(),
},
Header: http.Header{
"Cookie": []string{(&http.Cookie{
Name: "sso_nonce",
Value: nonce,
}).String()},
},
},
Auth: fakeSSOAuthenticator{
callbackResult: sso.CallbackResult{
Identifier: &sso.UserIdentifier{
Namespace: "anamespace",
Issuer: "anissuer",
Subject: "asubject",
},
SuggestedUserID: "asuggestedid",
},
},
WantLocationRE: `http://matrix.example.com/continue\?loginToken=atoken`,
WantSetCookieRE: "sso_nonce=;",
WantAccountCreation: []*uapi.PerformAccountCreationRequest{{Localpart: "asuggestedid", AccountType: uapi.AccountTypeUser, OnConflict: uapi.ConflictAbort}},
WantLoginTokenCreation: []*uapi.PerformLoginTokenCreationRequest{{Data: uapi.LoginTokenData{UserID: "@asuggestedid:aservername"}}},
WantSaveSSOAssociation: []*uapi.PerformSaveSSOAssociationRequest{{Namespace: "anamespace", Issuer: "anissuer", Subject: "asubject", Localpart: "asuggestedid"}},
WantQueryLocalpart: []*uapi.QueryLocalpartForSSORequest{{Namespace: "anamespace", Issuer: "anissuer", Subject: "asubject"}},
},
{
Name: "registerNumeric",
Req: http.Request{
Host: "matrix.example.com",
URL: &url.URL{
Path: "/_matrix/v4/login/sso/callback",
RawQuery: url.Values{
"provider": []string{"aprovider"},
}.Encode(),
},
Header: http.Header{
"Cookie": []string{(&http.Cookie{
Name: "sso_nonce",
Value: nonce,
}).String()},
},
},
Auth: fakeSSOAuthenticator{
callbackResult: sso.CallbackResult{
Identifier: &sso.UserIdentifier{
Namespace: "anamespace",
Issuer: "anissuer",
Subject: "asubject",
},
},
},
WantLocationRE: `http://matrix.example.com/continue\?loginToken=atoken`,
WantSetCookieRE: "sso_nonce=;",
WantAccountCreation: []*uapi.PerformAccountCreationRequest{{Localpart: "12345", AccountType: uapi.AccountTypeUser, OnConflict: uapi.ConflictAbort}},
WantLoginTokenCreation: []*uapi.PerformLoginTokenCreationRequest{{Data: uapi.LoginTokenData{UserID: "@12345:aservername"}}},
WantSaveSSOAssociation: []*uapi.PerformSaveSSOAssociationRequest{{Namespace: "anamespace", Issuer: "anissuer", Subject: "asubject", Localpart: "12345"}},
WantQueryLocalpart: []*uapi.QueryLocalpartForSSORequest{{Namespace: "anamespace", Issuer: "anissuer", Subject: "asubject"}},
},
{
Name: "noIdentifierRedirectURL",
Req: http.Request{
Host: "matrix.example.com",
URL: &url.URL{
Path: "/_matrix/v4/login/sso/callback",
RawQuery: url.Values{
"provider": []string{"aprovider"},
}.Encode(),
},
Header: http.Header{
"Cookie": []string{(&http.Cookie{
Name: "sso_nonce",
Value: nonce,
}).String()},
},
},
Auth: fakeSSOAuthenticator{
callbackResult: sso.CallbackResult{
RedirectURL: "http://auth.example.com/notdone",
},
},
WantLocationRE: `http://auth.example.com/notdone`,
WantSetCookieRE: "^$",
},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
got := SSOCallback(&tst.Req, &tst.UserAPI, &tst.Auth, &tst.Config, "aservername")
if want := http.StatusFound; got.Code != want {
t.Log(got)
t.Errorf("SSOCallback Code: got %v, want %v", got.Code, want)
}
if m, err := regexp.MatchString(tst.WantLocationRE, got.Headers["Location"]); err != nil {
t.Fatalf("WantSetCookieRE failed: %v", err)
} else if !m {
t.Errorf("SSOCallback Location: got %q, want match %v", got.Headers["Location"], tst.WantLocationRE)
}
if m, err := regexp.MatchString(tst.WantSetCookieRE, got.Headers["Set-Cookie"]); err != nil {
t.Fatalf("WantSetCookieRE failed: %v", err)
} else if !m {
t.Errorf("SSOCallback Set-Cookie: got %q, want match %v", got.Headers["Set-Cookie"], tst.WantSetCookieRE)
}
if diff := cmp.Diff(tst.WantAccountCreation, tst.UserAPI.gotAccountCreation); diff != "" {
t.Errorf("PerformAccountCreation: +got -want:\n%s", diff)
}
if diff := cmp.Diff(tst.WantLoginTokenCreation, tst.UserAPI.gotLoginTokenCreation); diff != "" {
t.Errorf("PerformLoginTokenCreation: +got -want:\n%s", diff)
}
if diff := cmp.Diff(tst.WantSaveSSOAssociation, tst.UserAPI.gotSaveSSOAssociation); diff != "" {
t.Errorf("PerformSaveSSOAssociation: +got -want:\n%s", diff)
}
if diff := cmp.Diff(tst.WantQueryLocalpart, tst.UserAPI.gotQueryLocalpart); diff != "" {
t.Errorf("QueryLocalpartForSSO: +got -want:\n%s", diff)
}
})
}
}
func TestSSOCallbackError(t *testing.T) {
nonce := "1234." + base64.RawURLEncoding.EncodeToString([]byte("http://matrix.example.com/continue"))
goodReq := http.Request{
Host: "matrix.example.com",
URL: &url.URL{
Path: "/_matrix/v4/login/sso/callback",
RawQuery: url.Values{
"provider": []string{"aprovider"},
}.Encode(),
},
Header: http.Header{
"Cookie": []string{(&http.Cookie{
Name: "sso_nonce",
Value: nonce,
}).String()},
},
}
goodAuth := fakeSSOAuthenticator{
callbackResult: sso.CallbackResult{
Identifier: &sso.UserIdentifier{
Namespace: "anamespace",
Issuer: "anissuer",
Subject: "asubject",
},
},
}
errMocked := errors.New("mocked error")
tsts := []struct {
Name string
Req http.Request
UserAPI fakeUserAPIForSSO
Auth fakeSSOAuthenticator
Config config.SSO
WantCode int
}{
{
Name: "missingProvider",
Req: http.Request{
Host: "matrix.example.com",
URL: &url.URL{
Path: "/_matrix/v4/login/sso/callback",
},
Header: http.Header{
"Cookie": []string{(&http.Cookie{
Name: "sso_nonce",
Value: nonce,
}).String()},
},
},
WantCode: http.StatusBadRequest,
},
{
Name: "missingCookie",
Req: http.Request{
Host: "matrix.example.com",
URL: &url.URL{
Path: "/_matrix/v4/login/sso/callback",
RawQuery: url.Values{
"provider": []string{"aprovider"},
}.Encode(),
},
},
WantCode: http.StatusBadRequest,
},
{
Name: "malformedCookie",
Req: http.Request{
Host: "matrix.example.com",
URL: &url.URL{
Path: "/_matrix/v4/login/sso/callback",
RawQuery: url.Values{
"provider": []string{"aprovider"},
}.Encode(),
},
Header: http.Header{
"Cookie": []string{(&http.Cookie{
Name: "sso_nonce",
Value: "badvalue",
}).String()},
},
},
WantCode: http.StatusBadRequest,
},
{
Name: "failedProcessCallback",
Req: goodReq,
Auth: fakeSSOAuthenticator{
callbackErr: errMocked,
},
WantCode: http.StatusInternalServerError,
},
{
Name: "failedQueryLocalpartForSSO",
Req: goodReq,
UserAPI: fakeUserAPIForSSO{
localpartErr: errMocked,
},
Auth: goodAuth,
WantCode: http.StatusUnauthorized,
},
{
Name: "failedQueryNumericLocalpart",
Req: goodReq,
UserAPI: fakeUserAPIForSSO{
numericLocalpartErr: errMocked,
},
Auth: goodAuth,
WantCode: http.StatusInternalServerError,
},
{
Name: "failedAccountCreation",
Req: goodReq,
UserAPI: fakeUserAPIForSSO{
accountCreationErr: errMocked,
},
Auth: goodAuth,
WantCode: http.StatusInternalServerError,
},
{
Name: "failedSaveSSOAssociation",
Req: goodReq,
UserAPI: fakeUserAPIForSSO{
saveSSOAssociationErr: errMocked,
},
Auth: goodAuth,
WantCode: http.StatusInternalServerError,
},
{
Name: "failedPerformLoginTokenCreation",
Req: goodReq,
UserAPI: fakeUserAPIForSSO{
localpart: "alocalpart",
tokenTokenCreationErr: errMocked,
},
Auth: goodAuth,
WantCode: http.StatusInternalServerError,
},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
got := SSOCallback(&tst.Req, &tst.UserAPI, &tst.Auth, &tst.Config, "aservername")
if got.Code != tst.WantCode {
t.Log(got)
t.Errorf("SSOCallback Code: got %v, want %v", got.Code, tst.WantCode)
}
})
}
}
type fakeSSOAuthenticator struct {
callbackResult sso.CallbackResult
callbackErr error
}
func (auth *fakeSSOAuthenticator) AuthorizationURL(ctx context.Context, providerID, callbackURL, nonce string) (string, error) {
return (&url.URL{
Scheme: "http",
Host: "auth.example.com",
Path: "/authorize",
}).ResolveReference(&url.URL{
RawQuery: url.Values{
"callbackURL": []string{callbackURL},
"nonce": []string{nonce},
"providerID": []string{providerID},
}.Encode(),
}).String(), nil
}
func (auth *fakeSSOAuthenticator) ProcessCallback(ctx context.Context, providerID, callbackURL, nonce string, query url.Values) (*sso.CallbackResult, error) {
return &auth.callbackResult, auth.callbackErr
}
type fakeUserAPIForSSO struct {
userAPIForSSO
accountCreationErr error
tokenTokenCreationErr error
saveSSOAssociationErr error
localpart string
localpartErr error
numericLocalpartErr error
gotAccountCreation []*uapi.PerformAccountCreationRequest
gotLoginTokenCreation []*uapi.PerformLoginTokenCreationRequest
gotSaveSSOAssociation []*uapi.PerformSaveSSOAssociationRequest
gotQueryLocalpart []*uapi.QueryLocalpartForSSORequest
}
func (userAPI *fakeUserAPIForSSO) PerformAccountCreation(ctx context.Context, req *uapi.PerformAccountCreationRequest, res *uapi.PerformAccountCreationResponse) error {
userAPI.gotAccountCreation = append(userAPI.gotAccountCreation, req)
return userAPI.accountCreationErr
}
func (userAPI *fakeUserAPIForSSO) PerformLoginTokenCreation(ctx context.Context, req *uapi.PerformLoginTokenCreationRequest, res *uapi.PerformLoginTokenCreationResponse) error {
userAPI.gotLoginTokenCreation = append(userAPI.gotLoginTokenCreation, req)
res.Metadata = uapi.LoginTokenMetadata{
Token: "atoken",
}
return userAPI.tokenTokenCreationErr
}
func (userAPI *fakeUserAPIForSSO) PerformSaveSSOAssociation(ctx context.Context, req *uapi.PerformSaveSSOAssociationRequest, res *struct{}) error {
userAPI.gotSaveSSOAssociation = append(userAPI.gotSaveSSOAssociation, req)
return userAPI.saveSSOAssociationErr
}
func (userAPI *fakeUserAPIForSSO) QueryLocalpartForSSO(ctx context.Context, req *uapi.QueryLocalpartForSSORequest, res *uapi.QueryLocalpartForSSOResponse) error {
userAPI.gotQueryLocalpart = append(userAPI.gotQueryLocalpart, req)
res.Localpart = userAPI.localpart
return userAPI.localpartErr
}
func (userAPI *fakeUserAPIForSSO) QueryNumericLocalpart(ctx context.Context, res *uapi.QueryNumericLocalpartResponse) error {
res.ID = 12345
return userAPI.numericLocalpartErr
}

View file

@ -62,7 +62,7 @@ global:
local_part: "_server"
display_name: "Server alerts"
avatar: ""
room_name: "Server Alerts"
room_name: "Server Alerts"
app_service_api:
internal_api:
listen: http://localhost:7777
@ -86,6 +86,22 @@ client_api:
recaptcha_private_key: ""
recaptcha_bypass_secret: ""
recaptcha_siteverify_api: ""
login:
sso:
enabled: true
callback_url: http://example.com:8071/_matrix/v3/login/sso/callback
default_provider: github
providers:
- brand: github
- id: custom
name: "Custom Provider"
icon: "mxc://example.com/abc123"
type: oidc
oauth2:
client_id: aclientid
client_secret: aclientsecret
oidc:
discovery_url: http://auth.example.com/.well-known/openid-configuration
turn:
turn_user_lifetime: ""
turn_uris: []

View file

@ -429,6 +429,41 @@ func Test_Pusher(t *testing.T) {
})
}
func Test_SSO(t *testing.T) {
alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
defer close()
t.Log("Create SSO association")
ns := util.RandomString(8)
issuer := util.RandomString(8)
subject := util.RandomString(8)
err := db.SaveSSOAssociation(ctx, ns, issuer, subject, aliceLocalpart)
assert.NoError(t, err, "unable to save SSO association")
t.Log("Retrieve localpart for association")
gotLocalpart, err := db.GetLocalpartForSSO(ctx, ns, issuer, subject)
assert.Equal(t, aliceLocalpart, gotLocalpart)
t.Log("Remove SSO association")
err = db.RemoveSSOAssociation(ctx, ns, issuer, subject)
assert.NoError(t, err, "unexpected error")
t.Log("Verify the SSO association was removed")
gotLocalpart, err = db.GetLocalpartForSSO(ctx, ns, issuer, subject)
assert.NoError(t, err, "unable to get localpart for SSO subject")
assert.Equal(t, "", gotLocalpart)
})
}
func Test_ThreePID(t *testing.T) {
alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)