mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-22 05:23:09 -06:00
Shuffle Validate* functions arround
This commit is contained in:
parent
ad5009dbcc
commit
adedb38f8c
|
|
@ -137,7 +137,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
|
|||
request := struct {
|
||||
Password string `json:"password"`
|
||||
}{}
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
if err = json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.Unknown("Failed to decode request body: " + err.Error()),
|
||||
|
|
@ -150,8 +150,8 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
|
|||
}
|
||||
}
|
||||
|
||||
if resErr := internal.ValidatePassword(request.Password); resErr != nil {
|
||||
return *resErr
|
||||
if err = internal.ValidatePassword(request.Password); err != nil {
|
||||
return *internal.PasswordResponse(err)
|
||||
}
|
||||
|
||||
updateReq := &userapi.PerformPasswordUpdateRequest{
|
||||
|
|
|
|||
|
|
@ -82,8 +82,8 @@ func Password(
|
|||
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
|
||||
|
||||
// Check the new password strength.
|
||||
if resErr = internal.ValidatePassword(r.NewPassword); resErr != nil {
|
||||
return *resErr
|
||||
if err := internal.ValidatePassword(r.NewPassword); err != nil {
|
||||
return *internal.PasswordResponse(err)
|
||||
}
|
||||
|
||||
// Get the local part.
|
||||
|
|
|
|||
|
|
@ -24,7 +24,6 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
|
@ -61,10 +60,7 @@ var (
|
|||
)
|
||||
)
|
||||
|
||||
const (
|
||||
maxUsernameLength = 254 // http://matrix.org/speculator/spec/HEAD/intro.html#user-identifiers TODO account for domain
|
||||
sessionIDLength = 24
|
||||
)
|
||||
const sessionIDLength = 24
|
||||
|
||||
// sessionsDict keeps track of completed auth stages for each session.
|
||||
// It shouldn't be passed by value because it contains a mutex.
|
||||
|
|
@ -199,8 +195,7 @@ func (d *sessionsDict) getDeviceToDelete(sessionID string) (string, bool) {
|
|||
}
|
||||
|
||||
var (
|
||||
sessions = newSessionsDict()
|
||||
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`)
|
||||
sessions = newSessionsDict()
|
||||
)
|
||||
|
||||
// registerRequest represents the submitted registration request.
|
||||
|
|
@ -276,44 +271,6 @@ type recaptchaResponse struct {
|
|||
ErrorCodes []int `json:"error-codes"`
|
||||
}
|
||||
|
||||
// validateUsername returns an error response if the username is invalid
|
||||
func validateUsername(localpart string, domain gomatrixserverlib.ServerName) *util.JSONResponse {
|
||||
// https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
|
||||
if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength {
|
||||
return &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.BadJSON(fmt.Sprintf("%q exceeds the maximum length of %d characters", id, maxUsernameLength)),
|
||||
}
|
||||
} else if !validUsernameRegex.MatchString(localpart) {
|
||||
return &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"),
|
||||
}
|
||||
} else if localpart[0] == '_' { // Regex checks its not a zero length string
|
||||
return &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidUsername("Username cannot start with a '_'"),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateApplicationServiceUsername returns an error response if the username is invalid for an application service
|
||||
func validateApplicationServiceUsername(localpart string, domain gomatrixserverlib.ServerName) *util.JSONResponse {
|
||||
if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength {
|
||||
return &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.BadJSON(fmt.Sprintf("%q exceeds the maximum length of %d characters", id, maxUsernameLength)),
|
||||
}
|
||||
} else if !validUsernameRegex.MatchString(localpart) {
|
||||
return &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
ErrInvalidCaptcha = errors.New("invalid captcha response")
|
||||
ErrMissingResponse = errors.New("captcha response is required")
|
||||
|
|
@ -496,8 +453,8 @@ func validateApplicationService(
|
|||
}
|
||||
|
||||
// Check username application service is trying to register is valid
|
||||
if err := validateApplicationServiceUsername(username, cfg.Matrix.ServerName); err != nil {
|
||||
return "", err
|
||||
if err := internal.ValidateApplicationServiceUsername(username, cfg.Matrix.ServerName); err != nil {
|
||||
return "", internal.UsernameResponse(err)
|
||||
}
|
||||
|
||||
// No errors, registration valid
|
||||
|
|
@ -552,7 +509,9 @@ func Register(
|
|||
if resErr := httputil.UnmarshalJSON(reqBody, &r); resErr != nil {
|
||||
return *resErr
|
||||
}
|
||||
if l, d, err := cfg.Matrix.SplitLocalID('@', r.Username); err == nil {
|
||||
var l string
|
||||
var d gomatrixserverlib.ServerName
|
||||
if l, d, err = cfg.Matrix.SplitLocalID('@', r.Username); err == nil {
|
||||
r.Username, r.ServerName = l, d
|
||||
}
|
||||
if req.URL.Query().Get("kind") == "guest" {
|
||||
|
|
@ -560,7 +519,7 @@ func Register(
|
|||
}
|
||||
|
||||
// Don't allow numeric usernames less than MAX_INT64.
|
||||
if _, err := strconv.ParseInt(r.Username, 10, 64); err == nil {
|
||||
if _, err = strconv.ParseInt(r.Username, 10, 64); err == nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidUsername("Numeric user IDs are reserved"),
|
||||
|
|
@ -572,7 +531,7 @@ func Register(
|
|||
ServerName: r.ServerName,
|
||||
}
|
||||
nres := &userapi.QueryNumericLocalpartResponse{}
|
||||
if err := userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil {
|
||||
if err = userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
|
|
@ -589,8 +548,8 @@ func Register(
|
|||
case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil:
|
||||
// Spec-compliant case (the access_token is specified and the login type
|
||||
// is correctly set, so it's an appservice registration)
|
||||
if resErr := validateApplicationServiceUsername(r.Username, r.ServerName); resErr != nil {
|
||||
return *resErr
|
||||
if err = internal.ValidateApplicationServiceUsername(r.Username, r.ServerName); err != nil {
|
||||
return *internal.UsernameResponse(err)
|
||||
}
|
||||
case accessTokenErr == nil:
|
||||
// Non-spec-compliant case (the access_token is specified but the login
|
||||
|
|
@ -602,12 +561,12 @@ func Register(
|
|||
default:
|
||||
// Spec-compliant case (neither the access_token nor the login type are
|
||||
// specified, so it's a normal user registration)
|
||||
if resErr := validateUsername(r.Username, r.ServerName); resErr != nil {
|
||||
return *resErr
|
||||
if err = internal.ValidateUsername(r.Username, r.ServerName); err != nil {
|
||||
return *internal.UsernameResponse(err)
|
||||
}
|
||||
}
|
||||
if resErr := internal.ValidatePassword(r.Password); resErr != nil {
|
||||
return *resErr
|
||||
if err = internal.ValidatePassword(r.Password); err != nil {
|
||||
return *internal.PasswordResponse(err)
|
||||
}
|
||||
|
||||
logger := util.GetLogger(req.Context())
|
||||
|
|
@ -1048,8 +1007,8 @@ func RegisterAvailable(
|
|||
}
|
||||
}
|
||||
|
||||
if err := validateUsername(username, domain); err != nil {
|
||||
return *err
|
||||
if err := internal.ValidateUsername(username, domain); err != nil {
|
||||
return *internal.UsernameResponse(err)
|
||||
}
|
||||
|
||||
// Check if this username is reserved by an application service
|
||||
|
|
@ -1111,11 +1070,11 @@ func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.Clien
|
|||
// downcase capitals
|
||||
ssrr.User = strings.ToLower(ssrr.User)
|
||||
|
||||
if resErr := validateUsername(ssrr.User, cfg.Matrix.ServerName); resErr != nil {
|
||||
return *resErr
|
||||
if err = internal.ValidateUsername(ssrr.User, cfg.Matrix.ServerName); err != nil {
|
||||
return *internal.UsernameResponse(err)
|
||||
}
|
||||
if resErr := internal.ValidatePassword(ssrr.Password); resErr != nil {
|
||||
return *resErr
|
||||
if err = internal.ValidatePassword(ssrr.Password); err != nil {
|
||||
return *internal.PasswordResponse(err)
|
||||
}
|
||||
deviceID := "shared_secret_registration"
|
||||
|
||||
|
|
|
|||
|
|
@ -28,13 +28,13 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/keyserver"
|
||||
"github.com/matrix-org/dendrite/roomserver"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/dendrite/test/testrig"
|
||||
"github.com/matrix-org/dendrite/userapi"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
|
|
@ -280,102 +280,6 @@ func TestSessionCleanUp(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func Test_validateUsername(t *testing.T) {
|
||||
tooLongUsername := strings.Repeat("a", maxUsernameLength)
|
||||
tests := []struct {
|
||||
name string
|
||||
localpart string
|
||||
domain gomatrixserverlib.ServerName
|
||||
want *util.JSONResponse
|
||||
}{
|
||||
{
|
||||
name: "empty username",
|
||||
localpart: "",
|
||||
domain: "localhost",
|
||||
want: &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid username",
|
||||
localpart: "INVALIDUSERNAME",
|
||||
domain: "localhost",
|
||||
want: &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "username too long",
|
||||
localpart: tooLongUsername,
|
||||
domain: "localhost",
|
||||
want: &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.BadJSON(fmt.Sprintf("%q exceeds the maximum length of %d characters", fmt.Sprintf("@%s:%s", tooLongUsername, "localhost"), maxUsernameLength)),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "localpart starting with an underscore",
|
||||
localpart: "_notvalid",
|
||||
domain: "localhost",
|
||||
want: &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidUsername("Username cannot start with a '_'"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid username",
|
||||
localpart: "valid",
|
||||
domain: "localhost",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := validateUsername(tt.localpart, tt.domain); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("validateUsername() = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got := validateApplicationServiceUsername(tt.localpart, tt.domain); !reflect.DeepEqual(got, tt.want) {
|
||||
if got != nil && got.JSON != jsonerror.InvalidUsername("Username cannot start with a '_'") {
|
||||
t.Errorf("validateUsername() = %v, want %v", got, tt.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_validatePassword(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
password string
|
||||
want *util.JSONResponse
|
||||
}{
|
||||
{
|
||||
name: "password too short",
|
||||
password: "shortpw",
|
||||
want: &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.WeakPassword(fmt.Sprintf("password too weak: min %d chars", minPasswordLength)),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "password too long",
|
||||
password: strings.Repeat("a", maxPasswordLength+1),
|
||||
want: &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.BadJSON(fmt.Sprintf("'password' >%d characters", maxPasswordLength)),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := validatePassword(tt.password); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("validatePassword() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_register(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
|
|
@ -440,12 +344,9 @@ func Test_register(t *testing.T) {
|
|||
username: "LOWERCASED", // this is going to be lower-cased
|
||||
},
|
||||
{
|
||||
name: "invalid username",
|
||||
username: "#totalyNotValid",
|
||||
wantResponse: util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"),
|
||||
},
|
||||
name: "invalid username",
|
||||
username: "#totalyNotValid",
|
||||
wantResponse: *internal.UsernameResponse(internal.ErrUsernameInvalid),
|
||||
},
|
||||
{
|
||||
name: "numeric username is forbidden",
|
||||
|
|
@ -471,126 +372,128 @@ func Test_register(t *testing.T) {
|
|||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
if tc.kind == "" {
|
||||
tc.kind = "user"
|
||||
}
|
||||
if tc.password == "" && !tc.forceEmpty {
|
||||
tc.password = "someRandomPassword"
|
||||
}
|
||||
if tc.username == "" && !tc.forceEmpty {
|
||||
tc.username = "valid"
|
||||
}
|
||||
if tc.loginType == "" {
|
||||
tc.loginType = "m.login.dummy"
|
||||
}
|
||||
|
||||
reg := registerRequest{
|
||||
Password: tc.password,
|
||||
Username: tc.username,
|
||||
}
|
||||
|
||||
body := &bytes.Buffer{}
|
||||
|
||||
base.Cfg.ClientAPI.RegistrationDisabled = tc.registrationDisabled
|
||||
base.Cfg.ClientAPI.GuestsDisabled = tc.guestsDisabled
|
||||
|
||||
err := json.NewEncoder(body).Encode(reg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("/?kind=%s", tc.kind), body)
|
||||
|
||||
resp := Register(req, userAPI, &base.Cfg.ClientAPI)
|
||||
t.Logf("Resp: %+v", resp)
|
||||
|
||||
// The first request should return a userInteractiveResponse
|
||||
switch r := resp.JSON.(type) {
|
||||
case userInteractiveResponse:
|
||||
// Check that the flows are the ones we configured
|
||||
if !reflect.DeepEqual(r.Flows, base.Cfg.Derived.Registration.Flows) {
|
||||
t.Fatalf("unexpected registration flows: %+v, want %+v", r.Flows, base.Cfg.Derived.Registration.Flows)
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if tc.kind == "" {
|
||||
tc.kind = "user"
|
||||
}
|
||||
case *jsonerror.MatrixError:
|
||||
if !reflect.DeepEqual(tc.wantResponse, resp) {
|
||||
t.Fatalf("(%s), unexpected response: %+v, want: %+v", tc.name, resp, tc.wantResponse)
|
||||
if tc.password == "" && !tc.forceEmpty {
|
||||
tc.password = "someRandomPassword"
|
||||
}
|
||||
continue
|
||||
case registerResponse:
|
||||
// this should only be possible on guest user registration, never for normal users
|
||||
if tc.kind != "guest" {
|
||||
t.Fatalf("got register response on first request: %+v", r)
|
||||
if tc.username == "" && !tc.forceEmpty {
|
||||
tc.username = "valid"
|
||||
}
|
||||
// assert we've got a UserID, AccessToken and DeviceID
|
||||
if r.UserID == "" {
|
||||
t.Fatalf("missing userID in response")
|
||||
if tc.loginType == "" {
|
||||
tc.loginType = "m.login.dummy"
|
||||
}
|
||||
if r.AccessToken == "" {
|
||||
|
||||
reg := registerRequest{
|
||||
Password: tc.password,
|
||||
Username: tc.username,
|
||||
}
|
||||
|
||||
body := &bytes.Buffer{}
|
||||
|
||||
base.Cfg.ClientAPI.RegistrationDisabled = tc.registrationDisabled
|
||||
base.Cfg.ClientAPI.GuestsDisabled = tc.guestsDisabled
|
||||
|
||||
err := json.NewEncoder(body).Encode(reg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("/?kind=%s", tc.kind), body)
|
||||
|
||||
resp := Register(req, userAPI, &base.Cfg.ClientAPI)
|
||||
t.Logf("Resp: %+v", resp)
|
||||
|
||||
// The first request should return a userInteractiveResponse
|
||||
switch r := resp.JSON.(type) {
|
||||
case userInteractiveResponse:
|
||||
// Check that the flows are the ones we configured
|
||||
if !reflect.DeepEqual(r.Flows, base.Cfg.Derived.Registration.Flows) {
|
||||
t.Fatalf("unexpected registration flows: %+v, want %+v", r.Flows, base.Cfg.Derived.Registration.Flows)
|
||||
}
|
||||
case *jsonerror.MatrixError:
|
||||
if !reflect.DeepEqual(tc.wantResponse, resp) {
|
||||
t.Fatalf("(%s), unexpected response: %+v, want: %+v", tc.name, resp, tc.wantResponse)
|
||||
}
|
||||
return
|
||||
case registerResponse:
|
||||
// this should only be possible on guest user registration, never for normal users
|
||||
if tc.kind != "guest" {
|
||||
t.Fatalf("got register response on first request: %+v", r)
|
||||
}
|
||||
// assert we've got a UserID, AccessToken and DeviceID
|
||||
if r.UserID == "" {
|
||||
t.Fatalf("missing userID in response")
|
||||
}
|
||||
if r.AccessToken == "" {
|
||||
t.Fatalf("missing accessToken in response")
|
||||
}
|
||||
if r.DeviceID == "" {
|
||||
t.Fatalf("missing deviceID in response")
|
||||
}
|
||||
return
|
||||
default:
|
||||
t.Logf("Got response: %T", resp.JSON)
|
||||
}
|
||||
|
||||
// If we reached this, we should have received a UIA response
|
||||
uia, ok := resp.JSON.(userInteractiveResponse)
|
||||
if !ok {
|
||||
t.Fatalf("did not receive a userInteractiveResponse: %T", resp.JSON)
|
||||
}
|
||||
t.Logf("%+v", uia)
|
||||
|
||||
// Register the user
|
||||
reg.Auth = authDict{
|
||||
Type: authtypes.LoginType(tc.loginType),
|
||||
Session: uia.Session,
|
||||
}
|
||||
dummy := "dummy"
|
||||
reg.DeviceID = &dummy
|
||||
reg.InitialDisplayName = &dummy
|
||||
reg.Type = authtypes.LoginType(tc.loginType)
|
||||
|
||||
err = json.NewEncoder(body).Encode(reg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req = httptest.NewRequest(http.MethodPost, "/", body)
|
||||
|
||||
resp = Register(req, userAPI, &base.Cfg.ClientAPI)
|
||||
|
||||
switch resp.JSON.(type) {
|
||||
case *jsonerror.MatrixError:
|
||||
if !reflect.DeepEqual(tc.wantResponse, resp) {
|
||||
t.Fatalf("unexpected response: %+v, want: %+v", resp, tc.wantResponse)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
rr, ok := resp.JSON.(registerResponse)
|
||||
if !ok {
|
||||
t.Fatalf("expected a registerresponse, got %T", resp.JSON)
|
||||
}
|
||||
|
||||
// validate the response
|
||||
if tc.forceEmpty {
|
||||
// when not supplying a username, one will be generated. Given this _SHOULD_ be
|
||||
// the second user, set the username accordingly
|
||||
reg.Username = "2"
|
||||
}
|
||||
wantUserID := strings.ToLower(fmt.Sprintf("@%s:%s", reg.Username, "test"))
|
||||
if wantUserID != rr.UserID {
|
||||
t.Fatalf("unexpected userID: %s, want %s", rr.UserID, wantUserID)
|
||||
}
|
||||
if rr.DeviceID != *reg.DeviceID {
|
||||
t.Fatalf("unexpected deviceID: %s, want %s", rr.DeviceID, *reg.DeviceID)
|
||||
}
|
||||
if rr.AccessToken == "" {
|
||||
t.Fatalf("missing accessToken in response")
|
||||
}
|
||||
if r.DeviceID == "" {
|
||||
t.Fatalf("missing deviceID in response")
|
||||
}
|
||||
continue
|
||||
default:
|
||||
t.Logf("Got response: %T", resp.JSON)
|
||||
}
|
||||
|
||||
// If we reached this, we should have received a UIA response
|
||||
uia, ok := resp.JSON.(userInteractiveResponse)
|
||||
if !ok {
|
||||
t.Fatalf("did not receive a userInteractiveResponse: %T", resp.JSON)
|
||||
}
|
||||
t.Logf("%+v", uia)
|
||||
|
||||
// Register the user
|
||||
reg.Auth = authDict{
|
||||
Type: authtypes.LoginType(tc.loginType),
|
||||
Session: uia.Session,
|
||||
}
|
||||
dummy := "dummy"
|
||||
reg.DeviceID = &dummy
|
||||
reg.InitialDisplayName = &dummy
|
||||
reg.Type = authtypes.LoginType(tc.loginType)
|
||||
|
||||
err = json.NewEncoder(body).Encode(reg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req = httptest.NewRequest(http.MethodPost, "/", body)
|
||||
|
||||
resp = Register(req, userAPI, &base.Cfg.ClientAPI)
|
||||
|
||||
switch resp.JSON.(type) {
|
||||
case *jsonerror.MatrixError:
|
||||
if !reflect.DeepEqual(tc.wantResponse, resp) {
|
||||
t.Fatalf("unexpected response: %+v, want: %+v", resp, tc.wantResponse)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
rr, ok := resp.JSON.(registerResponse)
|
||||
if !ok {
|
||||
t.Fatalf("expected a registerresponse, got %T", resp.JSON)
|
||||
}
|
||||
|
||||
// validate the response
|
||||
if tc.forceEmpty {
|
||||
// when not supplying a username, one will be generated. Given this _SHOULD_ be
|
||||
// the second user, set the username accordingly
|
||||
reg.Username = "2"
|
||||
}
|
||||
wantUserID := strings.ToLower(fmt.Sprintf("@%s:%s", reg.Username, "test"))
|
||||
if wantUserID != rr.UserID {
|
||||
t.Fatalf("unexpected userID: %s, want %s", rr.UserID, wantUserID)
|
||||
}
|
||||
if rr.DeviceID != *reg.DeviceID {
|
||||
t.Fatalf("unexpected deviceID: %s, want %s", rr.DeviceID, *reg.DeviceID)
|
||||
}
|
||||
if rr.AccessToken == "" {
|
||||
t.Fatalf("missing accessToken in response")
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,10 +25,10 @@ import (
|
|||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
|
@ -58,15 +58,14 @@ Arguments:
|
|||
`
|
||||
|
||||
var (
|
||||
username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')")
|
||||
password = flag.String("password", "", "The password to associate with the account")
|
||||
pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)")
|
||||
pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin")
|
||||
isAdmin = flag.Bool("admin", false, "Create an admin account")
|
||||
resetPassword = flag.Bool("reset-password", false, "Deprecated")
|
||||
serverURL = flag.String("url", "http://localhost:8008", "The URL to connect to.")
|
||||
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`)
|
||||
timeout = flag.Duration("timeout", time.Second*30, "Timeout for the http client when connecting to the server")
|
||||
username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')")
|
||||
password = flag.String("password", "", "The password to associate with the account")
|
||||
pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)")
|
||||
pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin")
|
||||
isAdmin = flag.Bool("admin", false, "Create an admin account")
|
||||
resetPassword = flag.Bool("reset-password", false, "Deprecated")
|
||||
serverURL = flag.String("url", "http://localhost:8008", "The URL to connect to.")
|
||||
timeout = flag.Duration("timeout", time.Second*30, "Timeout for the http client when connecting to the server")
|
||||
)
|
||||
|
||||
var cl = http.Client{
|
||||
|
|
@ -95,20 +94,21 @@ func main() {
|
|||
os.Exit(1)
|
||||
}
|
||||
|
||||
if !validUsernameRegex.MatchString(*username) {
|
||||
logrus.Warn("Username can only contain characters a-z, 0-9, or '_-./='")
|
||||
if err := internal.ValidateUsername(*username, cfg.Global.ServerName); err != nil {
|
||||
logrus.WithError(err).Error("Specified username is invalid")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if len(fmt.Sprintf("@%s:%s", *username, cfg.Global.ServerName)) > 255 {
|
||||
logrus.Fatalf("Username can not be longer than 255 characters: %s", fmt.Sprintf("@%s:%s", *username, cfg.Global.ServerName))
|
||||
}
|
||||
|
||||
pass, err := getPassword(*password, *pwdFile, *pwdStdin, os.Stdin)
|
||||
if err != nil {
|
||||
logrus.Fatalln(err)
|
||||
}
|
||||
|
||||
if err = internal.ValidatePassword(pass); err != nil {
|
||||
logrus.WithError(err).Error("Specified password is invalid")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
cl.Timeout = *timeout
|
||||
|
||||
accessToken, err := sharedSecretRegister(cfg.ClientAPI.RegistrationSharedSecret, *serverURL, *username, pass, *isAdmin)
|
||||
|
|
|
|||
|
|
@ -15,30 +15,96 @@
|
|||
package internal
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
const minPasswordLength = 8 // http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based
|
||||
const (
|
||||
maxUsernameLength = 254 // http://matrix.org/speculator/spec/HEAD/intro.html#user-identifiers TODO account for domain
|
||||
|
||||
const maxPasswordLength = 512 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
|
||||
minPasswordLength = 8 // http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based
|
||||
maxPasswordLength = 512 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPasswordTooLong = fmt.Errorf("password too long: max %d characters", maxPasswordLength)
|
||||
ErrPasswordWeak = fmt.Errorf("password too weak: min %d characters", minPasswordLength)
|
||||
ErrUsernameTooLong = fmt.Errorf("username exceeds the maximum length of %d characters", maxUsernameLength)
|
||||
ErrUsernameInvalid = errors.New("username can only contain characters a-z, 0-9, or '_-./='")
|
||||
ErrUsernameUnderscore = errors.New("username cannot start with a '_'")
|
||||
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`)
|
||||
)
|
||||
|
||||
// ValidatePassword returns an error response if the password is invalid
|
||||
func ValidatePassword(password string) *util.JSONResponse {
|
||||
func ValidatePassword(password string) error {
|
||||
// https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
|
||||
if len(password) > maxPasswordLength {
|
||||
return &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.BadJSON(fmt.Sprintf("password too long: max %d characters", maxPasswordLength)),
|
||||
}
|
||||
return ErrPasswordTooLong
|
||||
} else if len(password) > 0 && len(password) < minPasswordLength {
|
||||
return ErrPasswordWeak
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PasswordResponse returns a util.JSONResponse for a given error, if any.
|
||||
func PasswordResponse(err error) *util.JSONResponse {
|
||||
switch err {
|
||||
case ErrPasswordWeak:
|
||||
return &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.WeakPassword(fmt.Sprintf("password too weak: min %d chars", minPasswordLength)),
|
||||
JSON: jsonerror.WeakPassword(ErrPasswordWeak.Error()),
|
||||
}
|
||||
case ErrPasswordTooLong:
|
||||
return &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.BadJSON(ErrPasswordTooLong.Error()),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateUsername returns an error if the username is invalid
|
||||
func ValidateUsername(localpart string, domain gomatrixserverlib.ServerName) error {
|
||||
// https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
|
||||
if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength {
|
||||
return ErrUsernameTooLong
|
||||
} else if !validUsernameRegex.MatchString(localpart) {
|
||||
return ErrUsernameInvalid
|
||||
} else if localpart[0] == '_' { // Regex checks its not a zero length string
|
||||
return ErrUsernameUnderscore
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UsernameResponse returns a util.JSONResponse for the given error, if any.
|
||||
func UsernameResponse(err error) *util.JSONResponse {
|
||||
switch err {
|
||||
case ErrUsernameTooLong:
|
||||
return &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.BadJSON(err.Error()),
|
||||
}
|
||||
case ErrUsernameInvalid, ErrUsernameUnderscore:
|
||||
return &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidUsername(err.Error()),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateApplicationServiceUsername returns an error if the username is invalid for an application service
|
||||
func ValidateApplicationServiceUsername(localpart string, domain gomatrixserverlib.ServerName) error {
|
||||
if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength {
|
||||
return ErrUsernameTooLong
|
||||
} else if !validUsernameRegex.MatchString(localpart) {
|
||||
return ErrUsernameInvalid
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
170
internal/validate_test.go
Normal file
170
internal/validate_test.go
Normal file
|
|
@ -0,0 +1,170 @@
|
|||
package internal
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
func Test_validatePassword(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
password string
|
||||
wantError error
|
||||
wantJSON *util.JSONResponse
|
||||
}{
|
||||
{
|
||||
name: "password too short",
|
||||
password: "shortpw",
|
||||
wantError: ErrPasswordWeak,
|
||||
wantJSON: &util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.WeakPassword(ErrPasswordWeak.Error())},
|
||||
},
|
||||
{
|
||||
name: "password too long",
|
||||
password: strings.Repeat("a", maxPasswordLength+1),
|
||||
wantError: ErrPasswordTooLong,
|
||||
wantJSON: &util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(ErrPasswordTooLong.Error())},
|
||||
},
|
||||
{
|
||||
name: "password OK",
|
||||
password: util.RandomString(10),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotErr := ValidatePassword(tt.password)
|
||||
if !reflect.DeepEqual(gotErr, tt.wantError) {
|
||||
t.Errorf("validatePassword() = %v, wantJSON %v", gotErr, tt.wantError)
|
||||
}
|
||||
|
||||
if got := PasswordResponse(gotErr); !reflect.DeepEqual(got, tt.wantJSON) {
|
||||
t.Errorf("validatePassword() = %v, wantJSON %v", got, tt.wantJSON)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_validateUsername(t *testing.T) {
|
||||
tooLongUsername := strings.Repeat("a", maxUsernameLength)
|
||||
tests := []struct {
|
||||
name string
|
||||
localpart string
|
||||
domain gomatrixserverlib.ServerName
|
||||
wantErr error
|
||||
wantJSON *util.JSONResponse
|
||||
}{
|
||||
{
|
||||
name: "empty username",
|
||||
localpart: "",
|
||||
domain: "localhost",
|
||||
wantErr: ErrUsernameInvalid,
|
||||
wantJSON: &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid username",
|
||||
localpart: "INVALIDUSERNAME",
|
||||
domain: "localhost",
|
||||
wantErr: ErrUsernameInvalid,
|
||||
wantJSON: &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "username too long",
|
||||
localpart: tooLongUsername,
|
||||
domain: "localhost",
|
||||
wantErr: ErrUsernameTooLong,
|
||||
wantJSON: &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.BadJSON(ErrUsernameTooLong.Error()),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "localpart starting with an underscore",
|
||||
localpart: "_notvalid",
|
||||
domain: "localhost",
|
||||
wantErr: ErrUsernameUnderscore,
|
||||
wantJSON: &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidUsername(ErrUsernameUnderscore.Error()),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid username",
|
||||
localpart: "valid",
|
||||
domain: "localhost",
|
||||
},
|
||||
{
|
||||
name: "complex username",
|
||||
localpart: "f00_bar-baz.=40/",
|
||||
domain: "localhost",
|
||||
},
|
||||
{
|
||||
name: "rejects emoji username 💥",
|
||||
localpart: "💥",
|
||||
domain: "localhost",
|
||||
wantErr: ErrUsernameInvalid,
|
||||
wantJSON: &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "special characters are allowed",
|
||||
localpart: "/dev/null",
|
||||
domain: "localhost",
|
||||
},
|
||||
{
|
||||
name: "special characters are allowed 2",
|
||||
localpart: "i_am_allowed=1",
|
||||
domain: "localhost",
|
||||
},
|
||||
{
|
||||
name: "not all special characters are allowed",
|
||||
localpart: "notallowed#", // contains #
|
||||
domain: "localhost",
|
||||
wantErr: ErrUsernameInvalid,
|
||||
wantJSON: &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "username containing numbers",
|
||||
localpart: "hello1337",
|
||||
domain: "localhost",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotErr := ValidateUsername(tt.localpart, tt.domain)
|
||||
if !reflect.DeepEqual(gotErr, tt.wantErr) {
|
||||
t.Errorf("ValidateUsername() = %v, wantErr %v", gotErr, tt.wantErr)
|
||||
}
|
||||
if gotJSON := UsernameResponse(gotErr); !reflect.DeepEqual(gotJSON, tt.wantJSON) {
|
||||
t.Errorf("UsernameResponse() = %v, wantJSON %v", gotJSON, tt.wantJSON)
|
||||
}
|
||||
|
||||
// Application services are allowed usernames starting with an underscore
|
||||
if tt.wantErr == ErrUsernameUnderscore {
|
||||
return
|
||||
}
|
||||
gotErr = ValidateApplicationServiceUsername(tt.localpart, tt.domain)
|
||||
if !reflect.DeepEqual(gotErr, tt.wantErr) {
|
||||
t.Errorf("ValidateUsername() = %v, wantErr %v", gotErr, tt.wantErr)
|
||||
}
|
||||
if gotJSON := UsernameResponse(gotErr); !reflect.DeepEqual(gotJSON, tt.wantJSON) {
|
||||
t.Errorf("UsernameResponse() = %v, wantJSON %v", gotJSON, tt.wantJSON)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
Reference in a new issue