mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-23 14:53:10 -06:00
LoginToken and SSO Login initial changes
Database level changes not made
This commit is contained in:
parent
913020e4b7
commit
9dc798c5e4
209
clientapi/auth/login_token.go
Normal file
209
clientapi/auth/login_token.go
Normal file
|
|
@ -0,0 +1,209 @@
|
||||||
|
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||||
|
"github.com/matrix-org/dendrite/internal/config"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
// This file handles all the m.login.token logic
|
||||||
|
|
||||||
|
// GetAccountByLocalpart function implemented by the appropriate database type
|
||||||
|
type GetAccountByLocalpart func(ctx context.Context, localpart string) (*api.Account, error)
|
||||||
|
|
||||||
|
// LoginTokenRequest struct to hold the possible parameters from an m.login.token http request
|
||||||
|
type LoginTokenRequest struct {
|
||||||
|
Login
|
||||||
|
Token string `json:"token"`
|
||||||
|
TxnID string `json:"txn_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginTypeToken holds the configs and the appropriate GetAccountByLocalpart function for the database
|
||||||
|
type LoginTypeToken struct {
|
||||||
|
GetAccountByLocalpart GetAccountByLocalpart
|
||||||
|
Config *config.ClientAPI
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name returns the expected type of "m.login.token"
|
||||||
|
func (t *LoginTypeToken) Name() string {
|
||||||
|
return "m.login.token"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request returns a struct of type LoginTokenRequest
|
||||||
|
func (t *LoginTypeToken) Request() interface{} {
|
||||||
|
return &LoginTokenRequest{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type of the LoginToken
|
||||||
|
type loginToken struct {
|
||||||
|
UserID string
|
||||||
|
CreationTime int64
|
||||||
|
RandomPart string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login completes the whole token validation, user verification for m.login.token
|
||||||
|
// returns a struct of type *auth.Login which has the users details
|
||||||
|
func (t *LoginTypeToken) Login(ctx context.Context, req interface{}) (*Login, *util.JSONResponse) {
|
||||||
|
r := req.(*LoginTokenRequest)
|
||||||
|
userID, err := validateLoginToken(r.Token, r.TxnID, &t.Config.Matrix.ServerName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &util.JSONResponse{
|
||||||
|
Code: http.StatusUnauthorized,
|
||||||
|
JSON: jsonerror.InvalidArgumentValue(err.Error()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.Login.Identifier.User = userID
|
||||||
|
r.Login.Identifier.Type = "m.id.user"
|
||||||
|
|
||||||
|
return &r.Login, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decodes and validates a LoginToken
|
||||||
|
// Accepts the base64 encoded token string as param
|
||||||
|
// Checks the time expiry, userID (only the format, doesn't check to see if the user exists)
|
||||||
|
// Also checks the DB to see if the token exists
|
||||||
|
// Returns the localpart if successful
|
||||||
|
func validateLoginToken(tokenStr string, txnID string, serverName *gomatrixserverlib.ServerName) (string, error) {
|
||||||
|
token, err := decodeLoginToken(tokenStr)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// check whether the token has a valid time.
|
||||||
|
// TODO: should this 5 second window be configurable?
|
||||||
|
if time.Now().Unix()-token.CreationTime > 5 {
|
||||||
|
return "", errors.New("Token has expired")
|
||||||
|
}
|
||||||
|
|
||||||
|
// check whether the UserID is malformed
|
||||||
|
if !strings.Contains(token.UserID, "@") {
|
||||||
|
// TODO: should we reveal details about the error with the token or give vague responses instead?
|
||||||
|
return "", errors.New("Invalid UserID")
|
||||||
|
}
|
||||||
|
if _, err := userutil.ParseUsernameParam(token.UserID, serverName); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// check in the database
|
||||||
|
if err := checkDBToken(tokenStr, txnID); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return token.UserID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateLoginToken creates a login token which is a base64 encoded string of (userID+time+random)
|
||||||
|
// returns an error if it cannot create a random string
|
||||||
|
func GenerateLoginToken(userID string) (string, error) {
|
||||||
|
// the time of token creation
|
||||||
|
timePart := []byte(strconv.FormatInt(time.Now().Unix(), 10))
|
||||||
|
|
||||||
|
// the random part of the token
|
||||||
|
randPart := make([]byte, 10)
|
||||||
|
if _, err := rand.Read(randPart); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// url-safe no padding
|
||||||
|
return base64.RawURLEncoding.EncodeToString([]byte(userID)) + "." + base64.RawURLEncoding.EncodeToString(timePart) + "." + base64.RawURLEncoding.EncodeToString(randPart), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decodes the given tokenStr into a LoginToken struct
|
||||||
|
func decodeLoginToken(tokenStr string) (*loginToken, error) {
|
||||||
|
// split the string into it's constituent parts
|
||||||
|
strParts := strings.Split(tokenStr, ".")
|
||||||
|
if len(strParts) != 3 {
|
||||||
|
return nil, errors.New("Malformed token string")
|
||||||
|
}
|
||||||
|
|
||||||
|
var token loginToken
|
||||||
|
// decode each of the strParts
|
||||||
|
userBytes, err := base64.RawURLEncoding.DecodeString(strParts[0])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("Invalid user ID")
|
||||||
|
}
|
||||||
|
token.UserID = string(userBytes)
|
||||||
|
|
||||||
|
// first decode the time to a string
|
||||||
|
timeBytes, err := base64.RawURLEncoding.DecodeString(strParts[1])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("Invalid creation time")
|
||||||
|
}
|
||||||
|
// now convert the string to an integer
|
||||||
|
creationTime, err := strconv.ParseInt(string(timeBytes), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("Invalid creation time")
|
||||||
|
}
|
||||||
|
token.CreationTime = creationTime
|
||||||
|
|
||||||
|
randomBytes, err := base64.RawURLEncoding.DecodeString(strParts[2])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("Invalid random part")
|
||||||
|
}
|
||||||
|
token.UserID = string(randomBytes)
|
||||||
|
|
||||||
|
token = loginToken{
|
||||||
|
UserID: string(userBytes),
|
||||||
|
CreationTime: creationTime,
|
||||||
|
RandomPart: string(randomBytes),
|
||||||
|
}
|
||||||
|
return &token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Checks whether the token exists in the DB and whether the token is assigned to the current transaction ID
|
||||||
|
// Does not validate the userID or the creation time expiry
|
||||||
|
// Returns nil if successful
|
||||||
|
func checkDBToken(tokenStr string, txnID string) error {
|
||||||
|
// if the client has provided a transaction id, try to lock the token to that ID
|
||||||
|
if txnID != "" {
|
||||||
|
if err := LinkToken(tokenStr, txnID); err != nil {
|
||||||
|
// TODO: should we abort the login attempt or something else?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StoreLoginToken stores the login token in the database
|
||||||
|
// Returns nil if successful
|
||||||
|
func StoreLoginToken(tokenStr string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteLoginToken Deletes a token from the DB
|
||||||
|
// used to delete a token that has already been used
|
||||||
|
// Returns nil if successful
|
||||||
|
func DeleteLoginToken(tokenStr string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LinkToken Links a token to a transaction ID so no other client can try to login using that token
|
||||||
|
// as specified in https://matrix.org/docs/spec/client_server/r0.6.1#token-based
|
||||||
|
func LinkToken(tokenStr string, txnID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
@ -15,7 +15,10 @@
|
||||||
package routing
|
package routing
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth"
|
"github.com/matrix-org/dendrite/clientapi/auth"
|
||||||
|
|
@ -53,6 +56,15 @@ func passwordLogin() flows {
|
||||||
return f
|
return f
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ssoLogin() flows {
|
||||||
|
f := flows{}
|
||||||
|
s := flow{
|
||||||
|
Type: "m.login.sso",
|
||||||
|
}
|
||||||
|
f.Flows = append(f.Flows, s)
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
|
||||||
// Login implements GET and POST /login
|
// Login implements GET and POST /login
|
||||||
func Login(
|
func Login(
|
||||||
req *http.Request, accountDB accounts.Database, userAPI userapi.UserInternalAPI,
|
req *http.Request, accountDB accounts.Database, userAPI userapi.UserInternalAPI,
|
||||||
|
|
@ -60,33 +72,104 @@ func Login(
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
if req.Method == http.MethodGet {
|
if req.Method == http.MethodGet {
|
||||||
// TODO: support other forms of login other than password, depending on config options
|
// TODO: support other forms of login other than password, depending on config options
|
||||||
|
flows := passwordLogin()
|
||||||
|
if cfg.CAS.Enabled {
|
||||||
|
flows.Flows = append(flows.Flows, ssoLogin().Flows...)
|
||||||
|
}
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
JSON: passwordLogin(),
|
JSON: flows,
|
||||||
}
|
}
|
||||||
} else if req.Method == http.MethodPost {
|
} else if req.Method == http.MethodPost {
|
||||||
typePassword := auth.LoginTypePassword{
|
// TODO: is the the right way to read the body and re-add it?
|
||||||
GetAccountByPassword: accountDB.GetAccountByPassword,
|
body, err := ioutil.ReadAll(req.Body)
|
||||||
Config: cfg,
|
if err != nil {
|
||||||
|
// TODO: is this appropriate?
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusMethodNotAllowed,
|
||||||
|
JSON: jsonerror.NotFound("Bad method"),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
r := typePassword.Request()
|
// add the body back to the request because ioutil.ReadAll consumes the body
|
||||||
resErr := httputil.UnmarshalJSONRequest(req, r)
|
req.Body = ioutil.NopCloser(bytes.NewBuffer(body))
|
||||||
if resErr != nil {
|
|
||||||
return *resErr
|
// marshall the body into an unstructured json map
|
||||||
|
var jsonBody map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(body), &jsonBody); err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusMethodNotAllowed,
|
||||||
|
JSON: jsonerror.NotFound("Bad method"),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
login, authErr := typePassword.Login(req.Context(), r)
|
|
||||||
if authErr != nil {
|
loginType := jsonBody["type"].(string)
|
||||||
return *authErr
|
if loginType == "m.login.password" {
|
||||||
|
return doPasswordLogin(req, accountDB, userAPI, cfg)
|
||||||
|
} else if loginType == "m.login.token" {
|
||||||
|
return doTokenLogin(req, accountDB, userAPI, cfg)
|
||||||
}
|
}
|
||||||
// make a device/access token
|
|
||||||
return completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusMethodNotAllowed,
|
Code: http.StatusMethodNotAllowed,
|
||||||
JSON: jsonerror.NotFound("Bad method"),
|
JSON: jsonerror.NotFound("Bad method"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handles a m.login.password login type request
|
||||||
|
func doPasswordLogin(
|
||||||
|
req *http.Request, accountDB accounts.Database, userAPI userapi.UserInternalAPI,
|
||||||
|
cfg *config.ClientAPI,
|
||||||
|
) util.JSONResponse {
|
||||||
|
typePassword := auth.LoginTypePassword{
|
||||||
|
GetAccountByPassword: accountDB.GetAccountByPassword,
|
||||||
|
Config: cfg,
|
||||||
|
}
|
||||||
|
r := typePassword.Request()
|
||||||
|
resErr := httputil.UnmarshalJSONRequest(req, r)
|
||||||
|
if resErr != nil {
|
||||||
|
return *resErr
|
||||||
|
}
|
||||||
|
login, authErr := typePassword.Login(req.Context(), r)
|
||||||
|
if authErr != nil {
|
||||||
|
return *authErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// make a device/access token
|
||||||
|
return completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handles a m.login.token login type request
|
||||||
|
func doTokenLogin(req *http.Request, accountDB accounts.Database, userAPI userapi.UserInternalAPI,
|
||||||
|
cfg *config.ClientAPI,
|
||||||
|
) util.JSONResponse {
|
||||||
|
// create a struct with the appropriate DB(postgres/sqlite) function and the configs
|
||||||
|
typeToken := auth.LoginTypeToken{
|
||||||
|
GetAccountByLocalpart: accountDB.GetAccountByLocalpart,
|
||||||
|
Config: cfg,
|
||||||
|
}
|
||||||
|
r := typeToken.Request()
|
||||||
|
resErr := httputil.UnmarshalJSONRequest(req, r)
|
||||||
|
if resErr != nil {
|
||||||
|
return *resErr
|
||||||
|
}
|
||||||
|
login, authErr := typeToken.Login(req.Context(), r)
|
||||||
|
if authErr != nil {
|
||||||
|
return *authErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// make a device/access token
|
||||||
|
authResult := completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login)
|
||||||
|
|
||||||
|
// the login is successful, delete the login token before returning the access token to the client
|
||||||
|
if authResult.Code == http.StatusOK {
|
||||||
|
if err := auth.DeleteLoginToken(r.(*auth.LoginTokenRequest).Token); err != nil {
|
||||||
|
// TODO: what to do here?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return authResult
|
||||||
|
}
|
||||||
|
|
||||||
func completeAuth(
|
func completeAuth(
|
||||||
ctx context.Context, serverName gomatrixserverlib.ServerName, userAPI userapi.UserInternalAPI, login *auth.Login,
|
ctx context.Context, serverName gomatrixserverlib.ServerName, userAPI userapi.UserInternalAPI, login *auth.Login,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
|
|
|
||||||
|
|
@ -446,6 +446,12 @@ func Setup(
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
|
r0mux.Handle("/login/sso/redirect",
|
||||||
|
httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse {
|
||||||
|
return SSORedirect(req, accountDB, cfg)
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
r0mux.Handle("/auth/{authType}/fallback/web",
|
r0mux.Handle("/auth/{authType}/fallback/web",
|
||||||
httputil.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) *util.JSONResponse {
|
httputil.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) *util.JSONResponse {
|
||||||
vars := mux.Vars(req)
|
vars := mux.Vars(req)
|
||||||
|
|
|
||||||
218
clientapi/routing/sso.go
Normal file
218
clientapi/routing/sso.go
Normal file
|
|
@ -0,0 +1,218 @@
|
||||||
|
// Copyright 2017 Vector Creations Ltd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/xml"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
|
"github.com/matrix-org/dendrite/internal/config"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
// the XML response structure of CAS ticket validation
|
||||||
|
type casValidateResponse struct {
|
||||||
|
XMLName xml.Name `xml:"serviceResponse"`
|
||||||
|
Cas string `xml:"cas,attr"`
|
||||||
|
AuthenticationSuccess struct {
|
||||||
|
User string `xml:"user"`
|
||||||
|
} `xml:"authenticationSuccess"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSORedirect implements GET /login/sso/redirect
|
||||||
|
// https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-login-sso-redirect
|
||||||
|
// If the incoming request doesn't contain a SSO token, it will redirect to the SSO server
|
||||||
|
// Else it will validate the SSO token, and redirect to the "redirectURL" provided with an extra "loginToken" param
|
||||||
|
func SSORedirect(
|
||||||
|
req *http.Request,
|
||||||
|
accountDB accounts.Database,
|
||||||
|
cfg *config.ClientAPI,
|
||||||
|
) util.JSONResponse {
|
||||||
|
// If dendrite is not configured to use SSO by the admin return bad method
|
||||||
|
if !cfg.CAS.Enabled || cfg.CAS.Server == "" {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusMethodNotAllowed,
|
||||||
|
JSON: jsonerror.NotFound("Bad method"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to parse the SSO URL configured to a url.URL type
|
||||||
|
ssoURL, err := url.Parse(cfg.CAS.Server)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
JSON: jsonerror.Unknown("Failed to parse SSO URL configured: " + err.Error()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// A redirect URL is required for this endpoint
|
||||||
|
redirectURLStr := req.FormValue("redirectUrl")
|
||||||
|
if redirectURLStr == "" {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.MissingArgument("redirectUrl parameter missing"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Check if the redirect url is a valid URL
|
||||||
|
redirectURL, err := url.Parse(redirectURLStr)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
JSON: jsonerror.Unknown("Invalid redirectURL: " + err.Error()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the request has a ticket param, validate the ticket instead of redirecting to SSO server
|
||||||
|
if ticket := req.FormValue("ticket"); ticket != "" {
|
||||||
|
return ssoTicket(req, redirectURL, accountDB, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adding the params to the sso url
|
||||||
|
ssoQueries := make(url.Values)
|
||||||
|
// the service url that we send to CAS is homeserver.com/_matrix/client/r0/login/sso/redirect?redirectUrl=xyz
|
||||||
|
ssoQueries.Set("service", req.RequestURI)
|
||||||
|
|
||||||
|
ssoURL.RawQuery = ssoQueries.Encode()
|
||||||
|
|
||||||
|
return util.RedirectResponse(ssoURL.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ssoTicket handles the m.login.sso login attempt after the user had completed auth at the SSO server
|
||||||
|
// - gets the ticket from the SSO server (this is different from the matrix login/access token)
|
||||||
|
// - calls validateTicket to validate the ticket
|
||||||
|
// - calls completeSSOAuth
|
||||||
|
func ssoTicket(
|
||||||
|
req *http.Request,
|
||||||
|
redirectURL *url.URL,
|
||||||
|
accountDB accounts.Database,
|
||||||
|
cfg *config.ClientAPI,
|
||||||
|
) util.JSONResponse {
|
||||||
|
// form the ticket validation URL from the config
|
||||||
|
ssoURL, err := url.Parse(cfg.CAS.Server + cfg.CAS.ValidateEndpoint)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
JSON: jsonerror.Unknown("Failed to parse SSO URL configured: " + err.Error()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ticket := req.FormValue("ticket")
|
||||||
|
|
||||||
|
// append required params to the CAS validate endpoint
|
||||||
|
ssoQueries := make(url.Values)
|
||||||
|
ssoQueries.Set("ticket", ticket)
|
||||||
|
ssoURL.RawQuery = ssoQueries.Encode()
|
||||||
|
|
||||||
|
// validate the ticket
|
||||||
|
casUsername, err := validateTicket(ssoURL.String())
|
||||||
|
if err != nil {
|
||||||
|
// TODO: should I be logging these? What else should I log?
|
||||||
|
util.GetLogger(req.Context()).WithError(err).Error("CAS SSO ticket validation failed")
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusUnauthorized,
|
||||||
|
JSON: jsonerror.Unknown("Could not validate SSO token: " + err.Error()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if casUsername == "" {
|
||||||
|
util.GetLogger(req.Context()).WithError(err).Error("CAS SSO returned no user")
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusUnauthorized,
|
||||||
|
JSON: jsonerror.Unknown("CAS SSO returned no user"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ticket validated. Login the user
|
||||||
|
return completeSSOAuth(req, casUsername, redirectURL, accountDB)
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateTicket sends the ticket to the sso server to get it validated
|
||||||
|
// the CAS server responds with an xml which contains the username
|
||||||
|
// validateTicket returns the SSO User
|
||||||
|
func validateTicket(
|
||||||
|
ssoURL string,
|
||||||
|
) (string, error) {
|
||||||
|
// make the call to the sso server to validate
|
||||||
|
response, err := http.Get(ssoURL)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// extract the response from the sso server
|
||||||
|
data, err := ioutil.ReadAll(response.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// parse the response to the CAS XML format
|
||||||
|
var res casValidateResponse
|
||||||
|
if err := xml.Unmarshal([]byte(data), &res); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return res.AuthenticationSuccess.User, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// completeSSOAuth completes the SSO auth and returns a m.login.token for the client to authenticate with
|
||||||
|
// if the user doesn't exist, a new user is created
|
||||||
|
func completeSSOAuth(
|
||||||
|
req *http.Request,
|
||||||
|
username string,
|
||||||
|
redirectURL *url.URL,
|
||||||
|
accountDB accounts.Database,
|
||||||
|
) util.JSONResponse {
|
||||||
|
// try to create an account with that username
|
||||||
|
// if the user exists, then we pick that user, else we create a new user
|
||||||
|
account, err := accountDB.CreateAccount(req.Context(), username, "", "")
|
||||||
|
if err != nil {
|
||||||
|
// some error
|
||||||
|
if err != sqlutil.ErrUserExists {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusUnauthorized,
|
||||||
|
JSON: jsonerror.Unknown("Could not create new user"),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// user already exists, so just pick up their details
|
||||||
|
account, err = accountDB.GetAccountByLocalpart(req.Context(), username)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusUnauthorized,
|
||||||
|
JSON: jsonerror.Unknown("Could not query user"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
token, err := auth.GenerateLoginToken(account.UserID)
|
||||||
|
if err != nil || token == "" {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
JSON: jsonerror.Unknown("Could not generate login token"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// add the params to the sso url
|
||||||
|
redirectQueries := make(url.Values)
|
||||||
|
// the service url that we send to CAS is homeserver.com/_matrix/client/r0/login/sso/redirect?redirectUrl=xyz
|
||||||
|
redirectQueries.Set("loginToken", token)
|
||||||
|
|
||||||
|
redirectURL.RawQuery = redirectQueries.Encode()
|
||||||
|
|
||||||
|
return util.RedirectResponse(redirectURL.String())
|
||||||
|
}
|
||||||
|
|
@ -32,6 +32,9 @@ type ClientAPI struct {
|
||||||
// was successful
|
// was successful
|
||||||
RecaptchaSiteVerifyAPI string `yaml:"recaptcha_siteverify_api"`
|
RecaptchaSiteVerifyAPI string `yaml:"recaptcha_siteverify_api"`
|
||||||
|
|
||||||
|
// CAS server settings
|
||||||
|
CAS CAS `yaml:"cas"`
|
||||||
|
|
||||||
// TURN options
|
// TURN options
|
||||||
TURN TURN `yaml:"turn"`
|
TURN TURN `yaml:"turn"`
|
||||||
|
|
||||||
|
|
@ -51,6 +54,7 @@ func (c *ClientAPI) Defaults() {
|
||||||
c.RecaptchaSiteVerifyAPI = ""
|
c.RecaptchaSiteVerifyAPI = ""
|
||||||
c.RegistrationDisabled = false
|
c.RegistrationDisabled = false
|
||||||
c.RateLimiting.Defaults()
|
c.RateLimiting.Defaults()
|
||||||
|
c.CAS.Enabled = false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
|
func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
|
||||||
|
|
@ -64,10 +68,24 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
|
||||||
checkNotEmpty(configErrs, "client_api.recaptcha_private_key", string(c.RecaptchaPrivateKey))
|
checkNotEmpty(configErrs, "client_api.recaptcha_private_key", string(c.RecaptchaPrivateKey))
|
||||||
checkNotEmpty(configErrs, "client_api.recaptcha_siteverify_api", string(c.RecaptchaSiteVerifyAPI))
|
checkNotEmpty(configErrs, "client_api.recaptcha_siteverify_api", string(c.RecaptchaSiteVerifyAPI))
|
||||||
}
|
}
|
||||||
|
c.CAS.Verify(configErrs)
|
||||||
c.TURN.Verify(configErrs)
|
c.TURN.Verify(configErrs)
|
||||||
c.RateLimiting.Verify(configErrs)
|
c.RateLimiting.Verify(configErrs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type CAS struct {
|
||||||
|
Enabled bool `yaml:"cas_enabled"`
|
||||||
|
Server string `yaml:"cas_server"`
|
||||||
|
ValidateEndpoint string `yaml:"cas_validate_endpoint"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cas *CAS) Verify(ConfigErrors *ConfigErrors) {
|
||||||
|
if cas.Enabled {
|
||||||
|
checkURL(ConfigErrors, "client_api.cas.cas_server", cas.Server)
|
||||||
|
checkNotEmpty(ConfigErrors, "client_api.cas.cas_validate_endpoint", cas.ValidateEndpoint)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type TURN struct {
|
type TURN struct {
|
||||||
// TODO Guest Support
|
// TODO Guest Support
|
||||||
// Whether or not guests can request TURN credentials
|
// Whether or not guests can request TURN credentials
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue