Updated sso implementation

Changed with comments from Kegsay and Half-Shot
- changed some return error codes
- moved sso url creation &validation to startup time
- added test to sytest whitelist
This commit is contained in:
Anand Vasudevan 2020-09-21 17:41:37 +05:30
parent 9dc798c5e4
commit ef21bed096
4 changed files with 43 additions and 38 deletions

View file

@ -84,10 +84,9 @@ func Login(
// TODO: is the the right way to read the body and re-add it? // TODO: is the the right way to read the body and re-add it?
body, err := ioutil.ReadAll(req.Body) body, err := ioutil.ReadAll(req.Body)
if err != nil { if err != nil {
// TODO: is this appropriate?
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusMethodNotAllowed, Code: http.StatusBadRequest,
JSON: jsonerror.NotFound("Bad method"), JSON: jsonerror.BadJSON("Bad JSON"),
} }
} }
// add the body back to the request because ioutil.ReadAll consumes the body // add the body back to the request because ioutil.ReadAll consumes the body
@ -97,12 +96,20 @@ func Login(
var jsonBody map[string]interface{} var jsonBody map[string]interface{}
if err := json.Unmarshal([]byte(body), &jsonBody); err != nil { if err := json.Unmarshal([]byte(body), &jsonBody); err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusMethodNotAllowed, Code: http.StatusBadRequest,
JSON: jsonerror.NotFound("Bad method"), JSON: jsonerror.BadJSON("Bad JSON"),
} }
} }
loginType := jsonBody["type"].(string) var loginType string
if val, ok := jsonBody["type"]; ok {
loginType = val.(string)
} else {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("No 'type' parameter"),
}
}
if loginType == "m.login.password" { if loginType == "m.login.password" {
return doPasswordLogin(req, accountDB, userAPI, cfg) return doPasswordLogin(req, accountDB, userAPI, cfg)
} else if loginType == "m.login.token" { } else if loginType == "m.login.token" {
@ -164,7 +171,7 @@ func doTokenLogin(req *http.Request, accountDB accounts.Database, userAPI userap
// the login is successful, delete the login token before returning the access token to the client // the login is successful, delete the login token before returning the access token to the client
if authResult.Code == http.StatusOK { if authResult.Code == http.StatusOK {
if err := auth.DeleteLoginToken(r.(*auth.LoginTokenRequest).Token); err != nil { if err := auth.DeleteLoginToken(r.(*auth.LoginTokenRequest).Token); err != nil {
// TODO: what to do here? util.GetLogger(req.Context()).WithError(err).Error("Could not delete login ticket from DB")
} }
} }
return authResult return authResult

View file

@ -1,4 +1,4 @@
// Copyright 2017 Vector Creations Ltd // Copyright 2020 The Matrix.org Foundation C.I.C.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -49,17 +49,8 @@ func SSORedirect(
// If dendrite is not configured to use SSO by the admin return bad method // If dendrite is not configured to use SSO by the admin return bad method
if !cfg.CAS.Enabled || cfg.CAS.Server == "" { if !cfg.CAS.Enabled || cfg.CAS.Server == "" {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusMethodNotAllowed, Code: http.StatusNotImplemented,
JSON: jsonerror.NotFound("Bad method"), JSON: jsonerror.NotFound("Method disabled"),
}
}
// Try to parse the SSO URL configured to a url.URL type
ssoURL, err := url.Parse(cfg.CAS.Server)
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: jsonerror.Unknown("Failed to parse SSO URL configured: " + err.Error()),
} }
} }
@ -75,8 +66,8 @@ func SSORedirect(
redirectURL, err := url.Parse(redirectURLStr) redirectURL, err := url.Parse(redirectURLStr)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusBadRequest,
JSON: jsonerror.Unknown("Invalid redirectURL: " + err.Error()), JSON: jsonerror.InvalidArgumentValue("Invalid redirectURL: " + err.Error()),
} }
} }
@ -86,10 +77,10 @@ func SSORedirect(
} }
// Adding the params to the sso url // Adding the params to the sso url
ssoURL := cfg.CAS.URL
ssoQueries := make(url.Values) ssoQueries := make(url.Values)
// the service url that we send to CAS is homeserver.com/_matrix/client/r0/login/sso/redirect?redirectUrl=xyz // the service url that we send to CAS is homeserver.com/_matrix/client/r0/login/sso/redirect?redirectUrl=xyz
ssoQueries.Set("service", req.RequestURI) ssoQueries.Set("service", req.RequestURI)
ssoURL.RawQuery = ssoQueries.Encode() ssoURL.RawQuery = ssoQueries.Encode()
return util.RedirectResponse(ssoURL.String()) return util.RedirectResponse(ssoURL.String())
@ -106,23 +97,16 @@ func ssoTicket(
cfg *config.ClientAPI, cfg *config.ClientAPI,
) util.JSONResponse { ) util.JSONResponse {
// form the ticket validation URL from the config // form the ticket validation URL from the config
ssoURL, err := url.Parse(cfg.CAS.Server + cfg.CAS.ValidateEndpoint) validateURL := cfg.CAS.ValidateURL
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: jsonerror.Unknown("Failed to parse SSO URL configured: " + err.Error()),
}
}
ticket := req.FormValue("ticket") ticket := req.FormValue("ticket")
// append required params to the CAS validate endpoint // append required params to the CAS validate endpoint
ssoQueries := make(url.Values) validateQueries := make(url.Values)
ssoQueries.Set("ticket", ticket) validateQueries.Set("ticket", ticket)
ssoURL.RawQuery = ssoQueries.Encode() validateURL.RawQuery = validateQueries.Encode()
// validate the ticket // validate the ticket
casUsername, err := validateTicket(ssoURL.String()) casUsername, err := validateTicket(validateURL.String())
if err != nil { if err != nil {
// TODO: should I be logging these? What else should I log? // TODO: should I be logging these? What else should I log?
util.GetLogger(req.Context()).WithError(err).Error("CAS SSO ticket validation failed") util.GetLogger(req.Context()).WithError(err).Error("CAS SSO ticket validation failed")
@ -182,10 +166,11 @@ func completeSSOAuth(
// if the user exists, then we pick that user, else we create a new user // if the user exists, then we pick that user, else we create a new user
account, err := accountDB.CreateAccount(req.Context(), username, "", "") account, err := accountDB.CreateAccount(req.Context(), username, "", "")
if err != nil { if err != nil {
// some error
if err != sqlutil.ErrUserExists { if err != sqlutil.ErrUserExists {
// some error
util.GetLogger(req.Context()).WithError(err).Error("Could not create new user")
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusUnauthorized, Code: http.StatusInternalServerError,
JSON: jsonerror.Unknown("Could not create new user"), JSON: jsonerror.Unknown("Could not create new user"),
} }
} else { } else {
@ -193,7 +178,7 @@ func completeSSOAuth(
account, err = accountDB.GetAccountByLocalpart(req.Context(), username) account, err = accountDB.GetAccountByLocalpart(req.Context(), username)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusUnauthorized, Code: http.StatusInternalServerError,
JSON: jsonerror.Unknown("Could not query user"), JSON: jsonerror.Unknown("Could not query user"),
} }
} }

View file

@ -2,6 +2,8 @@ package config
import ( import (
"fmt" "fmt"
"net/url"
"path"
"time" "time"
) )
@ -77,12 +79,21 @@ type CAS struct {
Enabled bool `yaml:"cas_enabled"` Enabled bool `yaml:"cas_enabled"`
Server string `yaml:"cas_server"` Server string `yaml:"cas_server"`
ValidateEndpoint string `yaml:"cas_validate_endpoint"` ValidateEndpoint string `yaml:"cas_validate_endpoint"`
URL *url.URL
ValidateURL *url.URL
} }
func (cas *CAS) Verify(ConfigErrors *ConfigErrors) { func (cas *CAS) Verify(ConfigErrors *ConfigErrors) {
if cas.Enabled { if cas.Enabled {
checkURL(ConfigErrors, "client_api.cas.cas_server", cas.Server) checkURL(ConfigErrors, "client_api.cas.cas_server", cas.Server)
checkNotEmpty(ConfigErrors, "client_api.cas.cas_validate_endpoint", cas.ValidateEndpoint) checkNotEmpty(ConfigErrors, "client_api.cas.cas_validate_endpoint", cas.ValidateEndpoint)
var err error
cas.URL, err = url.Parse(cas.Server)
if err != nil {
ConfigErrors.Add(fmt.Sprintf("Couldn't parse %q (%q)to a URL", "client_api.cas.cas_server", cas.Server))
}
cas.ValidateURL.Path = path.Join(cas.URL.Path, cas.ValidateEndpoint)
checkURL(ConfigErrors, "client_api.cas.cas_validate_endpoint", cas.ValidateURL.String())
} }
} }

View file

@ -470,4 +470,6 @@ We can't peek into rooms with shared history_visibility
We can't peek into rooms with invited history_visibility We can't peek into rooms with invited history_visibility
We can't peek into rooms with joined history_visibility We can't peek into rooms with joined history_visibility
Local users can peek by room alias Local users can peek by room alias
Peeked rooms only turn up in the sync for the device who peeked them Peeked rooms only turn up in the sync for the device who peeked them
Room state at a rejected message event is the same as its predecessor
Room state at a rejected state event is the same as its predecessor