diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 801000f61..f216b777a 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -323,7 +323,7 @@ func validatePassword(password string) *util.JSONResponse { Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(fmt.Sprintf("'password' >%d characters", maxPasswordLength)), } - } else if len(password) > 0 && len(password) < minPasswordLength { + } else if len(password) < minPasswordLength { return &util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.WeakPassword(fmt.Sprintf("password too weak: min %d chars", minPasswordLength)), diff --git a/clientapi/routing/register_test.go b/clientapi/routing/register_test.go index 85846c7d6..d1b59d3ea 100644 --- a/clientapi/routing/register_test.go +++ b/clientapi/routing/register_test.go @@ -15,12 +15,19 @@ package routing import ( + "fmt" + "net/http" + "reflect" "regexp" + "strings" "testing" "time" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" ) var ( @@ -264,3 +271,106 @@ func TestSessionCleanUp(t *testing.T) { } }) } + +func Test_validateUsername(t *testing.T) { + tooLongUsername := strings.Repeat("a", maxUsernameLength) + tests := []struct { + name string + localpart string + domain gomatrixserverlib.ServerName + want *util.JSONResponse + }{ + { + name: "empty username", + localpart: "", + domain: "localhost", + want: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"), + }, + }, + { + name: "invalid username", + localpart: "INVALIDUSERNAME", + domain: "localhost", + want: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"), + }, + }, + { + name: "username too long", + localpart: tooLongUsername, + domain: "localhost", + want: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON(fmt.Sprintf("%q exceeds the maximum length of %d characters", fmt.Sprintf("@%s:%s", tooLongUsername, "localhost"), maxUsernameLength)), + }, + }, + { + name: "localpart starting with an underscore", + localpart: "_notvalid", + domain: "localhost", + want: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername("Username cannot start with a '_'"), + }, + }, + { + name: "valid username", + localpart: "valid", + domain: "localhost", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := validateUsername(tt.localpart, tt.domain); !reflect.DeepEqual(got, tt.want) { + t.Errorf("validateUsername() = %v, want %v", got, tt.want) + } + if got := validateApplicationServiceUsername(tt.localpart, tt.domain); !reflect.DeepEqual(got, tt.want) { + if got != nil && got.JSON != jsonerror.InvalidUsername("Username cannot start with a '_'") { + t.Errorf("validateUsername() = %v, want %v", got, tt.want) + } + } + }) + } +} + +func Test_validatePassword(t *testing.T) { + tests := []struct { + name string + password string + want *util.JSONResponse + }{ + { + name: "no password supplied", + want: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.WeakPassword(fmt.Sprintf("password too weak: min %d chars", minPasswordLength)), + }, + }, + { + name: "password too short", + password: "shortpw", + want: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.WeakPassword(fmt.Sprintf("password too weak: min %d chars", minPasswordLength)), + }, + }, + { + name: "password too long", + password: strings.Repeat("a", maxPasswordLength+1), + want: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON(fmt.Sprintf("'password' >%d characters", maxPasswordLength)), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := validatePassword(tt.password); !reflect.DeepEqual(got, tt.want) { + t.Errorf("validatePassword() = %v, want %v", got, tt.want) + } + }) + } +}