From f9c6fbab69363270173cc59d6b1a60e22e5415e0 Mon Sep 17 00:00:00 2001 From: Boris Rybalkin Date: Sat, 3 Jun 2023 19:12:39 +0100 Subject: [PATCH] basic ldap authentication support --- clientapi/auth/login.go | 4 +- clientapi/auth/login_test.go | 8 + clientapi/auth/password.go | 205 +++++++++++++++++++++--- clientapi/auth/user_interactive.go | 4 +- clientapi/auth/user_interactive_test.go | 8 + clientapi/routing/key_crosssigning.go | 6 +- clientapi/routing/password.go | 4 +- clientapi/routing/routing.go | 2 +- go.mod | 3 + setup/config/config_clientapi.go | 19 ++- userapi/api/api.go | 2 + 11 files changed, 229 insertions(+), 36 deletions(-) diff --git a/clientapi/auth/login.go b/clientapi/auth/login.go index 58a27e593..6751a3fd7 100644 --- a/clientapi/auth/login.go +++ b/clientapi/auth/login.go @@ -61,8 +61,8 @@ func LoginFromJSONReader( switch header.Type { case authtypes.LoginTypePassword: typ = &LoginTypePassword{ - GetAccountByPassword: useraccountAPI.QueryAccountByPassword, - Config: cfg, + UserAPI: useraccountAPI, + Config: cfg, } case authtypes.LoginTypeToken: typ = &LoginTypeToken{ diff --git a/clientapi/auth/login_test.go b/clientapi/auth/login_test.go index a2c2a719c..a44824254 100644 --- a/clientapi/auth/login_test.go +++ b/clientapi/auth/login_test.go @@ -292,6 +292,14 @@ func (ua *fakeUserInternalAPI) QueryAccountByPassword(ctx context.Context, req * return nil } +func (ua *fakeUserInternalAPI) QueryAccountByLocalpart(ctx context.Context, req *uapi.QueryAccountByLocalpartRequest, res *uapi.QueryAccountByLocalpartResponse) error { + return nil +} + +func (ua *fakeUserInternalAPI) PerformAccountCreation(ctx context.Context, req *uapi.PerformAccountCreationRequest, res *uapi.PerformAccountCreationResponse) error { + return nil +} + func (ua *fakeUserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, req *uapi.PerformLoginTokenDeletionRequest, res *uapi.PerformLoginTokenDeletionResponse) error { ua.DeletedTokens = append(ua.DeletedTokens, req.Token) return nil diff --git a/clientapi/auth/password.go b/clientapi/auth/password.go index fb7def024..4557e62b7 100644 --- a/clientapi/auth/password.go +++ b/clientapi/auth/password.go @@ -16,6 +16,9 @@ package auth import ( "context" + "database/sql" + "github.com/go-ldap/ldap/v3" + "github.com/google/uuid" "net/http" "strings" @@ -28,8 +31,6 @@ import ( "github.com/matrix-org/util" ) -type GetAccountByPassword func(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error - type PasswordRequest struct { Login Password string `json:"password"` @@ -37,8 +38,8 @@ type PasswordRequest struct { // LoginTypePassword implements https://matrix.org/docs/spec/client_server/r0.6.1#password-based type LoginTypePassword struct { - GetAccountByPassword GetAccountByPassword - Config *config.ClientAPI + Config *config.ClientAPI + UserAPI api.UserLoginAPI } func (t *LoginTypePassword) Name() string { @@ -59,22 +60,21 @@ func (t *LoginTypePassword) LoginFromJSON(ctx context.Context, reqBytes []byte) return login, func(context.Context, *util.JSONResponse) {}, nil } -func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, *util.JSONResponse) { - r := req.(*PasswordRequest) - username := r.Username() - if username == "" { +func (t *LoginTypePassword) Login(ctx context.Context, request *PasswordRequest) (*Login, *util.JSONResponse) { + fullUsername := request.Username() + if fullUsername == "" { return nil, &util.JSONResponse{ Code: http.StatusUnauthorized, JSON: spec.BadJSON("A username must be supplied."), } } - if len(r.Password) == 0 { + if len(request.Password) == 0 { return nil, &util.JSONResponse{ Code: http.StatusUnauthorized, JSON: spec.BadJSON("A password must be supplied."), } } - localpart, domain, err := userutil.ParseUsernameParam(username, t.Config.Matrix) + username, domain, err := userutil.ParseUsernameParam(fullUsername, t.Config.Matrix) if err != nil { return nil, &util.JSONResponse{ Code: http.StatusUnauthorized, @@ -87,12 +87,38 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, JSON: spec.InvalidUsername("The server name is not known."), } } - // Squash username to all lowercase letters + + var account *api.Account + if t.Config.Ldap.Enabled { + isAdmin, err := t.authenticateLdap(username, request.Password) + if err != nil { + return nil, err + } + acc, err := t.getOrCreateAccount(ctx, username, domain, isAdmin) + if err != nil { + return nil, err + } + account = acc + } else { + acc, err := t.authenticateDb(ctx, username, domain, request.Password) + if err != nil { + return nil, err + } + account = acc + } + + // Set the user, so login.Username() can do the right thing + request.Identifier.User = account.UserID + request.User = account.UserID + return &request.Login, nil +} + +func (t *LoginTypePassword) authenticateDb(ctx context.Context, username string, domain spec.ServerName, password string) (*api.Account, *util.JSONResponse) { res := &api.QueryAccountByPasswordResponse{} - err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{ - Localpart: strings.ToLower(localpart), + err := t.UserAPI.QueryAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{ + Localpart: strings.ToLower(username), ServerName: domain, - PlaintextPassword: r.Password, + PlaintextPassword: password, }, res) if err != nil { return nil, &util.JSONResponse{ @@ -101,13 +127,11 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, } } - // If we couldn't find the user by the lower cased localpart, try the provided - // localpart as is. if !res.Exists { - err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{ - Localpart: localpart, + err = t.UserAPI.QueryAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{ + Localpart: username, ServerName: domain, - PlaintextPassword: r.Password, + PlaintextPassword: password, }, res) if err != nil { return nil, &util.JSONResponse{ @@ -115,8 +139,6 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, JSON: spec.Unknown("Unable to fetch account by password."), } } - // Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows - // but that would leak the existence of the user. if !res.Exists { return nil, &util.JSONResponse{ Code: http.StatusForbidden, @@ -124,8 +146,141 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, } } } - // Set the user, so login.Username() can do the right thing - r.Identifier.User = res.Account.UserID - r.User = res.Account.UserID - return &r.Login, nil + return res.Account, nil +} +func (t *LoginTypePassword) authenticateLdap(username, password string) (bool, *util.JSONResponse) { + var conn *ldap.Conn + conn, err := ldap.DialURL(t.Config.Ldap.Uri) + if err != nil { + return false, &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.Unknown("unable to connect to ldap: " + err.Error()), + } + } + defer conn.Close() + + if t.Config.Ldap.AdminBindEnabled { + err = conn.Bind(t.Config.Ldap.AdminBindDn, t.Config.Ldap.AdminBindPassword) + if err != nil { + return false, &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.Unknown("unable to bind to ldap: " + err.Error()), + } + } + filter := strings.ReplaceAll(t.Config.Ldap.SearchFilter, "{username}", username) + searchRequest := ldap.NewSearchRequest( + t.Config.Ldap.BaseDn, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, + 0, 0, false, filter, []string{t.Config.Ldap.SearchAttribute}, nil, + ) + result, err := conn.Search(searchRequest) + if err != nil { + return false, &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.Unknown("unable to bind to search ldap: " + err.Error()), + } + } + if len(result.Entries) > 1 { + return false, &util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: spec.BadJSON("'user' must be duplicated."), + } + } + if len(result.Entries) < 1 { + return false, &util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: spec.BadJSON("'user' not found."), + } + } + + userDN := result.Entries[0].DN + err = conn.Bind(userDN, password) + if err != nil { + return false, &util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: spec.InvalidUsername(err.Error()), + } + } + } else { + bindDn := strings.ReplaceAll(t.Config.Ldap.UserBindDn, "{username}", username) + err = conn.Bind(bindDn, password) + if err != nil { + return false, &util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: spec.InvalidUsername(err.Error()), + } + } + } + + isAdmin, err := t.isLdapAdmin(conn, username) + if err != nil { + return false, &util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: spec.InvalidUsername(err.Error()), + } + } + return isAdmin, nil +} + +func (t *LoginTypePassword) isLdapAdmin(conn *ldap.Conn, username string) (bool, error) { + searchRequest := ldap.NewSearchRequest( + t.Config.Ldap.AdminGroupDn, + ldap.ScopeWholeSubtree, ldap.DerefAlways, 0, 0, false, + strings.ReplaceAll(t.Config.Ldap.AdminGroupFilter, "{username}", username), + []string{t.Config.Ldap.AdminGroupAttribute}, + nil) + + sr, err := conn.Search(searchRequest) + if err != nil { + return false, err + } + + if len(sr.Entries) < 1 { + return false, nil + } + return true, nil +} + +func (t *LoginTypePassword) getOrCreateAccount(ctx context.Context, username string, domain spec.ServerName, admin bool) (*api.Account, *util.JSONResponse) { + var existing api.QueryAccountByLocalpartResponse + err := t.UserAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{ + Localpart: username, + ServerName: domain, + }, &existing) + + if err == nil { + return existing.Account, nil + } + if err != sql.ErrNoRows { + return nil, &util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: spec.InvalidUsername(err.Error()), + } + } + + accountType := api.AccountTypeUser + if admin { + accountType = api.AccountTypeAdmin + } + var created api.PerformAccountCreationResponse + err = t.UserAPI.PerformAccountCreation(ctx, &api.PerformAccountCreationRequest{ + AppServiceID: "ldap", + Localpart: username, + Password: uuid.New().String(), + AccountType: accountType, + OnConflict: api.ConflictAbort, + }, &created) + + if err != nil { + if _, ok := err.(*api.ErrorConflict); ok { + return nil, &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.UserInUse("Desired user ID is already taken."), + } + } + return nil, &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.Unknown("failed to create account: " + err.Error()), + } + } + return created.Account, nil } diff --git a/clientapi/auth/user_interactive.go b/clientapi/auth/user_interactive.go index 9831450cc..1481693f6 100644 --- a/clientapi/auth/user_interactive.go +++ b/clientapi/auth/user_interactive.go @@ -113,8 +113,8 @@ type UserInteractive struct { func NewUserInteractive(userAccountAPI api.UserLoginAPI, cfg *config.ClientAPI) *UserInteractive { typePassword := &LoginTypePassword{ - GetAccountByPassword: userAccountAPI.QueryAccountByPassword, - Config: cfg, + UserAPI: userAccountAPI, + Config: cfg, } return &UserInteractive{ Flows: []userInteractiveFlow{ diff --git a/clientapi/auth/user_interactive_test.go b/clientapi/auth/user_interactive_test.go index 4003e9647..22ea5abc6 100644 --- a/clientapi/auth/user_interactive_test.go +++ b/clientapi/auth/user_interactive_test.go @@ -45,6 +45,14 @@ func (d *fakeAccountDatabase) QueryAccountByPassword(ctx context.Context, req *a return nil } +func (d *fakeAccountDatabase) QueryAccountByLocalpart(ctx context.Context, req *api.QueryAccountByLocalpartRequest, res *api.QueryAccountByLocalpartResponse) error { + return nil +} + +func (d *fakeAccountDatabase) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { + return nil +} + func setup() *UserInteractive { cfg := &config.ClientAPI{ Matrix: &config.Global{ diff --git a/clientapi/routing/key_crosssigning.go b/clientapi/routing/key_crosssigning.go index 6bf7c58e3..7d1541abd 100644 --- a/clientapi/routing/key_crosssigning.go +++ b/clientapi/routing/key_crosssigning.go @@ -32,7 +32,7 @@ type crossSigningRequest struct { } func UploadCrossSigningDeviceKeys( - req *http.Request, userInteractiveAuth *auth.UserInteractive, + req *http.Request, keyserverAPI api.ClientKeyAPI, device *api.Device, accountAPI api.ClientUserAPI, cfg *config.ClientAPI, ) util.JSONResponse { @@ -62,8 +62,8 @@ func UploadCrossSigningDeviceKeys( } } typePassword := auth.LoginTypePassword{ - GetAccountByPassword: accountAPI.QueryAccountByPassword, - Config: cfg, + UserAPI: accountAPI, + Config: cfg, } if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil { return *authErr diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go index 24c52b06d..4fcadc2c6 100644 --- a/clientapi/routing/password.go +++ b/clientapi/routing/password.go @@ -73,8 +73,8 @@ func Password( // Check if the existing password is correct. typePassword := auth.LoginTypePassword{ - GetAccountByPassword: userAPI.QueryAccountByPassword, - Config: cfg, + UserAPI: userAPI, + Config: cfg, } if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest); authErr != nil { return *authErr diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 60dad5433..3bfac3d1d 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -1448,7 +1448,7 @@ func Setup( // Cross-signing device keys postDeviceSigningKeys := httputil.MakeAuthAPI("post_device_signing_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return UploadCrossSigningDeviceKeys(req, userInteractiveAuth, userAPI, device, userAPI, cfg) + return UploadCrossSigningDeviceKeys(req, userAPI, device, userAPI, cfg) }) postDeviceSigningSignatures := httputil.MakeAuthAPI("post_device_signing_signatures", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { diff --git a/go.mod b/go.mod index 5c8155281..e627f0970 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/docker/docker v24.0.9+incompatible github.com/docker/go-connections v0.4.0 github.com/getsentry/sentry-go v0.14.0 + github.com/go-ldap/ldap/v3 v3.4.4 github.com/gologme/log v1.3.0 github.com/google/go-cmp v0.6.0 github.com/google/uuid v1.3.0 @@ -57,6 +58,7 @@ require ( ) require ( + github.com/Azure/go-ntlmssp v0.0.0-20220621081337-cb9428e4ac1e // indirect github.com/HdrHistogram/hdrhistogram-go v1.1.2 // indirect github.com/Microsoft/go-winio v0.5.2 // indirect github.com/RoaringBitmap/roaring v1.2.3 // indirect @@ -84,6 +86,7 @@ require ( github.com/docker/distribution v2.8.2+incompatible // indirect github.com/docker/go-units v0.5.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect + github.com/go-asn1-ber/asn1-ber v1.5.4 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/geo v0.0.0-20210211234256-740aa86cb551 // indirect diff --git a/setup/config/config_clientapi.go b/setup/config/config_clientapi.go index 85dfe0beb..755aea460 100644 --- a/setup/config/config_clientapi.go +++ b/setup/config/config_clientapi.go @@ -56,9 +56,26 @@ type ClientAPI struct { RateLimiting RateLimiting `yaml:"rate_limiting"` MSCs *MSCs `yaml:"-"` + + Ldap Ldap `yaml:"ldap"` } -func (c *ClientAPI) Defaults(opts DefaultOpts) { +type Ldap struct { + Enabled bool `yaml:"enabled"` + Uri string `yaml:"uri"` + BaseDn string `yaml:"base_dn"` + SearchFilter string `yaml:"search_filter"` + SearchAttribute string `yaml:"search_attribute"` + AdminBindEnabled bool `yaml:"admin_bind_enabled"` + AdminBindDn string `yaml:"admin_bind_dn"` + AdminBindPassword string `yaml:"admin_bind_password"` + UserBindDn string `yaml:"user_bind_dn"` + AdminGroupDn string `yaml:"admin_group_dn"` + AdminGroupFilter string `yaml:"admin_group_filter"` + AdminGroupAttribute string `yaml:"admin_group_attribute"` +} + +func (c *ClientAPI) Defaults(_ DefaultOpts) { c.RegistrationSharedSecret = "" c.RegistrationRequiresToken = false c.RecaptchaPublicKey = "" diff --git a/userapi/api/api.go b/userapi/api/api.go index d4daec820..e63873f68 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -144,6 +144,8 @@ type QueryAcccessTokenAPI interface { type UserLoginAPI interface { QueryAccountByPassword(ctx context.Context, req *QueryAccountByPasswordRequest, res *QueryAccountByPasswordResponse) error + QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) error + PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error } type PerformKeyBackupRequest struct {