mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-07 06:03:09 -06:00
Add SSO tests.
Renames cookie oidc_nonce to sso_nonce, since it's defined in a file that doesn't know about OIDC specifically.
This commit is contained in:
parent
b8844fb1e2
commit
210ab1eef6
|
|
@ -101,6 +101,34 @@ func TestLoginFromJSONReader(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestLoginFromJSONReaderTokenDisabled(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
var userAPI fakeUserInternalAPI
|
||||
cfg := &config.ClientAPI{
|
||||
Matrix: &config.Global{
|
||||
ServerName: serverName,
|
||||
},
|
||||
Login: config.Login{
|
||||
SSO: config.SSO{
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
_, cleanup, err := LoginFromJSONReader(ctx, strings.NewReader(`{
|
||||
"type": "m.login.token",
|
||||
"token": "atoken",
|
||||
"device_id": "adevice"
|
||||
}`), &userAPI, &userAPI, cfg)
|
||||
wantCode := "M_INVALID_ARGUMENT_VALUE"
|
||||
if err == nil {
|
||||
cleanup(ctx, nil)
|
||||
t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", err, wantCode)
|
||||
} else if merr, ok := err.JSON.(*jsonerror.MatrixError); ok && merr.ErrCode != wantCode {
|
||||
t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", err, wantCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBadLoginFromJSONReader(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
|
|
|
|||
226
clientapi/auth/sso/oauth2_test.go
Normal file
226
clientapi/auth/sso/oauth2_test.go
Normal file
|
|
@ -0,0 +1,226 @@
|
|||
package sso
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
uapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
)
|
||||
|
||||
func TestOAuth2IdentityProviderAuthorizationURL(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
idp := &oauth2IdentityProvider{
|
||||
cfg: &config.IdentityProvider{
|
||||
OAuth2: config.OAuth2{
|
||||
ClientID: "aclientid",
|
||||
},
|
||||
},
|
||||
hc: http.DefaultClient,
|
||||
|
||||
authorizationURL: "https://oauth2.example.com/authorize",
|
||||
}
|
||||
|
||||
got, err := idp.AuthorizationURL(ctx, "https://matrix.example.com/continue", "anonce")
|
||||
if err != nil {
|
||||
t.Fatalf("AuthorizationURL failed: %v", err)
|
||||
}
|
||||
|
||||
if want := "https://oauth2.example.com/authorize?client_id=aclientid&redirect_uri=https%3A%2F%2Fmatrix.example.com%2Fcontinue&response_type=code&scope=&state=anonce"; got != want {
|
||||
t.Errorf("AuthorizationURL: got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuth2IdentityProviderProcessCallback(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
const callbackURL = "https://matrix.example.com/continue"
|
||||
|
||||
tsts := []struct {
|
||||
Name string
|
||||
Query url.Values
|
||||
|
||||
Want *CallbackResult
|
||||
WantTokenReq url.Values
|
||||
}{
|
||||
{
|
||||
Name: "gotEverything",
|
||||
Query: url.Values{
|
||||
"code": []string{"acode"},
|
||||
"state": []string{"anonce"},
|
||||
},
|
||||
|
||||
Want: &CallbackResult{
|
||||
Identifier: &UserIdentifier{
|
||||
Namespace: uapi.SSOIDNamespace,
|
||||
Issuer: "anid",
|
||||
Subject: "asub",
|
||||
},
|
||||
DisplayName: "aname",
|
||||
SuggestedUserID: "auser",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(tst.Name, func(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"access_token":"atoken", "token_type":"Bearer"}`))
|
||||
})
|
||||
mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"sub":"asub", "name":"aname", "preferred_user":"auser"}`))
|
||||
})
|
||||
|
||||
s := httptest.NewServer(mux)
|
||||
defer s.Close()
|
||||
|
||||
idp := &oauth2IdentityProvider{
|
||||
cfg: &config.IdentityProvider{
|
||||
ID: "anid",
|
||||
OAuth2: config.OAuth2{
|
||||
ClientID: "aclientid",
|
||||
ClientSecret: "aclientsecret",
|
||||
},
|
||||
},
|
||||
hc: s.Client(),
|
||||
|
||||
accessTokenURL: s.URL + "/token",
|
||||
userInfoURL: s.URL + "/userinfo",
|
||||
|
||||
subPath: "sub",
|
||||
displayNamePath: "name",
|
||||
suggestedUserIDPath: "preferred_user",
|
||||
}
|
||||
|
||||
got, err := idp.ProcessCallback(ctx, callbackURL, "anonce", tst.Query)
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessCallback failed: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, tst.Want) {
|
||||
t.Errorf("ProcessCallback: got %+v, want %+v", got, tst.Want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuth2IdentityProviderGetAccessToken(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
const callbackURL = "https://matrix.example.com/continue"
|
||||
|
||||
mux := http.NewServeMux()
|
||||
var gotReq url.Values
|
||||
mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
gotReq = r.Form
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"access_token":"atoken", "token_type":"Bearer"}`))
|
||||
})
|
||||
|
||||
s := httptest.NewServer(mux)
|
||||
defer s.Close()
|
||||
|
||||
idp := &oauth2IdentityProvider{
|
||||
cfg: &config.IdentityProvider{
|
||||
ID: "anid",
|
||||
OAuth2: config.OAuth2{
|
||||
ClientID: "aclientid",
|
||||
ClientSecret: "aclientsecret",
|
||||
},
|
||||
},
|
||||
hc: s.Client(),
|
||||
|
||||
accessTokenURL: s.URL + "/token",
|
||||
}
|
||||
|
||||
got, err := idp.getAccessToken(ctx, callbackURL, "acode")
|
||||
if err != nil {
|
||||
t.Fatalf("getAccessToken failed: %v", err)
|
||||
}
|
||||
|
||||
if want := "atoken"; got != want {
|
||||
t.Errorf("getAccessToken: got %q, want %q", got, want)
|
||||
}
|
||||
|
||||
wantReq := url.Values{
|
||||
"client_id": []string{"aclientid"},
|
||||
"client_secret": []string{"aclientsecret"},
|
||||
"code": []string{"acode"},
|
||||
"grant_type": []string{"authorization_code"},
|
||||
"redirect_uri": []string{callbackURL},
|
||||
}
|
||||
if !reflect.DeepEqual(gotReq, wantReq) {
|
||||
t.Errorf("getAccessToken request: got %+v, want %+v", gotReq, wantReq)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuth2IdentityProviderGetUserInfo(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
mux := http.NewServeMux()
|
||||
var gotHeader http.Header
|
||||
mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) {
|
||||
gotHeader = r.Header
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"sub":"asub", "name":"aname", "preferred_user":"auser"}`))
|
||||
})
|
||||
|
||||
s := httptest.NewServer(mux)
|
||||
defer s.Close()
|
||||
|
||||
idp := &oauth2IdentityProvider{
|
||||
cfg: &config.IdentityProvider{
|
||||
ID: "anid",
|
||||
OAuth2: config.OAuth2{
|
||||
ClientID: "aclientid",
|
||||
ClientSecret: "aclientsecret",
|
||||
},
|
||||
},
|
||||
hc: s.Client(),
|
||||
|
||||
userInfoURL: s.URL + "/userinfo",
|
||||
|
||||
responseMimeType: "application/json",
|
||||
subPath: "sub",
|
||||
displayNamePath: "name",
|
||||
suggestedUserIDPath: "preferred_user",
|
||||
}
|
||||
|
||||
gotSub, gotName, gotSuggestedUser, err := idp.getUserInfo(ctx, "atoken")
|
||||
if err != nil {
|
||||
t.Fatalf("getUserInfo failed: %v", err)
|
||||
}
|
||||
|
||||
if want := "asub"; gotSub != want {
|
||||
t.Errorf("getUserInfo subject: got %q, want %q", gotSub, want)
|
||||
}
|
||||
if want := "aname"; gotName != want {
|
||||
t.Errorf("getUserInfo displayName: got %q, want %q", gotName, want)
|
||||
}
|
||||
if want := "auser"; gotSuggestedUser != want {
|
||||
t.Errorf("getUserInfo suggestedUser: got %q, want %q", gotSuggestedUser, want)
|
||||
}
|
||||
|
||||
gotHeader.Del("Accept-Encoding")
|
||||
gotHeader.Del("User-Agent")
|
||||
wantHeader := http.Header{
|
||||
"Accept": []string{"application/json"},
|
||||
"Authorization": []string{"Bearer atoken"},
|
||||
}
|
||||
if !reflect.DeepEqual(gotHeader, wantHeader) {
|
||||
t.Errorf("getUserInfo header: got %+v, want %+v", gotHeader, wantHeader)
|
||||
}
|
||||
}
|
||||
|
|
@ -27,6 +27,18 @@ import (
|
|||
uapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
)
|
||||
|
||||
// oidcDiscoveryMaxStaleness indicates how stale the Discovery
|
||||
// information is allowed to be. This will very rarely change, so
|
||||
// we're just making sure even a Dendrite that isn't restarting often
|
||||
// is picking this up eventually.
|
||||
const oidcDiscoveryMaxStaleness = 24 * time.Hour
|
||||
|
||||
// An oidcIdentityProvider wraps OAuth2 with OpenID Connect Discovery.
|
||||
//
|
||||
// The SSO identifier is the "sub." A suggested UserID is grabbed from
|
||||
// "preferred_username", though this isn't commonly provided.
|
||||
//
|
||||
// See https://openid.net/specs/openid-connect-core-1_0.html and https://openid.net/specs/openid-connect-discovery-1_0.html.
|
||||
type oidcIdentityProvider struct {
|
||||
*oauth2IdentityProvider
|
||||
|
||||
|
|
@ -44,7 +56,7 @@ func newOIDCIdentityProvider(cfg *config.IdentityProvider, hc *http.Client) *oid
|
|||
scopes: []string{"openid", "profile", "email"},
|
||||
responseMimeType: "application/json",
|
||||
subPath: "sub",
|
||||
emailPath: "email",
|
||||
emailPath: "email", // TODO: should this require email_verified?
|
||||
displayNamePath: "name",
|
||||
suggestedUserIDPath: "preferred_username",
|
||||
},
|
||||
|
|
@ -92,7 +104,7 @@ func (p *oidcIdentityProvider) get(ctx context.Context) (*oauth2IdentityProvider
|
|||
return nil, nil, err
|
||||
}
|
||||
|
||||
p.exp = now.Add(24 * time.Hour)
|
||||
p.exp = now.Add(oidcDiscoveryMaxStaleness)
|
||||
newProvider := *p.oauth2IdentityProvider
|
||||
newProvider.authorizationURL = disc.AuthorizationEndpoint
|
||||
newProvider.accessTokenURL = disc.TokenEndpoint
|
||||
|
|
|
|||
118
clientapi/auth/sso/oidc_test.go
Normal file
118
clientapi/auth/sso/oidc_test.go
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
package sso
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
uapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
)
|
||||
|
||||
func TestOIDCIdentityProviderAuthorizationURL(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/discovery", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"authorization_endpoint":"http://oidc.example.com/authorize","token_endpoint":"http://oidc.example.com/token","userinfo_endpoint":"http://oidc.example.com/userinfo","issuer":"http://oidc.example.com/"}`))
|
||||
})
|
||||
|
||||
s := httptest.NewServer(mux)
|
||||
defer s.Close()
|
||||
|
||||
idp := newOIDCIdentityProvider(&config.IdentityProvider{
|
||||
OAuth2: config.OAuth2{
|
||||
ClientID: "aclientid",
|
||||
},
|
||||
OIDC: config.OIDC{
|
||||
DiscoveryURL: s.URL + "/discovery",
|
||||
},
|
||||
}, s.Client())
|
||||
|
||||
got, err := idp.AuthorizationURL(ctx, "https://matrix.example.com/continue", "anonce")
|
||||
if err != nil {
|
||||
t.Fatalf("AuthorizationURL failed: %v", err)
|
||||
}
|
||||
|
||||
if want := "http://oidc.example.com/authorize?client_id=aclientid&redirect_uri=https%3A%2F%2Fmatrix.example.com%2Fcontinue&response_type=code&scope=openid+profile+email&state=anonce"; got != want {
|
||||
t.Errorf("AuthorizationURL: got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCIdentityProviderProcessCallback(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
const callbackURL = "https://matrix.example.com/continue"
|
||||
|
||||
tsts := []struct {
|
||||
Name string
|
||||
Query url.Values
|
||||
|
||||
Want *CallbackResult
|
||||
WantTokenReq url.Values
|
||||
}{
|
||||
{
|
||||
Name: "gotEverything",
|
||||
Query: url.Values{
|
||||
"code": []string{"acode"},
|
||||
"state": []string{"anonce"},
|
||||
},
|
||||
|
||||
Want: &CallbackResult{
|
||||
Identifier: &UserIdentifier{
|
||||
Namespace: uapi.OIDCNamespace,
|
||||
Issuer: "http://oidc.example.com/",
|
||||
Subject: "asub",
|
||||
},
|
||||
DisplayName: "aname",
|
||||
SuggestedUserID: "auser",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(tst.Name, func(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
var sURL string
|
||||
mux.HandleFunc("/discovery", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(fmt.Sprintf(`{"authorization_endpoint":"%s/authorize","token_endpoint":"%s/token","userinfo_endpoint":"%s/userinfo","issuer":"http://oidc.example.com/"}`,
|
||||
sURL, sURL, sURL)))
|
||||
})
|
||||
mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"access_token":"atoken", "token_type":"Bearer"}`))
|
||||
})
|
||||
mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"sub":"asub", "name":"aname", "preferred_username":"auser"}`))
|
||||
})
|
||||
|
||||
s := httptest.NewServer(mux)
|
||||
defer s.Close()
|
||||
|
||||
sURL = s.URL
|
||||
idp := newOIDCIdentityProvider(&config.IdentityProvider{
|
||||
OAuth2: config.OAuth2{
|
||||
ClientID: "aclientid",
|
||||
},
|
||||
OIDC: config.OIDC{
|
||||
DiscoveryURL: sURL + "/discovery",
|
||||
},
|
||||
}, s.Client())
|
||||
|
||||
got, err := idp.ProcessCallback(ctx, callbackURL, "anonce", tst.Query)
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessCallback failed: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, tst.Want) {
|
||||
t.Errorf("ProcessCallback: got %+v, want %+v", got, tst.Want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -25,13 +25,19 @@ import (
|
|||
uapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
)
|
||||
|
||||
// maxHTTPTimeout is an upper bound on an HTTP request to an SSO
|
||||
// backend. The individual request context deadlines are also honored.
|
||||
const maxHTTPTimeout = 10 * time.Second
|
||||
|
||||
// An Authenticator keeps a set of identity providers and dispatches
|
||||
// calls to one of them, based on configured ID.
|
||||
type Authenticator struct {
|
||||
providers map[string]identityProvider
|
||||
}
|
||||
|
||||
func NewAuthenticator(ctx context.Context, cfg *config.SSO) (*Authenticator, error) {
|
||||
func NewAuthenticator(cfg *config.SSO) (*Authenticator, error) {
|
||||
hc := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Timeout: maxHTTPTimeout,
|
||||
Transport: &http.Transport{
|
||||
DisableKeepAlives: true,
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
|
|
|
|||
76
clientapi/auth/sso/sso_test.go
Normal file
76
clientapi/auth/sso/sso_test.go
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
package sso
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
)
|
||||
|
||||
func TestNewAuthenticator(t *testing.T) {
|
||||
_, err := NewAuthenticator(&config.SSO{
|
||||
Providers: []config.IdentityProvider{
|
||||
{
|
||||
Type: config.SSOTypeGitHub,
|
||||
OAuth2: config.OAuth2{
|
||||
ClientID: "aclientid",
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: config.SSOTypeOIDC,
|
||||
OAuth2: config.OAuth2{
|
||||
ClientID: "aclientid",
|
||||
},
|
||||
OIDC: config.OIDC{
|
||||
DiscoveryURL: "http://oidc.example.com/discovery",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewAuthenticator failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthenticator(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
var idp fakeIdentityProvider
|
||||
a := Authenticator{
|
||||
providers: map[string]identityProvider{
|
||||
"fake": &idp,
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("authorizationURL", func(t *testing.T) {
|
||||
got, err := a.AuthorizationURL(ctx, "fake", "http://matrix.example.com/continue", "anonce")
|
||||
if err != nil {
|
||||
t.Fatalf("AuthorizationURL failed: %v", err)
|
||||
}
|
||||
if want := "aurl"; got != want {
|
||||
t.Errorf("AuthorizationURL: got %q, want %q", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("processCallback", func(t *testing.T) {
|
||||
got, err := a.ProcessCallback(ctx, "fake", "http://matrix.example.com/continue", "anonce", url.Values{})
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessCallback failed: %v", err)
|
||||
}
|
||||
if want := (&CallbackResult{DisplayName: "aname"}); !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("ProcessCallback: got %+v, want %+v", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type fakeIdentityProvider struct{}
|
||||
|
||||
func (idp *fakeIdentityProvider) AuthorizationURL(ctx context.Context, callbackURL, nonce string) (string, error) {
|
||||
return "aurl", nil
|
||||
}
|
||||
|
||||
func (idp *fakeIdentityProvider) ProcessCallback(ctx context.Context, callbackURL, nonce string, query url.Values) (*CallbackResult, error) {
|
||||
return &CallbackResult{DisplayName: "aname"}, nil
|
||||
}
|
||||
|
|
@ -73,7 +73,7 @@ func Setup(
|
|||
var ssoAuthenticator *sso.Authenticator
|
||||
if cfg.Login.SSO.Enabled {
|
||||
var err error
|
||||
ssoAuthenticator, err = sso.NewAuthenticator(ctx, &cfg.Login.SSO)
|
||||
ssoAuthenticator, err = sso.NewAuthenticator(&cfg.Login.SSO)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Fatal("failed to create SSO authenticator")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ import (
|
|||
func SSORedirect(
|
||||
req *http.Request,
|
||||
idpID string,
|
||||
auth *sso.Authenticator,
|
||||
auth ssoAuthenticator,
|
||||
cfg *config.SSO,
|
||||
) util.JSONResponse {
|
||||
ctx := req.Context()
|
||||
|
|
@ -58,12 +58,16 @@ func SSORedirect(
|
|||
JSON: jsonerror.MissingArgument("redirectUrl parameter missing"),
|
||||
}
|
||||
}
|
||||
_, err := url.Parse(redirectURL)
|
||||
if err != nil {
|
||||
if ru, err := url.Parse(redirectURL); err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidArgumentValue("Invalid redirectURL: " + err.Error()),
|
||||
}
|
||||
} else if ru.Scheme == "" || ru.Host == "" || ru.Path == "" {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidArgumentValue("Invalid redirectURL: " + redirectURL),
|
||||
}
|
||||
}
|
||||
|
||||
callbackURL, err := buildCallbackURLFromOther(cfg, req, "/login/sso/redirect")
|
||||
|
|
@ -92,7 +96,7 @@ func SSORedirect(
|
|||
|
||||
resp := util.RedirectResponse(u)
|
||||
cookie := &http.Cookie{
|
||||
Name: "oidc_nonce",
|
||||
Name: "sso_nonce",
|
||||
Value: nonce,
|
||||
Path: path.Dir(callbackURL.Path),
|
||||
Expires: time.Now().Add(10 * time.Minute),
|
||||
|
|
@ -113,7 +117,6 @@ func SSORedirect(
|
|||
func buildCallbackURLFromOther(cfg *config.SSO, req *http.Request, expectedPath string) (*url.URL, error) {
|
||||
u := &url.URL{
|
||||
Scheme: "https",
|
||||
User: req.URL.User,
|
||||
Host: req.Host,
|
||||
Path: req.URL.Path,
|
||||
}
|
||||
|
|
@ -141,7 +144,7 @@ func buildCallbackURLFromOther(cfg *config.SSO, req *http.Request, expectedPath
|
|||
func SSOCallback(
|
||||
req *http.Request,
|
||||
userAPI userAPIForSSO,
|
||||
auth *sso.Authenticator,
|
||||
auth ssoAuthenticator,
|
||||
cfg *config.SSO,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
) util.JSONResponse {
|
||||
|
|
@ -163,7 +166,7 @@ func SSOCallback(
|
|||
}
|
||||
}
|
||||
|
||||
nonce, err := req.Cookie("oidc_nonce")
|
||||
nonce, err := req.Cookie("sso_nonce")
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
|
|
@ -246,7 +249,7 @@ func SSOCallback(
|
|||
rquery.Set("loginToken", token.Token)
|
||||
resp := util.RedirectResponse(finalRedirectURL.ResolveReference(&url.URL{RawQuery: rquery.Encode()}).String())
|
||||
resp.Headers["Set-Cookie"] = (&http.Cookie{
|
||||
Name: "oidc_nonce",
|
||||
Name: "sso_nonce",
|
||||
Value: "",
|
||||
MaxAge: -1,
|
||||
Secure: true,
|
||||
|
|
@ -254,6 +257,11 @@ func SSOCallback(
|
|||
return resp
|
||||
}
|
||||
|
||||
type ssoAuthenticator interface {
|
||||
AuthorizationURL(ctx context.Context, providerID, callbackURL, nonce string) (string, error)
|
||||
ProcessCallback(ctx context.Context, providerID, callbackURL, nonce string, query url.Values) (*sso.CallbackResult, error)
|
||||
}
|
||||
|
||||
type userAPIForSSO interface {
|
||||
uapi.LoginTokenInternalAPI
|
||||
|
||||
|
|
@ -273,21 +281,21 @@ func formatNonce(redirectURL string) string {
|
|||
// function. The URL is not integrity protected.
|
||||
func parseNonce(s string) (redirectURL *url.URL, _ error) {
|
||||
if s == "" {
|
||||
return nil, jsonerror.MissingArgument("empty OIDC nonce cookie")
|
||||
return nil, jsonerror.MissingArgument("empty SSO nonce cookie")
|
||||
}
|
||||
|
||||
ss := strings.Split(s, ".")
|
||||
if len(ss) < 2 {
|
||||
return nil, jsonerror.InvalidArgumentValue("malformed OIDC nonce cookie")
|
||||
return nil, jsonerror.InvalidArgumentValue("malformed SSO nonce cookie")
|
||||
}
|
||||
|
||||
urlbs, err := base64.RawURLEncoding.DecodeString(ss[1])
|
||||
if err != nil {
|
||||
return nil, jsonerror.InvalidArgumentValue("invalid redirect URL in OIDC nonce cookie")
|
||||
return nil, jsonerror.InvalidArgumentValue("invalid redirect URL in SSO nonce cookie")
|
||||
}
|
||||
u, err := url.Parse(string(urlbs))
|
||||
if err != nil {
|
||||
return nil, jsonerror.InvalidArgumentValue("invalid redirect URL in OIDC nonce cookie: " + err.Error())
|
||||
return nil, jsonerror.InvalidArgumentValue("invalid redirect URL in SSO nonce cookie: " + err.Error())
|
||||
}
|
||||
|
||||
return u, nil
|
||||
|
|
@ -309,6 +317,9 @@ func verifySSOUserIdentifier(ctx context.Context, userAPI userAPIForSSO, id *sso
|
|||
return res.Localpart, nil
|
||||
}
|
||||
|
||||
// registerSSOAccount creates an account and associates the SSO
|
||||
// identifier with it. Note that SSO login account creation doesn't
|
||||
// use the standard registration API, but happens ad-hoc.
|
||||
func registerSSOAccount(ctx context.Context, userAPI userAPIForSSO, ssoID *sso.UserIdentifier, localpart string) (bool, util.JSONResponse) {
|
||||
var accRes uapi.PerformAccountCreationResponse
|
||||
err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
|
||||
|
|
@ -347,6 +358,8 @@ func registerSSOAccount(ctx context.Context, userAPI userAPIForSSO, ssoID *sso.U
|
|||
return true, util.JSONResponse{}
|
||||
}
|
||||
|
||||
// createLoginToken produces a new login token, valid for the given
|
||||
// user.
|
||||
func createLoginToken(ctx context.Context, userAPI userAPIForSSO, userID string) (*uapi.LoginTokenMetadata, error) {
|
||||
req := uapi.PerformLoginTokenCreationRequest{Data: uapi.LoginTokenData{UserID: userID}}
|
||||
var resp uapi.PerformLoginTokenCreationResponse
|
||||
|
|
|
|||
531
clientapi/routing/sso_test.go
Normal file
531
clientapi/routing/sso_test.go
Normal file
|
|
@ -0,0 +1,531 @@
|
|||
package routing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/sso"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
uapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
)
|
||||
|
||||
func TestSSORedirect(t *testing.T) {
|
||||
tsts := []struct {
|
||||
Name string
|
||||
Req http.Request
|
||||
IDPID string
|
||||
Auth fakeSSOAuthenticator
|
||||
Config config.SSO
|
||||
|
||||
WantLocationRE string
|
||||
WantSetCookieRE string
|
||||
}{
|
||||
{
|
||||
Name: "redirectDefault",
|
||||
Req: http.Request{
|
||||
Host: "matrix.example.com",
|
||||
URL: &url.URL{
|
||||
Path: "/_matrix/v4/login/sso/redirect",
|
||||
RawQuery: url.Values{
|
||||
"redirectUrl": []string{"http://example.com/continue"},
|
||||
}.Encode(),
|
||||
},
|
||||
},
|
||||
WantLocationRE: `http://auth.example.com/authorize\?callbackURL=http%3A%2F%2Fmatrix.example.com%2F_matrix%2Fv4%2Flogin%2Fsso%2Fcallback%3Fprovider%3D&nonce=.+&providerID=`,
|
||||
WantSetCookieRE: "sso_nonce=[^;].*Path=/_matrix/v4/login/sso",
|
||||
},
|
||||
{
|
||||
Name: "redirectExplicitProvider",
|
||||
Req: http.Request{
|
||||
Host: "matrix.example.com",
|
||||
URL: &url.URL{
|
||||
Path: "/_matrix/v4/login/sso/redirect",
|
||||
RawQuery: url.Values{
|
||||
"redirectUrl": []string{"http://example.com/continue"},
|
||||
}.Encode(),
|
||||
},
|
||||
},
|
||||
IDPID: "someprovider",
|
||||
WantLocationRE: `http://auth.example.com/authorize\?callbackURL=http.*%3Fprovider%3Dsomeprovider&nonce=.+&providerID=someprovider`,
|
||||
WantSetCookieRE: "sso_nonce=[^;].*Path=/_matrix/v4/login/sso",
|
||||
},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(tst.Name, func(t *testing.T) {
|
||||
got := SSORedirect(&tst.Req, tst.IDPID, &tst.Auth, &tst.Config)
|
||||
|
||||
if want := http.StatusFound; got.Code != want {
|
||||
t.Errorf("SSORedirect Code: got %v, want %v", got.Code, want)
|
||||
}
|
||||
|
||||
if m, err := regexp.MatchString(tst.WantLocationRE, got.Headers["Location"]); err != nil {
|
||||
t.Fatalf("WantSetCookieRE failed: %v", err)
|
||||
} else if !m {
|
||||
t.Errorf("SSORedirect Location: got %q, want match %v", got.Headers["Location"], tst.WantLocationRE)
|
||||
}
|
||||
|
||||
if m, err := regexp.MatchString(tst.WantSetCookieRE, got.Headers["Set-Cookie"]); err != nil {
|
||||
t.Fatalf("WantSetCookieRE failed: %v", err)
|
||||
} else if !m {
|
||||
t.Errorf("SSORedirect Set-Cookie: got %q, want match %v", got.Headers["Set-Cookie"], tst.WantSetCookieRE)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSORedirectError(t *testing.T) {
|
||||
tsts := []struct {
|
||||
Name string
|
||||
Req http.Request
|
||||
IDPID string
|
||||
Auth fakeSSOAuthenticator
|
||||
Config config.SSO
|
||||
|
||||
WantCode int
|
||||
}{
|
||||
{
|
||||
Name: "missingRedirectUrl",
|
||||
Req: http.Request{
|
||||
Host: "matrix.example.com",
|
||||
URL: &url.URL{
|
||||
Path: "/_matrix/v4/login/sso/redirect",
|
||||
RawQuery: url.Values{}.Encode(),
|
||||
},
|
||||
},
|
||||
WantCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
Name: "invalidRedirectUrl",
|
||||
Req: http.Request{
|
||||
Host: "matrix.example.com",
|
||||
URL: &url.URL{
|
||||
Path: "/_matrix/v4/login/sso/redirect",
|
||||
RawQuery: url.Values{
|
||||
"redirectUrl": []string{"/continue"},
|
||||
}.Encode(),
|
||||
},
|
||||
},
|
||||
WantCode: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(tst.Name, func(t *testing.T) {
|
||||
got := SSORedirect(&tst.Req, tst.IDPID, &tst.Auth, &tst.Config)
|
||||
|
||||
if got.Code != tst.WantCode {
|
||||
t.Errorf("SSORedirect Code: got %v, want %v", got.Code, tst.WantCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOCallback(t *testing.T) {
|
||||
nonce := "1234." + base64.RawURLEncoding.EncodeToString([]byte("http://matrix.example.com/continue"))
|
||||
|
||||
tsts := []struct {
|
||||
Name string
|
||||
Req http.Request
|
||||
UserAPI fakeUserAPIForSSO
|
||||
Auth fakeSSOAuthenticator
|
||||
Config config.SSO
|
||||
|
||||
WantLocationRE string
|
||||
WantSetCookieRE string
|
||||
|
||||
WantAccountCreation []*uapi.PerformAccountCreationRequest
|
||||
WantLoginTokenCreation []*uapi.PerformLoginTokenCreationRequest
|
||||
WantSaveSSOAssociation []*uapi.PerformSaveSSOAssociationRequest
|
||||
WantQueryLocalpart []*uapi.QueryLocalpartForSSORequest
|
||||
}{
|
||||
{
|
||||
Name: "logIn",
|
||||
Req: http.Request{
|
||||
Host: "matrix.example.com",
|
||||
URL: &url.URL{
|
||||
Path: "/_matrix/v4/login/sso/callback",
|
||||
RawQuery: url.Values{
|
||||
"provider": []string{"aprovider"},
|
||||
}.Encode(),
|
||||
},
|
||||
Header: http.Header{
|
||||
"Cookie": []string{(&http.Cookie{
|
||||
Name: "sso_nonce",
|
||||
Value: nonce,
|
||||
}).String()},
|
||||
},
|
||||
},
|
||||
UserAPI: fakeUserAPIForSSO{
|
||||
localpart: "alocalpart",
|
||||
},
|
||||
Auth: fakeSSOAuthenticator{
|
||||
callbackResult: sso.CallbackResult{
|
||||
Identifier: &sso.UserIdentifier{
|
||||
Namespace: "anamespace",
|
||||
Issuer: "anissuer",
|
||||
Subject: "asubject",
|
||||
},
|
||||
},
|
||||
},
|
||||
WantLocationRE: `http://matrix.example.com/continue\?loginToken=atoken`,
|
||||
WantSetCookieRE: "sso_nonce=;",
|
||||
|
||||
WantLoginTokenCreation: []*uapi.PerformLoginTokenCreationRequest{{Data: uapi.LoginTokenData{UserID: "@alocalpart:aservername"}}},
|
||||
WantQueryLocalpart: []*uapi.QueryLocalpartForSSORequest{{Namespace: "anamespace", Issuer: "anissuer", Subject: "asubject"}},
|
||||
},
|
||||
{
|
||||
Name: "registerSuggested",
|
||||
Req: http.Request{
|
||||
Host: "matrix.example.com",
|
||||
URL: &url.URL{
|
||||
Path: "/_matrix/v4/login/sso/callback",
|
||||
RawQuery: url.Values{
|
||||
"provider": []string{"aprovider"},
|
||||
}.Encode(),
|
||||
},
|
||||
Header: http.Header{
|
||||
"Cookie": []string{(&http.Cookie{
|
||||
Name: "sso_nonce",
|
||||
Value: nonce,
|
||||
}).String()},
|
||||
},
|
||||
},
|
||||
Auth: fakeSSOAuthenticator{
|
||||
callbackResult: sso.CallbackResult{
|
||||
Identifier: &sso.UserIdentifier{
|
||||
Namespace: "anamespace",
|
||||
Issuer: "anissuer",
|
||||
Subject: "asubject",
|
||||
},
|
||||
SuggestedUserID: "asuggestedid",
|
||||
},
|
||||
},
|
||||
WantLocationRE: `http://matrix.example.com/continue\?loginToken=atoken`,
|
||||
WantSetCookieRE: "sso_nonce=;",
|
||||
|
||||
WantAccountCreation: []*uapi.PerformAccountCreationRequest{{Localpart: "asuggestedid", AccountType: uapi.AccountTypeUser, OnConflict: uapi.ConflictAbort}},
|
||||
WantLoginTokenCreation: []*uapi.PerformLoginTokenCreationRequest{{Data: uapi.LoginTokenData{UserID: "@asuggestedid:aservername"}}},
|
||||
WantSaveSSOAssociation: []*uapi.PerformSaveSSOAssociationRequest{{Namespace: "anamespace", Issuer: "anissuer", Subject: "asubject", Localpart: "asuggestedid"}},
|
||||
WantQueryLocalpart: []*uapi.QueryLocalpartForSSORequest{{Namespace: "anamespace", Issuer: "anissuer", Subject: "asubject"}},
|
||||
},
|
||||
{
|
||||
Name: "registerNumeric",
|
||||
Req: http.Request{
|
||||
Host: "matrix.example.com",
|
||||
URL: &url.URL{
|
||||
Path: "/_matrix/v4/login/sso/callback",
|
||||
RawQuery: url.Values{
|
||||
"provider": []string{"aprovider"},
|
||||
}.Encode(),
|
||||
},
|
||||
Header: http.Header{
|
||||
"Cookie": []string{(&http.Cookie{
|
||||
Name: "sso_nonce",
|
||||
Value: nonce,
|
||||
}).String()},
|
||||
},
|
||||
},
|
||||
Auth: fakeSSOAuthenticator{
|
||||
callbackResult: sso.CallbackResult{
|
||||
Identifier: &sso.UserIdentifier{
|
||||
Namespace: "anamespace",
|
||||
Issuer: "anissuer",
|
||||
Subject: "asubject",
|
||||
},
|
||||
},
|
||||
},
|
||||
WantLocationRE: `http://matrix.example.com/continue\?loginToken=atoken`,
|
||||
WantSetCookieRE: "sso_nonce=;",
|
||||
|
||||
WantAccountCreation: []*uapi.PerformAccountCreationRequest{{Localpart: "12345", AccountType: uapi.AccountTypeUser, OnConflict: uapi.ConflictAbort}},
|
||||
WantLoginTokenCreation: []*uapi.PerformLoginTokenCreationRequest{{Data: uapi.LoginTokenData{UserID: "@12345:aservername"}}},
|
||||
WantSaveSSOAssociation: []*uapi.PerformSaveSSOAssociationRequest{{Namespace: "anamespace", Issuer: "anissuer", Subject: "asubject", Localpart: "12345"}},
|
||||
WantQueryLocalpart: []*uapi.QueryLocalpartForSSORequest{{Namespace: "anamespace", Issuer: "anissuer", Subject: "asubject"}},
|
||||
},
|
||||
{
|
||||
Name: "noIdentifierRedirectURL",
|
||||
Req: http.Request{
|
||||
Host: "matrix.example.com",
|
||||
URL: &url.URL{
|
||||
Path: "/_matrix/v4/login/sso/callback",
|
||||
RawQuery: url.Values{
|
||||
"provider": []string{"aprovider"},
|
||||
}.Encode(),
|
||||
},
|
||||
Header: http.Header{
|
||||
"Cookie": []string{(&http.Cookie{
|
||||
Name: "sso_nonce",
|
||||
Value: nonce,
|
||||
}).String()},
|
||||
},
|
||||
},
|
||||
Auth: fakeSSOAuthenticator{
|
||||
callbackResult: sso.CallbackResult{
|
||||
RedirectURL: "http://auth.example.com/notdone",
|
||||
},
|
||||
},
|
||||
WantLocationRE: `http://auth.example.com/notdone`,
|
||||
WantSetCookieRE: "^$",
|
||||
},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(tst.Name, func(t *testing.T) {
|
||||
got := SSOCallback(&tst.Req, &tst.UserAPI, &tst.Auth, &tst.Config, "aservername")
|
||||
|
||||
if want := http.StatusFound; got.Code != want {
|
||||
t.Log(got)
|
||||
t.Errorf("SSOCallback Code: got %v, want %v", got.Code, want)
|
||||
}
|
||||
|
||||
if m, err := regexp.MatchString(tst.WantLocationRE, got.Headers["Location"]); err != nil {
|
||||
t.Fatalf("WantSetCookieRE failed: %v", err)
|
||||
} else if !m {
|
||||
t.Errorf("SSOCallback Location: got %q, want match %v", got.Headers["Location"], tst.WantLocationRE)
|
||||
}
|
||||
|
||||
if m, err := regexp.MatchString(tst.WantSetCookieRE, got.Headers["Set-Cookie"]); err != nil {
|
||||
t.Fatalf("WantSetCookieRE failed: %v", err)
|
||||
} else if !m {
|
||||
t.Errorf("SSOCallback Set-Cookie: got %q, want match %v", got.Headers["Set-Cookie"], tst.WantSetCookieRE)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tst.WantAccountCreation, tst.UserAPI.gotAccountCreation); diff != "" {
|
||||
t.Errorf("PerformAccountCreation: +got -want:\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(tst.WantLoginTokenCreation, tst.UserAPI.gotLoginTokenCreation); diff != "" {
|
||||
t.Errorf("PerformLoginTokenCreation: +got -want:\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(tst.WantSaveSSOAssociation, tst.UserAPI.gotSaveSSOAssociation); diff != "" {
|
||||
t.Errorf("PerformSaveSSOAssociation: +got -want:\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(tst.WantQueryLocalpart, tst.UserAPI.gotQueryLocalpart); diff != "" {
|
||||
t.Errorf("QueryLocalpartForSSO: +got -want:\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOCallbackError(t *testing.T) {
|
||||
nonce := "1234." + base64.RawURLEncoding.EncodeToString([]byte("http://matrix.example.com/continue"))
|
||||
goodReq := http.Request{
|
||||
Host: "matrix.example.com",
|
||||
URL: &url.URL{
|
||||
Path: "/_matrix/v4/login/sso/callback",
|
||||
RawQuery: url.Values{
|
||||
"provider": []string{"aprovider"},
|
||||
}.Encode(),
|
||||
},
|
||||
Header: http.Header{
|
||||
"Cookie": []string{(&http.Cookie{
|
||||
Name: "sso_nonce",
|
||||
Value: nonce,
|
||||
}).String()},
|
||||
},
|
||||
}
|
||||
goodAuth := fakeSSOAuthenticator{
|
||||
callbackResult: sso.CallbackResult{
|
||||
Identifier: &sso.UserIdentifier{
|
||||
Namespace: "anamespace",
|
||||
Issuer: "anissuer",
|
||||
Subject: "asubject",
|
||||
},
|
||||
},
|
||||
}
|
||||
errMocked := errors.New("mocked error")
|
||||
|
||||
tsts := []struct {
|
||||
Name string
|
||||
Req http.Request
|
||||
UserAPI fakeUserAPIForSSO
|
||||
Auth fakeSSOAuthenticator
|
||||
Config config.SSO
|
||||
|
||||
WantCode int
|
||||
}{
|
||||
{
|
||||
Name: "missingProvider",
|
||||
Req: http.Request{
|
||||
Host: "matrix.example.com",
|
||||
URL: &url.URL{
|
||||
Path: "/_matrix/v4/login/sso/callback",
|
||||
},
|
||||
Header: http.Header{
|
||||
"Cookie": []string{(&http.Cookie{
|
||||
Name: "sso_nonce",
|
||||
Value: nonce,
|
||||
}).String()},
|
||||
},
|
||||
},
|
||||
WantCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
Name: "missingCookie",
|
||||
Req: http.Request{
|
||||
Host: "matrix.example.com",
|
||||
URL: &url.URL{
|
||||
Path: "/_matrix/v4/login/sso/callback",
|
||||
RawQuery: url.Values{
|
||||
"provider": []string{"aprovider"},
|
||||
}.Encode(),
|
||||
},
|
||||
},
|
||||
WantCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
Name: "malformedCookie",
|
||||
Req: http.Request{
|
||||
Host: "matrix.example.com",
|
||||
URL: &url.URL{
|
||||
Path: "/_matrix/v4/login/sso/callback",
|
||||
RawQuery: url.Values{
|
||||
"provider": []string{"aprovider"},
|
||||
}.Encode(),
|
||||
},
|
||||
Header: http.Header{
|
||||
"Cookie": []string{(&http.Cookie{
|
||||
Name: "sso_nonce",
|
||||
Value: "badvalue",
|
||||
}).String()},
|
||||
},
|
||||
},
|
||||
WantCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
Name: "failedProcessCallback",
|
||||
Req: goodReq,
|
||||
Auth: fakeSSOAuthenticator{
|
||||
callbackErr: errMocked,
|
||||
},
|
||||
WantCode: http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
Name: "failedQueryLocalpartForSSO",
|
||||
Req: goodReq,
|
||||
UserAPI: fakeUserAPIForSSO{
|
||||
localpartErr: errMocked,
|
||||
},
|
||||
Auth: goodAuth,
|
||||
WantCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
Name: "failedQueryNumericLocalpart",
|
||||
Req: goodReq,
|
||||
UserAPI: fakeUserAPIForSSO{
|
||||
numericLocalpartErr: errMocked,
|
||||
},
|
||||
Auth: goodAuth,
|
||||
WantCode: http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
Name: "failedAccountCreation",
|
||||
Req: goodReq,
|
||||
UserAPI: fakeUserAPIForSSO{
|
||||
accountCreationErr: errMocked,
|
||||
},
|
||||
Auth: goodAuth,
|
||||
WantCode: http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
Name: "failedSaveSSOAssociation",
|
||||
Req: goodReq,
|
||||
UserAPI: fakeUserAPIForSSO{
|
||||
saveSSOAssociationErr: errMocked,
|
||||
},
|
||||
Auth: goodAuth,
|
||||
WantCode: http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
Name: "failedPerformLoginTokenCreation",
|
||||
Req: goodReq,
|
||||
UserAPI: fakeUserAPIForSSO{
|
||||
localpart: "alocalpart",
|
||||
tokenTokenCreationErr: errMocked,
|
||||
},
|
||||
Auth: goodAuth,
|
||||
WantCode: http.StatusInternalServerError,
|
||||
},
|
||||
}
|
||||
for _, tst := range tsts {
|
||||
t.Run(tst.Name, func(t *testing.T) {
|
||||
got := SSOCallback(&tst.Req, &tst.UserAPI, &tst.Auth, &tst.Config, "aservername")
|
||||
|
||||
if got.Code != tst.WantCode {
|
||||
t.Log(got)
|
||||
t.Errorf("SSOCallback Code: got %v, want %v", got.Code, tst.WantCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type fakeSSOAuthenticator struct {
|
||||
callbackResult sso.CallbackResult
|
||||
callbackErr error
|
||||
}
|
||||
|
||||
func (auth *fakeSSOAuthenticator) AuthorizationURL(ctx context.Context, providerID, callbackURL, nonce string) (string, error) {
|
||||
return (&url.URL{
|
||||
Scheme: "http",
|
||||
Host: "auth.example.com",
|
||||
Path: "/authorize",
|
||||
}).ResolveReference(&url.URL{
|
||||
RawQuery: url.Values{
|
||||
"callbackURL": []string{callbackURL},
|
||||
"nonce": []string{nonce},
|
||||
"providerID": []string{providerID},
|
||||
}.Encode(),
|
||||
}).String(), nil
|
||||
}
|
||||
|
||||
func (auth *fakeSSOAuthenticator) ProcessCallback(ctx context.Context, providerID, callbackURL, nonce string, query url.Values) (*sso.CallbackResult, error) {
|
||||
return &auth.callbackResult, auth.callbackErr
|
||||
}
|
||||
|
||||
type fakeUserAPIForSSO struct {
|
||||
userAPIForSSO
|
||||
|
||||
accountCreationErr error
|
||||
tokenTokenCreationErr error
|
||||
saveSSOAssociationErr error
|
||||
localpart string
|
||||
localpartErr error
|
||||
numericLocalpartErr error
|
||||
|
||||
gotAccountCreation []*uapi.PerformAccountCreationRequest
|
||||
gotLoginTokenCreation []*uapi.PerformLoginTokenCreationRequest
|
||||
gotSaveSSOAssociation []*uapi.PerformSaveSSOAssociationRequest
|
||||
gotQueryLocalpart []*uapi.QueryLocalpartForSSORequest
|
||||
}
|
||||
|
||||
func (userAPI *fakeUserAPIForSSO) PerformAccountCreation(ctx context.Context, req *uapi.PerformAccountCreationRequest, res *uapi.PerformAccountCreationResponse) error {
|
||||
userAPI.gotAccountCreation = append(userAPI.gotAccountCreation, req)
|
||||
return userAPI.accountCreationErr
|
||||
}
|
||||
|
||||
func (userAPI *fakeUserAPIForSSO) PerformLoginTokenCreation(ctx context.Context, req *uapi.PerformLoginTokenCreationRequest, res *uapi.PerformLoginTokenCreationResponse) error {
|
||||
userAPI.gotLoginTokenCreation = append(userAPI.gotLoginTokenCreation, req)
|
||||
res.Metadata = uapi.LoginTokenMetadata{
|
||||
Token: "atoken",
|
||||
}
|
||||
return userAPI.tokenTokenCreationErr
|
||||
}
|
||||
|
||||
func (userAPI *fakeUserAPIForSSO) PerformSaveSSOAssociation(ctx context.Context, req *uapi.PerformSaveSSOAssociationRequest, res *struct{}) error {
|
||||
userAPI.gotSaveSSOAssociation = append(userAPI.gotSaveSSOAssociation, req)
|
||||
return userAPI.saveSSOAssociationErr
|
||||
}
|
||||
|
||||
func (userAPI *fakeUserAPIForSSO) QueryLocalpartForSSO(ctx context.Context, req *uapi.QueryLocalpartForSSORequest, res *uapi.QueryLocalpartForSSOResponse) error {
|
||||
userAPI.gotQueryLocalpart = append(userAPI.gotQueryLocalpart, req)
|
||||
res.Localpart = userAPI.localpart
|
||||
return userAPI.localpartErr
|
||||
}
|
||||
|
||||
func (userAPI *fakeUserAPIForSSO) QueryNumericLocalpart(ctx context.Context, res *uapi.QueryNumericLocalpartResponse) error {
|
||||
res.ID = 12345
|
||||
return userAPI.numericLocalpartErr
|
||||
}
|
||||
|
|
@ -86,6 +86,22 @@ client_api:
|
|||
recaptcha_private_key: ""
|
||||
recaptcha_bypass_secret: ""
|
||||
recaptcha_siteverify_api: ""
|
||||
login:
|
||||
sso:
|
||||
enabled: true
|
||||
callback_url: http://example.com:8071/_matrix/v3/login/sso/callback
|
||||
default_provider: github
|
||||
providers:
|
||||
- brand: github
|
||||
- id: custom
|
||||
name: "Custom Provider"
|
||||
icon: "mxc://example.com/abc123"
|
||||
type: oidc
|
||||
oauth2:
|
||||
client_id: aclientid
|
||||
client_secret: aclientsecret
|
||||
oidc:
|
||||
discovery_url: http://auth.example.com/.well-known/openid-configuration
|
||||
turn:
|
||||
turn_user_lifetime: ""
|
||||
turn_uris: []
|
||||
|
|
|
|||
|
|
@ -429,6 +429,41 @@ func Test_Pusher(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func Test_SSO(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
|
||||
t.Log("Create SSO association")
|
||||
|
||||
ns := util.RandomString(8)
|
||||
issuer := util.RandomString(8)
|
||||
subject := util.RandomString(8)
|
||||
err := db.SaveSSOAssociation(ctx, ns, issuer, subject, aliceLocalpart)
|
||||
assert.NoError(t, err, "unable to save SSO association")
|
||||
|
||||
t.Log("Retrieve localpart for association")
|
||||
|
||||
gotLocalpart, err := db.GetLocalpartForSSO(ctx, ns, issuer, subject)
|
||||
assert.Equal(t, aliceLocalpart, gotLocalpart)
|
||||
|
||||
t.Log("Remove SSO association")
|
||||
|
||||
err = db.RemoveSSOAssociation(ctx, ns, issuer, subject)
|
||||
assert.NoError(t, err, "unexpected error")
|
||||
|
||||
t.Log("Verify the SSO association was removed")
|
||||
|
||||
gotLocalpart, err = db.GetLocalpartForSSO(ctx, ns, issuer, subject)
|
||||
assert.NoError(t, err, "unable to get localpart for SSO subject")
|
||||
assert.Equal(t, "", gotLocalpart)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_ThreePID(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
|
|
|
|||
Loading…
Reference in a new issue