From a6b13a703d2eee554b7c149af3b382aa73ed8e53 Mon Sep 17 00:00:00 2001 From: Tommie Gannert Date: Mon, 23 May 2022 18:14:45 +0200 Subject: [PATCH] Add automatic registration of SSO accounts. --- clientapi/routing/sso.go | 77 +++++++++++++++++++++++++++++----------- 1 file changed, 57 insertions(+), 20 deletions(-) diff --git a/clientapi/routing/sso.go b/clientapi/routing/sso.go index 17c95a8ca..d1b224fbe 100644 --- a/clientapi/routing/sso.go +++ b/clientapi/routing/sso.go @@ -165,7 +165,7 @@ func SSOCallback( return util.RedirectResponse(result.RedirectURL) } - id, err := verifySSOUserIdentifier(ctx, userAPI, result.Identifier, cfg.Matrix.ServerName) + localpart, err := verifySSOUserIdentifier(ctx, userAPI, result.Identifier, cfg.Matrix.ServerName) if err != nil { util.GetLogger(ctx).WithError(err).WithField("identifier", result.Identifier).Error("failed to find user") return util.JSONResponse{ @@ -173,17 +173,18 @@ func SSOCallback( JSON: jsonerror.Forbidden("ID not associated with a local account"), } } - if id == nil { + if localpart == "" { // The user doesn't exist. - // TODO: let the user select a localpart and register an account. - util.GetLogger(ctx).WithError(err).WithField("identifier", result.Identifier).Error("failed to find user") - return util.JSONResponse{ - Code: http.StatusNotImplemented, - JSON: jsonerror.Forbidden("SSO registration not implemented"), + // TODO: let the user select the local part, and whether to associate email addresses. + localpart = result.SuggestedUserID + ok, resp := registerSSOAccount(ctx, userAPI, result.Identifier, localpart) + if !ok { + util.GetLogger(ctx).WithError(err).WithField("identifier", result.Identifier).WithField("localpart", localpart).Error("failed to create account") + return resp } } - token, err := createLoginToken(ctx, userAPI, id) + token, err := createLoginToken(ctx, userAPI, userutil.MakeUserID(localpart, cfg.Matrix.ServerName)) if err != nil { util.GetLogger(ctx).WithError(err).Errorf("PerformLoginTokenCreation failed") return jsonerror.InternalServerError() @@ -204,6 +205,8 @@ func SSOCallback( type userAPIForSSO interface { uapi.LoginTokenInternalAPI + PerformAccountCreation(ctx context.Context, req *uapi.PerformAccountCreationRequest, res *uapi.PerformAccountCreationResponse) error + PerformSaveSSOAssociation(ctx context.Context, req *uapi.PerformSaveSSOAssociationRequest, res *struct{}) error QueryLocalpartForSSO(ctx context.Context, req *uapi.QueryLocalpartForSSORequest, res *uapi.QueryLocalpartForSSOResponse) error } @@ -254,10 +257,10 @@ func parseNonce(s string) (redirectURL *url.URL, _ error) { return u, nil } -// verifySSOUserIdentifier resolves an sso.UserIdentifier to a -// UserIdentifier using the User API. Returns nil if there is no -// associated user. -func verifySSOUserIdentifier(ctx context.Context, userAPI userAPIForSSO, id *sso.UserIdentifier, serverName gomatrixserverlib.ServerName) (*userutil.UserIdentifier, error) { +// verifySSOUserIdentifier resolves an sso.UserIdentifier to a local +// part using the User API. Returns empty if there is no associated +// user. +func verifySSOUserIdentifier(ctx context.Context, userAPI userAPIForSSO, id *sso.UserIdentifier, serverName gomatrixserverlib.ServerName) (localpart string, _ error) { req := &uapi.QueryLocalpartForSSORequest{ Namespace: id.Namespace, Issuer: id.Issuer, @@ -265,17 +268,51 @@ func verifySSOUserIdentifier(ctx context.Context, userAPI userAPIForSSO, id *sso } var res uapi.QueryLocalpartForSSOResponse if err := userAPI.QueryLocalpartForSSO(ctx, req, &res); err != nil { - return nil, err + return "", err } - if res.Localpart == "" { - return nil, nil - } - - return &userutil.UserIdentifier{UserID: userutil.MakeUserID(res.Localpart, serverName)}, nil + return res.Localpart, nil } -func createLoginToken(ctx context.Context, userAPI userAPIForSSO, id *userutil.UserIdentifier) (*uapi.LoginTokenMetadata, error) { - req := uapi.PerformLoginTokenCreationRequest{Data: uapi.LoginTokenData{UserID: id.UserID}} +func registerSSOAccount(ctx context.Context, userAPI userAPIForSSO, ssoID *sso.UserIdentifier, localpart string) (bool, util.JSONResponse) { + var accRes uapi.PerformAccountCreationResponse + err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + Localpart: localpart, + AccountType: uapi.AccountTypeUser, + OnConflict: uapi.ConflictAbort, + }, &accRes) + if err != nil { + if _, ok := err.(*uapi.ErrorConflict); ok { + return false, util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.UserInUse("Desired user ID is already taken."), + } + } + return false, util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.Unknown("failed to create account: " + err.Error()), + } + } + + amtRegUsers.Inc() + + err = userAPI.PerformSaveSSOAssociation(ctx, &uapi.PerformSaveSSOAssociationRequest{ + Namespace: ssoID.Namespace, + Issuer: ssoID.Issuer, + Subject: ssoID.Subject, + Localpart: localpart, + }, &struct{}{}) + if err != nil { + return false, util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.Unknown("failed to associate SSO credentials with account: " + err.Error()), + } + } + + return true, util.JSONResponse{} +} + +func createLoginToken(ctx context.Context, userAPI userAPIForSSO, userID string) (*uapi.LoginTokenMetadata, error) { + req := uapi.PerformLoginTokenCreationRequest{Data: uapi.LoginTokenData{UserID: userID}} var resp uapi.PerformLoginTokenCreationResponse if err := userAPI.PerformLoginTokenCreation(ctx, &req, &resp); err != nil { return nil, err