mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-23 14:53:10 -06:00
Updated sso implementation
Changed with comments from Kegsay and Half-Shot - changed some return error codes - moved sso url creation &validation to startup time - added test to sytest whitelist
This commit is contained in:
parent
9dc798c5e4
commit
ef21bed096
|
|
@ -84,10 +84,9 @@ func Login(
|
||||||
// TODO: is the the right way to read the body and re-add it?
|
// TODO: is the the right way to read the body and re-add it?
|
||||||
body, err := ioutil.ReadAll(req.Body)
|
body, err := ioutil.ReadAll(req.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO: is this appropriate?
|
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusMethodNotAllowed,
|
Code: http.StatusBadRequest,
|
||||||
JSON: jsonerror.NotFound("Bad method"),
|
JSON: jsonerror.BadJSON("Bad JSON"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// add the body back to the request because ioutil.ReadAll consumes the body
|
// add the body back to the request because ioutil.ReadAll consumes the body
|
||||||
|
|
@ -97,12 +96,20 @@ func Login(
|
||||||
var jsonBody map[string]interface{}
|
var jsonBody map[string]interface{}
|
||||||
if err := json.Unmarshal([]byte(body), &jsonBody); err != nil {
|
if err := json.Unmarshal([]byte(body), &jsonBody); err != nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusMethodNotAllowed,
|
Code: http.StatusBadRequest,
|
||||||
JSON: jsonerror.NotFound("Bad method"),
|
JSON: jsonerror.BadJSON("Bad JSON"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
loginType := jsonBody["type"].(string)
|
var loginType string
|
||||||
|
if val, ok := jsonBody["type"]; ok {
|
||||||
|
loginType = val.(string)
|
||||||
|
} else {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.BadJSON("No 'type' parameter"),
|
||||||
|
}
|
||||||
|
}
|
||||||
if loginType == "m.login.password" {
|
if loginType == "m.login.password" {
|
||||||
return doPasswordLogin(req, accountDB, userAPI, cfg)
|
return doPasswordLogin(req, accountDB, userAPI, cfg)
|
||||||
} else if loginType == "m.login.token" {
|
} else if loginType == "m.login.token" {
|
||||||
|
|
@ -164,7 +171,7 @@ func doTokenLogin(req *http.Request, accountDB accounts.Database, userAPI userap
|
||||||
// the login is successful, delete the login token before returning the access token to the client
|
// the login is successful, delete the login token before returning the access token to the client
|
||||||
if authResult.Code == http.StatusOK {
|
if authResult.Code == http.StatusOK {
|
||||||
if err := auth.DeleteLoginToken(r.(*auth.LoginTokenRequest).Token); err != nil {
|
if err := auth.DeleteLoginToken(r.(*auth.LoginTokenRequest).Token); err != nil {
|
||||||
// TODO: what to do here?
|
util.GetLogger(req.Context()).WithError(err).Error("Could not delete login ticket from DB")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return authResult
|
return authResult
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
// Copyright 2017 Vector Creations Ltd
|
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
// you may not use this file except in compliance with the License.
|
// you may not use this file except in compliance with the License.
|
||||||
|
|
@ -49,17 +49,8 @@ func SSORedirect(
|
||||||
// If dendrite is not configured to use SSO by the admin return bad method
|
// If dendrite is not configured to use SSO by the admin return bad method
|
||||||
if !cfg.CAS.Enabled || cfg.CAS.Server == "" {
|
if !cfg.CAS.Enabled || cfg.CAS.Server == "" {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusMethodNotAllowed,
|
Code: http.StatusNotImplemented,
|
||||||
JSON: jsonerror.NotFound("Bad method"),
|
JSON: jsonerror.NotFound("Method disabled"),
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try to parse the SSO URL configured to a url.URL type
|
|
||||||
ssoURL, err := url.Parse(cfg.CAS.Server)
|
|
||||||
if err != nil {
|
|
||||||
return util.JSONResponse{
|
|
||||||
Code: http.StatusInternalServerError,
|
|
||||||
JSON: jsonerror.Unknown("Failed to parse SSO URL configured: " + err.Error()),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -75,8 +66,8 @@ func SSORedirect(
|
||||||
redirectURL, err := url.Parse(redirectURLStr)
|
redirectURL, err := url.Parse(redirectURLStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusInternalServerError,
|
Code: http.StatusBadRequest,
|
||||||
JSON: jsonerror.Unknown("Invalid redirectURL: " + err.Error()),
|
JSON: jsonerror.InvalidArgumentValue("Invalid redirectURL: " + err.Error()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -86,10 +77,10 @@ func SSORedirect(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Adding the params to the sso url
|
// Adding the params to the sso url
|
||||||
|
ssoURL := cfg.CAS.URL
|
||||||
ssoQueries := make(url.Values)
|
ssoQueries := make(url.Values)
|
||||||
// the service url that we send to CAS is homeserver.com/_matrix/client/r0/login/sso/redirect?redirectUrl=xyz
|
// the service url that we send to CAS is homeserver.com/_matrix/client/r0/login/sso/redirect?redirectUrl=xyz
|
||||||
ssoQueries.Set("service", req.RequestURI)
|
ssoQueries.Set("service", req.RequestURI)
|
||||||
|
|
||||||
ssoURL.RawQuery = ssoQueries.Encode()
|
ssoURL.RawQuery = ssoQueries.Encode()
|
||||||
|
|
||||||
return util.RedirectResponse(ssoURL.String())
|
return util.RedirectResponse(ssoURL.String())
|
||||||
|
|
@ -106,23 +97,16 @@ func ssoTicket(
|
||||||
cfg *config.ClientAPI,
|
cfg *config.ClientAPI,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
// form the ticket validation URL from the config
|
// form the ticket validation URL from the config
|
||||||
ssoURL, err := url.Parse(cfg.CAS.Server + cfg.CAS.ValidateEndpoint)
|
validateURL := cfg.CAS.ValidateURL
|
||||||
if err != nil {
|
|
||||||
return util.JSONResponse{
|
|
||||||
Code: http.StatusInternalServerError,
|
|
||||||
JSON: jsonerror.Unknown("Failed to parse SSO URL configured: " + err.Error()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ticket := req.FormValue("ticket")
|
ticket := req.FormValue("ticket")
|
||||||
|
|
||||||
// append required params to the CAS validate endpoint
|
// append required params to the CAS validate endpoint
|
||||||
ssoQueries := make(url.Values)
|
validateQueries := make(url.Values)
|
||||||
ssoQueries.Set("ticket", ticket)
|
validateQueries.Set("ticket", ticket)
|
||||||
ssoURL.RawQuery = ssoQueries.Encode()
|
validateURL.RawQuery = validateQueries.Encode()
|
||||||
|
|
||||||
// validate the ticket
|
// validate the ticket
|
||||||
casUsername, err := validateTicket(ssoURL.String())
|
casUsername, err := validateTicket(validateURL.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO: should I be logging these? What else should I log?
|
// TODO: should I be logging these? What else should I log?
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("CAS SSO ticket validation failed")
|
util.GetLogger(req.Context()).WithError(err).Error("CAS SSO ticket validation failed")
|
||||||
|
|
@ -182,10 +166,11 @@ func completeSSOAuth(
|
||||||
// if the user exists, then we pick that user, else we create a new user
|
// if the user exists, then we pick that user, else we create a new user
|
||||||
account, err := accountDB.CreateAccount(req.Context(), username, "", "")
|
account, err := accountDB.CreateAccount(req.Context(), username, "", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// some error
|
|
||||||
if err != sqlutil.ErrUserExists {
|
if err != sqlutil.ErrUserExists {
|
||||||
|
// some error
|
||||||
|
util.GetLogger(req.Context()).WithError(err).Error("Could not create new user")
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusUnauthorized,
|
Code: http.StatusInternalServerError,
|
||||||
JSON: jsonerror.Unknown("Could not create new user"),
|
JSON: jsonerror.Unknown("Could not create new user"),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -193,7 +178,7 @@ func completeSSOAuth(
|
||||||
account, err = accountDB.GetAccountByLocalpart(req.Context(), username)
|
account, err = accountDB.GetAccountByLocalpart(req.Context(), username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusUnauthorized,
|
Code: http.StatusInternalServerError,
|
||||||
JSON: jsonerror.Unknown("Could not query user"),
|
JSON: jsonerror.Unknown("Could not query user"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,8 @@ package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"path"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -77,12 +79,21 @@ type CAS struct {
|
||||||
Enabled bool `yaml:"cas_enabled"`
|
Enabled bool `yaml:"cas_enabled"`
|
||||||
Server string `yaml:"cas_server"`
|
Server string `yaml:"cas_server"`
|
||||||
ValidateEndpoint string `yaml:"cas_validate_endpoint"`
|
ValidateEndpoint string `yaml:"cas_validate_endpoint"`
|
||||||
|
URL *url.URL
|
||||||
|
ValidateURL *url.URL
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cas *CAS) Verify(ConfigErrors *ConfigErrors) {
|
func (cas *CAS) Verify(ConfigErrors *ConfigErrors) {
|
||||||
if cas.Enabled {
|
if cas.Enabled {
|
||||||
checkURL(ConfigErrors, "client_api.cas.cas_server", cas.Server)
|
checkURL(ConfigErrors, "client_api.cas.cas_server", cas.Server)
|
||||||
checkNotEmpty(ConfigErrors, "client_api.cas.cas_validate_endpoint", cas.ValidateEndpoint)
|
checkNotEmpty(ConfigErrors, "client_api.cas.cas_validate_endpoint", cas.ValidateEndpoint)
|
||||||
|
var err error
|
||||||
|
cas.URL, err = url.Parse(cas.Server)
|
||||||
|
if err != nil {
|
||||||
|
ConfigErrors.Add(fmt.Sprintf("Couldn't parse %q (%q)to a URL", "client_api.cas.cas_server", cas.Server))
|
||||||
|
}
|
||||||
|
cas.ValidateURL.Path = path.Join(cas.URL.Path, cas.ValidateEndpoint)
|
||||||
|
checkURL(ConfigErrors, "client_api.cas.cas_validate_endpoint", cas.ValidateURL.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -470,4 +470,6 @@ We can't peek into rooms with shared history_visibility
|
||||||
We can't peek into rooms with invited history_visibility
|
We can't peek into rooms with invited history_visibility
|
||||||
We can't peek into rooms with joined history_visibility
|
We can't peek into rooms with joined history_visibility
|
||||||
Local users can peek by room alias
|
Local users can peek by room alias
|
||||||
Peeked rooms only turn up in the sync for the device who peeked them
|
Peeked rooms only turn up in the sync for the device who peeked them
|
||||||
|
Room state at a rejected message event is the same as its predecessor
|
||||||
|
Room state at a rejected state event is the same as its predecessor
|
||||||
Loading…
Reference in a new issue