Registration checks

This commit is contained in:
Neil Alexander 2022-11-14 16:35:57 +00:00
parent 7ff7c7eaba
commit 0615fea17b
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
2 changed files with 24 additions and 7 deletions

View file

@ -211,9 +211,10 @@ var (
// previous parameters with the ones supplied. This mean you cannot "build up" request params. // previous parameters with the ones supplied. This mean you cannot "build up" request params.
type registerRequest struct { type registerRequest struct {
// registration parameters // registration parameters
Password string `json:"password"` Password string `json:"password"`
Username string `json:"username"` Username string `json:"username"`
Admin bool `json:"admin"` ServerName gomatrixserverlib.ServerName `json:"-"`
Admin bool `json:"admin"`
// user-interactive auth params // user-interactive auth params
Auth authDict `json:"auth"` Auth authDict `json:"auth"`
@ -570,11 +571,14 @@ func Register(
JSON: response, JSON: response,
} }
} }
} }
if resErr := httputil.UnmarshalJSON(reqBody, &r); resErr != nil { if resErr := httputil.UnmarshalJSON(reqBody, &r); resErr != nil {
return *resErr return *resErr
} }
r.ServerName = cfg.Matrix.ServerName
if l, d, err := cfg.Matrix.SplitLocalID('@', r.Username); err == nil {
r.Username, r.ServerName = l, d
}
if req.URL.Query().Get("kind") == "guest" { if req.URL.Query().Get("kind") == "guest" {
return handleGuestRegistration(req, r, cfg, userAPI) return handleGuestRegistration(req, r, cfg, userAPI)
} }
@ -589,7 +593,7 @@ func Register(
// Auto generate a numeric username if r.Username is empty // Auto generate a numeric username if r.Username is empty
if r.Username == "" { if r.Username == "" {
nreq := &userapi.QueryNumericLocalpartRequest{ nreq := &userapi.QueryNumericLocalpartRequest{
ServerName: cfg.Matrix.ServerName, // TODO: might not be right ServerName: r.ServerName,
} }
nres := &userapi.QueryNumericLocalpartResponse{} nres := &userapi.QueryNumericLocalpartResponse{}
if err := userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil { if err := userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil {
@ -609,7 +613,7 @@ func Register(
case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil: case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil:
// Spec-compliant case (the access_token is specified and the login type // Spec-compliant case (the access_token is specified and the login type
// is correctly set, so it's an appservice registration) // is correctly set, so it's an appservice registration)
if resErr := validateApplicationServiceUsername(r.Username, cfg.Matrix.ServerName); resErr != nil { if resErr := validateApplicationServiceUsername(r.Username, r.ServerName); resErr != nil {
return *resErr return *resErr
} }
case accessTokenErr == nil: case accessTokenErr == nil:
@ -622,7 +626,7 @@ func Register(
default: default:
// Spec-compliant case (neither the access_token nor the login type are // Spec-compliant case (neither the access_token nor the login type are
// specified, so it's a normal user registration) // specified, so it's a normal user registration)
if resErr := validateUsername(r.Username, cfg.Matrix.ServerName); resErr != nil { if resErr := validateUsername(r.Username, r.ServerName); resErr != nil {
return *resErr return *resErr
} }
} }
@ -1027,6 +1031,16 @@ func RegisterAvailable(
if u, l, err := cfg.Matrix.SplitLocalID('@', username); err == nil { if u, l, err := cfg.Matrix.SplitLocalID('@', username); err == nil {
username, domain = u, l username, domain = u, l
} }
for _, v := range cfg.Matrix.VirtualHosts {
if v.ServerName == domain && !v.AllowRegistration {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden(
fmt.Sprintf("Registration is not allowed on %q", string(v.ServerName)),
),
}
}
}
if err := validateUsername(username, domain); err != nil { if err := validateUsername(username, domain); err != nil {
return *err return *err

View file

@ -185,6 +185,9 @@ type VirtualHost struct {
// by remote servers. // by remote servers.
// Defaults to 24 hours. // Defaults to 24 hours.
KeyValidityPeriod time.Duration `yaml:"key_validity_period"` KeyValidityPeriod time.Duration `yaml:"key_validity_period"`
// Is registration enabled on this virtual host?
AllowRegistration bool `json:"allow_registration"`
} }
func (v *VirtualHost) Verify(configErrs *ConfigErrors) { func (v *VirtualHost) Verify(configErrs *ConfigErrors) {