Support for m.login.sso.

This is forked from @anandv96's #1374. Closes #1297.
This commit is contained in:
Tommie Gannert 2021-09-26 12:16:05 +02:00
parent 1d6501ae30
commit 43989aa017
11 changed files with 1062 additions and 11 deletions

View file

@ -43,6 +43,7 @@ type AccountDatabase interface {
// Look up the account matching the given localpart.
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
GetAccountByPassword(ctx context.Context, localpart, password string) (*api.Account, error)
GetLocalpartForThreePID(ctx context.Context, address, medium string) (string, error)
}
// VerifyUserFromRequest authenticates the HTTP request,

View file

@ -10,5 +10,6 @@ const (
LoginTypeSharedSecret = "org.matrix.login.shared_secret"
LoginTypeRecaptcha = "m.login.recaptcha"
LoginTypeApplicationService = "m.login.application_service"
LoginTypeSSO = "m.login.sso"
LoginTypeToken = "m.login.token"
)

View file

@ -0,0 +1,37 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sso
import (
"github.com/matrix-org/dendrite/setup/config"
)
// GitHubIdentityProvider is a GitHub-flavored identity provider.
var GitHubIdentityProvider IdentityProvider = githubIdentityProvider{
baseOIDCIdentityProvider: &baseOIDCIdentityProvider{
AuthURL: mustParseURLTemplate("https://github.com/login/oauth/authorize?scope=user:email"),
AccessTokenURL: mustParseURLTemplate("https://github.com/login/oauth/access_token"),
UserInfoURL: mustParseURLTemplate("https://api.github.com/user"),
UserInfoAccept: "application/vnd.github.v3+json",
UserInfoEmailPath: "email",
UserInfoSuggestedUserIDPath: "login",
},
}
type githubIdentityProvider struct {
*baseOIDCIdentityProvider
}
func (githubIdentityProvider) DefaultBrand() string { return config.SSOBrandGitHub }

View file

@ -0,0 +1,262 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sso
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"mime"
"net/http"
"net/url"
"strings"
"text/template"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/tidwall/gjson"
)
type baseOIDCIdentityProvider struct {
AuthURL *urlTemplate
AccessTokenURL *urlTemplate
UserInfoURL *urlTemplate
UserInfoAccept string
UserInfoEmailPath string
UserInfoSuggestedUserIDPath string
}
func (p *baseOIDCIdentityProvider) AuthorizationURL(ctx context.Context, req *IdentityProviderRequest) (string, error) {
u, err := p.AuthURL.Execute(map[string]interface{}{
"Config": req.System,
"State": req.DendriteNonce,
"RedirectURI": req.CallbackURL,
}, url.Values{
"client_id": []string{req.System.OIDC.ClientID},
"response_type": []string{"code"},
"redirect_uri": []string{req.CallbackURL},
"state": []string{req.DendriteNonce},
})
if err != nil {
return "", err
}
return u.String(), nil
}
func (p *baseOIDCIdentityProvider) ProcessCallback(ctx context.Context, req *IdentityProviderRequest, values url.Values) (*CallbackResult, error) {
state := values.Get("state")
if state == "" {
return nil, jsonerror.MissingArgument("state parameter missing")
}
if state != req.DendriteNonce {
return nil, jsonerror.InvalidArgumentValue("state parameter not matching nonce")
}
if error := values.Get("error"); error != "" {
if euri := values.Get("error_uri"); euri != "" {
return &CallbackResult{RedirectURL: euri}, nil
}
desc := values.Get("error_description")
if desc == "" {
desc = error
}
switch error {
case "unauthorized_client", "access_denied":
return nil, jsonerror.Forbidden("SSO said no: " + desc)
default:
return nil, fmt.Errorf("SSO failed: %v", error)
}
}
code := values.Get("code")
if code == "" {
return nil, jsonerror.MissingArgument("code parameter missing")
}
oidcAccessToken, err := p.getOIDCAccessToken(ctx, req, code)
if err != nil {
return nil, err
}
id, userID, err := p.getUserInfo(ctx, req, oidcAccessToken)
if err != nil {
return nil, err
}
return &CallbackResult{Identifier: id, SuggestedUserID: userID}, nil
}
func (p *baseOIDCIdentityProvider) getOIDCAccessToken(ctx context.Context, req *IdentityProviderRequest, code string) (string, error) {
u, err := p.AccessTokenURL.Execute(nil, nil)
if err != nil {
return "", err
}
body := url.Values{
"grant_type": []string{"authorization_code"},
"code": []string{code},
"redirect_uri": []string{req.CallbackURL},
"client_id": []string{req.System.OIDC.ClientID},
}
hreq, err := http.NewRequestWithContext(ctx, http.MethodPost, u.String(), strings.NewReader(body.Encode()))
if err != nil {
return "", err
}
hreq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
hreq.Header.Set("Accept", "application/x-www-form-urlencoded")
hresp, err := http.DefaultClient.Do(hreq)
if err != nil {
return "", err
}
defer hresp.Body.Close()
ctype, _, err := mime.ParseMediaType(hresp.Header.Get("Content-Type"))
if err != nil {
return "", err
}
if ctype != "application/json" {
return "", fmt.Errorf("expected URL encoded response, got content type %q", ctype)
}
var resp struct {
TokenType string `json:"token_type"`
AccessToken string `json:"access_token"`
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
ErrorURI string `json:"error_uri"`
}
if err := json.NewDecoder(hresp.Body).Decode(&resp); err != nil {
return "", err
}
if resp.Error != "" {
desc := resp.ErrorDescription
if desc == "" {
desc = resp.Error
}
return "", fmt.Errorf("failed to retrieve OIDC access token: %s", desc)
}
if strings.ToLower(resp.TokenType) != "bearer" {
return "", fmt.Errorf("expected bearer token, got type %q", resp.TokenType)
}
return resp.AccessToken, nil
}
func (p *baseOIDCIdentityProvider) getUserInfo(ctx context.Context, req *IdentityProviderRequest, oidcAccessToken string) (*userutil.ThirdPartyIdentifier, string, error) {
u, err := p.UserInfoURL.Execute(map[string]interface{}{
"Config": req.System,
}, nil)
if err != nil {
return nil, "", err
}
hreq, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
if err != nil {
return nil, "", err
}
hreq.Header.Set("Authorization", "token "+oidcAccessToken)
hreq.Header.Set("Accept", p.UserInfoAccept)
hresp, err := http.DefaultClient.Do(hreq)
if err != nil {
return nil, "", err
}
defer hresp.Body.Close()
ctype, _, err := mime.ParseMediaType(hresp.Header.Get("Content-Type"))
if err != nil {
return nil, "", err
}
var email string
var suggestedUserID string
switch ctype {
case "application/json":
body, err := ioutil.ReadAll(hresp.Body)
if err != nil {
return nil, "", err
}
emailRes := gjson.GetBytes(body, p.UserInfoEmailPath)
if !emailRes.Exists() {
return nil, "", fmt.Errorf("no email in user info response body")
}
email = emailRes.String()
// This is optional.
userIDRes := gjson.GetBytes(body, p.UserInfoSuggestedUserIDPath)
suggestedUserID = userIDRes.String()
default:
return nil, "", fmt.Errorf("got unknown content type %q for user info", ctype)
}
if email == "" {
return nil, "", fmt.Errorf("no email address in user info")
}
return &userutil.ThirdPartyIdentifier{Medium: "email", Address: email}, suggestedUserID, nil
}
type urlTemplate struct {
base *template.Template
}
func parseURLTemplate(s string) (*urlTemplate, error) {
t, err := template.New("").Parse(s)
if err != nil {
return nil, err
}
return &urlTemplate{base: t}, nil
}
func mustParseURLTemplate(s string) *urlTemplate {
t, err := parseURLTemplate(s)
if err != nil {
panic(err)
}
return t
}
func (t *urlTemplate) Execute(params interface{}, defaultQuery url.Values) (*url.URL, error) {
var sb strings.Builder
err := t.base.Execute(&sb, params)
if err != nil {
return nil, err
}
u, err := url.Parse(sb.String())
if err != nil {
return nil, err
}
if defaultQuery != nil {
q := u.Query()
for k, vs := range defaultQuery {
if q.Get(k) == "" {
q[k] = vs
}
}
u.RawQuery = q.Encode()
}
return u, nil
}

57
clientapi/auth/sso/sso.go Normal file
View file

@ -0,0 +1,57 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sso
import (
"context"
"net/url"
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/setup/config"
)
type IdentityProvider interface {
DefaultBrand() string
AuthorizationURL(context.Context, *IdentityProviderRequest) (string, error)
ProcessCallback(context.Context, *IdentityProviderRequest, url.Values) (*CallbackResult, error)
}
type IdentityProviderRequest struct {
System *config.IdentityProvider
CallbackURL string
DendriteNonce string
}
type CallbackResult struct {
RedirectURL string
Identifier *userutil.ThirdPartyIdentifier
SuggestedUserID string
}
type IdentityProviderType string
const (
TypeGitHub IdentityProviderType = config.SSOBrandGitHub
)
func GetIdentityProvider(t IdentityProviderType) IdentityProvider {
switch t {
case TypeGitHub:
return GitHubIdentityProvider
default:
return nil
}
}

View file

@ -19,6 +19,8 @@ import (
"net/http"
"github.com/matrix-org/dendrite/clientapi/auth"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/sso"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/setup/config"
@ -35,20 +37,54 @@ type loginResponse struct {
}
type flows struct {
Flows []flow `json:"flows"`
Flows []stage `json:"flows"`
}
type flow struct {
type stage struct {
Type string `json:"type"`
IdentityProviders []identityProvider `json:"identity_providers,omitempty"`
}
func passwordLogin() flows {
f := flows{}
s := flow{
Type: "m.login.password",
type identityProvider struct {
ID string `json:"id"`
Name string `json:"name"`
Brand string `json:"brand,omitempty"`
Icon string `json:"icon,omitempty"`
}
func passwordLogin() []stage {
return []stage{
{Type: authtypes.LoginTypePassword},
}
}
func ssoLogin(cfg *config.ClientAPI) []stage {
var idps []identityProvider
for _, idp := range cfg.Login.SSO.Providers {
brand := idp.Brand
if brand == "" {
typ := idp.Type
if typ == "" {
typ = idp.ID
}
idpType := sso.GetIdentityProvider(sso.IdentityProviderType(typ))
if idpType != nil {
brand = idpType.DefaultBrand()
}
}
idps = append(idps, identityProvider{
ID: idp.ID,
Name: idp.Name,
Brand: brand,
Icon: idp.Icon,
})
}
return []stage{
{
Type: authtypes.LoginTypeSSO,
IdentityProviders: idps,
},
}
f.Flows = append(f.Flows, s)
return f
}
// Login implements GET and POST /login
@ -57,10 +93,13 @@ func Login(
cfg *config.ClientAPI,
) util.JSONResponse {
if req.Method == http.MethodGet {
// TODO: support other forms of login other than password, depending on config options
allFlows := passwordLogin()
if cfg.Login.SSO.Enabled {
allFlows = append(allFlows, ssoLogin(cfg)...)
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: passwordLogin(),
JSON: flows{Flows: allFlows},
}
} else if req.Method == http.MethodPost {
login, cleanup, authErr := auth.LoginFromJSONReader(req.Context(), req.Body, userAPI, userAPI, cfg)
@ -72,6 +111,7 @@ func Login(
cleanup(req.Context(), &authErr2)
return authErr2
}
return util.JSONResponse{
Code: http.StatusMethodNotAllowed,
JSON: jsonerror.NotFound("Bad method"),

View file

@ -563,6 +563,25 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
v3mux.Handle("/login/sso/callback",
httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse {
return SSOCallback(req, userAPI, cfg)
}),
).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/login/sso/redirect",
httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse {
return SSORedirect(req, "", cfg)
}),
).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/login/sso/redirect/{idpID}",
httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse {
vars := mux.Vars(req)
return SSORedirect(req, vars["idpID"], cfg)
}),
).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/auth/{authType}/fallback/web",
httputil.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) *util.JSONResponse {
vars := mux.Vars(req)

283
clientapi/routing/sso.go Normal file
View file

@ -0,0 +1,283 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"context"
"encoding/base64"
"net/http"
"net/url"
"strings"
"time"
"github.com/matrix-org/dendrite/clientapi/auth/sso"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/setup/config"
uapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
// SSORedirect implements /login/sso/redirect
// https://spec.matrix.org/v1.2/client-server-api/#redirecting-to-the-authentication-server
func SSORedirect(
req *http.Request,
idpID string,
cfg *config.ClientAPI,
) util.JSONResponse {
if !cfg.Login.SSO.Enabled {
return util.JSONResponse{
Code: http.StatusNotImplemented,
JSON: jsonerror.NotFound("authentication method disabled"),
}
}
redirectURL := req.URL.Query().Get("redirectUrl")
if redirectURL == "" {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.MissingArgument("redirectUrl parameter missing"),
}
}
_, err := url.Parse(redirectURL)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue("Invalid redirectURL: " + err.Error()),
}
}
if idpID == "" {
// Check configuration if the client didn't provide an ID.
idpID = cfg.Login.SSO.DefaultProviderID
}
if idpID == "" && len(cfg.Login.SSO.Providers) > 0 {
// Fall back to the first provider. If there are no providers, getProvider("") will fail.
idpID = cfg.Login.SSO.Providers[0].ID
}
idpCfg, idpType := getProvider(cfg, idpID)
if idpType == nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue("unknown identity provider"),
}
}
idpReq := &sso.IdentityProviderRequest{
System: idpCfg,
CallbackURL: req.URL.ResolveReference(&url.URL{Path: "../callback", RawQuery: url.Values{"provider": []string{idpID}}.Encode()}).String(),
DendriteNonce: formatNonce(redirectURL),
}
u, err := idpType.AuthorizationURL(req.Context(), idpReq)
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: err,
}
}
resp := util.RedirectResponse(u)
resp.Headers["Set-Cookie"] = (&http.Cookie{
Name: "oidc_nonce",
Value: idpReq.DendriteNonce,
Expires: time.Now().Add(10 * time.Minute),
Secure: true,
SameSite: http.SameSiteStrictMode,
}).String()
return resp
}
// SSOCallback implements /login/sso/callback.
// https://spec.matrix.org/v1.2/client-server-api/#handling-the-callback-from-the-authentication-server
func SSOCallback(
req *http.Request,
userAPI userAPIForSSO,
cfg *config.ClientAPI,
) util.JSONResponse {
ctx := req.Context()
query := req.URL.Query()
idpID := query.Get("provider")
if idpID == "" {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.MissingArgument("provider parameter missing"),
}
}
idpCfg, idpType := getProvider(cfg, idpID)
if idpType == nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue("unknown identity provider"),
}
}
nonce, err := req.Cookie("oidc_nonce")
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.MissingArgument("no nonce cookie: " + err.Error()),
}
}
finalRedirectURL, err := parseNonce(nonce.Value)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: err,
}
}
idpReq := &sso.IdentityProviderRequest{
System: idpCfg,
CallbackURL: (&url.URL{
Scheme: req.URL.Scheme,
Host: req.URL.Host,
Path: req.URL.Path,
RawQuery: url.Values{
"provider": []string{idpID},
}.Encode(),
}).String(),
DendriteNonce: nonce.Value,
}
result, err := idpType.ProcessCallback(ctx, idpReq, query)
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: err,
}
}
if result.Identifier == nil {
// Not authenticated yet.
return util.RedirectResponse(result.RedirectURL)
}
id, err := verifyThirdPartyUserIdentifier(ctx, userAPI, result.Identifier, cfg.Matrix.ServerName)
if err != nil {
util.GetLogger(ctx).WithError(err).WithField("identifier", result.Identifier.String()).Error("failed to find user")
return util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: jsonerror.Forbidden("ID not associated with a local account"),
}
}
if id == nil {
// The user doesn't exist.
// TODO: let the user select a localpart and register an account.
util.GetLogger(ctx).WithError(err).WithField("identifier", result.Identifier.String()).Error("failed to find user")
return util.JSONResponse{
Code: http.StatusNotImplemented,
JSON: jsonerror.Forbidden("SSO registration not implemented"),
}
}
token, err := createLoginToken(ctx, userAPI, id)
if err != nil {
util.GetLogger(ctx).WithError(err).Errorf("PerformLoginTokenCreation failed")
return jsonerror.InternalServerError()
}
rquery := finalRedirectURL.Query()
rquery.Set("loginToken", token.Token)
resp := util.RedirectResponse(finalRedirectURL.ResolveReference(&url.URL{RawQuery: rquery.Encode()}).String())
resp.Headers["Set-Cookie"] = (&http.Cookie{
Name: "oidc_nonce",
Value: "",
MaxAge: -1,
Secure: true,
}).String()
return resp
}
type userAPIForSSO interface {
uapi.LoginTokenInternalAPI
QueryLocalpartForThreePID(ctx context.Context, req *uapi.QueryLocalpartForThreePIDRequest, res *uapi.QueryLocalpartForThreePIDResponse) error
}
// getProvider looks up the given provider in the
// configuration. Returns nil if it wasn't found or was of unknown
// type.
func getProvider(cfg *config.ClientAPI, id string) (*config.IdentityProvider, sso.IdentityProvider) {
for _, idp := range cfg.Login.SSO.Providers {
if idp.ID == id {
switch sso.IdentityProviderType(id) {
case sso.TypeGitHub:
return &idp, sso.GitHubIdentityProvider
default:
return nil, nil
}
}
}
return nil, nil
}
// formatNonce creates a random nonce that also contains the URL.
func formatNonce(redirectURL string) string {
return util.RandomString(16) + "." + base64.RawURLEncoding.EncodeToString([]byte(redirectURL))
}
// parseNonce extracts the embedded URL from the nonce. The nonce
// should have been validated to be the original before calling this
// function. The URL is not integrity protected.
func parseNonce(s string) (redirectURL *url.URL, _ error) {
if s == "" {
return nil, jsonerror.MissingArgument("empty OIDC nonce cookie")
}
ss := strings.Split(s, ".")
if len(ss) < 2 {
return nil, jsonerror.InvalidArgumentValue("malformed OIDC nonce cookie")
}
urlbs, err := base64.RawURLEncoding.DecodeString(ss[1])
if err != nil {
return nil, jsonerror.InvalidArgumentValue("invalid redirect URL in OIDC nonce cookie")
}
u, err := url.Parse(string(urlbs))
if err != nil {
return nil, jsonerror.InvalidArgumentValue("invalid redirect URL in OIDC nonce cookie: " + err.Error())
}
return u, nil
}
// verifyThirdPartyUserIdentifier resolves a ThirdPartyIdentifier to a
// UserIdentifier using the User API. Returns nil if there is no
// associated user.
func verifyThirdPartyUserIdentifier(ctx context.Context, userAPI userAPIForSSO, id *userutil.ThirdPartyIdentifier, serverName gomatrixserverlib.ServerName) (*userutil.UserIdentifier, error) {
req := &uapi.QueryLocalpartForThreePIDRequest{
ThreePID: id.Address,
Medium: string(id.Medium),
}
var res uapi.QueryLocalpartForThreePIDResponse
if err := userAPI.QueryLocalpartForThreePID(ctx, req, &res); err != nil {
return nil, err
}
if res.Localpart == "" {
return nil, nil
}
return &userutil.UserIdentifier{UserID: userutil.MakeUserID(res.Localpart, serverName)}, nil
}
func createLoginToken(ctx context.Context, userAPI userAPIForSSO, id *userutil.UserIdentifier) (*uapi.LoginTokenMetadata, error) {
req := uapi.PerformLoginTokenCreationRequest{Data: uapi.LoginTokenData{UserID: id.UserID}}
var resp uapi.PerformLoginTokenCreationResponse
if err := userAPI.PerformLoginTokenCreation(ctx, &req, &resp); err != nil {
return nil, err
}
return &resp.Metadata, nil
}

View file

@ -0,0 +1,153 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package userutil
import (
"bytes"
"encoding/json"
"errors"
)
// An Identifier identifies a user. There are many kinds, and this is
// the common interface for them.
//
// If you need to handle an identifier as JSON, use the AnyIdentifier wrapper.
// Passing around identifiers in code, the raw Identifier is enough.
//
// See https://matrix.org/docs/spec/client_server/r0.6.1#identifier-types
type Identifier interface {
// IdentifierType returns the identifier type, like "m.id.user".
IdentifierType() IdentifierType
// String returns a debug-output string representation. The format
// is not specified.
String() string
}
// A UserIdentifier contains an MXID. It may be only the local part.
type UserIdentifier struct {
UserID string `json:"user"`
}
func (i *UserIdentifier) IdentifierType() IdentifierType { return IdentifierUser }
func (i *UserIdentifier) String() string { return i.UserID }
// A ThirdPartyIdentifier references an identifier in another system.
type ThirdPartyIdentifier struct {
// Medium is normally MediumEmail.
Medium Medium `json:"medium"`
// Address is the medium-specific identifier.
Address string `json:"address"`
}
func (i *ThirdPartyIdentifier) IdentifierType() IdentifierType { return IdentifierThirdParty }
func (i *ThirdPartyIdentifier) String() string { return string(i.Medium) + ":" + i.Address }
// A PhoneIdentifier references a phone number.
type PhoneIdentifier struct {
// Country is a ISO-3166-1 alpha-2 country code.
Country string `json:"country"`
// PhoneNumber is a country-specific phone number, as it would be dialled from.
PhoneNumber string `json:"phone"`
}
func (i *PhoneIdentifier) IdentifierType() IdentifierType { return IdentifierPhone }
func (i *PhoneIdentifier) String() string { return i.Country + ":" + i.PhoneNumber }
// UnknownIdentifier is the catch-all for identifiers this code doesn't know about.
// It simply stores raw JSON.
type UnknownIdentifier struct {
json.RawMessage
Type IdentifierType
}
func (i *UnknownIdentifier) IdentifierType() IdentifierType { return i.Type }
func (i *UnknownIdentifier) String() string { return "unknown/" + string(i.Type) }
// AnyIdentifier is a wrapper that allows marshalling and unmarshalling the various
// types of identifiers to/from JSON. Always use this in data types that will be
// used in JSON manipulation.
type AnyIdentifier struct {
Identifier
}
func (i AnyIdentifier) MarshalJSON() ([]byte, error) {
v := struct {
*UserIdentifier
*ThirdPartyIdentifier
*PhoneIdentifier
Type IdentifierType `json:"type"`
}{
Type: i.Identifier.IdentifierType(),
}
switch iid := i.Identifier.(type) {
case *UserIdentifier:
v.UserIdentifier = iid
case *ThirdPartyIdentifier:
v.ThirdPartyIdentifier = iid
case *PhoneIdentifier:
v.PhoneIdentifier = iid
case *UnknownIdentifier:
return iid.RawMessage, nil
}
return json.Marshal(v)
}
func (i *AnyIdentifier) UnmarshalJSON(bs []byte) error {
var hdr struct {
Type IdentifierType `json:"type"`
}
if err := json.Unmarshal(bs, &hdr); err != nil {
return err
}
switch hdr.Type {
case IdentifierUser:
var ui UserIdentifier
i.Identifier = &ui
return json.Unmarshal(bs, &ui)
case IdentifierThirdParty:
var tpi ThirdPartyIdentifier
i.Identifier = &tpi
return json.Unmarshal(bs, &tpi)
case IdentifierPhone:
var pi PhoneIdentifier
i.Identifier = &pi
return json.Unmarshal(bs, &pi)
case "":
return errors.New("missing identifier type")
default:
i.Identifier = &UnknownIdentifier{RawMessage: json.RawMessage(bytes.TrimSpace(bs)), Type: hdr.Type}
return nil
}
}
// IdentifierType describes the type of identifier.
type IdentifierType string
const (
IdentifierUser IdentifierType = "m.id.user"
IdentifierThirdParty IdentifierType = "m.id.thirdparty"
IdentifierPhone IdentifierType = "m.id.phone"
)
// Medium describes the interpretation of a third-party identifier.
type Medium string
const (
// MediumEmail signifies that the address is an email address.
MediumEmail Medium = "email"
)

View file

@ -0,0 +1,75 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package userutil
import (
"encoding/json"
"reflect"
"testing"
)
func TestAnyIdentifierJSON(t *testing.T) {
tsts := []struct {
Name string
JSON string
Want Identifier
}{
{Name: "empty", JSON: `{}`},
{Name: "user", JSON: `{"type":"m.id.user","user":"auser"}`, Want: &UserIdentifier{UserID: "auser"}},
{Name: "thirdparty", JSON: `{"type":"m.id.thirdparty","medium":"email","address":"auser@example.com"}`, Want: &ThirdPartyIdentifier{Medium: "email", Address: "auser@example.com"}},
{Name: "phone", JSON: `{"type":"m.id.phone","country":"GB","phone":"123456789"}`, Want: &PhoneIdentifier{Country: "GB", PhoneNumber: "123456789"}},
// This test is a little fragile since it compares the output of json.Marshal.
{Name: "unknown", JSON: `{"type":"other"}`, Want: &UnknownIdentifier{Type: "other", RawMessage: json.RawMessage(`{"type":"other"}`)}},
}
for _, tst := range tsts {
t.Run("Unmarshal/"+tst.Name, func(t *testing.T) {
var got AnyIdentifier
if err := json.Unmarshal([]byte(tst.JSON), &got); err != nil {
if tst.Want == nil {
return
}
t.Fatalf("Unmarshal failed: %v", err)
}
if !reflect.DeepEqual(got.Identifier, tst.Want) {
t.Errorf("got %+v, want %+v", got.Identifier, tst.Want)
}
})
if tst.Want == nil {
continue
}
t.Run("Marshal/"+tst.Name, func(t *testing.T) {
id := AnyIdentifier{Identifier: tst.Want}
bs, err := json.Marshal(id)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
t.Logf("Marshalled JSON: %q", string(bs))
var got AnyIdentifier
if err := json.Unmarshal(bs, &got); err != nil {
if tst.Want == nil {
return
}
t.Fatalf("Unmarshal failed: %v", err)
}
if !reflect.DeepEqual(got.Identifier, tst.Want) {
t.Errorf("got %+v, want %+v", got.Identifier, tst.Want)
}
})
}
}

View file

@ -42,6 +42,8 @@ type ClientAPI struct {
// was successful
RecaptchaSiteVerifyAPI string `yaml:"recaptcha_siteverify_api"`
Login Login `yaml:"login"`
// TURN options
TURN TURN `yaml:"turn"`
@ -64,9 +66,11 @@ func (c *ClientAPI) Defaults(generate bool) {
c.RegistrationDisabled = true
c.OpenRegistrationWithoutVerificationEnabled = false
c.RateLimiting.Defaults()
c.Login.SSO.Enabled = false
}
func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
c.Login.Verify(configErrs)
c.TURN.Verify(configErrs)
c.RateLimiting.Verify(configErrs)
if c.RecaptchaEnabled {
@ -95,6 +99,125 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
checkURL(configErrs, "client_api.external_api.listen", string(c.ExternalAPI.Listen))
}
type Login struct {
SSO SSO `yaml:"sso"`
}
func (l *Login) Verify(configErrs *ConfigErrors) {
l.SSO.Verify(configErrs)
}
type SSO struct {
// Enabled determines whether SSO should be allowed.
Enabled bool `yaml:"enabled"`
// Providers list the identity providers this server is capable of confirming an
// identity with.
Providers []IdentityProvider `yaml:"providers"`
// DefaultProviderID is the provider to use when the client doesn't indicate one.
// This is legacy support. If empty, the first provider listed is used.
DefaultProviderID string `yaml:"default_provider"`
}
func (sso *SSO) Verify(configErrs *ConfigErrors) {
var foundDefaultProvider bool
seenPIDs := make(map[string]bool, len(sso.Providers))
for _, p := range sso.Providers {
p.Verify(configErrs)
if p.ID == sso.DefaultProviderID {
foundDefaultProvider = true
}
if seenPIDs[p.ID] {
configErrs.Add(fmt.Sprintf("duplicate identity provider for config key %q: %s", "client_api.sso.providers", p.ID))
}
seenPIDs[p.ID] = true
}
if sso.DefaultProviderID != "" && !foundDefaultProvider {
configErrs.Add(fmt.Sprintf("identity provider ID not found for config key %q: %s", "client_api.sso.default_provider", sso.DefaultProviderID))
}
if sso.Enabled {
if len(sso.Providers) == 0 {
configErrs.Add(fmt.Sprintf("empty list for config key %q", "client_api.sso.providers"))
}
}
}
// See https://github.com/matrix-org/matrix-doc/blob/old_master/informal/idp-brands.md.
type IdentityProvider struct {
// ID is the unique identifier of this IdP. We use the brand identifiers as provider
// identifiers for simplicity.
ID string `yaml:"id"`
// Name is a human-friendly name of the provider.
Name string `yaml:"name"`
// Brand is a hint on how to display the IdP to the user. If this is empty, a default
// based on the type is used.
Brand string `yaml:"brand"`
// Icon is an MXC URI describing how to display the IdP to the user. Prefer using `brand`.
Icon string `yaml:"icon"`
// Type describes how this provider is implemented. It must match "github". If this is
// empty, the ID is used, which means there is a weak expectation that ID is also a
// valid type, unless you have a complicated setup.
Type string `yaml:"type"`
// OIDC contains settings for providers based on OpenID Connect (OAuth 2).
OIDC struct {
ClientID string `yaml:"client_id"`
ClientSecret string `yaml:"client_secret"`
} `yaml:"oidc"`
}
func (idp *IdentityProvider) Verify(configErrs *ConfigErrors) {
checkNotEmpty(configErrs, "client_api.sso.providers.id", idp.ID)
if !checkIdentityProviderBrand(idp.ID) {
configErrs.Add(fmt.Sprintf("unrecognized ID config key %q: %s", "client_api.sso.providers", idp.ID))
}
checkNotEmpty(configErrs, "client_api.sso.providers.name", idp.Name)
if idp.Brand != "" && !checkIdentityProviderBrand(idp.Brand) {
configErrs.Add(fmt.Sprintf("unrecognized brand in identity provider %q for config key %q: %s", idp.ID, "client_api.sso.providers", idp.Brand))
}
if idp.Icon != "" {
checkURL(configErrs, "client_api.sso.providers.icon", idp.Icon)
}
typ := idp.Type
if idp.Type == "" {
typ = idp.ID
}
switch typ {
case "github":
checkNotEmpty(configErrs, "client_api.sso.providers.oidc.client_id", idp.OIDC.ClientID)
checkNotEmpty(configErrs, "client_api.sso.providers.oidc.client_secret", idp.OIDC.ClientSecret)
default:
configErrs.Add(fmt.Sprintf("unrecognized type in identity provider %q for config key %q: %s", idp.ID, "client_api.sso.providers", typ))
}
}
// See https://github.com/matrix-org/matrix-doc/blob/old_master/informal/idp-brands.md.
func checkIdentityProviderBrand(s string) bool {
switch s {
case SSOBrandApple, SSOBrandFacebook, SSOBrandGitHub, SSOBrandGitLab, SSOBrandGoogle, SSOBrandTwitter:
return true
default:
return false
}
}
const (
SSOBrandApple = "apple"
SSOBrandFacebook = "facebook"
SSOBrandGitHub = "github"
SSOBrandGitLab = "gitlab"
SSOBrandGoogle = "google"
SSOBrandTwitter = "twitter"
)
type TURN struct {
// TODO Guest Support
// Whether or not guests can request TURN credentials