diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 1a41a5243..a92513b8b 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -211,9 +211,10 @@ var ( // previous parameters with the ones supplied. This mean you cannot "build up" request params. type registerRequest struct { // registration parameters - Password string `json:"password"` - Username string `json:"username"` - Admin bool `json:"admin"` + Password string `json:"password"` + Username string `json:"username"` + ServerName gomatrixserverlib.ServerName `json:"-"` + Admin bool `json:"admin"` // user-interactive auth params Auth authDict `json:"auth"` @@ -570,11 +571,14 @@ func Register( JSON: response, } } - } if resErr := httputil.UnmarshalJSON(reqBody, &r); resErr != nil { return *resErr } + r.ServerName = cfg.Matrix.ServerName + if l, d, err := cfg.Matrix.SplitLocalID('@', r.Username); err == nil { + r.Username, r.ServerName = l, d + } if req.URL.Query().Get("kind") == "guest" { return handleGuestRegistration(req, r, cfg, userAPI) } @@ -589,7 +593,7 @@ func Register( // Auto generate a numeric username if r.Username is empty if r.Username == "" { nreq := &userapi.QueryNumericLocalpartRequest{ - ServerName: cfg.Matrix.ServerName, // TODO: might not be right + ServerName: r.ServerName, } nres := &userapi.QueryNumericLocalpartResponse{} if err := userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil { @@ -609,7 +613,7 @@ func Register( case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil: // Spec-compliant case (the access_token is specified and the login type // is correctly set, so it's an appservice registration) - if resErr := validateApplicationServiceUsername(r.Username, cfg.Matrix.ServerName); resErr != nil { + if resErr := validateApplicationServiceUsername(r.Username, r.ServerName); resErr != nil { return *resErr } case accessTokenErr == nil: @@ -622,7 +626,7 @@ func Register( default: // Spec-compliant case (neither the access_token nor the login type are // specified, so it's a normal user registration) - if resErr := validateUsername(r.Username, cfg.Matrix.ServerName); resErr != nil { + if resErr := validateUsername(r.Username, r.ServerName); resErr != nil { return *resErr } } @@ -1027,6 +1031,16 @@ func RegisterAvailable( if u, l, err := cfg.Matrix.SplitLocalID('@', username); err == nil { username, domain = u, l } + for _, v := range cfg.Matrix.VirtualHosts { + if v.ServerName == domain && !v.AllowRegistration { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden( + fmt.Sprintf("Registration is not allowed on %q", string(v.ServerName)), + ), + } + } + } if err := validateUsername(username, domain); err != nil { return *err diff --git a/setup/config/config_global.go b/setup/config/config_global.go index acf610218..c78b5c8d0 100644 --- a/setup/config/config_global.go +++ b/setup/config/config_global.go @@ -185,6 +185,9 @@ type VirtualHost struct { // by remote servers. // Defaults to 24 hours. KeyValidityPeriod time.Duration `yaml:"key_validity_period"` + + // Is registration enabled on this virtual host? + AllowRegistration bool `json:"allow_registration"` } func (v *VirtualHost) Verify(configErrs *ConfigErrors) {