Shuffle Validate* functions arround

This commit is contained in:
Till Faelligen 2022-12-23 07:56:36 +01:00
parent ad5009dbcc
commit adedb38f8c
No known key found for this signature in database
GPG key ID: ACCDC9606D472758
7 changed files with 406 additions and 308 deletions

View file

@ -137,7 +137,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
request := struct { request := struct {
Password string `json:"password"` Password string `json:"password"`
}{} }{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil { if err = json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.Unknown("Failed to decode request body: " + err.Error()), JSON: jsonerror.Unknown("Failed to decode request body: " + err.Error()),
@ -150,8 +150,8 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
} }
} }
if resErr := internal.ValidatePassword(request.Password); resErr != nil { if err = internal.ValidatePassword(request.Password); err != nil {
return *resErr return *internal.PasswordResponse(err)
} }
updateReq := &userapi.PerformPasswordUpdateRequest{ updateReq := &userapi.PerformPasswordUpdateRequest{

View file

@ -82,8 +82,8 @@ func Password(
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword) sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
// Check the new password strength. // Check the new password strength.
if resErr = internal.ValidatePassword(r.NewPassword); resErr != nil { if err := internal.ValidatePassword(r.NewPassword); err != nil {
return *resErr return *internal.PasswordResponse(err)
} }
// Get the local part. // Get the local part.

View file

@ -24,7 +24,6 @@ import (
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"regexp"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@ -61,10 +60,7 @@ var (
) )
) )
const ( const sessionIDLength = 24
maxUsernameLength = 254 // http://matrix.org/speculator/spec/HEAD/intro.html#user-identifiers TODO account for domain
sessionIDLength = 24
)
// sessionsDict keeps track of completed auth stages for each session. // sessionsDict keeps track of completed auth stages for each session.
// It shouldn't be passed by value because it contains a mutex. // It shouldn't be passed by value because it contains a mutex.
@ -200,7 +196,6 @@ func (d *sessionsDict) getDeviceToDelete(sessionID string) (string, bool) {
var ( var (
sessions = newSessionsDict() sessions = newSessionsDict()
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`)
) )
// registerRequest represents the submitted registration request. // registerRequest represents the submitted registration request.
@ -276,44 +271,6 @@ type recaptchaResponse struct {
ErrorCodes []int `json:"error-codes"` ErrorCodes []int `json:"error-codes"`
} }
// validateUsername returns an error response if the username is invalid
func validateUsername(localpart string, domain gomatrixserverlib.ServerName) *util.JSONResponse {
// https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength {
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(fmt.Sprintf("%q exceeds the maximum length of %d characters", id, maxUsernameLength)),
}
} else if !validUsernameRegex.MatchString(localpart) {
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"),
}
} else if localpart[0] == '_' { // Regex checks its not a zero length string
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Username cannot start with a '_'"),
}
}
return nil
}
// validateApplicationServiceUsername returns an error response if the username is invalid for an application service
func validateApplicationServiceUsername(localpart string, domain gomatrixserverlib.ServerName) *util.JSONResponse {
if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength {
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(fmt.Sprintf("%q exceeds the maximum length of %d characters", id, maxUsernameLength)),
}
} else if !validUsernameRegex.MatchString(localpart) {
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"),
}
}
return nil
}
var ( var (
ErrInvalidCaptcha = errors.New("invalid captcha response") ErrInvalidCaptcha = errors.New("invalid captcha response")
ErrMissingResponse = errors.New("captcha response is required") ErrMissingResponse = errors.New("captcha response is required")
@ -496,8 +453,8 @@ func validateApplicationService(
} }
// Check username application service is trying to register is valid // Check username application service is trying to register is valid
if err := validateApplicationServiceUsername(username, cfg.Matrix.ServerName); err != nil { if err := internal.ValidateApplicationServiceUsername(username, cfg.Matrix.ServerName); err != nil {
return "", err return "", internal.UsernameResponse(err)
} }
// No errors, registration valid // No errors, registration valid
@ -552,7 +509,9 @@ func Register(
if resErr := httputil.UnmarshalJSON(reqBody, &r); resErr != nil { if resErr := httputil.UnmarshalJSON(reqBody, &r); resErr != nil {
return *resErr return *resErr
} }
if l, d, err := cfg.Matrix.SplitLocalID('@', r.Username); err == nil { var l string
var d gomatrixserverlib.ServerName
if l, d, err = cfg.Matrix.SplitLocalID('@', r.Username); err == nil {
r.Username, r.ServerName = l, d r.Username, r.ServerName = l, d
} }
if req.URL.Query().Get("kind") == "guest" { if req.URL.Query().Get("kind") == "guest" {
@ -560,7 +519,7 @@ func Register(
} }
// Don't allow numeric usernames less than MAX_INT64. // Don't allow numeric usernames less than MAX_INT64.
if _, err := strconv.ParseInt(r.Username, 10, 64); err == nil { if _, err = strconv.ParseInt(r.Username, 10, 64); err == nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Numeric user IDs are reserved"), JSON: jsonerror.InvalidUsername("Numeric user IDs are reserved"),
@ -572,7 +531,7 @@ func Register(
ServerName: r.ServerName, ServerName: r.ServerName,
} }
nres := &userapi.QueryNumericLocalpartResponse{} nres := &userapi.QueryNumericLocalpartResponse{}
if err := userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil { if err = userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed") util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
@ -589,8 +548,8 @@ func Register(
case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil: case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil:
// Spec-compliant case (the access_token is specified and the login type // Spec-compliant case (the access_token is specified and the login type
// is correctly set, so it's an appservice registration) // is correctly set, so it's an appservice registration)
if resErr := validateApplicationServiceUsername(r.Username, r.ServerName); resErr != nil { if err = internal.ValidateApplicationServiceUsername(r.Username, r.ServerName); err != nil {
return *resErr return *internal.UsernameResponse(err)
} }
case accessTokenErr == nil: case accessTokenErr == nil:
// Non-spec-compliant case (the access_token is specified but the login // Non-spec-compliant case (the access_token is specified but the login
@ -602,12 +561,12 @@ func Register(
default: default:
// Spec-compliant case (neither the access_token nor the login type are // Spec-compliant case (neither the access_token nor the login type are
// specified, so it's a normal user registration) // specified, so it's a normal user registration)
if resErr := validateUsername(r.Username, r.ServerName); resErr != nil { if err = internal.ValidateUsername(r.Username, r.ServerName); err != nil {
return *resErr return *internal.UsernameResponse(err)
} }
} }
if resErr := internal.ValidatePassword(r.Password); resErr != nil { if err = internal.ValidatePassword(r.Password); err != nil {
return *resErr return *internal.PasswordResponse(err)
} }
logger := util.GetLogger(req.Context()) logger := util.GetLogger(req.Context())
@ -1048,8 +1007,8 @@ func RegisterAvailable(
} }
} }
if err := validateUsername(username, domain); err != nil { if err := internal.ValidateUsername(username, domain); err != nil {
return *err return *internal.UsernameResponse(err)
} }
// Check if this username is reserved by an application service // Check if this username is reserved by an application service
@ -1111,11 +1070,11 @@ func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.Clien
// downcase capitals // downcase capitals
ssrr.User = strings.ToLower(ssrr.User) ssrr.User = strings.ToLower(ssrr.User)
if resErr := validateUsername(ssrr.User, cfg.Matrix.ServerName); resErr != nil { if err = internal.ValidateUsername(ssrr.User, cfg.Matrix.ServerName); err != nil {
return *resErr return *internal.UsernameResponse(err)
} }
if resErr := internal.ValidatePassword(ssrr.Password); resErr != nil { if err = internal.ValidatePassword(ssrr.Password); err != nil {
return *resErr return *internal.PasswordResponse(err)
} }
deviceID := "shared_secret_registration" deviceID := "shared_secret_registration"

View file

@ -28,13 +28,13 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/dendrite/userapi"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -280,102 +280,6 @@ func TestSessionCleanUp(t *testing.T) {
}) })
} }
func Test_validateUsername(t *testing.T) {
tooLongUsername := strings.Repeat("a", maxUsernameLength)
tests := []struct {
name string
localpart string
domain gomatrixserverlib.ServerName
want *util.JSONResponse
}{
{
name: "empty username",
localpart: "",
domain: "localhost",
want: &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"),
},
},
{
name: "invalid username",
localpart: "INVALIDUSERNAME",
domain: "localhost",
want: &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"),
},
},
{
name: "username too long",
localpart: tooLongUsername,
domain: "localhost",
want: &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(fmt.Sprintf("%q exceeds the maximum length of %d characters", fmt.Sprintf("@%s:%s", tooLongUsername, "localhost"), maxUsernameLength)),
},
},
{
name: "localpart starting with an underscore",
localpart: "_notvalid",
domain: "localhost",
want: &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Username cannot start with a '_'"),
},
},
{
name: "valid username",
localpart: "valid",
domain: "localhost",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := validateUsername(tt.localpart, tt.domain); !reflect.DeepEqual(got, tt.want) {
t.Errorf("validateUsername() = %v, want %v", got, tt.want)
}
if got := validateApplicationServiceUsername(tt.localpart, tt.domain); !reflect.DeepEqual(got, tt.want) {
if got != nil && got.JSON != jsonerror.InvalidUsername("Username cannot start with a '_'") {
t.Errorf("validateUsername() = %v, want %v", got, tt.want)
}
}
})
}
}
func Test_validatePassword(t *testing.T) {
tests := []struct {
name string
password string
want *util.JSONResponse
}{
{
name: "password too short",
password: "shortpw",
want: &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.WeakPassword(fmt.Sprintf("password too weak: min %d chars", minPasswordLength)),
},
},
{
name: "password too long",
password: strings.Repeat("a", maxPasswordLength+1),
want: &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(fmt.Sprintf("'password' >%d characters", maxPasswordLength)),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := validatePassword(tt.password); !reflect.DeepEqual(got, tt.want) {
t.Errorf("validatePassword() = %v, want %v", got, tt.want)
}
})
}
}
func Test_register(t *testing.T) { func Test_register(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
@ -442,10 +346,7 @@ func Test_register(t *testing.T) {
{ {
name: "invalid username", name: "invalid username",
username: "#totalyNotValid", username: "#totalyNotValid",
wantResponse: util.JSONResponse{ wantResponse: *internal.UsernameResponse(internal.ErrUsernameInvalid),
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"),
},
}, },
{ {
name: "numeric username is forbidden", name: "numeric username is forbidden",
@ -471,6 +372,7 @@ func Test_register(t *testing.T) {
} }
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.kind == "" { if tc.kind == "" {
tc.kind = "user" tc.kind = "user"
} }
@ -515,7 +417,7 @@ func Test_register(t *testing.T) {
if !reflect.DeepEqual(tc.wantResponse, resp) { if !reflect.DeepEqual(tc.wantResponse, resp) {
t.Fatalf("(%s), unexpected response: %+v, want: %+v", tc.name, resp, tc.wantResponse) t.Fatalf("(%s), unexpected response: %+v, want: %+v", tc.name, resp, tc.wantResponse)
} }
continue return
case registerResponse: case registerResponse:
// this should only be possible on guest user registration, never for normal users // this should only be possible on guest user registration, never for normal users
if tc.kind != "guest" { if tc.kind != "guest" {
@ -531,7 +433,7 @@ func Test_register(t *testing.T) {
if r.DeviceID == "" { if r.DeviceID == "" {
t.Fatalf("missing deviceID in response") t.Fatalf("missing deviceID in response")
} }
continue return
default: default:
t.Logf("Got response: %T", resp.JSON) t.Logf("Got response: %T", resp.JSON)
} }
@ -567,7 +469,7 @@ func Test_register(t *testing.T) {
if !reflect.DeepEqual(tc.wantResponse, resp) { if !reflect.DeepEqual(tc.wantResponse, resp) {
t.Fatalf("unexpected response: %+v, want: %+v", resp, tc.wantResponse) t.Fatalf("unexpected response: %+v, want: %+v", resp, tc.wantResponse)
} }
continue return
} }
rr, ok := resp.JSON.(registerResponse) rr, ok := resp.JSON.(registerResponse)
@ -591,6 +493,7 @@ func Test_register(t *testing.T) {
if rr.AccessToken == "" { if rr.AccessToken == "" {
t.Fatalf("missing accessToken in response") t.Fatalf("missing accessToken in response")
} }
})
} }
}) })
} }

View file

@ -25,10 +25,10 @@ import (
"io" "io"
"net/http" "net/http"
"os" "os"
"regexp"
"strings" "strings"
"time" "time"
"github.com/matrix-org/dendrite/internal"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -65,7 +65,6 @@ var (
isAdmin = flag.Bool("admin", false, "Create an admin account") isAdmin = flag.Bool("admin", false, "Create an admin account")
resetPassword = flag.Bool("reset-password", false, "Deprecated") resetPassword = flag.Bool("reset-password", false, "Deprecated")
serverURL = flag.String("url", "http://localhost:8008", "The URL to connect to.") serverURL = flag.String("url", "http://localhost:8008", "The URL to connect to.")
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`)
timeout = flag.Duration("timeout", time.Second*30, "Timeout for the http client when connecting to the server") timeout = flag.Duration("timeout", time.Second*30, "Timeout for the http client when connecting to the server")
) )
@ -95,20 +94,21 @@ func main() {
os.Exit(1) os.Exit(1)
} }
if !validUsernameRegex.MatchString(*username) { if err := internal.ValidateUsername(*username, cfg.Global.ServerName); err != nil {
logrus.Warn("Username can only contain characters a-z, 0-9, or '_-./='") logrus.WithError(err).Error("Specified username is invalid")
os.Exit(1) os.Exit(1)
} }
if len(fmt.Sprintf("@%s:%s", *username, cfg.Global.ServerName)) > 255 {
logrus.Fatalf("Username can not be longer than 255 characters: %s", fmt.Sprintf("@%s:%s", *username, cfg.Global.ServerName))
}
pass, err := getPassword(*password, *pwdFile, *pwdStdin, os.Stdin) pass, err := getPassword(*password, *pwdFile, *pwdStdin, os.Stdin)
if err != nil { if err != nil {
logrus.Fatalln(err) logrus.Fatalln(err)
} }
if err = internal.ValidatePassword(pass); err != nil {
logrus.WithError(err).Error("Specified password is invalid")
os.Exit(1)
}
cl.Timeout = *timeout cl.Timeout = *timeout
accessToken, err := sharedSecretRegister(cfg.ClientAPI.RegistrationSharedSecret, *serverURL, *username, pass, *isAdmin) accessToken, err := sharedSecretRegister(cfg.ClientAPI.RegistrationSharedSecret, *serverURL, *username, pass, *isAdmin)

View file

@ -15,30 +15,96 @@
package internal package internal
import ( import (
"errors"
"fmt" "fmt"
"net/http" "net/http"
"regexp"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
const minPasswordLength = 8 // http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based const (
maxUsernameLength = 254 // http://matrix.org/speculator/spec/HEAD/intro.html#user-identifiers TODO account for domain
const maxPasswordLength = 512 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 minPasswordLength = 8 // http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based
maxPasswordLength = 512 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
)
var (
ErrPasswordTooLong = fmt.Errorf("password too long: max %d characters", maxPasswordLength)
ErrPasswordWeak = fmt.Errorf("password too weak: min %d characters", minPasswordLength)
ErrUsernameTooLong = fmt.Errorf("username exceeds the maximum length of %d characters", maxUsernameLength)
ErrUsernameInvalid = errors.New("username can only contain characters a-z, 0-9, or '_-./='")
ErrUsernameUnderscore = errors.New("username cannot start with a '_'")
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`)
)
// ValidatePassword returns an error response if the password is invalid // ValidatePassword returns an error response if the password is invalid
func ValidatePassword(password string) *util.JSONResponse { func ValidatePassword(password string) error {
// https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
if len(password) > maxPasswordLength { if len(password) > maxPasswordLength {
return &util.JSONResponse{ return ErrPasswordTooLong
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(fmt.Sprintf("password too long: max %d characters", maxPasswordLength)),
}
} else if len(password) > 0 && len(password) < minPasswordLength { } else if len(password) > 0 && len(password) < minPasswordLength {
return ErrPasswordWeak
}
return nil
}
// PasswordResponse returns a util.JSONResponse for a given error, if any.
func PasswordResponse(err error) *util.JSONResponse {
switch err {
case ErrPasswordWeak:
return &util.JSONResponse{ return &util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.WeakPassword(fmt.Sprintf("password too weak: min %d chars", minPasswordLength)), JSON: jsonerror.WeakPassword(ErrPasswordWeak.Error()),
}
case ErrPasswordTooLong:
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(ErrPasswordTooLong.Error()),
} }
} }
return nil return nil
} }
// ValidateUsername returns an error if the username is invalid
func ValidateUsername(localpart string, domain gomatrixserverlib.ServerName) error {
// https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength {
return ErrUsernameTooLong
} else if !validUsernameRegex.MatchString(localpart) {
return ErrUsernameInvalid
} else if localpart[0] == '_' { // Regex checks its not a zero length string
return ErrUsernameUnderscore
}
return nil
}
// UsernameResponse returns a util.JSONResponse for the given error, if any.
func UsernameResponse(err error) *util.JSONResponse {
switch err {
case ErrUsernameTooLong:
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(err.Error()),
}
case ErrUsernameInvalid, ErrUsernameUnderscore:
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername(err.Error()),
}
}
return nil
}
// ValidateApplicationServiceUsername returns an error if the username is invalid for an application service
func ValidateApplicationServiceUsername(localpart string, domain gomatrixserverlib.ServerName) error {
if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength {
return ErrUsernameTooLong
} else if !validUsernameRegex.MatchString(localpart) {
return ErrUsernameInvalid
}
return nil
}

170
internal/validate_test.go Normal file
View file

@ -0,0 +1,170 @@
package internal
import (
"net/http"
"reflect"
"strings"
"testing"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
func Test_validatePassword(t *testing.T) {
tests := []struct {
name string
password string
wantError error
wantJSON *util.JSONResponse
}{
{
name: "password too short",
password: "shortpw",
wantError: ErrPasswordWeak,
wantJSON: &util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.WeakPassword(ErrPasswordWeak.Error())},
},
{
name: "password too long",
password: strings.Repeat("a", maxPasswordLength+1),
wantError: ErrPasswordTooLong,
wantJSON: &util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(ErrPasswordTooLong.Error())},
},
{
name: "password OK",
password: util.RandomString(10),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotErr := ValidatePassword(tt.password)
if !reflect.DeepEqual(gotErr, tt.wantError) {
t.Errorf("validatePassword() = %v, wantJSON %v", gotErr, tt.wantError)
}
if got := PasswordResponse(gotErr); !reflect.DeepEqual(got, tt.wantJSON) {
t.Errorf("validatePassword() = %v, wantJSON %v", got, tt.wantJSON)
}
})
}
}
func Test_validateUsername(t *testing.T) {
tooLongUsername := strings.Repeat("a", maxUsernameLength)
tests := []struct {
name string
localpart string
domain gomatrixserverlib.ServerName
wantErr error
wantJSON *util.JSONResponse
}{
{
name: "empty username",
localpart: "",
domain: "localhost",
wantErr: ErrUsernameInvalid,
wantJSON: &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()),
},
},
{
name: "invalid username",
localpart: "INVALIDUSERNAME",
domain: "localhost",
wantErr: ErrUsernameInvalid,
wantJSON: &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()),
},
},
{
name: "username too long",
localpart: tooLongUsername,
domain: "localhost",
wantErr: ErrUsernameTooLong,
wantJSON: &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(ErrUsernameTooLong.Error()),
},
},
{
name: "localpart starting with an underscore",
localpart: "_notvalid",
domain: "localhost",
wantErr: ErrUsernameUnderscore,
wantJSON: &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername(ErrUsernameUnderscore.Error()),
},
},
{
name: "valid username",
localpart: "valid",
domain: "localhost",
},
{
name: "complex username",
localpart: "f00_bar-baz.=40/",
domain: "localhost",
},
{
name: "rejects emoji username 💥",
localpart: "💥",
domain: "localhost",
wantErr: ErrUsernameInvalid,
wantJSON: &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()),
},
},
{
name: "special characters are allowed",
localpart: "/dev/null",
domain: "localhost",
},
{
name: "special characters are allowed 2",
localpart: "i_am_allowed=1",
domain: "localhost",
},
{
name: "not all special characters are allowed",
localpart: "notallowed#", // contains #
domain: "localhost",
wantErr: ErrUsernameInvalid,
wantJSON: &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()),
},
},
{
name: "username containing numbers",
localpart: "hello1337",
domain: "localhost",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotErr := ValidateUsername(tt.localpart, tt.domain)
if !reflect.DeepEqual(gotErr, tt.wantErr) {
t.Errorf("ValidateUsername() = %v, wantErr %v", gotErr, tt.wantErr)
}
if gotJSON := UsernameResponse(gotErr); !reflect.DeepEqual(gotJSON, tt.wantJSON) {
t.Errorf("UsernameResponse() = %v, wantJSON %v", gotJSON, tt.wantJSON)
}
// Application services are allowed usernames starting with an underscore
if tt.wantErr == ErrUsernameUnderscore {
return
}
gotErr = ValidateApplicationServiceUsername(tt.localpart, tt.domain)
if !reflect.DeepEqual(gotErr, tt.wantErr) {
t.Errorf("ValidateUsername() = %v, wantErr %v", gotErr, tt.wantErr)
}
if gotJSON := UsernameResponse(gotErr); !reflect.DeepEqual(gotJSON, tt.wantJSON) {
t.Errorf("UsernameResponse() = %v, wantJSON %v", gotJSON, tt.wantJSON)
}
})
}
}