mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-07 06:03:09 -06:00
Add SSO tests.
Renames cookie oidc_nonce to sso_nonce, since it's defined in a file that doesn't know about OIDC specifically.
This commit is contained in:
parent
b8844fb1e2
commit
210ab1eef6
|
|
@ -101,6 +101,34 @@ func TestLoginFromJSONReader(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLoginFromJSONReaderTokenDisabled(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
var userAPI fakeUserInternalAPI
|
||||||
|
cfg := &config.ClientAPI{
|
||||||
|
Matrix: &config.Global{
|
||||||
|
ServerName: serverName,
|
||||||
|
},
|
||||||
|
Login: config.Login{
|
||||||
|
SSO: config.SSO{
|
||||||
|
Enabled: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_, cleanup, err := LoginFromJSONReader(ctx, strings.NewReader(`{
|
||||||
|
"type": "m.login.token",
|
||||||
|
"token": "atoken",
|
||||||
|
"device_id": "adevice"
|
||||||
|
}`), &userAPI, &userAPI, cfg)
|
||||||
|
wantCode := "M_INVALID_ARGUMENT_VALUE"
|
||||||
|
if err == nil {
|
||||||
|
cleanup(ctx, nil)
|
||||||
|
t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", err, wantCode)
|
||||||
|
} else if merr, ok := err.JSON.(*jsonerror.MatrixError); ok && merr.ErrCode != wantCode {
|
||||||
|
t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", err, wantCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestBadLoginFromJSONReader(t *testing.T) {
|
func TestBadLoginFromJSONReader(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
|
|
|
||||||
226
clientapi/auth/sso/oauth2_test.go
Normal file
226
clientapi/auth/sso/oauth2_test.go
Normal file
|
|
@ -0,0 +1,226 @@
|
||||||
|
package sso
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
uapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOAuth2IdentityProviderAuthorizationURL(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
idp := &oauth2IdentityProvider{
|
||||||
|
cfg: &config.IdentityProvider{
|
||||||
|
OAuth2: config.OAuth2{
|
||||||
|
ClientID: "aclientid",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
hc: http.DefaultClient,
|
||||||
|
|
||||||
|
authorizationURL: "https://oauth2.example.com/authorize",
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := idp.AuthorizationURL(ctx, "https://matrix.example.com/continue", "anonce")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("AuthorizationURL failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if want := "https://oauth2.example.com/authorize?client_id=aclientid&redirect_uri=https%3A%2F%2Fmatrix.example.com%2Fcontinue&response_type=code&scope=&state=anonce"; got != want {
|
||||||
|
t.Errorf("AuthorizationURL: got %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOAuth2IdentityProviderProcessCallback(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
const callbackURL = "https://matrix.example.com/continue"
|
||||||
|
|
||||||
|
tsts := []struct {
|
||||||
|
Name string
|
||||||
|
Query url.Values
|
||||||
|
|
||||||
|
Want *CallbackResult
|
||||||
|
WantTokenReq url.Values
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Name: "gotEverything",
|
||||||
|
Query: url.Values{
|
||||||
|
"code": []string{"acode"},
|
||||||
|
"state": []string{"anonce"},
|
||||||
|
},
|
||||||
|
|
||||||
|
Want: &CallbackResult{
|
||||||
|
Identifier: &UserIdentifier{
|
||||||
|
Namespace: uapi.SSOIDNamespace,
|
||||||
|
Issuer: "anid",
|
||||||
|
Subject: "asub",
|
||||||
|
},
|
||||||
|
DisplayName: "aname",
|
||||||
|
SuggestedUserID: "auser",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tst := range tsts {
|
||||||
|
t.Run(tst.Name, func(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Write([]byte(`{"access_token":"atoken", "token_type":"Bearer"}`))
|
||||||
|
})
|
||||||
|
mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Write([]byte(`{"sub":"asub", "name":"aname", "preferred_user":"auser"}`))
|
||||||
|
})
|
||||||
|
|
||||||
|
s := httptest.NewServer(mux)
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
idp := &oauth2IdentityProvider{
|
||||||
|
cfg: &config.IdentityProvider{
|
||||||
|
ID: "anid",
|
||||||
|
OAuth2: config.OAuth2{
|
||||||
|
ClientID: "aclientid",
|
||||||
|
ClientSecret: "aclientsecret",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
hc: s.Client(),
|
||||||
|
|
||||||
|
accessTokenURL: s.URL + "/token",
|
||||||
|
userInfoURL: s.URL + "/userinfo",
|
||||||
|
|
||||||
|
subPath: "sub",
|
||||||
|
displayNamePath: "name",
|
||||||
|
suggestedUserIDPath: "preferred_user",
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := idp.ProcessCallback(ctx, callbackURL, "anonce", tst.Query)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ProcessCallback failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(got, tst.Want) {
|
||||||
|
t.Errorf("ProcessCallback: got %+v, want %+v", got, tst.Want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOAuth2IdentityProviderGetAccessToken(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
const callbackURL = "https://matrix.example.com/continue"
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
var gotReq url.Values
|
||||||
|
mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
gotReq = r.Form
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Write([]byte(`{"access_token":"atoken", "token_type":"Bearer"}`))
|
||||||
|
})
|
||||||
|
|
||||||
|
s := httptest.NewServer(mux)
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
idp := &oauth2IdentityProvider{
|
||||||
|
cfg: &config.IdentityProvider{
|
||||||
|
ID: "anid",
|
||||||
|
OAuth2: config.OAuth2{
|
||||||
|
ClientID: "aclientid",
|
||||||
|
ClientSecret: "aclientsecret",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
hc: s.Client(),
|
||||||
|
|
||||||
|
accessTokenURL: s.URL + "/token",
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := idp.getAccessToken(ctx, callbackURL, "acode")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("getAccessToken failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if want := "atoken"; got != want {
|
||||||
|
t.Errorf("getAccessToken: got %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
wantReq := url.Values{
|
||||||
|
"client_id": []string{"aclientid"},
|
||||||
|
"client_secret": []string{"aclientsecret"},
|
||||||
|
"code": []string{"acode"},
|
||||||
|
"grant_type": []string{"authorization_code"},
|
||||||
|
"redirect_uri": []string{callbackURL},
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(gotReq, wantReq) {
|
||||||
|
t.Errorf("getAccessToken request: got %+v, want %+v", gotReq, wantReq)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOAuth2IdentityProviderGetUserInfo(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
var gotHeader http.Header
|
||||||
|
mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
gotHeader = r.Header
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Write([]byte(`{"sub":"asub", "name":"aname", "preferred_user":"auser"}`))
|
||||||
|
})
|
||||||
|
|
||||||
|
s := httptest.NewServer(mux)
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
idp := &oauth2IdentityProvider{
|
||||||
|
cfg: &config.IdentityProvider{
|
||||||
|
ID: "anid",
|
||||||
|
OAuth2: config.OAuth2{
|
||||||
|
ClientID: "aclientid",
|
||||||
|
ClientSecret: "aclientsecret",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
hc: s.Client(),
|
||||||
|
|
||||||
|
userInfoURL: s.URL + "/userinfo",
|
||||||
|
|
||||||
|
responseMimeType: "application/json",
|
||||||
|
subPath: "sub",
|
||||||
|
displayNamePath: "name",
|
||||||
|
suggestedUserIDPath: "preferred_user",
|
||||||
|
}
|
||||||
|
|
||||||
|
gotSub, gotName, gotSuggestedUser, err := idp.getUserInfo(ctx, "atoken")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("getUserInfo failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if want := "asub"; gotSub != want {
|
||||||
|
t.Errorf("getUserInfo subject: got %q, want %q", gotSub, want)
|
||||||
|
}
|
||||||
|
if want := "aname"; gotName != want {
|
||||||
|
t.Errorf("getUserInfo displayName: got %q, want %q", gotName, want)
|
||||||
|
}
|
||||||
|
if want := "auser"; gotSuggestedUser != want {
|
||||||
|
t.Errorf("getUserInfo suggestedUser: got %q, want %q", gotSuggestedUser, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
gotHeader.Del("Accept-Encoding")
|
||||||
|
gotHeader.Del("User-Agent")
|
||||||
|
wantHeader := http.Header{
|
||||||
|
"Accept": []string{"application/json"},
|
||||||
|
"Authorization": []string{"Bearer atoken"},
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(gotHeader, wantHeader) {
|
||||||
|
t.Errorf("getUserInfo header: got %+v, want %+v", gotHeader, wantHeader)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -27,6 +27,18 @@ import (
|
||||||
uapi "github.com/matrix-org/dendrite/userapi/api"
|
uapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// oidcDiscoveryMaxStaleness indicates how stale the Discovery
|
||||||
|
// information is allowed to be. This will very rarely change, so
|
||||||
|
// we're just making sure even a Dendrite that isn't restarting often
|
||||||
|
// is picking this up eventually.
|
||||||
|
const oidcDiscoveryMaxStaleness = 24 * time.Hour
|
||||||
|
|
||||||
|
// An oidcIdentityProvider wraps OAuth2 with OpenID Connect Discovery.
|
||||||
|
//
|
||||||
|
// The SSO identifier is the "sub." A suggested UserID is grabbed from
|
||||||
|
// "preferred_username", though this isn't commonly provided.
|
||||||
|
//
|
||||||
|
// See https://openid.net/specs/openid-connect-core-1_0.html and https://openid.net/specs/openid-connect-discovery-1_0.html.
|
||||||
type oidcIdentityProvider struct {
|
type oidcIdentityProvider struct {
|
||||||
*oauth2IdentityProvider
|
*oauth2IdentityProvider
|
||||||
|
|
||||||
|
|
@ -44,7 +56,7 @@ func newOIDCIdentityProvider(cfg *config.IdentityProvider, hc *http.Client) *oid
|
||||||
scopes: []string{"openid", "profile", "email"},
|
scopes: []string{"openid", "profile", "email"},
|
||||||
responseMimeType: "application/json",
|
responseMimeType: "application/json",
|
||||||
subPath: "sub",
|
subPath: "sub",
|
||||||
emailPath: "email",
|
emailPath: "email", // TODO: should this require email_verified?
|
||||||
displayNamePath: "name",
|
displayNamePath: "name",
|
||||||
suggestedUserIDPath: "preferred_username",
|
suggestedUserIDPath: "preferred_username",
|
||||||
},
|
},
|
||||||
|
|
@ -92,7 +104,7 @@ func (p *oidcIdentityProvider) get(ctx context.Context) (*oauth2IdentityProvider
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
p.exp = now.Add(24 * time.Hour)
|
p.exp = now.Add(oidcDiscoveryMaxStaleness)
|
||||||
newProvider := *p.oauth2IdentityProvider
|
newProvider := *p.oauth2IdentityProvider
|
||||||
newProvider.authorizationURL = disc.AuthorizationEndpoint
|
newProvider.authorizationURL = disc.AuthorizationEndpoint
|
||||||
newProvider.accessTokenURL = disc.TokenEndpoint
|
newProvider.accessTokenURL = disc.TokenEndpoint
|
||||||
|
|
|
||||||
118
clientapi/auth/sso/oidc_test.go
Normal file
118
clientapi/auth/sso/oidc_test.go
Normal file
|
|
@ -0,0 +1,118 @@
|
||||||
|
package sso
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
uapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOIDCIdentityProviderAuthorizationURL(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/discovery", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Write([]byte(`{"authorization_endpoint":"http://oidc.example.com/authorize","token_endpoint":"http://oidc.example.com/token","userinfo_endpoint":"http://oidc.example.com/userinfo","issuer":"http://oidc.example.com/"}`))
|
||||||
|
})
|
||||||
|
|
||||||
|
s := httptest.NewServer(mux)
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
idp := newOIDCIdentityProvider(&config.IdentityProvider{
|
||||||
|
OAuth2: config.OAuth2{
|
||||||
|
ClientID: "aclientid",
|
||||||
|
},
|
||||||
|
OIDC: config.OIDC{
|
||||||
|
DiscoveryURL: s.URL + "/discovery",
|
||||||
|
},
|
||||||
|
}, s.Client())
|
||||||
|
|
||||||
|
got, err := idp.AuthorizationURL(ctx, "https://matrix.example.com/continue", "anonce")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("AuthorizationURL failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if want := "http://oidc.example.com/authorize?client_id=aclientid&redirect_uri=https%3A%2F%2Fmatrix.example.com%2Fcontinue&response_type=code&scope=openid+profile+email&state=anonce"; got != want {
|
||||||
|
t.Errorf("AuthorizationURL: got %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOIDCIdentityProviderProcessCallback(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
const callbackURL = "https://matrix.example.com/continue"
|
||||||
|
|
||||||
|
tsts := []struct {
|
||||||
|
Name string
|
||||||
|
Query url.Values
|
||||||
|
|
||||||
|
Want *CallbackResult
|
||||||
|
WantTokenReq url.Values
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Name: "gotEverything",
|
||||||
|
Query: url.Values{
|
||||||
|
"code": []string{"acode"},
|
||||||
|
"state": []string{"anonce"},
|
||||||
|
},
|
||||||
|
|
||||||
|
Want: &CallbackResult{
|
||||||
|
Identifier: &UserIdentifier{
|
||||||
|
Namespace: uapi.OIDCNamespace,
|
||||||
|
Issuer: "http://oidc.example.com/",
|
||||||
|
Subject: "asub",
|
||||||
|
},
|
||||||
|
DisplayName: "aname",
|
||||||
|
SuggestedUserID: "auser",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tst := range tsts {
|
||||||
|
t.Run(tst.Name, func(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
var sURL string
|
||||||
|
mux.HandleFunc("/discovery", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Write([]byte(fmt.Sprintf(`{"authorization_endpoint":"%s/authorize","token_endpoint":"%s/token","userinfo_endpoint":"%s/userinfo","issuer":"http://oidc.example.com/"}`,
|
||||||
|
sURL, sURL, sURL)))
|
||||||
|
})
|
||||||
|
mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Write([]byte(`{"access_token":"atoken", "token_type":"Bearer"}`))
|
||||||
|
})
|
||||||
|
mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Write([]byte(`{"sub":"asub", "name":"aname", "preferred_username":"auser"}`))
|
||||||
|
})
|
||||||
|
|
||||||
|
s := httptest.NewServer(mux)
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
sURL = s.URL
|
||||||
|
idp := newOIDCIdentityProvider(&config.IdentityProvider{
|
||||||
|
OAuth2: config.OAuth2{
|
||||||
|
ClientID: "aclientid",
|
||||||
|
},
|
||||||
|
OIDC: config.OIDC{
|
||||||
|
DiscoveryURL: sURL + "/discovery",
|
||||||
|
},
|
||||||
|
}, s.Client())
|
||||||
|
|
||||||
|
got, err := idp.ProcessCallback(ctx, callbackURL, "anonce", tst.Query)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ProcessCallback failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(got, tst.Want) {
|
||||||
|
t.Errorf("ProcessCallback: got %+v, want %+v", got, tst.Want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -25,13 +25,19 @@ import (
|
||||||
uapi "github.com/matrix-org/dendrite/userapi/api"
|
uapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// maxHTTPTimeout is an upper bound on an HTTP request to an SSO
|
||||||
|
// backend. The individual request context deadlines are also honored.
|
||||||
|
const maxHTTPTimeout = 10 * time.Second
|
||||||
|
|
||||||
|
// An Authenticator keeps a set of identity providers and dispatches
|
||||||
|
// calls to one of them, based on configured ID.
|
||||||
type Authenticator struct {
|
type Authenticator struct {
|
||||||
providers map[string]identityProvider
|
providers map[string]identityProvider
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAuthenticator(ctx context.Context, cfg *config.SSO) (*Authenticator, error) {
|
func NewAuthenticator(cfg *config.SSO) (*Authenticator, error) {
|
||||||
hc := &http.Client{
|
hc := &http.Client{
|
||||||
Timeout: 10 * time.Second,
|
Timeout: maxHTTPTimeout,
|
||||||
Transport: &http.Transport{
|
Transport: &http.Transport{
|
||||||
DisableKeepAlives: true,
|
DisableKeepAlives: true,
|
||||||
Proxy: http.ProxyFromEnvironment,
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
|
|
||||||
76
clientapi/auth/sso/sso_test.go
Normal file
76
clientapi/auth/sso/sso_test.go
Normal file
|
|
@ -0,0 +1,76 @@
|
||||||
|
package sso
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/url"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewAuthenticator(t *testing.T) {
|
||||||
|
_, err := NewAuthenticator(&config.SSO{
|
||||||
|
Providers: []config.IdentityProvider{
|
||||||
|
{
|
||||||
|
Type: config.SSOTypeGitHub,
|
||||||
|
OAuth2: config.OAuth2{
|
||||||
|
ClientID: "aclientid",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: config.SSOTypeOIDC,
|
||||||
|
OAuth2: config.OAuth2{
|
||||||
|
ClientID: "aclientid",
|
||||||
|
},
|
||||||
|
OIDC: config.OIDC{
|
||||||
|
DiscoveryURL: "http://oidc.example.com/discovery",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewAuthenticator failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthenticator(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
var idp fakeIdentityProvider
|
||||||
|
a := Authenticator{
|
||||||
|
providers: map[string]identityProvider{
|
||||||
|
"fake": &idp,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("authorizationURL", func(t *testing.T) {
|
||||||
|
got, err := a.AuthorizationURL(ctx, "fake", "http://matrix.example.com/continue", "anonce")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("AuthorizationURL failed: %v", err)
|
||||||
|
}
|
||||||
|
if want := "aurl"; got != want {
|
||||||
|
t.Errorf("AuthorizationURL: got %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("processCallback", func(t *testing.T) {
|
||||||
|
got, err := a.ProcessCallback(ctx, "fake", "http://matrix.example.com/continue", "anonce", url.Values{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ProcessCallback failed: %v", err)
|
||||||
|
}
|
||||||
|
if want := (&CallbackResult{DisplayName: "aname"}); !reflect.DeepEqual(got, want) {
|
||||||
|
t.Errorf("ProcessCallback: got %+v, want %+v", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeIdentityProvider struct{}
|
||||||
|
|
||||||
|
func (idp *fakeIdentityProvider) AuthorizationURL(ctx context.Context, callbackURL, nonce string) (string, error) {
|
||||||
|
return "aurl", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (idp *fakeIdentityProvider) ProcessCallback(ctx context.Context, callbackURL, nonce string, query url.Values) (*CallbackResult, error) {
|
||||||
|
return &CallbackResult{DisplayName: "aname"}, nil
|
||||||
|
}
|
||||||
|
|
@ -73,7 +73,7 @@ func Setup(
|
||||||
var ssoAuthenticator *sso.Authenticator
|
var ssoAuthenticator *sso.Authenticator
|
||||||
if cfg.Login.SSO.Enabled {
|
if cfg.Login.SSO.Enabled {
|
||||||
var err error
|
var err error
|
||||||
ssoAuthenticator, err = sso.NewAuthenticator(ctx, &cfg.Login.SSO)
|
ssoAuthenticator, err = sso.NewAuthenticator(&cfg.Login.SSO)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Fatal("failed to create SSO authenticator")
|
logrus.WithError(err).Fatal("failed to create SSO authenticator")
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ import (
|
||||||
func SSORedirect(
|
func SSORedirect(
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
idpID string,
|
idpID string,
|
||||||
auth *sso.Authenticator,
|
auth ssoAuthenticator,
|
||||||
cfg *config.SSO,
|
cfg *config.SSO,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
ctx := req.Context()
|
ctx := req.Context()
|
||||||
|
|
@ -58,12 +58,16 @@ func SSORedirect(
|
||||||
JSON: jsonerror.MissingArgument("redirectUrl parameter missing"),
|
JSON: jsonerror.MissingArgument("redirectUrl parameter missing"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_, err := url.Parse(redirectURL)
|
if ru, err := url.Parse(redirectURL); err != nil {
|
||||||
if err != nil {
|
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusBadRequest,
|
Code: http.StatusBadRequest,
|
||||||
JSON: jsonerror.InvalidArgumentValue("Invalid redirectURL: " + err.Error()),
|
JSON: jsonerror.InvalidArgumentValue("Invalid redirectURL: " + err.Error()),
|
||||||
}
|
}
|
||||||
|
} else if ru.Scheme == "" || ru.Host == "" || ru.Path == "" {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.InvalidArgumentValue("Invalid redirectURL: " + redirectURL),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
callbackURL, err := buildCallbackURLFromOther(cfg, req, "/login/sso/redirect")
|
callbackURL, err := buildCallbackURLFromOther(cfg, req, "/login/sso/redirect")
|
||||||
|
|
@ -92,7 +96,7 @@ func SSORedirect(
|
||||||
|
|
||||||
resp := util.RedirectResponse(u)
|
resp := util.RedirectResponse(u)
|
||||||
cookie := &http.Cookie{
|
cookie := &http.Cookie{
|
||||||
Name: "oidc_nonce",
|
Name: "sso_nonce",
|
||||||
Value: nonce,
|
Value: nonce,
|
||||||
Path: path.Dir(callbackURL.Path),
|
Path: path.Dir(callbackURL.Path),
|
||||||
Expires: time.Now().Add(10 * time.Minute),
|
Expires: time.Now().Add(10 * time.Minute),
|
||||||
|
|
@ -113,7 +117,6 @@ func SSORedirect(
|
||||||
func buildCallbackURLFromOther(cfg *config.SSO, req *http.Request, expectedPath string) (*url.URL, error) {
|
func buildCallbackURLFromOther(cfg *config.SSO, req *http.Request, expectedPath string) (*url.URL, error) {
|
||||||
u := &url.URL{
|
u := &url.URL{
|
||||||
Scheme: "https",
|
Scheme: "https",
|
||||||
User: req.URL.User,
|
|
||||||
Host: req.Host,
|
Host: req.Host,
|
||||||
Path: req.URL.Path,
|
Path: req.URL.Path,
|
||||||
}
|
}
|
||||||
|
|
@ -141,7 +144,7 @@ func buildCallbackURLFromOther(cfg *config.SSO, req *http.Request, expectedPath
|
||||||
func SSOCallback(
|
func SSOCallback(
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
userAPI userAPIForSSO,
|
userAPI userAPIForSSO,
|
||||||
auth *sso.Authenticator,
|
auth ssoAuthenticator,
|
||||||
cfg *config.SSO,
|
cfg *config.SSO,
|
||||||
serverName gomatrixserverlib.ServerName,
|
serverName gomatrixserverlib.ServerName,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
|
|
@ -163,7 +166,7 @@ func SSOCallback(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
nonce, err := req.Cookie("oidc_nonce")
|
nonce, err := req.Cookie("sso_nonce")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusBadRequest,
|
Code: http.StatusBadRequest,
|
||||||
|
|
@ -246,7 +249,7 @@ func SSOCallback(
|
||||||
rquery.Set("loginToken", token.Token)
|
rquery.Set("loginToken", token.Token)
|
||||||
resp := util.RedirectResponse(finalRedirectURL.ResolveReference(&url.URL{RawQuery: rquery.Encode()}).String())
|
resp := util.RedirectResponse(finalRedirectURL.ResolveReference(&url.URL{RawQuery: rquery.Encode()}).String())
|
||||||
resp.Headers["Set-Cookie"] = (&http.Cookie{
|
resp.Headers["Set-Cookie"] = (&http.Cookie{
|
||||||
Name: "oidc_nonce",
|
Name: "sso_nonce",
|
||||||
Value: "",
|
Value: "",
|
||||||
MaxAge: -1,
|
MaxAge: -1,
|
||||||
Secure: true,
|
Secure: true,
|
||||||
|
|
@ -254,6 +257,11 @@ func SSOCallback(
|
||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ssoAuthenticator interface {
|
||||||
|
AuthorizationURL(ctx context.Context, providerID, callbackURL, nonce string) (string, error)
|
||||||
|
ProcessCallback(ctx context.Context, providerID, callbackURL, nonce string, query url.Values) (*sso.CallbackResult, error)
|
||||||
|
}
|
||||||
|
|
||||||
type userAPIForSSO interface {
|
type userAPIForSSO interface {
|
||||||
uapi.LoginTokenInternalAPI
|
uapi.LoginTokenInternalAPI
|
||||||
|
|
||||||
|
|
@ -273,21 +281,21 @@ func formatNonce(redirectURL string) string {
|
||||||
// function. The URL is not integrity protected.
|
// function. The URL is not integrity protected.
|
||||||
func parseNonce(s string) (redirectURL *url.URL, _ error) {
|
func parseNonce(s string) (redirectURL *url.URL, _ error) {
|
||||||
if s == "" {
|
if s == "" {
|
||||||
return nil, jsonerror.MissingArgument("empty OIDC nonce cookie")
|
return nil, jsonerror.MissingArgument("empty SSO nonce cookie")
|
||||||
}
|
}
|
||||||
|
|
||||||
ss := strings.Split(s, ".")
|
ss := strings.Split(s, ".")
|
||||||
if len(ss) < 2 {
|
if len(ss) < 2 {
|
||||||
return nil, jsonerror.InvalidArgumentValue("malformed OIDC nonce cookie")
|
return nil, jsonerror.InvalidArgumentValue("malformed SSO nonce cookie")
|
||||||
}
|
}
|
||||||
|
|
||||||
urlbs, err := base64.RawURLEncoding.DecodeString(ss[1])
|
urlbs, err := base64.RawURLEncoding.DecodeString(ss[1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, jsonerror.InvalidArgumentValue("invalid redirect URL in OIDC nonce cookie")
|
return nil, jsonerror.InvalidArgumentValue("invalid redirect URL in SSO nonce cookie")
|
||||||
}
|
}
|
||||||
u, err := url.Parse(string(urlbs))
|
u, err := url.Parse(string(urlbs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, jsonerror.InvalidArgumentValue("invalid redirect URL in OIDC nonce cookie: " + err.Error())
|
return nil, jsonerror.InvalidArgumentValue("invalid redirect URL in SSO nonce cookie: " + err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
return u, nil
|
return u, nil
|
||||||
|
|
@ -309,6 +317,9 @@ func verifySSOUserIdentifier(ctx context.Context, userAPI userAPIForSSO, id *sso
|
||||||
return res.Localpart, nil
|
return res.Localpart, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// registerSSOAccount creates an account and associates the SSO
|
||||||
|
// identifier with it. Note that SSO login account creation doesn't
|
||||||
|
// use the standard registration API, but happens ad-hoc.
|
||||||
func registerSSOAccount(ctx context.Context, userAPI userAPIForSSO, ssoID *sso.UserIdentifier, localpart string) (bool, util.JSONResponse) {
|
func registerSSOAccount(ctx context.Context, userAPI userAPIForSSO, ssoID *sso.UserIdentifier, localpart string) (bool, util.JSONResponse) {
|
||||||
var accRes uapi.PerformAccountCreationResponse
|
var accRes uapi.PerformAccountCreationResponse
|
||||||
err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
|
err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
|
||||||
|
|
@ -347,6 +358,8 @@ func registerSSOAccount(ctx context.Context, userAPI userAPIForSSO, ssoID *sso.U
|
||||||
return true, util.JSONResponse{}
|
return true, util.JSONResponse{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// createLoginToken produces a new login token, valid for the given
|
||||||
|
// user.
|
||||||
func createLoginToken(ctx context.Context, userAPI userAPIForSSO, userID string) (*uapi.LoginTokenMetadata, error) {
|
func createLoginToken(ctx context.Context, userAPI userAPIForSSO, userID string) (*uapi.LoginTokenMetadata, error) {
|
||||||
req := uapi.PerformLoginTokenCreationRequest{Data: uapi.LoginTokenData{UserID: userID}}
|
req := uapi.PerformLoginTokenCreationRequest{Data: uapi.LoginTokenData{UserID: userID}}
|
||||||
var resp uapi.PerformLoginTokenCreationResponse
|
var resp uapi.PerformLoginTokenCreationResponse
|
||||||
|
|
|
||||||
531
clientapi/routing/sso_test.go
Normal file
531
clientapi/routing/sso_test.go
Normal file
|
|
@ -0,0 +1,531 @@
|
||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"regexp"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/sso"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
uapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSSORedirect(t *testing.T) {
|
||||||
|
tsts := []struct {
|
||||||
|
Name string
|
||||||
|
Req http.Request
|
||||||
|
IDPID string
|
||||||
|
Auth fakeSSOAuthenticator
|
||||||
|
Config config.SSO
|
||||||
|
|
||||||
|
WantLocationRE string
|
||||||
|
WantSetCookieRE string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Name: "redirectDefault",
|
||||||
|
Req: http.Request{
|
||||||
|
Host: "matrix.example.com",
|
||||||
|
URL: &url.URL{
|
||||||
|
Path: "/_matrix/v4/login/sso/redirect",
|
||||||
|
RawQuery: url.Values{
|
||||||
|
"redirectUrl": []string{"http://example.com/continue"},
|
||||||
|
}.Encode(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
WantLocationRE: `http://auth.example.com/authorize\?callbackURL=http%3A%2F%2Fmatrix.example.com%2F_matrix%2Fv4%2Flogin%2Fsso%2Fcallback%3Fprovider%3D&nonce=.+&providerID=`,
|
||||||
|
WantSetCookieRE: "sso_nonce=[^;].*Path=/_matrix/v4/login/sso",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "redirectExplicitProvider",
|
||||||
|
Req: http.Request{
|
||||||
|
Host: "matrix.example.com",
|
||||||
|
URL: &url.URL{
|
||||||
|
Path: "/_matrix/v4/login/sso/redirect",
|
||||||
|
RawQuery: url.Values{
|
||||||
|
"redirectUrl": []string{"http://example.com/continue"},
|
||||||
|
}.Encode(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
IDPID: "someprovider",
|
||||||
|
WantLocationRE: `http://auth.example.com/authorize\?callbackURL=http.*%3Fprovider%3Dsomeprovider&nonce=.+&providerID=someprovider`,
|
||||||
|
WantSetCookieRE: "sso_nonce=[^;].*Path=/_matrix/v4/login/sso",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tst := range tsts {
|
||||||
|
t.Run(tst.Name, func(t *testing.T) {
|
||||||
|
got := SSORedirect(&tst.Req, tst.IDPID, &tst.Auth, &tst.Config)
|
||||||
|
|
||||||
|
if want := http.StatusFound; got.Code != want {
|
||||||
|
t.Errorf("SSORedirect Code: got %v, want %v", got.Code, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
if m, err := regexp.MatchString(tst.WantLocationRE, got.Headers["Location"]); err != nil {
|
||||||
|
t.Fatalf("WantSetCookieRE failed: %v", err)
|
||||||
|
} else if !m {
|
||||||
|
t.Errorf("SSORedirect Location: got %q, want match %v", got.Headers["Location"], tst.WantLocationRE)
|
||||||
|
}
|
||||||
|
|
||||||
|
if m, err := regexp.MatchString(tst.WantSetCookieRE, got.Headers["Set-Cookie"]); err != nil {
|
||||||
|
t.Fatalf("WantSetCookieRE failed: %v", err)
|
||||||
|
} else if !m {
|
||||||
|
t.Errorf("SSORedirect Set-Cookie: got %q, want match %v", got.Headers["Set-Cookie"], tst.WantSetCookieRE)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSORedirectError(t *testing.T) {
|
||||||
|
tsts := []struct {
|
||||||
|
Name string
|
||||||
|
Req http.Request
|
||||||
|
IDPID string
|
||||||
|
Auth fakeSSOAuthenticator
|
||||||
|
Config config.SSO
|
||||||
|
|
||||||
|
WantCode int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Name: "missingRedirectUrl",
|
||||||
|
Req: http.Request{
|
||||||
|
Host: "matrix.example.com",
|
||||||
|
URL: &url.URL{
|
||||||
|
Path: "/_matrix/v4/login/sso/redirect",
|
||||||
|
RawQuery: url.Values{}.Encode(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
WantCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "invalidRedirectUrl",
|
||||||
|
Req: http.Request{
|
||||||
|
Host: "matrix.example.com",
|
||||||
|
URL: &url.URL{
|
||||||
|
Path: "/_matrix/v4/login/sso/redirect",
|
||||||
|
RawQuery: url.Values{
|
||||||
|
"redirectUrl": []string{"/continue"},
|
||||||
|
}.Encode(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
WantCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tst := range tsts {
|
||||||
|
t.Run(tst.Name, func(t *testing.T) {
|
||||||
|
got := SSORedirect(&tst.Req, tst.IDPID, &tst.Auth, &tst.Config)
|
||||||
|
|
||||||
|
if got.Code != tst.WantCode {
|
||||||
|
t.Errorf("SSORedirect Code: got %v, want %v", got.Code, tst.WantCode)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSOCallback(t *testing.T) {
|
||||||
|
nonce := "1234." + base64.RawURLEncoding.EncodeToString([]byte("http://matrix.example.com/continue"))
|
||||||
|
|
||||||
|
tsts := []struct {
|
||||||
|
Name string
|
||||||
|
Req http.Request
|
||||||
|
UserAPI fakeUserAPIForSSO
|
||||||
|
Auth fakeSSOAuthenticator
|
||||||
|
Config config.SSO
|
||||||
|
|
||||||
|
WantLocationRE string
|
||||||
|
WantSetCookieRE string
|
||||||
|
|
||||||
|
WantAccountCreation []*uapi.PerformAccountCreationRequest
|
||||||
|
WantLoginTokenCreation []*uapi.PerformLoginTokenCreationRequest
|
||||||
|
WantSaveSSOAssociation []*uapi.PerformSaveSSOAssociationRequest
|
||||||
|
WantQueryLocalpart []*uapi.QueryLocalpartForSSORequest
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Name: "logIn",
|
||||||
|
Req: http.Request{
|
||||||
|
Host: "matrix.example.com",
|
||||||
|
URL: &url.URL{
|
||||||
|
Path: "/_matrix/v4/login/sso/callback",
|
||||||
|
RawQuery: url.Values{
|
||||||
|
"provider": []string{"aprovider"},
|
||||||
|
}.Encode(),
|
||||||
|
},
|
||||||
|
Header: http.Header{
|
||||||
|
"Cookie": []string{(&http.Cookie{
|
||||||
|
Name: "sso_nonce",
|
||||||
|
Value: nonce,
|
||||||
|
}).String()},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
UserAPI: fakeUserAPIForSSO{
|
||||||
|
localpart: "alocalpart",
|
||||||
|
},
|
||||||
|
Auth: fakeSSOAuthenticator{
|
||||||
|
callbackResult: sso.CallbackResult{
|
||||||
|
Identifier: &sso.UserIdentifier{
|
||||||
|
Namespace: "anamespace",
|
||||||
|
Issuer: "anissuer",
|
||||||
|
Subject: "asubject",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
WantLocationRE: `http://matrix.example.com/continue\?loginToken=atoken`,
|
||||||
|
WantSetCookieRE: "sso_nonce=;",
|
||||||
|
|
||||||
|
WantLoginTokenCreation: []*uapi.PerformLoginTokenCreationRequest{{Data: uapi.LoginTokenData{UserID: "@alocalpart:aservername"}}},
|
||||||
|
WantQueryLocalpart: []*uapi.QueryLocalpartForSSORequest{{Namespace: "anamespace", Issuer: "anissuer", Subject: "asubject"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "registerSuggested",
|
||||||
|
Req: http.Request{
|
||||||
|
Host: "matrix.example.com",
|
||||||
|
URL: &url.URL{
|
||||||
|
Path: "/_matrix/v4/login/sso/callback",
|
||||||
|
RawQuery: url.Values{
|
||||||
|
"provider": []string{"aprovider"},
|
||||||
|
}.Encode(),
|
||||||
|
},
|
||||||
|
Header: http.Header{
|
||||||
|
"Cookie": []string{(&http.Cookie{
|
||||||
|
Name: "sso_nonce",
|
||||||
|
Value: nonce,
|
||||||
|
}).String()},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Auth: fakeSSOAuthenticator{
|
||||||
|
callbackResult: sso.CallbackResult{
|
||||||
|
Identifier: &sso.UserIdentifier{
|
||||||
|
Namespace: "anamespace",
|
||||||
|
Issuer: "anissuer",
|
||||||
|
Subject: "asubject",
|
||||||
|
},
|
||||||
|
SuggestedUserID: "asuggestedid",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
WantLocationRE: `http://matrix.example.com/continue\?loginToken=atoken`,
|
||||||
|
WantSetCookieRE: "sso_nonce=;",
|
||||||
|
|
||||||
|
WantAccountCreation: []*uapi.PerformAccountCreationRequest{{Localpart: "asuggestedid", AccountType: uapi.AccountTypeUser, OnConflict: uapi.ConflictAbort}},
|
||||||
|
WantLoginTokenCreation: []*uapi.PerformLoginTokenCreationRequest{{Data: uapi.LoginTokenData{UserID: "@asuggestedid:aservername"}}},
|
||||||
|
WantSaveSSOAssociation: []*uapi.PerformSaveSSOAssociationRequest{{Namespace: "anamespace", Issuer: "anissuer", Subject: "asubject", Localpart: "asuggestedid"}},
|
||||||
|
WantQueryLocalpart: []*uapi.QueryLocalpartForSSORequest{{Namespace: "anamespace", Issuer: "anissuer", Subject: "asubject"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "registerNumeric",
|
||||||
|
Req: http.Request{
|
||||||
|
Host: "matrix.example.com",
|
||||||
|
URL: &url.URL{
|
||||||
|
Path: "/_matrix/v4/login/sso/callback",
|
||||||
|
RawQuery: url.Values{
|
||||||
|
"provider": []string{"aprovider"},
|
||||||
|
}.Encode(),
|
||||||
|
},
|
||||||
|
Header: http.Header{
|
||||||
|
"Cookie": []string{(&http.Cookie{
|
||||||
|
Name: "sso_nonce",
|
||||||
|
Value: nonce,
|
||||||
|
}).String()},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Auth: fakeSSOAuthenticator{
|
||||||
|
callbackResult: sso.CallbackResult{
|
||||||
|
Identifier: &sso.UserIdentifier{
|
||||||
|
Namespace: "anamespace",
|
||||||
|
Issuer: "anissuer",
|
||||||
|
Subject: "asubject",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
WantLocationRE: `http://matrix.example.com/continue\?loginToken=atoken`,
|
||||||
|
WantSetCookieRE: "sso_nonce=;",
|
||||||
|
|
||||||
|
WantAccountCreation: []*uapi.PerformAccountCreationRequest{{Localpart: "12345", AccountType: uapi.AccountTypeUser, OnConflict: uapi.ConflictAbort}},
|
||||||
|
WantLoginTokenCreation: []*uapi.PerformLoginTokenCreationRequest{{Data: uapi.LoginTokenData{UserID: "@12345:aservername"}}},
|
||||||
|
WantSaveSSOAssociation: []*uapi.PerformSaveSSOAssociationRequest{{Namespace: "anamespace", Issuer: "anissuer", Subject: "asubject", Localpart: "12345"}},
|
||||||
|
WantQueryLocalpart: []*uapi.QueryLocalpartForSSORequest{{Namespace: "anamespace", Issuer: "anissuer", Subject: "asubject"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "noIdentifierRedirectURL",
|
||||||
|
Req: http.Request{
|
||||||
|
Host: "matrix.example.com",
|
||||||
|
URL: &url.URL{
|
||||||
|
Path: "/_matrix/v4/login/sso/callback",
|
||||||
|
RawQuery: url.Values{
|
||||||
|
"provider": []string{"aprovider"},
|
||||||
|
}.Encode(),
|
||||||
|
},
|
||||||
|
Header: http.Header{
|
||||||
|
"Cookie": []string{(&http.Cookie{
|
||||||
|
Name: "sso_nonce",
|
||||||
|
Value: nonce,
|
||||||
|
}).String()},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Auth: fakeSSOAuthenticator{
|
||||||
|
callbackResult: sso.CallbackResult{
|
||||||
|
RedirectURL: "http://auth.example.com/notdone",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
WantLocationRE: `http://auth.example.com/notdone`,
|
||||||
|
WantSetCookieRE: "^$",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tst := range tsts {
|
||||||
|
t.Run(tst.Name, func(t *testing.T) {
|
||||||
|
got := SSOCallback(&tst.Req, &tst.UserAPI, &tst.Auth, &tst.Config, "aservername")
|
||||||
|
|
||||||
|
if want := http.StatusFound; got.Code != want {
|
||||||
|
t.Log(got)
|
||||||
|
t.Errorf("SSOCallback Code: got %v, want %v", got.Code, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
if m, err := regexp.MatchString(tst.WantLocationRE, got.Headers["Location"]); err != nil {
|
||||||
|
t.Fatalf("WantSetCookieRE failed: %v", err)
|
||||||
|
} else if !m {
|
||||||
|
t.Errorf("SSOCallback Location: got %q, want match %v", got.Headers["Location"], tst.WantLocationRE)
|
||||||
|
}
|
||||||
|
|
||||||
|
if m, err := regexp.MatchString(tst.WantSetCookieRE, got.Headers["Set-Cookie"]); err != nil {
|
||||||
|
t.Fatalf("WantSetCookieRE failed: %v", err)
|
||||||
|
} else if !m {
|
||||||
|
t.Errorf("SSOCallback Set-Cookie: got %q, want match %v", got.Headers["Set-Cookie"], tst.WantSetCookieRE)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tst.WantAccountCreation, tst.UserAPI.gotAccountCreation); diff != "" {
|
||||||
|
t.Errorf("PerformAccountCreation: +got -want:\n%s", diff)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(tst.WantLoginTokenCreation, tst.UserAPI.gotLoginTokenCreation); diff != "" {
|
||||||
|
t.Errorf("PerformLoginTokenCreation: +got -want:\n%s", diff)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(tst.WantSaveSSOAssociation, tst.UserAPI.gotSaveSSOAssociation); diff != "" {
|
||||||
|
t.Errorf("PerformSaveSSOAssociation: +got -want:\n%s", diff)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(tst.WantQueryLocalpart, tst.UserAPI.gotQueryLocalpart); diff != "" {
|
||||||
|
t.Errorf("QueryLocalpartForSSO: +got -want:\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSOCallbackError(t *testing.T) {
|
||||||
|
nonce := "1234." + base64.RawURLEncoding.EncodeToString([]byte("http://matrix.example.com/continue"))
|
||||||
|
goodReq := http.Request{
|
||||||
|
Host: "matrix.example.com",
|
||||||
|
URL: &url.URL{
|
||||||
|
Path: "/_matrix/v4/login/sso/callback",
|
||||||
|
RawQuery: url.Values{
|
||||||
|
"provider": []string{"aprovider"},
|
||||||
|
}.Encode(),
|
||||||
|
},
|
||||||
|
Header: http.Header{
|
||||||
|
"Cookie": []string{(&http.Cookie{
|
||||||
|
Name: "sso_nonce",
|
||||||
|
Value: nonce,
|
||||||
|
}).String()},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
goodAuth := fakeSSOAuthenticator{
|
||||||
|
callbackResult: sso.CallbackResult{
|
||||||
|
Identifier: &sso.UserIdentifier{
|
||||||
|
Namespace: "anamespace",
|
||||||
|
Issuer: "anissuer",
|
||||||
|
Subject: "asubject",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
errMocked := errors.New("mocked error")
|
||||||
|
|
||||||
|
tsts := []struct {
|
||||||
|
Name string
|
||||||
|
Req http.Request
|
||||||
|
UserAPI fakeUserAPIForSSO
|
||||||
|
Auth fakeSSOAuthenticator
|
||||||
|
Config config.SSO
|
||||||
|
|
||||||
|
WantCode int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Name: "missingProvider",
|
||||||
|
Req: http.Request{
|
||||||
|
Host: "matrix.example.com",
|
||||||
|
URL: &url.URL{
|
||||||
|
Path: "/_matrix/v4/login/sso/callback",
|
||||||
|
},
|
||||||
|
Header: http.Header{
|
||||||
|
"Cookie": []string{(&http.Cookie{
|
||||||
|
Name: "sso_nonce",
|
||||||
|
Value: nonce,
|
||||||
|
}).String()},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
WantCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "missingCookie",
|
||||||
|
Req: http.Request{
|
||||||
|
Host: "matrix.example.com",
|
||||||
|
URL: &url.URL{
|
||||||
|
Path: "/_matrix/v4/login/sso/callback",
|
||||||
|
RawQuery: url.Values{
|
||||||
|
"provider": []string{"aprovider"},
|
||||||
|
}.Encode(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
WantCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "malformedCookie",
|
||||||
|
Req: http.Request{
|
||||||
|
Host: "matrix.example.com",
|
||||||
|
URL: &url.URL{
|
||||||
|
Path: "/_matrix/v4/login/sso/callback",
|
||||||
|
RawQuery: url.Values{
|
||||||
|
"provider": []string{"aprovider"},
|
||||||
|
}.Encode(),
|
||||||
|
},
|
||||||
|
Header: http.Header{
|
||||||
|
"Cookie": []string{(&http.Cookie{
|
||||||
|
Name: "sso_nonce",
|
||||||
|
Value: "badvalue",
|
||||||
|
}).String()},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
WantCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "failedProcessCallback",
|
||||||
|
Req: goodReq,
|
||||||
|
Auth: fakeSSOAuthenticator{
|
||||||
|
callbackErr: errMocked,
|
||||||
|
},
|
||||||
|
WantCode: http.StatusInternalServerError,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "failedQueryLocalpartForSSO",
|
||||||
|
Req: goodReq,
|
||||||
|
UserAPI: fakeUserAPIForSSO{
|
||||||
|
localpartErr: errMocked,
|
||||||
|
},
|
||||||
|
Auth: goodAuth,
|
||||||
|
WantCode: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "failedQueryNumericLocalpart",
|
||||||
|
Req: goodReq,
|
||||||
|
UserAPI: fakeUserAPIForSSO{
|
||||||
|
numericLocalpartErr: errMocked,
|
||||||
|
},
|
||||||
|
Auth: goodAuth,
|
||||||
|
WantCode: http.StatusInternalServerError,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "failedAccountCreation",
|
||||||
|
Req: goodReq,
|
||||||
|
UserAPI: fakeUserAPIForSSO{
|
||||||
|
accountCreationErr: errMocked,
|
||||||
|
},
|
||||||
|
Auth: goodAuth,
|
||||||
|
WantCode: http.StatusInternalServerError,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "failedSaveSSOAssociation",
|
||||||
|
Req: goodReq,
|
||||||
|
UserAPI: fakeUserAPIForSSO{
|
||||||
|
saveSSOAssociationErr: errMocked,
|
||||||
|
},
|
||||||
|
Auth: goodAuth,
|
||||||
|
WantCode: http.StatusInternalServerError,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "failedPerformLoginTokenCreation",
|
||||||
|
Req: goodReq,
|
||||||
|
UserAPI: fakeUserAPIForSSO{
|
||||||
|
localpart: "alocalpart",
|
||||||
|
tokenTokenCreationErr: errMocked,
|
||||||
|
},
|
||||||
|
Auth: goodAuth,
|
||||||
|
WantCode: http.StatusInternalServerError,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tst := range tsts {
|
||||||
|
t.Run(tst.Name, func(t *testing.T) {
|
||||||
|
got := SSOCallback(&tst.Req, &tst.UserAPI, &tst.Auth, &tst.Config, "aservername")
|
||||||
|
|
||||||
|
if got.Code != tst.WantCode {
|
||||||
|
t.Log(got)
|
||||||
|
t.Errorf("SSOCallback Code: got %v, want %v", got.Code, tst.WantCode)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeSSOAuthenticator struct {
|
||||||
|
callbackResult sso.CallbackResult
|
||||||
|
callbackErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *fakeSSOAuthenticator) AuthorizationURL(ctx context.Context, providerID, callbackURL, nonce string) (string, error) {
|
||||||
|
return (&url.URL{
|
||||||
|
Scheme: "http",
|
||||||
|
Host: "auth.example.com",
|
||||||
|
Path: "/authorize",
|
||||||
|
}).ResolveReference(&url.URL{
|
||||||
|
RawQuery: url.Values{
|
||||||
|
"callbackURL": []string{callbackURL},
|
||||||
|
"nonce": []string{nonce},
|
||||||
|
"providerID": []string{providerID},
|
||||||
|
}.Encode(),
|
||||||
|
}).String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *fakeSSOAuthenticator) ProcessCallback(ctx context.Context, providerID, callbackURL, nonce string, query url.Values) (*sso.CallbackResult, error) {
|
||||||
|
return &auth.callbackResult, auth.callbackErr
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeUserAPIForSSO struct {
|
||||||
|
userAPIForSSO
|
||||||
|
|
||||||
|
accountCreationErr error
|
||||||
|
tokenTokenCreationErr error
|
||||||
|
saveSSOAssociationErr error
|
||||||
|
localpart string
|
||||||
|
localpartErr error
|
||||||
|
numericLocalpartErr error
|
||||||
|
|
||||||
|
gotAccountCreation []*uapi.PerformAccountCreationRequest
|
||||||
|
gotLoginTokenCreation []*uapi.PerformLoginTokenCreationRequest
|
||||||
|
gotSaveSSOAssociation []*uapi.PerformSaveSSOAssociationRequest
|
||||||
|
gotQueryLocalpart []*uapi.QueryLocalpartForSSORequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func (userAPI *fakeUserAPIForSSO) PerformAccountCreation(ctx context.Context, req *uapi.PerformAccountCreationRequest, res *uapi.PerformAccountCreationResponse) error {
|
||||||
|
userAPI.gotAccountCreation = append(userAPI.gotAccountCreation, req)
|
||||||
|
return userAPI.accountCreationErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (userAPI *fakeUserAPIForSSO) PerformLoginTokenCreation(ctx context.Context, req *uapi.PerformLoginTokenCreationRequest, res *uapi.PerformLoginTokenCreationResponse) error {
|
||||||
|
userAPI.gotLoginTokenCreation = append(userAPI.gotLoginTokenCreation, req)
|
||||||
|
res.Metadata = uapi.LoginTokenMetadata{
|
||||||
|
Token: "atoken",
|
||||||
|
}
|
||||||
|
return userAPI.tokenTokenCreationErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (userAPI *fakeUserAPIForSSO) PerformSaveSSOAssociation(ctx context.Context, req *uapi.PerformSaveSSOAssociationRequest, res *struct{}) error {
|
||||||
|
userAPI.gotSaveSSOAssociation = append(userAPI.gotSaveSSOAssociation, req)
|
||||||
|
return userAPI.saveSSOAssociationErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (userAPI *fakeUserAPIForSSO) QueryLocalpartForSSO(ctx context.Context, req *uapi.QueryLocalpartForSSORequest, res *uapi.QueryLocalpartForSSOResponse) error {
|
||||||
|
userAPI.gotQueryLocalpart = append(userAPI.gotQueryLocalpart, req)
|
||||||
|
res.Localpart = userAPI.localpart
|
||||||
|
return userAPI.localpartErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (userAPI *fakeUserAPIForSSO) QueryNumericLocalpart(ctx context.Context, res *uapi.QueryNumericLocalpartResponse) error {
|
||||||
|
res.ID = 12345
|
||||||
|
return userAPI.numericLocalpartErr
|
||||||
|
}
|
||||||
|
|
@ -62,7 +62,7 @@ global:
|
||||||
local_part: "_server"
|
local_part: "_server"
|
||||||
display_name: "Server alerts"
|
display_name: "Server alerts"
|
||||||
avatar: ""
|
avatar: ""
|
||||||
room_name: "Server Alerts"
|
room_name: "Server Alerts"
|
||||||
app_service_api:
|
app_service_api:
|
||||||
internal_api:
|
internal_api:
|
||||||
listen: http://localhost:7777
|
listen: http://localhost:7777
|
||||||
|
|
@ -86,6 +86,22 @@ client_api:
|
||||||
recaptcha_private_key: ""
|
recaptcha_private_key: ""
|
||||||
recaptcha_bypass_secret: ""
|
recaptcha_bypass_secret: ""
|
||||||
recaptcha_siteverify_api: ""
|
recaptcha_siteverify_api: ""
|
||||||
|
login:
|
||||||
|
sso:
|
||||||
|
enabled: true
|
||||||
|
callback_url: http://example.com:8071/_matrix/v3/login/sso/callback
|
||||||
|
default_provider: github
|
||||||
|
providers:
|
||||||
|
- brand: github
|
||||||
|
- id: custom
|
||||||
|
name: "Custom Provider"
|
||||||
|
icon: "mxc://example.com/abc123"
|
||||||
|
type: oidc
|
||||||
|
oauth2:
|
||||||
|
client_id: aclientid
|
||||||
|
client_secret: aclientsecret
|
||||||
|
oidc:
|
||||||
|
discovery_url: http://auth.example.com/.well-known/openid-configuration
|
||||||
turn:
|
turn:
|
||||||
turn_user_lifetime: ""
|
turn_user_lifetime: ""
|
||||||
turn_uris: []
|
turn_uris: []
|
||||||
|
|
|
||||||
|
|
@ -429,6 +429,41 @@ func Test_Pusher(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_SSO(t *testing.T) {
|
||||||
|
alice := test.NewUser(t)
|
||||||
|
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
db, close := mustCreateDatabase(t, dbType)
|
||||||
|
defer close()
|
||||||
|
|
||||||
|
t.Log("Create SSO association")
|
||||||
|
|
||||||
|
ns := util.RandomString(8)
|
||||||
|
issuer := util.RandomString(8)
|
||||||
|
subject := util.RandomString(8)
|
||||||
|
err := db.SaveSSOAssociation(ctx, ns, issuer, subject, aliceLocalpart)
|
||||||
|
assert.NoError(t, err, "unable to save SSO association")
|
||||||
|
|
||||||
|
t.Log("Retrieve localpart for association")
|
||||||
|
|
||||||
|
gotLocalpart, err := db.GetLocalpartForSSO(ctx, ns, issuer, subject)
|
||||||
|
assert.Equal(t, aliceLocalpart, gotLocalpart)
|
||||||
|
|
||||||
|
t.Log("Remove SSO association")
|
||||||
|
|
||||||
|
err = db.RemoveSSOAssociation(ctx, ns, issuer, subject)
|
||||||
|
assert.NoError(t, err, "unexpected error")
|
||||||
|
|
||||||
|
t.Log("Verify the SSO association was removed")
|
||||||
|
|
||||||
|
gotLocalpart, err = db.GetLocalpartForSSO(ctx, ns, issuer, subject)
|
||||||
|
assert.NoError(t, err, "unable to get localpart for SSO subject")
|
||||||
|
assert.Equal(t, "", gotLocalpart)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func Test_ThreePID(t *testing.T) {
|
func Test_ThreePID(t *testing.T) {
|
||||||
alice := test.NewUser(t)
|
alice := test.NewUser(t)
|
||||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue