From adedb38f8cf05a90f251a33cc5131f28cf9643ed Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Fri, 23 Dec 2022 07:56:36 +0100 Subject: [PATCH] Shuffle Validate* functions arround --- clientapi/routing/admin.go | 6 +- clientapi/routing/password.go | 4 +- clientapi/routing/register.go | 83 ++----- clientapi/routing/register_test.go | 337 ++++++++++------------------- cmd/create-account/main.go | 32 +-- internal/validate.go | 82 ++++++- internal/validate_test.go | 170 +++++++++++++++ 7 files changed, 406 insertions(+), 308 deletions(-) create mode 100644 internal/validate_test.go diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 8419622df..dbd913376 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -137,7 +137,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap request := struct { Password string `json:"password"` }{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + if err = json.NewDecoder(req.Body).Decode(&request); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.Unknown("Failed to decode request body: " + err.Error()), @@ -150,8 +150,8 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap } } - if resErr := internal.ValidatePassword(request.Password); resErr != nil { - return *resErr + if err = internal.ValidatePassword(request.Password); err != nil { + return *internal.PasswordResponse(err) } updateReq := &userapi.PerformPasswordUpdateRequest{ diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go index cd88b025a..f7f9da622 100644 --- a/clientapi/routing/password.go +++ b/clientapi/routing/password.go @@ -82,8 +82,8 @@ func Password( sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword) // Check the new password strength. - if resErr = internal.ValidatePassword(r.NewPassword); resErr != nil { - return *resErr + if err := internal.ValidatePassword(r.NewPassword); err != nil { + return *internal.PasswordResponse(err) } // Get the local part. diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 9a7084a34..b678256f6 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -24,7 +24,6 @@ import ( "net" "net/http" "net/url" - "regexp" "sort" "strconv" "strings" @@ -61,10 +60,7 @@ var ( ) ) -const ( - maxUsernameLength = 254 // http://matrix.org/speculator/spec/HEAD/intro.html#user-identifiers TODO account for domain - sessionIDLength = 24 -) +const sessionIDLength = 24 // sessionsDict keeps track of completed auth stages for each session. // It shouldn't be passed by value because it contains a mutex. @@ -199,8 +195,7 @@ func (d *sessionsDict) getDeviceToDelete(sessionID string) (string, bool) { } var ( - sessions = newSessionsDict() - validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`) + sessions = newSessionsDict() ) // registerRequest represents the submitted registration request. @@ -276,44 +271,6 @@ type recaptchaResponse struct { ErrorCodes []int `json:"error-codes"` } -// validateUsername returns an error response if the username is invalid -func validateUsername(localpart string, domain gomatrixserverlib.ServerName) *util.JSONResponse { - // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 - if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(fmt.Sprintf("%q exceeds the maximum length of %d characters", id, maxUsernameLength)), - } - } else if !validUsernameRegex.MatchString(localpart) { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"), - } - } else if localpart[0] == '_' { // Regex checks its not a zero length string - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername("Username cannot start with a '_'"), - } - } - return nil -} - -// validateApplicationServiceUsername returns an error response if the username is invalid for an application service -func validateApplicationServiceUsername(localpart string, domain gomatrixserverlib.ServerName) *util.JSONResponse { - if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(fmt.Sprintf("%q exceeds the maximum length of %d characters", id, maxUsernameLength)), - } - } else if !validUsernameRegex.MatchString(localpart) { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"), - } - } - return nil -} - var ( ErrInvalidCaptcha = errors.New("invalid captcha response") ErrMissingResponse = errors.New("captcha response is required") @@ -496,8 +453,8 @@ func validateApplicationService( } // Check username application service is trying to register is valid - if err := validateApplicationServiceUsername(username, cfg.Matrix.ServerName); err != nil { - return "", err + if err := internal.ValidateApplicationServiceUsername(username, cfg.Matrix.ServerName); err != nil { + return "", internal.UsernameResponse(err) } // No errors, registration valid @@ -552,7 +509,9 @@ func Register( if resErr := httputil.UnmarshalJSON(reqBody, &r); resErr != nil { return *resErr } - if l, d, err := cfg.Matrix.SplitLocalID('@', r.Username); err == nil { + var l string + var d gomatrixserverlib.ServerName + if l, d, err = cfg.Matrix.SplitLocalID('@', r.Username); err == nil { r.Username, r.ServerName = l, d } if req.URL.Query().Get("kind") == "guest" { @@ -560,7 +519,7 @@ func Register( } // Don't allow numeric usernames less than MAX_INT64. - if _, err := strconv.ParseInt(r.Username, 10, 64); err == nil { + if _, err = strconv.ParseInt(r.Username, 10, 64); err == nil { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.InvalidUsername("Numeric user IDs are reserved"), @@ -572,7 +531,7 @@ func Register( ServerName: r.ServerName, } nres := &userapi.QueryNumericLocalpartResponse{} - if err := userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil { + if err = userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil { util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed") return jsonerror.InternalServerError() } @@ -589,8 +548,8 @@ func Register( case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil: // Spec-compliant case (the access_token is specified and the login type // is correctly set, so it's an appservice registration) - if resErr := validateApplicationServiceUsername(r.Username, r.ServerName); resErr != nil { - return *resErr + if err = internal.ValidateApplicationServiceUsername(r.Username, r.ServerName); err != nil { + return *internal.UsernameResponse(err) } case accessTokenErr == nil: // Non-spec-compliant case (the access_token is specified but the login @@ -602,12 +561,12 @@ func Register( default: // Spec-compliant case (neither the access_token nor the login type are // specified, so it's a normal user registration) - if resErr := validateUsername(r.Username, r.ServerName); resErr != nil { - return *resErr + if err = internal.ValidateUsername(r.Username, r.ServerName); err != nil { + return *internal.UsernameResponse(err) } } - if resErr := internal.ValidatePassword(r.Password); resErr != nil { - return *resErr + if err = internal.ValidatePassword(r.Password); err != nil { + return *internal.PasswordResponse(err) } logger := util.GetLogger(req.Context()) @@ -1048,8 +1007,8 @@ func RegisterAvailable( } } - if err := validateUsername(username, domain); err != nil { - return *err + if err := internal.ValidateUsername(username, domain); err != nil { + return *internal.UsernameResponse(err) } // Check if this username is reserved by an application service @@ -1111,11 +1070,11 @@ func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.Clien // downcase capitals ssrr.User = strings.ToLower(ssrr.User) - if resErr := validateUsername(ssrr.User, cfg.Matrix.ServerName); resErr != nil { - return *resErr + if err = internal.ValidateUsername(ssrr.User, cfg.Matrix.ServerName); err != nil { + return *internal.UsernameResponse(err) } - if resErr := internal.ValidatePassword(ssrr.Password); resErr != nil { - return *resErr + if err = internal.ValidatePassword(ssrr.Password); err != nil { + return *internal.PasswordResponse(err) } deviceID := "shared_secret_registration" diff --git a/clientapi/routing/register_test.go b/clientapi/routing/register_test.go index d44231028..9d6dd88c8 100644 --- a/clientapi/routing/register_test.go +++ b/clientapi/routing/register_test.go @@ -28,13 +28,13 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/userapi" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -280,102 +280,6 @@ func TestSessionCleanUp(t *testing.T) { }) } -func Test_validateUsername(t *testing.T) { - tooLongUsername := strings.Repeat("a", maxUsernameLength) - tests := []struct { - name string - localpart string - domain gomatrixserverlib.ServerName - want *util.JSONResponse - }{ - { - name: "empty username", - localpart: "", - domain: "localhost", - want: &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"), - }, - }, - { - name: "invalid username", - localpart: "INVALIDUSERNAME", - domain: "localhost", - want: &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"), - }, - }, - { - name: "username too long", - localpart: tooLongUsername, - domain: "localhost", - want: &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(fmt.Sprintf("%q exceeds the maximum length of %d characters", fmt.Sprintf("@%s:%s", tooLongUsername, "localhost"), maxUsernameLength)), - }, - }, - { - name: "localpart starting with an underscore", - localpart: "_notvalid", - domain: "localhost", - want: &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername("Username cannot start with a '_'"), - }, - }, - { - name: "valid username", - localpart: "valid", - domain: "localhost", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := validateUsername(tt.localpart, tt.domain); !reflect.DeepEqual(got, tt.want) { - t.Errorf("validateUsername() = %v, want %v", got, tt.want) - } - if got := validateApplicationServiceUsername(tt.localpart, tt.domain); !reflect.DeepEqual(got, tt.want) { - if got != nil && got.JSON != jsonerror.InvalidUsername("Username cannot start with a '_'") { - t.Errorf("validateUsername() = %v, want %v", got, tt.want) - } - } - }) - } -} - -func Test_validatePassword(t *testing.T) { - tests := []struct { - name string - password string - want *util.JSONResponse - }{ - { - name: "password too short", - password: "shortpw", - want: &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.WeakPassword(fmt.Sprintf("password too weak: min %d chars", minPasswordLength)), - }, - }, - { - name: "password too long", - password: strings.Repeat("a", maxPasswordLength+1), - want: &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(fmt.Sprintf("'password' >%d characters", maxPasswordLength)), - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := validatePassword(tt.password); !reflect.DeepEqual(got, tt.want) { - t.Errorf("validatePassword() = %v, want %v", got, tt.want) - } - }) - } -} - func Test_register(t *testing.T) { testCases := []struct { name string @@ -440,12 +344,9 @@ func Test_register(t *testing.T) { username: "LOWERCASED", // this is going to be lower-cased }, { - name: "invalid username", - username: "#totalyNotValid", - wantResponse: util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"), - }, + name: "invalid username", + username: "#totalyNotValid", + wantResponse: *internal.UsernameResponse(internal.ErrUsernameInvalid), }, { name: "numeric username is forbidden", @@ -471,126 +372,128 @@ func Test_register(t *testing.T) { } for _, tc := range testCases { - if tc.kind == "" { - tc.kind = "user" - } - if tc.password == "" && !tc.forceEmpty { - tc.password = "someRandomPassword" - } - if tc.username == "" && !tc.forceEmpty { - tc.username = "valid" - } - if tc.loginType == "" { - tc.loginType = "m.login.dummy" - } - - reg := registerRequest{ - Password: tc.password, - Username: tc.username, - } - - body := &bytes.Buffer{} - - base.Cfg.ClientAPI.RegistrationDisabled = tc.registrationDisabled - base.Cfg.ClientAPI.GuestsDisabled = tc.guestsDisabled - - err := json.NewEncoder(body).Encode(reg) - if err != nil { - t.Fatal(err) - } - - req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("/?kind=%s", tc.kind), body) - - resp := Register(req, userAPI, &base.Cfg.ClientAPI) - t.Logf("Resp: %+v", resp) - - // The first request should return a userInteractiveResponse - switch r := resp.JSON.(type) { - case userInteractiveResponse: - // Check that the flows are the ones we configured - if !reflect.DeepEqual(r.Flows, base.Cfg.Derived.Registration.Flows) { - t.Fatalf("unexpected registration flows: %+v, want %+v", r.Flows, base.Cfg.Derived.Registration.Flows) + t.Run(tc.name, func(t *testing.T) { + if tc.kind == "" { + tc.kind = "user" } - case *jsonerror.MatrixError: - if !reflect.DeepEqual(tc.wantResponse, resp) { - t.Fatalf("(%s), unexpected response: %+v, want: %+v", tc.name, resp, tc.wantResponse) + if tc.password == "" && !tc.forceEmpty { + tc.password = "someRandomPassword" } - continue - case registerResponse: - // this should only be possible on guest user registration, never for normal users - if tc.kind != "guest" { - t.Fatalf("got register response on first request: %+v", r) + if tc.username == "" && !tc.forceEmpty { + tc.username = "valid" } - // assert we've got a UserID, AccessToken and DeviceID - if r.UserID == "" { - t.Fatalf("missing userID in response") + if tc.loginType == "" { + tc.loginType = "m.login.dummy" } - if r.AccessToken == "" { + + reg := registerRequest{ + Password: tc.password, + Username: tc.username, + } + + body := &bytes.Buffer{} + + base.Cfg.ClientAPI.RegistrationDisabled = tc.registrationDisabled + base.Cfg.ClientAPI.GuestsDisabled = tc.guestsDisabled + + err := json.NewEncoder(body).Encode(reg) + if err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("/?kind=%s", tc.kind), body) + + resp := Register(req, userAPI, &base.Cfg.ClientAPI) + t.Logf("Resp: %+v", resp) + + // The first request should return a userInteractiveResponse + switch r := resp.JSON.(type) { + case userInteractiveResponse: + // Check that the flows are the ones we configured + if !reflect.DeepEqual(r.Flows, base.Cfg.Derived.Registration.Flows) { + t.Fatalf("unexpected registration flows: %+v, want %+v", r.Flows, base.Cfg.Derived.Registration.Flows) + } + case *jsonerror.MatrixError: + if !reflect.DeepEqual(tc.wantResponse, resp) { + t.Fatalf("(%s), unexpected response: %+v, want: %+v", tc.name, resp, tc.wantResponse) + } + return + case registerResponse: + // this should only be possible on guest user registration, never for normal users + if tc.kind != "guest" { + t.Fatalf("got register response on first request: %+v", r) + } + // assert we've got a UserID, AccessToken and DeviceID + if r.UserID == "" { + t.Fatalf("missing userID in response") + } + if r.AccessToken == "" { + t.Fatalf("missing accessToken in response") + } + if r.DeviceID == "" { + t.Fatalf("missing deviceID in response") + } + return + default: + t.Logf("Got response: %T", resp.JSON) + } + + // If we reached this, we should have received a UIA response + uia, ok := resp.JSON.(userInteractiveResponse) + if !ok { + t.Fatalf("did not receive a userInteractiveResponse: %T", resp.JSON) + } + t.Logf("%+v", uia) + + // Register the user + reg.Auth = authDict{ + Type: authtypes.LoginType(tc.loginType), + Session: uia.Session, + } + dummy := "dummy" + reg.DeviceID = &dummy + reg.InitialDisplayName = &dummy + reg.Type = authtypes.LoginType(tc.loginType) + + err = json.NewEncoder(body).Encode(reg) + if err != nil { + t.Fatal(err) + } + + req = httptest.NewRequest(http.MethodPost, "/", body) + + resp = Register(req, userAPI, &base.Cfg.ClientAPI) + + switch resp.JSON.(type) { + case *jsonerror.MatrixError: + if !reflect.DeepEqual(tc.wantResponse, resp) { + t.Fatalf("unexpected response: %+v, want: %+v", resp, tc.wantResponse) + } + return + } + + rr, ok := resp.JSON.(registerResponse) + if !ok { + t.Fatalf("expected a registerresponse, got %T", resp.JSON) + } + + // validate the response + if tc.forceEmpty { + // when not supplying a username, one will be generated. Given this _SHOULD_ be + // the second user, set the username accordingly + reg.Username = "2" + } + wantUserID := strings.ToLower(fmt.Sprintf("@%s:%s", reg.Username, "test")) + if wantUserID != rr.UserID { + t.Fatalf("unexpected userID: %s, want %s", rr.UserID, wantUserID) + } + if rr.DeviceID != *reg.DeviceID { + t.Fatalf("unexpected deviceID: %s, want %s", rr.DeviceID, *reg.DeviceID) + } + if rr.AccessToken == "" { t.Fatalf("missing accessToken in response") } - if r.DeviceID == "" { - t.Fatalf("missing deviceID in response") - } - continue - default: - t.Logf("Got response: %T", resp.JSON) - } - - // If we reached this, we should have received a UIA response - uia, ok := resp.JSON.(userInteractiveResponse) - if !ok { - t.Fatalf("did not receive a userInteractiveResponse: %T", resp.JSON) - } - t.Logf("%+v", uia) - - // Register the user - reg.Auth = authDict{ - Type: authtypes.LoginType(tc.loginType), - Session: uia.Session, - } - dummy := "dummy" - reg.DeviceID = &dummy - reg.InitialDisplayName = &dummy - reg.Type = authtypes.LoginType(tc.loginType) - - err = json.NewEncoder(body).Encode(reg) - if err != nil { - t.Fatal(err) - } - - req = httptest.NewRequest(http.MethodPost, "/", body) - - resp = Register(req, userAPI, &base.Cfg.ClientAPI) - - switch resp.JSON.(type) { - case *jsonerror.MatrixError: - if !reflect.DeepEqual(tc.wantResponse, resp) { - t.Fatalf("unexpected response: %+v, want: %+v", resp, tc.wantResponse) - } - continue - } - - rr, ok := resp.JSON.(registerResponse) - if !ok { - t.Fatalf("expected a registerresponse, got %T", resp.JSON) - } - - // validate the response - if tc.forceEmpty { - // when not supplying a username, one will be generated. Given this _SHOULD_ be - // the second user, set the username accordingly - reg.Username = "2" - } - wantUserID := strings.ToLower(fmt.Sprintf("@%s:%s", reg.Username, "test")) - if wantUserID != rr.UserID { - t.Fatalf("unexpected userID: %s, want %s", rr.UserID, wantUserID) - } - if rr.DeviceID != *reg.DeviceID { - t.Fatalf("unexpected deviceID: %s, want %s", rr.DeviceID, *reg.DeviceID) - } - if rr.AccessToken == "" { - t.Fatalf("missing accessToken in response") - } + }) } }) } diff --git a/cmd/create-account/main.go b/cmd/create-account/main.go index 15b043ed5..772778680 100644 --- a/cmd/create-account/main.go +++ b/cmd/create-account/main.go @@ -25,10 +25,10 @@ import ( "io" "net/http" "os" - "regexp" "strings" "time" + "github.com/matrix-org/dendrite/internal" "github.com/tidwall/gjson" "github.com/sirupsen/logrus" @@ -58,15 +58,14 @@ Arguments: ` var ( - username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')") - password = flag.String("password", "", "The password to associate with the account") - pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)") - pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") - isAdmin = flag.Bool("admin", false, "Create an admin account") - resetPassword = flag.Bool("reset-password", false, "Deprecated") - serverURL = flag.String("url", "http://localhost:8008", "The URL to connect to.") - validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`) - timeout = flag.Duration("timeout", time.Second*30, "Timeout for the http client when connecting to the server") + username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')") + password = flag.String("password", "", "The password to associate with the account") + pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)") + pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") + isAdmin = flag.Bool("admin", false, "Create an admin account") + resetPassword = flag.Bool("reset-password", false, "Deprecated") + serverURL = flag.String("url", "http://localhost:8008", "The URL to connect to.") + timeout = flag.Duration("timeout", time.Second*30, "Timeout for the http client when connecting to the server") ) var cl = http.Client{ @@ -95,20 +94,21 @@ func main() { os.Exit(1) } - if !validUsernameRegex.MatchString(*username) { - logrus.Warn("Username can only contain characters a-z, 0-9, or '_-./='") + if err := internal.ValidateUsername(*username, cfg.Global.ServerName); err != nil { + logrus.WithError(err).Error("Specified username is invalid") os.Exit(1) } - if len(fmt.Sprintf("@%s:%s", *username, cfg.Global.ServerName)) > 255 { - logrus.Fatalf("Username can not be longer than 255 characters: %s", fmt.Sprintf("@%s:%s", *username, cfg.Global.ServerName)) - } - pass, err := getPassword(*password, *pwdFile, *pwdStdin, os.Stdin) if err != nil { logrus.Fatalln(err) } + if err = internal.ValidatePassword(pass); err != nil { + logrus.WithError(err).Error("Specified password is invalid") + os.Exit(1) + } + cl.Timeout = *timeout accessToken, err := sharedSecretRegister(cfg.ClientAPI.RegistrationSharedSecret, *serverURL, *username, pass, *isAdmin) diff --git a/internal/validate.go b/internal/validate.go index fc685ad50..c8301f834 100644 --- a/internal/validate.go +++ b/internal/validate.go @@ -15,30 +15,96 @@ package internal import ( + "errors" "fmt" "net/http" + "regexp" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) -const minPasswordLength = 8 // http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based +const ( + maxUsernameLength = 254 // http://matrix.org/speculator/spec/HEAD/intro.html#user-identifiers TODO account for domain -const maxPasswordLength = 512 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 + minPasswordLength = 8 // http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based + maxPasswordLength = 512 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 +) + +var ( + ErrPasswordTooLong = fmt.Errorf("password too long: max %d characters", maxPasswordLength) + ErrPasswordWeak = fmt.Errorf("password too weak: min %d characters", minPasswordLength) + ErrUsernameTooLong = fmt.Errorf("username exceeds the maximum length of %d characters", maxUsernameLength) + ErrUsernameInvalid = errors.New("username can only contain characters a-z, 0-9, or '_-./='") + ErrUsernameUnderscore = errors.New("username cannot start with a '_'") + validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`) +) // ValidatePassword returns an error response if the password is invalid -func ValidatePassword(password string) *util.JSONResponse { +func ValidatePassword(password string) error { // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 if len(password) > maxPasswordLength { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(fmt.Sprintf("password too long: max %d characters", maxPasswordLength)), - } + return ErrPasswordTooLong } else if len(password) > 0 && len(password) < minPasswordLength { + return ErrPasswordWeak + } + return nil +} + +// PasswordResponse returns a util.JSONResponse for a given error, if any. +func PasswordResponse(err error) *util.JSONResponse { + switch err { + case ErrPasswordWeak: return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.WeakPassword(fmt.Sprintf("password too weak: min %d chars", minPasswordLength)), + JSON: jsonerror.WeakPassword(ErrPasswordWeak.Error()), + } + case ErrPasswordTooLong: + return &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON(ErrPasswordTooLong.Error()), } } return nil } + +// ValidateUsername returns an error if the username is invalid +func ValidateUsername(localpart string, domain gomatrixserverlib.ServerName) error { + // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 + if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength { + return ErrUsernameTooLong + } else if !validUsernameRegex.MatchString(localpart) { + return ErrUsernameInvalid + } else if localpart[0] == '_' { // Regex checks its not a zero length string + return ErrUsernameUnderscore + } + return nil +} + +// UsernameResponse returns a util.JSONResponse for the given error, if any. +func UsernameResponse(err error) *util.JSONResponse { + switch err { + case ErrUsernameTooLong: + return &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON(err.Error()), + } + case ErrUsernameInvalid, ErrUsernameUnderscore: + return &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername(err.Error()), + } + } + return nil +} + +// ValidateApplicationServiceUsername returns an error if the username is invalid for an application service +func ValidateApplicationServiceUsername(localpart string, domain gomatrixserverlib.ServerName) error { + if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength { + return ErrUsernameTooLong + } else if !validUsernameRegex.MatchString(localpart) { + return ErrUsernameInvalid + } + return nil +} diff --git a/internal/validate_test.go b/internal/validate_test.go new file mode 100644 index 000000000..d0ad04707 --- /dev/null +++ b/internal/validate_test.go @@ -0,0 +1,170 @@ +package internal + +import ( + "net/http" + "reflect" + "strings" + "testing" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +func Test_validatePassword(t *testing.T) { + tests := []struct { + name string + password string + wantError error + wantJSON *util.JSONResponse + }{ + { + name: "password too short", + password: "shortpw", + wantError: ErrPasswordWeak, + wantJSON: &util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.WeakPassword(ErrPasswordWeak.Error())}, + }, + { + name: "password too long", + password: strings.Repeat("a", maxPasswordLength+1), + wantError: ErrPasswordTooLong, + wantJSON: &util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(ErrPasswordTooLong.Error())}, + }, + { + name: "password OK", + password: util.RandomString(10), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotErr := ValidatePassword(tt.password) + if !reflect.DeepEqual(gotErr, tt.wantError) { + t.Errorf("validatePassword() = %v, wantJSON %v", gotErr, tt.wantError) + } + + if got := PasswordResponse(gotErr); !reflect.DeepEqual(got, tt.wantJSON) { + t.Errorf("validatePassword() = %v, wantJSON %v", got, tt.wantJSON) + } + }) + } +} + +func Test_validateUsername(t *testing.T) { + tooLongUsername := strings.Repeat("a", maxUsernameLength) + tests := []struct { + name string + localpart string + domain gomatrixserverlib.ServerName + wantErr error + wantJSON *util.JSONResponse + }{ + { + name: "empty username", + localpart: "", + domain: "localhost", + wantErr: ErrUsernameInvalid, + wantJSON: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()), + }, + }, + { + name: "invalid username", + localpart: "INVALIDUSERNAME", + domain: "localhost", + wantErr: ErrUsernameInvalid, + wantJSON: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()), + }, + }, + { + name: "username too long", + localpart: tooLongUsername, + domain: "localhost", + wantErr: ErrUsernameTooLong, + wantJSON: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON(ErrUsernameTooLong.Error()), + }, + }, + { + name: "localpart starting with an underscore", + localpart: "_notvalid", + domain: "localhost", + wantErr: ErrUsernameUnderscore, + wantJSON: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername(ErrUsernameUnderscore.Error()), + }, + }, + { + name: "valid username", + localpart: "valid", + domain: "localhost", + }, + { + name: "complex username", + localpart: "f00_bar-baz.=40/", + domain: "localhost", + }, + { + name: "rejects emoji username 💥", + localpart: "💥", + domain: "localhost", + wantErr: ErrUsernameInvalid, + wantJSON: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()), + }, + }, + { + name: "special characters are allowed", + localpart: "/dev/null", + domain: "localhost", + }, + { + name: "special characters are allowed 2", + localpart: "i_am_allowed=1", + domain: "localhost", + }, + { + name: "not all special characters are allowed", + localpart: "notallowed#", // contains # + domain: "localhost", + wantErr: ErrUsernameInvalid, + wantJSON: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()), + }, + }, + { + name: "username containing numbers", + localpart: "hello1337", + domain: "localhost", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotErr := ValidateUsername(tt.localpart, tt.domain) + if !reflect.DeepEqual(gotErr, tt.wantErr) { + t.Errorf("ValidateUsername() = %v, wantErr %v", gotErr, tt.wantErr) + } + if gotJSON := UsernameResponse(gotErr); !reflect.DeepEqual(gotJSON, tt.wantJSON) { + t.Errorf("UsernameResponse() = %v, wantJSON %v", gotJSON, tt.wantJSON) + } + + // Application services are allowed usernames starting with an underscore + if tt.wantErr == ErrUsernameUnderscore { + return + } + gotErr = ValidateApplicationServiceUsername(tt.localpart, tt.domain) + if !reflect.DeepEqual(gotErr, tt.wantErr) { + t.Errorf("ValidateUsername() = %v, wantErr %v", gotErr, tt.wantErr) + } + if gotJSON := UsernameResponse(gotErr); !reflect.DeepEqual(gotJSON, tt.wantJSON) { + t.Errorf("UsernameResponse() = %v, wantJSON %v", gotJSON, tt.wantJSON) + } + }) + } +}