mirror of
https://github.com/matrix-org/dendrite.git
synced 2024-11-26 00:01:55 -06:00
Add clientapi tests (#2916)
This PR - adds several tests for the clientapi, mostly around `/register` and auth fallback. - removes the now deprecated `homeserver` field from responses to `/register` and `/login` - slightly refactors auth fallback handling
This commit is contained in:
parent
f47515e38b
commit
f762ce1050
3
.github/workflows/dendrite.yml
vendored
3
.github/workflows/dendrite.yml
vendored
|
@ -331,8 +331,7 @@ jobs:
|
||||||
postgres: postgres
|
postgres: postgres
|
||||||
api: full-http
|
api: full-http
|
||||||
container:
|
container:
|
||||||
# Temporary for debugging to see if this image is working better.
|
image: matrixdotorg/sytest-dendrite
|
||||||
image: matrixdotorg/sytest-dendrite@sha256:434ad464a9f4ed3f8c3cc47200275b6ccb5c5031a8063daf4acea62be5a23c73
|
|
||||||
volumes:
|
volumes:
|
||||||
- ${{ github.workspace }}:/src
|
- ${{ github.workspace }}:/src
|
||||||
- /root/.cache/go-build:/github/home/.cache/go-build
|
- /root/.cache/go-build:/github/home/.cache/go-build
|
||||||
|
|
|
@ -137,7 +137,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
|
||||||
request := struct {
|
request := struct {
|
||||||
Password string `json:"password"`
|
Password string `json:"password"`
|
||||||
}{}
|
}{}
|
||||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
if err = json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusBadRequest,
|
Code: http.StatusBadRequest,
|
||||||
JSON: jsonerror.Unknown("Failed to decode request body: " + err.Error()),
|
JSON: jsonerror.Unknown("Failed to decode request body: " + err.Error()),
|
||||||
|
@ -150,8 +150,8 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if resErr := internal.ValidatePassword(request.Password); resErr != nil {
|
if err = internal.ValidatePassword(request.Password); err != nil {
|
||||||
return *resErr
|
return *internal.PasswordResponse(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
updateReq := &userapi.PerformPasswordUpdateRequest{
|
updateReq := &userapi.PerformPasswordUpdateRequest{
|
||||||
|
|
|
@ -15,11 +15,11 @@
|
||||||
package routing
|
package routing
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"html/template"
|
"html/template"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
)
|
)
|
||||||
|
@ -101,14 +101,28 @@ func serveTemplate(w http.ResponseWriter, templateHTML string, data map[string]s
|
||||||
func AuthFallback(
|
func AuthFallback(
|
||||||
w http.ResponseWriter, req *http.Request, authType string,
|
w http.ResponseWriter, req *http.Request, authType string,
|
||||||
cfg *config.ClientAPI,
|
cfg *config.ClientAPI,
|
||||||
) *util.JSONResponse {
|
) {
|
||||||
sessionID := req.URL.Query().Get("session")
|
// We currently only support "m.login.recaptcha", so fail early if that's not requested
|
||||||
|
if authType == authtypes.LoginTypeRecaptcha {
|
||||||
|
if !cfg.RecaptchaEnabled {
|
||||||
|
writeHTTPMessage(w, req,
|
||||||
|
"Recaptcha login is disabled on this Homeserver",
|
||||||
|
http.StatusBadRequest,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
writeHTTPMessage(w, req, fmt.Sprintf("Unknown authtype %q", authType), http.StatusNotImplemented)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionID := req.URL.Query().Get("session")
|
||||||
if sessionID == "" {
|
if sessionID == "" {
|
||||||
return writeHTTPMessage(w, req,
|
writeHTTPMessage(w, req,
|
||||||
"Session ID not provided",
|
"Session ID not provided",
|
||||||
http.StatusBadRequest,
|
http.StatusBadRequest,
|
||||||
)
|
)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
serveRecaptcha := func() {
|
serveRecaptcha := func() {
|
||||||
|
@ -130,70 +144,44 @@ func AuthFallback(
|
||||||
|
|
||||||
if req.Method == http.MethodGet {
|
if req.Method == http.MethodGet {
|
||||||
// Handle Recaptcha
|
// Handle Recaptcha
|
||||||
if authType == authtypes.LoginTypeRecaptcha {
|
|
||||||
if err := checkRecaptchaEnabled(cfg, w, req); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
serveRecaptcha()
|
serveRecaptcha()
|
||||||
return nil
|
return
|
||||||
}
|
|
||||||
return &util.JSONResponse{
|
|
||||||
Code: http.StatusNotFound,
|
|
||||||
JSON: jsonerror.NotFound("Unknown auth stage type"),
|
|
||||||
}
|
|
||||||
} else if req.Method == http.MethodPost {
|
} else if req.Method == http.MethodPost {
|
||||||
// Handle Recaptcha
|
// Handle Recaptcha
|
||||||
if authType == authtypes.LoginTypeRecaptcha {
|
|
||||||
if err := checkRecaptchaEnabled(cfg, w, req); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
clientIP := req.RemoteAddr
|
clientIP := req.RemoteAddr
|
||||||
err := req.ParseForm()
|
err := req.ParseForm()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("req.ParseForm failed")
|
util.GetLogger(req.Context()).WithError(err).Error("req.ParseForm failed")
|
||||||
res := jsonerror.InternalServerError()
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
return &res
|
serveRecaptcha()
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response := req.Form.Get(cfg.RecaptchaFormField)
|
response := req.Form.Get(cfg.RecaptchaFormField)
|
||||||
if err := validateRecaptcha(cfg, response, clientIP); err != nil {
|
err = validateRecaptcha(cfg, response, clientIP)
|
||||||
util.GetLogger(req.Context()).Error(err)
|
switch err {
|
||||||
return err
|
case ErrMissingResponse:
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
serveRecaptcha() // serve the initial page again, instead of nothing
|
||||||
|
return
|
||||||
|
case ErrInvalidCaptcha:
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
serveRecaptcha()
|
||||||
|
return
|
||||||
|
case nil:
|
||||||
|
default: // something else failed
|
||||||
|
util.GetLogger(req.Context()).WithError(err).Error("failed to validate recaptcha")
|
||||||
|
serveRecaptcha()
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Success. Add recaptcha as a completed login flow
|
// Success. Add recaptcha as a completed login flow
|
||||||
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha)
|
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha)
|
||||||
|
|
||||||
serveSuccess()
|
serveSuccess()
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
writeHTTPMessage(w, req, "Bad method", http.StatusMethodNotAllowed)
|
||||||
return &util.JSONResponse{
|
|
||||||
Code: http.StatusNotFound,
|
|
||||||
JSON: jsonerror.NotFound("Unknown auth stage type"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &util.JSONResponse{
|
|
||||||
Code: http.StatusMethodNotAllowed,
|
|
||||||
JSON: jsonerror.NotFound("Bad method"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkRecaptchaEnabled creates an error response if recaptcha is not usable on homeserver.
|
|
||||||
func checkRecaptchaEnabled(
|
|
||||||
cfg *config.ClientAPI,
|
|
||||||
w http.ResponseWriter,
|
|
||||||
req *http.Request,
|
|
||||||
) *util.JSONResponse {
|
|
||||||
if !cfg.RecaptchaEnabled {
|
|
||||||
return writeHTTPMessage(w, req,
|
|
||||||
"Recaptcha login is disabled on this Homeserver",
|
|
||||||
http.StatusBadRequest,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeHTTPMessage writes the given header and message to the HTTP response writer.
|
// writeHTTPMessage writes the given header and message to the HTTP response writer.
|
||||||
|
@ -201,13 +189,10 @@ func checkRecaptchaEnabled(
|
||||||
func writeHTTPMessage(
|
func writeHTTPMessage(
|
||||||
w http.ResponseWriter, req *http.Request,
|
w http.ResponseWriter, req *http.Request,
|
||||||
message string, header int,
|
message string, header int,
|
||||||
) *util.JSONResponse {
|
) {
|
||||||
w.WriteHeader(header)
|
w.WriteHeader(header)
|
||||||
_, err := w.Write([]byte(message))
|
_, err := w.Write([]byte(message))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("w.Write failed")
|
util.GetLogger(req.Context()).WithError(err).Error("w.Write failed")
|
||||||
res := jsonerror.InternalServerError()
|
|
||||||
return &res
|
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
149
clientapi/routing/auth_fallback_test.go
Normal file
149
clientapi/routing/auth_fallback_test.go
Normal file
|
@ -0,0 +1,149 @@
|
||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/test/testrig"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_AuthFallback(t *testing.T) {
|
||||||
|
base, _, _ := testrig.Base(nil)
|
||||||
|
defer base.Close()
|
||||||
|
|
||||||
|
for _, useHCaptcha := range []bool{false, true} {
|
||||||
|
for _, recaptchaEnabled := range []bool{false, true} {
|
||||||
|
for _, wantErr := range []bool{false, true} {
|
||||||
|
t.Run(fmt.Sprintf("useHCaptcha(%v) - recaptchaEnabled(%v) - wantErr(%v)", useHCaptcha, recaptchaEnabled, wantErr), func(t *testing.T) {
|
||||||
|
// Set the defaults for each test
|
||||||
|
base.Cfg.ClientAPI.Defaults(config.DefaultOpts{Generate: true, Monolithic: true})
|
||||||
|
base.Cfg.ClientAPI.RecaptchaEnabled = recaptchaEnabled
|
||||||
|
base.Cfg.ClientAPI.RecaptchaPublicKey = "pub"
|
||||||
|
base.Cfg.ClientAPI.RecaptchaPrivateKey = "priv"
|
||||||
|
if useHCaptcha {
|
||||||
|
base.Cfg.ClientAPI.RecaptchaSiteVerifyAPI = "https://hcaptcha.com/siteverify"
|
||||||
|
base.Cfg.ClientAPI.RecaptchaApiJsUrl = "https://js.hcaptcha.com/1/api.js"
|
||||||
|
base.Cfg.ClientAPI.RecaptchaFormField = "h-captcha-response"
|
||||||
|
base.Cfg.ClientAPI.RecaptchaSitekeyClass = "h-captcha"
|
||||||
|
}
|
||||||
|
cfgErrs := &config.ConfigErrors{}
|
||||||
|
base.Cfg.ClientAPI.Verify(cfgErrs, true)
|
||||||
|
if len(*cfgErrs) > 0 {
|
||||||
|
t.Fatalf("(hCaptcha=%v) unexpected config errors: %s", useHCaptcha, cfgErrs.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/?session=1337", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI)
|
||||||
|
if !recaptchaEnabled {
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("unexpected response code: %d, want %d", rec.Code, http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
if rec.Body.String() != "Recaptcha login is disabled on this Homeserver" {
|
||||||
|
t.Fatalf("unexpected response body: %s", rec.Body.String())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if !strings.Contains(rec.Body.String(), base.Cfg.ClientAPI.RecaptchaSitekeyClass) {
|
||||||
|
t.Fatalf("body does not contain %s: %s", base.Cfg.ClientAPI.RecaptchaSitekeyClass, rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if wantErr {
|
||||||
|
_, _ = w.Write([]byte(`{"success":false}`))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, _ = w.Write([]byte(`{"success":true}`))
|
||||||
|
}))
|
||||||
|
defer srv.Close() // nolint: errcheck
|
||||||
|
|
||||||
|
base.Cfg.ClientAPI.RecaptchaSiteVerifyAPI = srv.URL
|
||||||
|
|
||||||
|
// check the result after sending the captcha
|
||||||
|
req = httptest.NewRequest(http.MethodPost, "/?session=1337", nil)
|
||||||
|
req.Form = url.Values{}
|
||||||
|
req.Form.Add(base.Cfg.ClientAPI.RecaptchaFormField, "someRandomValue")
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI)
|
||||||
|
if recaptchaEnabled {
|
||||||
|
if !wantErr {
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("unexpected response code: %d, want %d", rec.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
if rec.Body.String() != successTemplate {
|
||||||
|
t.Fatalf("unexpected response: %s, want %s", rec.Body.String(), successTemplate)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if rec.Code != http.StatusUnauthorized {
|
||||||
|
t.Fatalf("unexpected response code: %d, want %d", rec.Code, http.StatusUnauthorized)
|
||||||
|
}
|
||||||
|
wantString := "Authentication"
|
||||||
|
if !strings.Contains(rec.Body.String(), wantString) {
|
||||||
|
t.Fatalf("expected response to contain '%s', but didn't: %s", wantString, rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("unexpected response code: %d, want %d", rec.Code, http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
if rec.Body.String() != "Recaptcha login is disabled on this Homeserver" {
|
||||||
|
t.Fatalf("unexpected response: %s, want %s", rec.Body.String(), "successTemplate")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("unknown fallbacks are handled correctly", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/?session=1337", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
AuthFallback(rec, req, "DoesNotExist", &base.Cfg.ClientAPI)
|
||||||
|
if rec.Code != http.StatusNotImplemented {
|
||||||
|
t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusNotImplemented)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unknown methods are handled correctly", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodDelete, "/?session=1337", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI)
|
||||||
|
if rec.Code != http.StatusMethodNotAllowed {
|
||||||
|
t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusMethodNotAllowed)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing session parameter is handled correctly", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI)
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing session parameter is handled correctly", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI)
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing 'response' is handled correctly", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/?session=1337", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI)
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
|
@ -23,14 +23,12 @@ import (
|
||||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
type loginResponse struct {
|
type loginResponse struct {
|
||||||
UserID string `json:"user_id"`
|
UserID string `json:"user_id"`
|
||||||
AccessToken string `json:"access_token"`
|
AccessToken string `json:"access_token"`
|
||||||
HomeServer gomatrixserverlib.ServerName `json:"home_server"`
|
|
||||||
DeviceID string `json:"device_id"`
|
DeviceID string `json:"device_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -116,7 +114,6 @@ func completeAuth(
|
||||||
JSON: loginResponse{
|
JSON: loginResponse{
|
||||||
UserID: performRes.Device.UserID,
|
UserID: performRes.Device.UserID,
|
||||||
AccessToken: performRes.Device.AccessToken,
|
AccessToken: performRes.Device.AccessToken,
|
||||||
HomeServer: serverName,
|
|
||||||
DeviceID: performRes.Device.ID,
|
DeviceID: performRes.Device.ID,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -82,8 +82,8 @@ func Password(
|
||||||
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
|
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
|
||||||
|
|
||||||
// Check the new password strength.
|
// Check the new password strength.
|
||||||
if resErr = internal.ValidatePassword(r.NewPassword); resErr != nil {
|
if err := internal.ValidatePassword(r.NewPassword); err != nil {
|
||||||
return *resErr
|
return *internal.PasswordResponse(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the local part.
|
// Get the local part.
|
||||||
|
|
|
@ -18,12 +18,12 @@ package routing
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"regexp"
|
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -60,10 +60,7 @@ var (
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const sessionIDLength = 24
|
||||||
maxUsernameLength = 254 // http://matrix.org/speculator/spec/HEAD/intro.html#user-identifiers TODO account for domain
|
|
||||||
sessionIDLength = 24
|
|
||||||
)
|
|
||||||
|
|
||||||
// sessionsDict keeps track of completed auth stages for each session.
|
// sessionsDict keeps track of completed auth stages for each session.
|
||||||
// It shouldn't be passed by value because it contains a mutex.
|
// It shouldn't be passed by value because it contains a mutex.
|
||||||
|
@ -199,7 +196,6 @@ func (d *sessionsDict) getDeviceToDelete(sessionID string) (string, bool) {
|
||||||
|
|
||||||
var (
|
var (
|
||||||
sessions = newSessionsDict()
|
sessions = newSessionsDict()
|
||||||
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// registerRequest represents the submitted registration request.
|
// registerRequest represents the submitted registration request.
|
||||||
|
@ -264,7 +260,6 @@ func newUserInteractiveResponse(
|
||||||
type registerResponse struct {
|
type registerResponse struct {
|
||||||
UserID string `json:"user_id"`
|
UserID string `json:"user_id"`
|
||||||
AccessToken string `json:"access_token,omitempty"`
|
AccessToken string `json:"access_token,omitempty"`
|
||||||
HomeServer gomatrixserverlib.ServerName `json:"home_server"`
|
|
||||||
DeviceID string `json:"device_id,omitempty"`
|
DeviceID string `json:"device_id,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -276,66 +271,28 @@ type recaptchaResponse struct {
|
||||||
ErrorCodes []int `json:"error-codes"`
|
ErrorCodes []int `json:"error-codes"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateUsername returns an error response if the username is invalid
|
var (
|
||||||
func validateUsername(localpart string, domain gomatrixserverlib.ServerName) *util.JSONResponse {
|
ErrInvalidCaptcha = errors.New("invalid captcha response")
|
||||||
// https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
|
ErrMissingResponse = errors.New("captcha response is required")
|
||||||
if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength {
|
ErrCaptchaDisabled = errors.New("captcha registration is disabled")
|
||||||
return &util.JSONResponse{
|
)
|
||||||
Code: http.StatusBadRequest,
|
|
||||||
JSON: jsonerror.BadJSON(fmt.Sprintf("%q exceeds the maximum length of %d characters", id, maxUsernameLength)),
|
|
||||||
}
|
|
||||||
} else if !validUsernameRegex.MatchString(localpart) {
|
|
||||||
return &util.JSONResponse{
|
|
||||||
Code: http.StatusBadRequest,
|
|
||||||
JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"),
|
|
||||||
}
|
|
||||||
} else if localpart[0] == '_' { // Regex checks its not a zero length string
|
|
||||||
return &util.JSONResponse{
|
|
||||||
Code: http.StatusBadRequest,
|
|
||||||
JSON: jsonerror.InvalidUsername("Username cannot start with a '_'"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateApplicationServiceUsername returns an error response if the username is invalid for an application service
|
|
||||||
func validateApplicationServiceUsername(localpart string, domain gomatrixserverlib.ServerName) *util.JSONResponse {
|
|
||||||
if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength {
|
|
||||||
return &util.JSONResponse{
|
|
||||||
Code: http.StatusBadRequest,
|
|
||||||
JSON: jsonerror.BadJSON(fmt.Sprintf("%q exceeds the maximum length of %d characters", id, maxUsernameLength)),
|
|
||||||
}
|
|
||||||
} else if !validUsernameRegex.MatchString(localpart) {
|
|
||||||
return &util.JSONResponse{
|
|
||||||
Code: http.StatusBadRequest,
|
|
||||||
JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateRecaptcha returns an error response if the captcha response is invalid
|
// validateRecaptcha returns an error response if the captcha response is invalid
|
||||||
func validateRecaptcha(
|
func validateRecaptcha(
|
||||||
cfg *config.ClientAPI,
|
cfg *config.ClientAPI,
|
||||||
response string,
|
response string,
|
||||||
clientip string,
|
clientip string,
|
||||||
) *util.JSONResponse {
|
) error {
|
||||||
ip, _, _ := net.SplitHostPort(clientip)
|
ip, _, _ := net.SplitHostPort(clientip)
|
||||||
if !cfg.RecaptchaEnabled {
|
if !cfg.RecaptchaEnabled {
|
||||||
return &util.JSONResponse{
|
return ErrCaptchaDisabled
|
||||||
Code: http.StatusConflict,
|
|
||||||
JSON: jsonerror.Unknown("Captcha registration is disabled"),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if response == "" {
|
if response == "" {
|
||||||
return &util.JSONResponse{
|
return ErrMissingResponse
|
||||||
Code: http.StatusBadRequest,
|
|
||||||
JSON: jsonerror.BadJSON("Captcha response is required"),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make a POST request to Google's API to check the captcha response
|
// Make a POST request to the captcha provider API to check the captcha response
|
||||||
resp, err := http.PostForm(cfg.RecaptchaSiteVerifyAPI,
|
resp, err := http.PostForm(cfg.RecaptchaSiteVerifyAPI,
|
||||||
url.Values{
|
url.Values{
|
||||||
"secret": {cfg.RecaptchaPrivateKey},
|
"secret": {cfg.RecaptchaPrivateKey},
|
||||||
|
@ -345,10 +302,7 @@ func validateRecaptcha(
|
||||||
)
|
)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &util.JSONResponse{
|
return err
|
||||||
Code: http.StatusInternalServerError,
|
|
||||||
JSON: jsonerror.BadJSON("Error in requesting validation of captcha response"),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close the request once we're finishing reading from it
|
// Close the request once we're finishing reading from it
|
||||||
|
@ -358,25 +312,16 @@ func validateRecaptcha(
|
||||||
var r recaptchaResponse
|
var r recaptchaResponse
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &util.JSONResponse{
|
return err
|
||||||
Code: http.StatusGatewayTimeout,
|
|
||||||
JSON: jsonerror.Unknown("Error in contacting captcha server" + err.Error()),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
err = json.Unmarshal(body, &r)
|
err = json.Unmarshal(body, &r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &util.JSONResponse{
|
return err
|
||||||
Code: http.StatusInternalServerError,
|
|
||||||
JSON: jsonerror.BadJSON("Error in unmarshaling captcha server's response: " + err.Error()),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check that we received a "success"
|
// Check that we received a "success"
|
||||||
if !r.Success {
|
if !r.Success {
|
||||||
return &util.JSONResponse{
|
return ErrInvalidCaptcha
|
||||||
Code: http.StatusUnauthorized,
|
|
||||||
JSON: jsonerror.BadJSON("Invalid captcha response. Please try again."),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -508,8 +453,8 @@ func validateApplicationService(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check username application service is trying to register is valid
|
// Check username application service is trying to register is valid
|
||||||
if err := validateApplicationServiceUsername(username, cfg.Matrix.ServerName); err != nil {
|
if err := internal.ValidateApplicationServiceUsername(username, cfg.Matrix.ServerName); err != nil {
|
||||||
return "", err
|
return "", internal.UsernameResponse(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// No errors, registration valid
|
// No errors, registration valid
|
||||||
|
@ -564,15 +509,12 @@ func Register(
|
||||||
if resErr := httputil.UnmarshalJSON(reqBody, &r); resErr != nil {
|
if resErr := httputil.UnmarshalJSON(reqBody, &r); resErr != nil {
|
||||||
return *resErr
|
return *resErr
|
||||||
}
|
}
|
||||||
if l, d, err := cfg.Matrix.SplitLocalID('@', r.Username); err == nil {
|
|
||||||
r.Username, r.ServerName = l, d
|
|
||||||
}
|
|
||||||
if req.URL.Query().Get("kind") == "guest" {
|
if req.URL.Query().Get("kind") == "guest" {
|
||||||
return handleGuestRegistration(req, r, cfg, userAPI)
|
return handleGuestRegistration(req, r, cfg, userAPI)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Don't allow numeric usernames less than MAX_INT64.
|
// Don't allow numeric usernames less than MAX_INT64.
|
||||||
if _, err := strconv.ParseInt(r.Username, 10, 64); err == nil {
|
if _, err = strconv.ParseInt(r.Username, 10, 64); err == nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusBadRequest,
|
Code: http.StatusBadRequest,
|
||||||
JSON: jsonerror.InvalidUsername("Numeric user IDs are reserved"),
|
JSON: jsonerror.InvalidUsername("Numeric user IDs are reserved"),
|
||||||
|
@ -584,7 +526,7 @@ func Register(
|
||||||
ServerName: r.ServerName,
|
ServerName: r.ServerName,
|
||||||
}
|
}
|
||||||
nres := &userapi.QueryNumericLocalpartResponse{}
|
nres := &userapi.QueryNumericLocalpartResponse{}
|
||||||
if err := userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil {
|
if err = userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed")
|
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
@ -601,8 +543,8 @@ func Register(
|
||||||
case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil:
|
case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil:
|
||||||
// Spec-compliant case (the access_token is specified and the login type
|
// Spec-compliant case (the access_token is specified and the login type
|
||||||
// is correctly set, so it's an appservice registration)
|
// is correctly set, so it's an appservice registration)
|
||||||
if resErr := validateApplicationServiceUsername(r.Username, r.ServerName); resErr != nil {
|
if err = internal.ValidateApplicationServiceUsername(r.Username, r.ServerName); err != nil {
|
||||||
return *resErr
|
return *internal.UsernameResponse(err)
|
||||||
}
|
}
|
||||||
case accessTokenErr == nil:
|
case accessTokenErr == nil:
|
||||||
// Non-spec-compliant case (the access_token is specified but the login
|
// Non-spec-compliant case (the access_token is specified but the login
|
||||||
|
@ -614,12 +556,12 @@ func Register(
|
||||||
default:
|
default:
|
||||||
// Spec-compliant case (neither the access_token nor the login type are
|
// Spec-compliant case (neither the access_token nor the login type are
|
||||||
// specified, so it's a normal user registration)
|
// specified, so it's a normal user registration)
|
||||||
if resErr := validateUsername(r.Username, r.ServerName); resErr != nil {
|
if err = internal.ValidateUsername(r.Username, r.ServerName); err != nil {
|
||||||
return *resErr
|
return *internal.UsernameResponse(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if resErr := internal.ValidatePassword(r.Password); resErr != nil {
|
if err = internal.ValidatePassword(r.Password); err != nil {
|
||||||
return *resErr
|
return *internal.PasswordResponse(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
logger := util.GetLogger(req.Context())
|
logger := util.GetLogger(req.Context())
|
||||||
|
@ -697,7 +639,6 @@ func handleGuestRegistration(
|
||||||
JSON: registerResponse{
|
JSON: registerResponse{
|
||||||
UserID: devRes.Device.UserID,
|
UserID: devRes.Device.UserID,
|
||||||
AccessToken: devRes.Device.AccessToken,
|
AccessToken: devRes.Device.AccessToken,
|
||||||
HomeServer: res.Account.ServerName,
|
|
||||||
DeviceID: devRes.Device.ID,
|
DeviceID: devRes.Device.ID,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -761,9 +702,18 @@ func handleRegistrationFlow(
|
||||||
switch r.Auth.Type {
|
switch r.Auth.Type {
|
||||||
case authtypes.LoginTypeRecaptcha:
|
case authtypes.LoginTypeRecaptcha:
|
||||||
// Check given captcha response
|
// Check given captcha response
|
||||||
resErr := validateRecaptcha(cfg, r.Auth.Response, req.RemoteAddr)
|
err := validateRecaptcha(cfg, r.Auth.Response, req.RemoteAddr)
|
||||||
if resErr != nil {
|
switch err {
|
||||||
return *resErr
|
case ErrCaptchaDisabled:
|
||||||
|
return util.JSONResponse{Code: http.StatusForbidden, JSON: jsonerror.Unknown(err.Error())}
|
||||||
|
case ErrMissingResponse:
|
||||||
|
return util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(err.Error())}
|
||||||
|
case ErrInvalidCaptcha:
|
||||||
|
return util.JSONResponse{Code: http.StatusUnauthorized, JSON: jsonerror.BadJSON(err.Error())}
|
||||||
|
case nil:
|
||||||
|
default:
|
||||||
|
util.GetLogger(req.Context()).WithError(err).Error("failed to validate recaptcha")
|
||||||
|
return util.JSONResponse{Code: http.StatusInternalServerError, JSON: jsonerror.InternalServerError()}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add Recaptcha to the list of completed registration stages
|
// Add Recaptcha to the list of completed registration stages
|
||||||
|
@ -925,7 +875,6 @@ func completeRegistration(
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
JSON: registerResponse{
|
JSON: registerResponse{
|
||||||
UserID: userutil.MakeUserID(username, accRes.Account.ServerName),
|
UserID: userutil.MakeUserID(username, accRes.Account.ServerName),
|
||||||
HomeServer: accRes.Account.ServerName,
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -958,7 +907,6 @@ func completeRegistration(
|
||||||
result := registerResponse{
|
result := registerResponse{
|
||||||
UserID: devRes.Device.UserID,
|
UserID: devRes.Device.UserID,
|
||||||
AccessToken: devRes.Device.AccessToken,
|
AccessToken: devRes.Device.AccessToken,
|
||||||
HomeServer: accRes.Account.ServerName,
|
|
||||||
DeviceID: devRes.Device.ID,
|
DeviceID: devRes.Device.ID,
|
||||||
}
|
}
|
||||||
sessions.addCompletedRegistration(sessionID, result)
|
sessions.addCompletedRegistration(sessionID, result)
|
||||||
|
@ -1054,8 +1002,8 @@ func RegisterAvailable(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := validateUsername(username, domain); err != nil {
|
if err := internal.ValidateUsername(username, domain); err != nil {
|
||||||
return *err
|
return *internal.UsernameResponse(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if this username is reserved by an application service
|
// Check if this username is reserved by an application service
|
||||||
|
@ -1117,11 +1065,11 @@ func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.Clien
|
||||||
// downcase capitals
|
// downcase capitals
|
||||||
ssrr.User = strings.ToLower(ssrr.User)
|
ssrr.User = strings.ToLower(ssrr.User)
|
||||||
|
|
||||||
if resErr := validateUsername(ssrr.User, cfg.Matrix.ServerName); resErr != nil {
|
if err = internal.ValidateUsername(ssrr.User, cfg.Matrix.ServerName); err != nil {
|
||||||
return *resErr
|
return *internal.UsernameResponse(err)
|
||||||
}
|
}
|
||||||
if resErr := internal.ValidatePassword(ssrr.Password); resErr != nil {
|
if err = internal.ValidatePassword(ssrr.Password); err != nil {
|
||||||
return *resErr
|
return *internal.PasswordResponse(err)
|
||||||
}
|
}
|
||||||
deviceID := "shared_secret_registration"
|
deviceID := "shared_secret_registration"
|
||||||
|
|
||||||
|
|
|
@ -15,12 +15,27 @@
|
||||||
package routing
|
package routing
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/keyserver"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/test"
|
||||||
|
"github.com/matrix-org/dendrite/test/testrig"
|
||||||
|
"github.com/matrix-org/dendrite/userapi"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -264,3 +279,294 @@ func TestSessionCleanUp(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_register(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
kind string
|
||||||
|
password string
|
||||||
|
username string
|
||||||
|
loginType string
|
||||||
|
forceEmpty bool
|
||||||
|
registrationDisabled bool
|
||||||
|
guestsDisabled bool
|
||||||
|
enableRecaptcha bool
|
||||||
|
captchaBody string
|
||||||
|
wantResponse util.JSONResponse
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "disallow guests",
|
||||||
|
kind: "guest",
|
||||||
|
guestsDisabled: true,
|
||||||
|
wantResponse: util.JSONResponse{
|
||||||
|
Code: http.StatusForbidden,
|
||||||
|
JSON: jsonerror.Forbidden(`Guest registration is disabled on "test"`),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allow guests",
|
||||||
|
kind: "guest",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown login type",
|
||||||
|
loginType: "im.not.known",
|
||||||
|
wantResponse: util.JSONResponse{
|
||||||
|
Code: http.StatusNotImplemented,
|
||||||
|
JSON: jsonerror.Unknown("unknown/unimplemented auth type"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "disabled registration",
|
||||||
|
registrationDisabled: true,
|
||||||
|
wantResponse: util.JSONResponse{
|
||||||
|
Code: http.StatusForbidden,
|
||||||
|
JSON: jsonerror.Forbidden(`Registration is disabled on "test"`),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "successful registration, numeric ID",
|
||||||
|
username: "",
|
||||||
|
password: "someRandomPassword",
|
||||||
|
forceEmpty: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "successful registration",
|
||||||
|
username: "success",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "failing registration - user already exists",
|
||||||
|
username: "success",
|
||||||
|
wantResponse: util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.UserInUse("Desired user ID is already taken."),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "successful registration uppercase username",
|
||||||
|
username: "LOWERCASED", // this is going to be lower-cased
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid username",
|
||||||
|
username: "#totalyNotValid",
|
||||||
|
wantResponse: *internal.UsernameResponse(internal.ErrUsernameInvalid),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "numeric username is forbidden",
|
||||||
|
username: "1337",
|
||||||
|
wantResponse: util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.InvalidUsername("Numeric user IDs are reserved"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "disabled recaptcha login",
|
||||||
|
loginType: authtypes.LoginTypeRecaptcha,
|
||||||
|
wantResponse: util.JSONResponse{
|
||||||
|
Code: http.StatusForbidden,
|
||||||
|
JSON: jsonerror.Unknown(ErrCaptchaDisabled.Error()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "enabled recaptcha, no response defined",
|
||||||
|
enableRecaptcha: true,
|
||||||
|
loginType: authtypes.LoginTypeRecaptcha,
|
||||||
|
wantResponse: util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.BadJSON(ErrMissingResponse.Error()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid captcha response",
|
||||||
|
enableRecaptcha: true,
|
||||||
|
loginType: authtypes.LoginTypeRecaptcha,
|
||||||
|
captchaBody: `notvalid`,
|
||||||
|
wantResponse: util.JSONResponse{
|
||||||
|
Code: http.StatusUnauthorized,
|
||||||
|
JSON: jsonerror.BadJSON(ErrInvalidCaptcha.Error()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid captcha response",
|
||||||
|
enableRecaptcha: true,
|
||||||
|
loginType: authtypes.LoginTypeRecaptcha,
|
||||||
|
captchaBody: `success`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "captcha invalid from remote",
|
||||||
|
enableRecaptcha: true,
|
||||||
|
loginType: authtypes.LoginTypeRecaptcha,
|
||||||
|
captchaBody: `i should fail for other reasons`,
|
||||||
|
wantResponse: util.JSONResponse{Code: http.StatusInternalServerError, JSON: jsonerror.InternalServerError()},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
base, baseClose := testrig.CreateBaseDendrite(t, dbType)
|
||||||
|
defer baseClose()
|
||||||
|
|
||||||
|
rsAPI := roomserver.NewInternalAPI(base)
|
||||||
|
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI)
|
||||||
|
userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil)
|
||||||
|
keyAPI.SetUserAPI(userAPI)
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
if tc.enableRecaptcha {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
response := r.Form.Get("response")
|
||||||
|
|
||||||
|
// Respond with valid JSON or no JSON at all to test happy/error cases
|
||||||
|
switch response {
|
||||||
|
case "success":
|
||||||
|
json.NewEncoder(w).Encode(recaptchaResponse{Success: true})
|
||||||
|
case "notvalid":
|
||||||
|
json.NewEncoder(w).Encode(recaptchaResponse{Success: false})
|
||||||
|
default:
|
||||||
|
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
base.Cfg.ClientAPI.RecaptchaSiteVerifyAPI = srv.URL
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := base.Cfg.Derive(); err != nil {
|
||||||
|
t.Fatalf("failed to derive config: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
base.Cfg.ClientAPI.RecaptchaEnabled = tc.enableRecaptcha
|
||||||
|
base.Cfg.ClientAPI.RegistrationDisabled = tc.registrationDisabled
|
||||||
|
base.Cfg.ClientAPI.GuestsDisabled = tc.guestsDisabled
|
||||||
|
|
||||||
|
if tc.kind == "" {
|
||||||
|
tc.kind = "user"
|
||||||
|
}
|
||||||
|
if tc.password == "" && !tc.forceEmpty {
|
||||||
|
tc.password = "someRandomPassword"
|
||||||
|
}
|
||||||
|
if tc.username == "" && !tc.forceEmpty {
|
||||||
|
tc.username = "valid"
|
||||||
|
}
|
||||||
|
if tc.loginType == "" {
|
||||||
|
tc.loginType = "m.login.dummy"
|
||||||
|
}
|
||||||
|
|
||||||
|
reg := registerRequest{
|
||||||
|
Password: tc.password,
|
||||||
|
Username: tc.username,
|
||||||
|
}
|
||||||
|
|
||||||
|
body := &bytes.Buffer{}
|
||||||
|
err := json.NewEncoder(body).Encode(reg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("/?kind=%s", tc.kind), body)
|
||||||
|
|
||||||
|
resp := Register(req, userAPI, &base.Cfg.ClientAPI)
|
||||||
|
t.Logf("Resp: %+v", resp)
|
||||||
|
|
||||||
|
// The first request should return a userInteractiveResponse
|
||||||
|
switch r := resp.JSON.(type) {
|
||||||
|
case userInteractiveResponse:
|
||||||
|
// Check that the flows are the ones we configured
|
||||||
|
if !reflect.DeepEqual(r.Flows, base.Cfg.Derived.Registration.Flows) {
|
||||||
|
t.Fatalf("unexpected registration flows: %+v, want %+v", r.Flows, base.Cfg.Derived.Registration.Flows)
|
||||||
|
}
|
||||||
|
case *jsonerror.MatrixError:
|
||||||
|
if !reflect.DeepEqual(tc.wantResponse, resp) {
|
||||||
|
t.Fatalf("(%s), unexpected response: %+v, want: %+v", tc.name, resp, tc.wantResponse)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case registerResponse:
|
||||||
|
// this should only be possible on guest user registration, never for normal users
|
||||||
|
if tc.kind != "guest" {
|
||||||
|
t.Fatalf("got register response on first request: %+v", r)
|
||||||
|
}
|
||||||
|
// assert we've got a UserID, AccessToken and DeviceID
|
||||||
|
if r.UserID == "" {
|
||||||
|
t.Fatalf("missing userID in response")
|
||||||
|
}
|
||||||
|
if r.AccessToken == "" {
|
||||||
|
t.Fatalf("missing accessToken in response")
|
||||||
|
}
|
||||||
|
if r.DeviceID == "" {
|
||||||
|
t.Fatalf("missing deviceID in response")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
t.Logf("Got response: %T", resp.JSON)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we reached this, we should have received a UIA response
|
||||||
|
uia, ok := resp.JSON.(userInteractiveResponse)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("did not receive a userInteractiveResponse: %T", resp.JSON)
|
||||||
|
}
|
||||||
|
t.Logf("%+v", uia)
|
||||||
|
|
||||||
|
// Register the user
|
||||||
|
reg.Auth = authDict{
|
||||||
|
Type: authtypes.LoginType(tc.loginType),
|
||||||
|
Session: uia.Session,
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.captchaBody != "" {
|
||||||
|
reg.Auth.Response = tc.captchaBody
|
||||||
|
}
|
||||||
|
|
||||||
|
dummy := "dummy"
|
||||||
|
reg.DeviceID = &dummy
|
||||||
|
reg.InitialDisplayName = &dummy
|
||||||
|
reg.Type = authtypes.LoginType(tc.loginType)
|
||||||
|
|
||||||
|
err = json.NewEncoder(body).Encode(reg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req = httptest.NewRequest(http.MethodPost, "/", body)
|
||||||
|
|
||||||
|
resp = Register(req, userAPI, &base.Cfg.ClientAPI)
|
||||||
|
|
||||||
|
switch resp.JSON.(type) {
|
||||||
|
case *jsonerror.MatrixError:
|
||||||
|
if !reflect.DeepEqual(tc.wantResponse, resp) {
|
||||||
|
t.Fatalf("unexpected response: %+v, want: %+v", resp, tc.wantResponse)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case util.JSONResponse:
|
||||||
|
if !reflect.DeepEqual(tc.wantResponse, resp) {
|
||||||
|
t.Fatalf("unexpected response: %+v, want: %+v", resp, tc.wantResponse)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rr, ok := resp.JSON.(registerResponse)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected a registerresponse, got %T", resp.JSON)
|
||||||
|
}
|
||||||
|
|
||||||
|
// validate the response
|
||||||
|
if tc.forceEmpty {
|
||||||
|
// when not supplying a username, one will be generated. Given this _SHOULD_ be
|
||||||
|
// the second user, set the username accordingly
|
||||||
|
reg.Username = "2"
|
||||||
|
}
|
||||||
|
wantUserID := strings.ToLower(fmt.Sprintf("@%s:%s", reg.Username, "test"))
|
||||||
|
if wantUserID != rr.UserID {
|
||||||
|
t.Fatalf("unexpected userID: %s, want %s", rr.UserID, wantUserID)
|
||||||
|
}
|
||||||
|
if rr.DeviceID != *reg.DeviceID {
|
||||||
|
t.Fatalf("unexpected deviceID: %s, want %s", rr.DeviceID, *reg.DeviceID)
|
||||||
|
}
|
||||||
|
if rr.AccessToken == "" {
|
||||||
|
t.Fatalf("missing accessToken in response")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -639,9 +639,9 @@ func Setup(
|
||||||
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
v3mux.Handle("/auth/{authType}/fallback/web",
|
v3mux.Handle("/auth/{authType}/fallback/web",
|
||||||
httputil.MakeHTMLAPI("auth_fallback", base.EnableMetrics, func(w http.ResponseWriter, req *http.Request) *util.JSONResponse {
|
httputil.MakeHTMLAPI("auth_fallback", base.EnableMetrics, func(w http.ResponseWriter, req *http.Request) {
|
||||||
vars := mux.Vars(req)
|
vars := mux.Vars(req)
|
||||||
return AuthFallback(w, req, vars["authType"], cfg)
|
AuthFallback(w, req, vars["authType"], cfg)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
|
|
|
@ -25,10 +25,10 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"regexp"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
@ -65,7 +65,6 @@ var (
|
||||||
isAdmin = flag.Bool("admin", false, "Create an admin account")
|
isAdmin = flag.Bool("admin", false, "Create an admin account")
|
||||||
resetPassword = flag.Bool("reset-password", false, "Deprecated")
|
resetPassword = flag.Bool("reset-password", false, "Deprecated")
|
||||||
serverURL = flag.String("url", "http://localhost:8008", "The URL to connect to.")
|
serverURL = flag.String("url", "http://localhost:8008", "The URL to connect to.")
|
||||||
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`)
|
|
||||||
timeout = flag.Duration("timeout", time.Second*30, "Timeout for the http client when connecting to the server")
|
timeout = flag.Duration("timeout", time.Second*30, "Timeout for the http client when connecting to the server")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -95,20 +94,21 @@ func main() {
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !validUsernameRegex.MatchString(*username) {
|
if err := internal.ValidateUsername(*username, cfg.Global.ServerName); err != nil {
|
||||||
logrus.Warn("Username can only contain characters a-z, 0-9, or '_-./='")
|
logrus.WithError(err).Error("Specified username is invalid")
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(fmt.Sprintf("@%s:%s", *username, cfg.Global.ServerName)) > 255 {
|
|
||||||
logrus.Fatalf("Username can not be longer than 255 characters: %s", fmt.Sprintf("@%s:%s", *username, cfg.Global.ServerName))
|
|
||||||
}
|
|
||||||
|
|
||||||
pass, err := getPassword(*password, *pwdFile, *pwdStdin, os.Stdin)
|
pass, err := getPassword(*password, *pwdFile, *pwdStdin, os.Stdin)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Fatalln(err)
|
logrus.Fatalln(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err = internal.ValidatePassword(pass); err != nil {
|
||||||
|
logrus.WithError(err).Error("Specified password is invalid")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
cl.Timeout = *timeout
|
cl.Timeout = *timeout
|
||||||
|
|
||||||
accessToken, err := sharedSecretRegister(cfg.ClientAPI.RegistrationSharedSecret, *serverURL, *username, pass, *isAdmin)
|
accessToken, err := sharedSecretRegister(cfg.ClientAPI.RegistrationSharedSecret, *serverURL, *username, pass, *isAdmin)
|
||||||
|
|
|
@ -198,17 +198,12 @@ func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse
|
||||||
|
|
||||||
// MakeHTMLAPI adds Span metrics to the HTML Handler function
|
// MakeHTMLAPI adds Span metrics to the HTML Handler function
|
||||||
// This is used to serve HTML alongside JSON error messages
|
// This is used to serve HTML alongside JSON error messages
|
||||||
func MakeHTMLAPI(metricsName string, enableMetrics bool, f func(http.ResponseWriter, *http.Request) *util.JSONResponse) http.Handler {
|
func MakeHTMLAPI(metricsName string, enableMetrics bool, f func(http.ResponseWriter, *http.Request)) http.Handler {
|
||||||
withSpan := func(w http.ResponseWriter, req *http.Request) {
|
withSpan := func(w http.ResponseWriter, req *http.Request) {
|
||||||
span := opentracing.StartSpan(metricsName)
|
span := opentracing.StartSpan(metricsName)
|
||||||
defer span.Finish()
|
defer span.Finish()
|
||||||
req = req.WithContext(opentracing.ContextWithSpan(req.Context(), span))
|
req = req.WithContext(opentracing.ContextWithSpan(req.Context(), span))
|
||||||
if err := f(w, req); err != nil {
|
f(w, req)
|
||||||
h := util.MakeJSONAPI(util.NewJSONRequestHandler(func(req *http.Request) util.JSONResponse {
|
|
||||||
return *err
|
|
||||||
}))
|
|
||||||
h.ServeHTTP(w, req)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !enableMetrics {
|
if !enableMetrics {
|
||||||
|
|
|
@ -15,30 +15,96 @@
|
||||||
package internal
|
package internal
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"regexp"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
const minPasswordLength = 8 // http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based
|
const (
|
||||||
|
maxUsernameLength = 254 // http://matrix.org/speculator/spec/HEAD/intro.html#user-identifiers TODO account for domain
|
||||||
|
|
||||||
const maxPasswordLength = 512 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
|
minPasswordLength = 8 // http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based
|
||||||
|
maxPasswordLength = 512 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
|
||||||
|
)
|
||||||
|
|
||||||
// ValidatePassword returns an error response if the password is invalid
|
var (
|
||||||
func ValidatePassword(password string) *util.JSONResponse {
|
ErrPasswordTooLong = fmt.Errorf("password too long: max %d characters", maxPasswordLength)
|
||||||
|
ErrPasswordWeak = fmt.Errorf("password too weak: min %d characters", minPasswordLength)
|
||||||
|
ErrUsernameTooLong = fmt.Errorf("username exceeds the maximum length of %d characters", maxUsernameLength)
|
||||||
|
ErrUsernameInvalid = errors.New("username can only contain characters a-z, 0-9, or '_-./='")
|
||||||
|
ErrUsernameUnderscore = errors.New("username cannot start with a '_'")
|
||||||
|
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`)
|
||||||
|
)
|
||||||
|
|
||||||
|
// ValidatePassword returns an error if the password is invalid
|
||||||
|
func ValidatePassword(password string) error {
|
||||||
// https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
|
// https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
|
||||||
if len(password) > maxPasswordLength {
|
if len(password) > maxPasswordLength {
|
||||||
return &util.JSONResponse{
|
return ErrPasswordTooLong
|
||||||
Code: http.StatusBadRequest,
|
|
||||||
JSON: jsonerror.BadJSON(fmt.Sprintf("password too long: max %d characters", maxPasswordLength)),
|
|
||||||
}
|
|
||||||
} else if len(password) > 0 && len(password) < minPasswordLength {
|
} else if len(password) > 0 && len(password) < minPasswordLength {
|
||||||
|
return ErrPasswordWeak
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PasswordResponse returns a util.JSONResponse for a given error, if any.
|
||||||
|
func PasswordResponse(err error) *util.JSONResponse {
|
||||||
|
switch err {
|
||||||
|
case ErrPasswordWeak:
|
||||||
return &util.JSONResponse{
|
return &util.JSONResponse{
|
||||||
Code: http.StatusBadRequest,
|
Code: http.StatusBadRequest,
|
||||||
JSON: jsonerror.WeakPassword(fmt.Sprintf("password too weak: min %d chars", minPasswordLength)),
|
JSON: jsonerror.WeakPassword(ErrPasswordWeak.Error()),
|
||||||
|
}
|
||||||
|
case ErrPasswordTooLong:
|
||||||
|
return &util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.BadJSON(ErrPasswordTooLong.Error()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ValidateUsername returns an error if the username is invalid
|
||||||
|
func ValidateUsername(localpart string, domain gomatrixserverlib.ServerName) error {
|
||||||
|
// https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
|
||||||
|
if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength {
|
||||||
|
return ErrUsernameTooLong
|
||||||
|
} else if !validUsernameRegex.MatchString(localpart) {
|
||||||
|
return ErrUsernameInvalid
|
||||||
|
} else if localpart[0] == '_' { // Regex checks its not a zero length string
|
||||||
|
return ErrUsernameUnderscore
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UsernameResponse returns a util.JSONResponse for the given error, if any.
|
||||||
|
func UsernameResponse(err error) *util.JSONResponse {
|
||||||
|
switch err {
|
||||||
|
case ErrUsernameTooLong:
|
||||||
|
return &util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.BadJSON(err.Error()),
|
||||||
|
}
|
||||||
|
case ErrUsernameInvalid, ErrUsernameUnderscore:
|
||||||
|
return &util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.InvalidUsername(err.Error()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateApplicationServiceUsername returns an error if the username is invalid for an application service
|
||||||
|
func ValidateApplicationServiceUsername(localpart string, domain gomatrixserverlib.ServerName) error {
|
||||||
|
if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength {
|
||||||
|
return ErrUsernameTooLong
|
||||||
|
} else if !validUsernameRegex.MatchString(localpart) {
|
||||||
|
return ErrUsernameInvalid
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
170
internal/validate_test.go
Normal file
170
internal/validate_test.go
Normal file
|
@ -0,0 +1,170 @@
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_validatePassword(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
password string
|
||||||
|
wantError error
|
||||||
|
wantJSON *util.JSONResponse
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "password too short",
|
||||||
|
password: "shortpw",
|
||||||
|
wantError: ErrPasswordWeak,
|
||||||
|
wantJSON: &util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.WeakPassword(ErrPasswordWeak.Error())},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "password too long",
|
||||||
|
password: strings.Repeat("a", maxPasswordLength+1),
|
||||||
|
wantError: ErrPasswordTooLong,
|
||||||
|
wantJSON: &util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(ErrPasswordTooLong.Error())},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "password OK",
|
||||||
|
password: util.RandomString(10),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
gotErr := ValidatePassword(tt.password)
|
||||||
|
if !reflect.DeepEqual(gotErr, tt.wantError) {
|
||||||
|
t.Errorf("validatePassword() = %v, wantJSON %v", gotErr, tt.wantError)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := PasswordResponse(gotErr); !reflect.DeepEqual(got, tt.wantJSON) {
|
||||||
|
t.Errorf("validatePassword() = %v, wantJSON %v", got, tt.wantJSON)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_validateUsername(t *testing.T) {
|
||||||
|
tooLongUsername := strings.Repeat("a", maxUsernameLength)
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
localpart string
|
||||||
|
domain gomatrixserverlib.ServerName
|
||||||
|
wantErr error
|
||||||
|
wantJSON *util.JSONResponse
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty username",
|
||||||
|
localpart: "",
|
||||||
|
domain: "localhost",
|
||||||
|
wantErr: ErrUsernameInvalid,
|
||||||
|
wantJSON: &util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid username",
|
||||||
|
localpart: "INVALIDUSERNAME",
|
||||||
|
domain: "localhost",
|
||||||
|
wantErr: ErrUsernameInvalid,
|
||||||
|
wantJSON: &util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "username too long",
|
||||||
|
localpart: tooLongUsername,
|
||||||
|
domain: "localhost",
|
||||||
|
wantErr: ErrUsernameTooLong,
|
||||||
|
wantJSON: &util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.BadJSON(ErrUsernameTooLong.Error()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "localpart starting with an underscore",
|
||||||
|
localpart: "_notvalid",
|
||||||
|
domain: "localhost",
|
||||||
|
wantErr: ErrUsernameUnderscore,
|
||||||
|
wantJSON: &util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.InvalidUsername(ErrUsernameUnderscore.Error()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid username",
|
||||||
|
localpart: "valid",
|
||||||
|
domain: "localhost",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "complex username",
|
||||||
|
localpart: "f00_bar-baz.=40/",
|
||||||
|
domain: "localhost",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "rejects emoji username 💥",
|
||||||
|
localpart: "💥",
|
||||||
|
domain: "localhost",
|
||||||
|
wantErr: ErrUsernameInvalid,
|
||||||
|
wantJSON: &util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "special characters are allowed",
|
||||||
|
localpart: "/dev/null",
|
||||||
|
domain: "localhost",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "special characters are allowed 2",
|
||||||
|
localpart: "i_am_allowed=1",
|
||||||
|
domain: "localhost",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not all special characters are allowed",
|
||||||
|
localpart: "notallowed#", // contains #
|
||||||
|
domain: "localhost",
|
||||||
|
wantErr: ErrUsernameInvalid,
|
||||||
|
wantJSON: &util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "username containing numbers",
|
||||||
|
localpart: "hello1337",
|
||||||
|
domain: "localhost",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
gotErr := ValidateUsername(tt.localpart, tt.domain)
|
||||||
|
if !reflect.DeepEqual(gotErr, tt.wantErr) {
|
||||||
|
t.Errorf("ValidateUsername() = %v, wantErr %v", gotErr, tt.wantErr)
|
||||||
|
}
|
||||||
|
if gotJSON := UsernameResponse(gotErr); !reflect.DeepEqual(gotJSON, tt.wantJSON) {
|
||||||
|
t.Errorf("UsernameResponse() = %v, wantJSON %v", gotJSON, tt.wantJSON)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Application services are allowed usernames starting with an underscore
|
||||||
|
if tt.wantErr == ErrUsernameUnderscore {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
gotErr = ValidateApplicationServiceUsername(tt.localpart, tt.domain)
|
||||||
|
if !reflect.DeepEqual(gotErr, tt.wantErr) {
|
||||||
|
t.Errorf("ValidateUsername() = %v, wantErr %v", gotErr, tt.wantErr)
|
||||||
|
}
|
||||||
|
if gotJSON := UsernameResponse(gotErr); !reflect.DeepEqual(gotJSON, tt.wantJSON) {
|
||||||
|
t.Errorf("UsernameResponse() = %v, wantJSON %v", gotJSON, tt.wantJSON)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -29,7 +29,7 @@ import (
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"golang.org/x/crypto/ed25519"
|
"golang.org/x/crypto/ed25519"
|
||||||
yaml "gopkg.in/yaml.v2"
|
"gopkg.in/yaml.v2"
|
||||||
|
|
||||||
jaegerconfig "github.com/uber/jaeger-client-go/config"
|
jaegerconfig "github.com/uber/jaeger-client-go/config"
|
||||||
jaegermetrics "github.com/uber/jaeger-lib/metrics"
|
jaegermetrics "github.com/uber/jaeger-lib/metrics"
|
||||||
|
@ -314,11 +314,13 @@ func (config *Dendrite) Derive() error {
|
||||||
|
|
||||||
if config.ClientAPI.RecaptchaEnabled {
|
if config.ClientAPI.RecaptchaEnabled {
|
||||||
config.Derived.Registration.Params[authtypes.LoginTypeRecaptcha] = map[string]string{"public_key": config.ClientAPI.RecaptchaPublicKey}
|
config.Derived.Registration.Params[authtypes.LoginTypeRecaptcha] = map[string]string{"public_key": config.ClientAPI.RecaptchaPublicKey}
|
||||||
config.Derived.Registration.Flows = append(config.Derived.Registration.Flows,
|
config.Derived.Registration.Flows = []authtypes.Flow{
|
||||||
authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeRecaptcha}})
|
{Stages: []authtypes.LoginType{authtypes.LoginTypeRecaptcha}},
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
config.Derived.Registration.Flows = append(config.Derived.Registration.Flows,
|
config.Derived.Registration.Flows = []authtypes.Flow{
|
||||||
authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeDummy}})
|
{Stages: []authtypes.LoginType{authtypes.LoginTypeDummy}},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load application service configuration files
|
// Load application service configuration files
|
||||||
|
|
|
@ -78,9 +78,6 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
|
||||||
c.TURN.Verify(configErrs)
|
c.TURN.Verify(configErrs)
|
||||||
c.RateLimiting.Verify(configErrs)
|
c.RateLimiting.Verify(configErrs)
|
||||||
if c.RecaptchaEnabled {
|
if c.RecaptchaEnabled {
|
||||||
checkNotEmpty(configErrs, "client_api.recaptcha_public_key", c.RecaptchaPublicKey)
|
|
||||||
checkNotEmpty(configErrs, "client_api.recaptcha_private_key", c.RecaptchaPrivateKey)
|
|
||||||
checkNotEmpty(configErrs, "client_api.recaptcha_siteverify_api", c.RecaptchaSiteVerifyAPI)
|
|
||||||
if c.RecaptchaSiteVerifyAPI == "" {
|
if c.RecaptchaSiteVerifyAPI == "" {
|
||||||
c.RecaptchaSiteVerifyAPI = "https://www.google.com/recaptcha/api/siteverify"
|
c.RecaptchaSiteVerifyAPI = "https://www.google.com/recaptcha/api/siteverify"
|
||||||
}
|
}
|
||||||
|
@ -93,6 +90,10 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
|
||||||
if c.RecaptchaSitekeyClass == "" {
|
if c.RecaptchaSitekeyClass == "" {
|
||||||
c.RecaptchaSitekeyClass = "g-recaptcha-response"
|
c.RecaptchaSitekeyClass = "g-recaptcha-response"
|
||||||
}
|
}
|
||||||
|
checkNotEmpty(configErrs, "client_api.recaptcha_public_key", c.RecaptchaPublicKey)
|
||||||
|
checkNotEmpty(configErrs, "client_api.recaptcha_private_key", c.RecaptchaPrivateKey)
|
||||||
|
checkNotEmpty(configErrs, "client_api.recaptcha_siteverify_api", c.RecaptchaSiteVerifyAPI)
|
||||||
|
checkNotEmpty(configErrs, "client_api.recaptcha_sitekey_class", c.RecaptchaSitekeyClass)
|
||||||
}
|
}
|
||||||
// Ensure there is any spam counter measure when enabling registration
|
// Ensure there is any spam counter measure when enabling registration
|
||||||
if !c.RegistrationDisabled && !c.OpenRegistrationWithoutVerificationEnabled {
|
if !c.RegistrationDisabled && !c.OpenRegistrationWithoutVerificationEnabled {
|
||||||
|
|
Loading…
Reference in a new issue