Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/helm

This commit is contained in:
Till Faelligen 2023-01-05 15:32:05 +01:00
commit c6994efe70
No known key found for this signature in database
GPG key ID: ACCDC9606D472758
56 changed files with 1709 additions and 420 deletions

View file

@ -331,8 +331,7 @@ jobs:
postgres: postgres postgres: postgres
api: full-http api: full-http
container: container:
# Temporary for debugging to see if this image is working better. image: matrixdotorg/sytest-dendrite
image: matrixdotorg/sytest-dendrite@sha256:434ad464a9f4ed3f8c3cc47200275b6ccb5c5031a8063daf4acea62be5a23c73
volumes: volumes:
- ${{ github.workspace }}:/src - ${{ github.workspace }}:/src
- /root/.cache/go-build:/github/home/.cache/go-build - /root/.cache/go-build:/github/home/.cache/go-build

View file

@ -63,30 +63,3 @@ WORKDIR /etc/dendrite
ENTRYPOINT ["/usr/bin/dendrite-monolith-server"] ENTRYPOINT ["/usr/bin/dendrite-monolith-server"]
EXPOSE 8008 8448 EXPOSE 8008 8448
#
# Builds the Complement image, used for integration tests
#
FROM base AS complement
LABEL org.opencontainers.image.title="Dendrite (Complement)"
RUN apk add --no-cache sqlite openssl ca-certificates
COPY --from=build /out/generate-config /usr/bin/generate-config
COPY --from=build /out/generate-keys /usr/bin/generate-keys
COPY --from=build /out/dendrite-monolith-server /usr/bin/dendrite-monolith-server
WORKDIR /dendrite
RUN /usr/bin/generate-keys --private-key matrix_key.pem && \
mkdir /ca && \
openssl genrsa -out /ca/ca.key 2048 && \
openssl req -new -x509 -key /ca/ca.key -days 3650 -subj "/C=GB/ST=London/O=matrix.org/CN=Complement CA" -out /ca/ca.crt
ENV SERVER_NAME=localhost
ENV API=0
EXPOSE 8008 8448
# At runtime, generate TLS cert based on the CA now mounted at /ca
# At runtime, replace the SERVER_NAME with what we are told
CMD /usr/bin/generate-keys --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /ca/ca.crt --tls-authority-key /ca/ca.key && \
/usr/bin/generate-config -server $SERVER_NAME --ci > dendrite.yaml && \
cp /ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \
/usr/bin/dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0}

View file

@ -16,13 +16,16 @@ RUN --mount=target=. \
--mount=type=cache,target=/root/.cache/go-build \ --mount=type=cache,target=/root/.cache/go-build \
CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-config && \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-config && \
CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-keys && \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-keys && \
CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server && \
CGO_ENABLED=${CGO} go test -c -cover -covermode=atomic -o /dendrite/dendrite-monolith-server-cover -coverpkg "github.com/matrix-org/..." ./cmd/dendrite-monolith-server && \
cp build/scripts/complement-cmd.sh /complement-cmd.sh
WORKDIR /dendrite WORKDIR /dendrite
RUN ./generate-keys --private-key matrix_key.pem RUN ./generate-keys --private-key matrix_key.pem
ENV SERVER_NAME=localhost ENV SERVER_NAME=localhost
ENV API=0 ENV API=0
ENV COVER=0
EXPOSE 8008 8448 EXPOSE 8008 8448
# At runtime, generate TLS cert based on the CA now mounted at /ca # At runtime, generate TLS cert based on the CA now mounted at /ca
@ -30,4 +33,4 @@ EXPOSE 8008 8448
CMD ./generate-keys -keysize 1024 --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key && \ CMD ./generate-keys -keysize 1024 --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key && \
./generate-config -server $SERVER_NAME --ci > dendrite.yaml && \ ./generate-config -server $SERVER_NAME --ci > dendrite.yaml && \
cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \ cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \
exec ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0} exec /complement-cmd.sh

View file

@ -12,18 +12,20 @@ FROM golang:1.18-stretch
RUN apt-get update && apt-get install -y sqlite3 RUN apt-get update && apt-get install -y sqlite3
ENV SERVER_NAME=localhost ENV SERVER_NAME=localhost
ENV COVER=0
EXPOSE 8008 8448 EXPOSE 8008 8448
WORKDIR /runtime WORKDIR /runtime
# This script compiles Dendrite for us. # This script compiles Dendrite for us.
RUN echo '\ RUN echo '\
#!/bin/bash -eux \n\ #!/bin/bash -eux \n\
if test -f "/runtime/dendrite-monolith-server"; then \n\ if test -f "/runtime/dendrite-monolith-server" && test -f "/runtime/dendrite-monolith-server-cover"; then \n\
echo "Skipping compilation; binaries exist" \n\ echo "Skipping compilation; binaries exist" \n\
exit 0 \n\ exit 0 \n\
fi \n\ fi \n\
cd /dendrite \n\ cd /dendrite \n\
go build -v -o /runtime /dendrite/cmd/dendrite-monolith-server \n\ go build -v -o /runtime /dendrite/cmd/dendrite-monolith-server \n\
go test -c -cover -covermode=atomic -o /runtime/dendrite-monolith-server-cover -coverpkg "github.com/matrix-org/..." /dendrite/cmd/dendrite-monolith-server \n\
' > compile.sh && chmod +x compile.sh ' > compile.sh && chmod +x compile.sh
# This script runs Dendrite for us. Must be run in the /runtime directory. # This script runs Dendrite for us. Must be run in the /runtime directory.
@ -33,6 +35,7 @@ RUN echo '\
./generate-keys -keysize 1024 --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key \n\ ./generate-keys -keysize 1024 --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key \n\
./generate-config -server $SERVER_NAME --ci > dendrite.yaml \n\ ./generate-config -server $SERVER_NAME --ci > dendrite.yaml \n\
cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates \n\ cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates \n\
[ ${COVER} -eq 1 ] && exec ./dendrite-monolith-server-cover --test.coverprofile=integrationcover.log --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml \n\
exec ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml \n\ exec ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml \n\
' > run.sh && chmod +x run.sh ' > run.sh && chmod +x run.sh

View file

@ -34,13 +34,16 @@ RUN --mount=target=. \
--mount=type=cache,target=/root/.cache/go-build \ --mount=type=cache,target=/root/.cache/go-build \
CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-config && \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-config && \
CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-keys && \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-keys && \
CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server && \
CGO_ENABLED=${CGO} go test -c -cover -covermode=atomic -o /dendrite/dendrite-monolith-server-cover -coverpkg "github.com/matrix-org/..." ./cmd/dendrite-monolith-server && \
cp build/scripts/complement-cmd.sh /complement-cmd.sh
WORKDIR /dendrite WORKDIR /dendrite
RUN ./generate-keys --private-key matrix_key.pem RUN ./generate-keys --private-key matrix_key.pem
ENV SERVER_NAME=localhost ENV SERVER_NAME=localhost
ENV API=0 ENV API=0
ENV COVER=0
EXPOSE 8008 8448 EXPOSE 8008 8448
@ -51,4 +54,4 @@ CMD /build/run_postgres.sh && ./generate-keys --keysize 1024 --server $SERVER_NA
# Bump max_open_conns up here in the global database config # Bump max_open_conns up here in the global database config
sed -i 's/max_open_conns:.*$/max_open_conns: 1990/g' dendrite.yaml && \ sed -i 's/max_open_conns:.*$/max_open_conns: 1990/g' dendrite.yaml && \
cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \ cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \
exec ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0} exec /complement-cmd.sh

22
build/scripts/complement-cmd.sh Executable file
View file

@ -0,0 +1,22 @@
#!/bin/bash -e
# This script is intended to be used inside a docker container for Complement
if [[ "${COVER}" -eq 1 ]]; then
echo "Running with coverage"
exec /dendrite/dendrite-monolith-server-cover \
--really-enable-open-registration \
--tls-cert server.crt \
--tls-key server.key \
--config dendrite.yaml \
-api=${API:-0} \
--test.coverprofile=integrationcover.log
else
echo "Not running with coverage"
exec /dendrite/dendrite-monolith-server \
--really-enable-open-registration \
--tls-cert server.crt \
--tls-key server.key \
--config dendrite.yaml \
-api=${API:-0}
fi

View file

@ -15,6 +15,8 @@
package clientapi package clientapi
import ( import (
"github.com/matrix-org/gomatrixserverlib"
appserviceAPI "github.com/matrix-org/dendrite/appservice/api" appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/api"
"github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/clientapi/producers"
@ -26,7 +28,6 @@ import (
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
) )
// AddPublicRoutes sets up and registers HTTP handlers for the ClientAPI component. // AddPublicRoutes sets up and registers HTTP handlers for the ClientAPI component.

View file

@ -137,7 +137,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
request := struct { request := struct {
Password string `json:"password"` Password string `json:"password"`
}{} }{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil { if err = json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.Unknown("Failed to decode request body: " + err.Error()), JSON: jsonerror.Unknown("Failed to decode request body: " + err.Error()),
@ -150,8 +150,8 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
} }
} }
if resErr := internal.ValidatePassword(request.Password); resErr != nil { if err = internal.ValidatePassword(request.Password); err != nil {
return *resErr return *internal.PasswordResponse(err)
} }
updateReq := &userapi.PerformPasswordUpdateRequest{ updateReq := &userapi.PerformPasswordUpdateRequest{

View file

@ -15,11 +15,11 @@
package routing package routing
import ( import (
"fmt"
"html/template" "html/template"
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
@ -101,14 +101,28 @@ func serveTemplate(w http.ResponseWriter, templateHTML string, data map[string]s
func AuthFallback( func AuthFallback(
w http.ResponseWriter, req *http.Request, authType string, w http.ResponseWriter, req *http.Request, authType string,
cfg *config.ClientAPI, cfg *config.ClientAPI,
) *util.JSONResponse { ) {
sessionID := req.URL.Query().Get("session") // We currently only support "m.login.recaptcha", so fail early if that's not requested
if authType == authtypes.LoginTypeRecaptcha {
if !cfg.RecaptchaEnabled {
writeHTTPMessage(w, req,
"Recaptcha login is disabled on this Homeserver",
http.StatusBadRequest,
)
return
}
} else {
writeHTTPMessage(w, req, fmt.Sprintf("Unknown authtype %q", authType), http.StatusNotImplemented)
return
}
sessionID := req.URL.Query().Get("session")
if sessionID == "" { if sessionID == "" {
return writeHTTPMessage(w, req, writeHTTPMessage(w, req,
"Session ID not provided", "Session ID not provided",
http.StatusBadRequest, http.StatusBadRequest,
) )
return
} }
serveRecaptcha := func() { serveRecaptcha := func() {
@ -130,70 +144,44 @@ func AuthFallback(
if req.Method == http.MethodGet { if req.Method == http.MethodGet {
// Handle Recaptcha // Handle Recaptcha
if authType == authtypes.LoginTypeRecaptcha { serveRecaptcha()
if err := checkRecaptchaEnabled(cfg, w, req); err != nil { return
return err
}
serveRecaptcha()
return nil
}
return &util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound("Unknown auth stage type"),
}
} else if req.Method == http.MethodPost { } else if req.Method == http.MethodPost {
// Handle Recaptcha // Handle Recaptcha
if authType == authtypes.LoginTypeRecaptcha { clientIP := req.RemoteAddr
if err := checkRecaptchaEnabled(cfg, w, req); err != nil { err := req.ParseForm()
return err if err != nil {
} util.GetLogger(req.Context()).WithError(err).Error("req.ParseForm failed")
w.WriteHeader(http.StatusBadRequest)
clientIP := req.RemoteAddr serveRecaptcha()
err := req.ParseForm() return
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("req.ParseForm failed")
res := jsonerror.InternalServerError()
return &res
}
response := req.Form.Get(cfg.RecaptchaFormField)
if err := validateRecaptcha(cfg, response, clientIP); err != nil {
util.GetLogger(req.Context()).Error(err)
return err
}
// Success. Add recaptcha as a completed login flow
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha)
serveSuccess()
return nil
} }
return &util.JSONResponse{ response := req.Form.Get(cfg.RecaptchaFormField)
Code: http.StatusNotFound, err = validateRecaptcha(cfg, response, clientIP)
JSON: jsonerror.NotFound("Unknown auth stage type"), switch err {
case ErrMissingResponse:
w.WriteHeader(http.StatusBadRequest)
serveRecaptcha() // serve the initial page again, instead of nothing
return
case ErrInvalidCaptcha:
w.WriteHeader(http.StatusUnauthorized)
serveRecaptcha()
return
case nil:
default: // something else failed
util.GetLogger(req.Context()).WithError(err).Error("failed to validate recaptcha")
serveRecaptcha()
return
} }
}
return &util.JSONResponse{
Code: http.StatusMethodNotAllowed,
JSON: jsonerror.NotFound("Bad method"),
}
}
// checkRecaptchaEnabled creates an error response if recaptcha is not usable on homeserver. // Success. Add recaptcha as a completed login flow
func checkRecaptchaEnabled( sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha)
cfg *config.ClientAPI,
w http.ResponseWriter, serveSuccess()
req *http.Request, return
) *util.JSONResponse {
if !cfg.RecaptchaEnabled {
return writeHTTPMessage(w, req,
"Recaptcha login is disabled on this Homeserver",
http.StatusBadRequest,
)
} }
return nil writeHTTPMessage(w, req, "Bad method", http.StatusMethodNotAllowed)
} }
// writeHTTPMessage writes the given header and message to the HTTP response writer. // writeHTTPMessage writes the given header and message to the HTTP response writer.
@ -201,13 +189,10 @@ func checkRecaptchaEnabled(
func writeHTTPMessage( func writeHTTPMessage(
w http.ResponseWriter, req *http.Request, w http.ResponseWriter, req *http.Request,
message string, header int, message string, header int,
) *util.JSONResponse { ) {
w.WriteHeader(header) w.WriteHeader(header)
_, err := w.Write([]byte(message)) _, err := w.Write([]byte(message))
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("w.Write failed") util.GetLogger(req.Context()).WithError(err).Error("w.Write failed")
res := jsonerror.InternalServerError()
return &res
} }
return nil
} }

View file

@ -0,0 +1,149 @@
package routing
import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test/testrig"
)
func Test_AuthFallback(t *testing.T) {
base, _, _ := testrig.Base(nil)
defer base.Close()
for _, useHCaptcha := range []bool{false, true} {
for _, recaptchaEnabled := range []bool{false, true} {
for _, wantErr := range []bool{false, true} {
t.Run(fmt.Sprintf("useHCaptcha(%v) - recaptchaEnabled(%v) - wantErr(%v)", useHCaptcha, recaptchaEnabled, wantErr), func(t *testing.T) {
// Set the defaults for each test
base.Cfg.ClientAPI.Defaults(config.DefaultOpts{Generate: true, Monolithic: true})
base.Cfg.ClientAPI.RecaptchaEnabled = recaptchaEnabled
base.Cfg.ClientAPI.RecaptchaPublicKey = "pub"
base.Cfg.ClientAPI.RecaptchaPrivateKey = "priv"
if useHCaptcha {
base.Cfg.ClientAPI.RecaptchaSiteVerifyAPI = "https://hcaptcha.com/siteverify"
base.Cfg.ClientAPI.RecaptchaApiJsUrl = "https://js.hcaptcha.com/1/api.js"
base.Cfg.ClientAPI.RecaptchaFormField = "h-captcha-response"
base.Cfg.ClientAPI.RecaptchaSitekeyClass = "h-captcha"
}
cfgErrs := &config.ConfigErrors{}
base.Cfg.ClientAPI.Verify(cfgErrs, true)
if len(*cfgErrs) > 0 {
t.Fatalf("(hCaptcha=%v) unexpected config errors: %s", useHCaptcha, cfgErrs.Error())
}
req := httptest.NewRequest(http.MethodGet, "/?session=1337", nil)
rec := httptest.NewRecorder()
AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI)
if !recaptchaEnabled {
if rec.Code != http.StatusBadRequest {
t.Fatalf("unexpected response code: %d, want %d", rec.Code, http.StatusBadRequest)
}
if rec.Body.String() != "Recaptcha login is disabled on this Homeserver" {
t.Fatalf("unexpected response body: %s", rec.Body.String())
}
} else {
if !strings.Contains(rec.Body.String(), base.Cfg.ClientAPI.RecaptchaSitekeyClass) {
t.Fatalf("body does not contain %s: %s", base.Cfg.ClientAPI.RecaptchaSitekeyClass, rec.Body.String())
}
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if wantErr {
_, _ = w.Write([]byte(`{"success":false}`))
return
}
_, _ = w.Write([]byte(`{"success":true}`))
}))
defer srv.Close() // nolint: errcheck
base.Cfg.ClientAPI.RecaptchaSiteVerifyAPI = srv.URL
// check the result after sending the captcha
req = httptest.NewRequest(http.MethodPost, "/?session=1337", nil)
req.Form = url.Values{}
req.Form.Add(base.Cfg.ClientAPI.RecaptchaFormField, "someRandomValue")
rec = httptest.NewRecorder()
AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI)
if recaptchaEnabled {
if !wantErr {
if rec.Code != http.StatusOK {
t.Fatalf("unexpected response code: %d, want %d", rec.Code, http.StatusOK)
}
if rec.Body.String() != successTemplate {
t.Fatalf("unexpected response: %s, want %s", rec.Body.String(), successTemplate)
}
} else {
if rec.Code != http.StatusUnauthorized {
t.Fatalf("unexpected response code: %d, want %d", rec.Code, http.StatusUnauthorized)
}
wantString := "Authentication"
if !strings.Contains(rec.Body.String(), wantString) {
t.Fatalf("expected response to contain '%s', but didn't: %s", wantString, rec.Body.String())
}
}
} else {
if rec.Code != http.StatusBadRequest {
t.Fatalf("unexpected response code: %d, want %d", rec.Code, http.StatusBadRequest)
}
if rec.Body.String() != "Recaptcha login is disabled on this Homeserver" {
t.Fatalf("unexpected response: %s, want %s", rec.Body.String(), "successTemplate")
}
}
})
}
}
}
t.Run("unknown fallbacks are handled correctly", func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/?session=1337", nil)
rec := httptest.NewRecorder()
AuthFallback(rec, req, "DoesNotExist", &base.Cfg.ClientAPI)
if rec.Code != http.StatusNotImplemented {
t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusNotImplemented)
}
})
t.Run("unknown methods are handled correctly", func(t *testing.T) {
req := httptest.NewRequest(http.MethodDelete, "/?session=1337", nil)
rec := httptest.NewRecorder()
AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI)
if rec.Code != http.StatusMethodNotAllowed {
t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusMethodNotAllowed)
}
})
t.Run("missing session parameter is handled correctly", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI)
if rec.Code != http.StatusBadRequest {
t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest)
}
})
t.Run("missing session parameter is handled correctly", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI)
if rec.Code != http.StatusBadRequest {
t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest)
}
})
t.Run("missing 'response' is handled correctly", func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/?session=1337", nil)
rec := httptest.NewRecorder()
AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI)
if rec.Code != http.StatusBadRequest {
t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest)
}
})
}

View file

@ -37,6 +37,7 @@ func JoinRoomByIDOrAlias(
joinReq := roomserverAPI.PerformJoinRequest{ joinReq := roomserverAPI.PerformJoinRequest{
RoomIDOrAlias: roomIDOrAlias, RoomIDOrAlias: roomIDOrAlias,
UserID: device.UserID, UserID: device.UserID,
IsGuest: device.AccountType == api.AccountTypeGuest,
Content: map[string]interface{}{}, Content: map[string]interface{}{},
} }
joinRes := roomserverAPI.PerformJoinResponse{} joinRes := roomserverAPI.PerformJoinResponse{}
@ -84,7 +85,14 @@ func JoinRoomByIDOrAlias(
if err := rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes); err != nil { if err := rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes); err != nil {
done <- jsonerror.InternalAPIError(req.Context(), err) done <- jsonerror.InternalAPIError(req.Context(), err)
} else if joinRes.Error != nil { } else if joinRes.Error != nil {
done <- joinRes.Error.JSONResponse() if joinRes.Error.Code == roomserverAPI.PerformErrorNotAllowed && device.AccountType == api.AccountTypeGuest {
done <- util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.GuestAccessForbidden(joinRes.Error.Msg),
}
} else {
done <- joinRes.Error.JSONResponse()
}
} else { } else {
done <- util.JSONResponse{ done <- util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,

View file

@ -0,0 +1,158 @@
package routing
import (
"bytes"
"context"
"net/http"
"testing"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/appservice"
"github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/dendrite/userapi"
uapi "github.com/matrix-org/dendrite/userapi/api"
)
func TestJoinRoomByIDOrAlias(t *testing.T) {
alice := test.NewUser(t)
bob := test.NewUser(t)
charlie := test.NewUser(t, test.WithAccountType(uapi.AccountTypeGuest))
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
base, baseClose := testrig.CreateBaseDendrite(t, dbType)
defer baseClose()
rsAPI := roomserver.NewInternalAPI(base)
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI)
userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil)
asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI)
rsAPI.SetFederationAPI(nil, nil) // creates the rs.Inputer etc
// Create the users in the userapi
for _, u := range []*test.User{alice, bob, charlie} {
localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID)
userRes := &uapi.PerformAccountCreationResponse{}
if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
AccountType: u.AccountType,
Localpart: localpart,
ServerName: serverName,
Password: "someRandomPassword",
}, userRes); err != nil {
t.Errorf("failed to create account: %s", err)
}
}
aliceDev := &uapi.Device{UserID: alice.ID}
bobDev := &uapi.Device{UserID: bob.ID}
charlieDev := &uapi.Device{UserID: charlie.ID, AccountType: uapi.AccountTypeGuest}
// create a room with disabled guest access and invite Bob
resp := createRoom(ctx, createRoomRequest{
Name: "testing",
IsDirect: true,
Topic: "testing",
Visibility: "public",
Preset: presetPublicChat,
RoomAliasName: "alias",
Invite: []string{bob.ID},
GuestCanJoin: false,
}, aliceDev, &base.Cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now())
crResp, ok := resp.JSON.(createRoomResponse)
if !ok {
t.Fatalf("response is not a createRoomResponse: %+v", resp)
}
// create a room with guest access enabled and invite Charlie
resp = createRoom(ctx, createRoomRequest{
Name: "testing",
IsDirect: true,
Topic: "testing",
Visibility: "public",
Preset: presetPublicChat,
Invite: []string{charlie.ID},
GuestCanJoin: true,
}, aliceDev, &base.Cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now())
crRespWithGuestAccess, ok := resp.JSON.(createRoomResponse)
if !ok {
t.Fatalf("response is not a createRoomResponse: %+v", resp)
}
// Dummy request
body := &bytes.Buffer{}
req, err := http.NewRequest(http.MethodPost, "/?server_name=test", body)
if err != nil {
t.Fatal(err)
}
testCases := []struct {
name string
device *uapi.Device
roomID string
wantHTTP200 bool
}{
{
name: "User can join successfully by alias",
device: bobDev,
roomID: crResp.RoomAlias,
wantHTTP200: true,
},
{
name: "User can join successfully by roomID",
device: bobDev,
roomID: crResp.RoomID,
wantHTTP200: true,
},
{
name: "join is forbidden if user is guest",
device: charlieDev,
roomID: crResp.RoomID,
},
{
name: "room does not exist",
device: aliceDev,
roomID: "!doesnotexist:test",
},
{
name: "user from different server",
device: &uapi.Device{UserID: "@wrong:server"},
roomID: crResp.RoomAlias,
},
{
name: "user doesn't exist locally",
device: &uapi.Device{UserID: "@doesnotexist:test"},
roomID: crResp.RoomAlias,
},
{
name: "invalid room ID",
device: aliceDev,
roomID: "invalidRoomID",
},
{
name: "roomAlias does not exist",
device: aliceDev,
roomID: "#doesnotexist:test",
},
{
name: "room with guest_access event",
device: charlieDev,
roomID: crRespWithGuestAccess.RoomID,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
joinResp := JoinRoomByIDOrAlias(req, tc.device, rsAPI, userAPI, tc.roomID)
if tc.wantHTTP200 && !joinResp.Is2xx() {
t.Fatalf("expected join room to succeed, but didn't: %+v", joinResp)
}
})
}
})
}

View file

@ -23,15 +23,13 @@ import (
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
type loginResponse struct { type loginResponse struct {
UserID string `json:"user_id"` UserID string `json:"user_id"`
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
HomeServer gomatrixserverlib.ServerName `json:"home_server"` DeviceID string `json:"device_id"`
DeviceID string `json:"device_id"`
} }
type flows struct { type flows struct {
@ -116,7 +114,6 @@ func completeAuth(
JSON: loginResponse{ JSON: loginResponse{
UserID: performRes.Device.UserID, UserID: performRes.Device.UserID,
AccessToken: performRes.Device.AccessToken, AccessToken: performRes.Device.AccessToken,
HomeServer: serverName,
DeviceID: performRes.Device.ID, DeviceID: performRes.Device.ID,
}, },
} }

View file

@ -82,8 +82,8 @@ func Password(
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword) sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
// Check the new password strength. // Check the new password strength.
if resErr = internal.ValidatePassword(r.NewPassword); resErr != nil { if err := internal.ValidatePassword(r.NewPassword); err != nil {
return *resErr return *internal.PasswordResponse(err)
} }
// Get the local part. // Get the local part.

View file

@ -18,12 +18,12 @@ package routing
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"regexp"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@ -60,10 +60,7 @@ var (
) )
) )
const ( const sessionIDLength = 24
maxUsernameLength = 254 // http://matrix.org/speculator/spec/HEAD/intro.html#user-identifiers TODO account for domain
sessionIDLength = 24
)
// sessionsDict keeps track of completed auth stages for each session. // sessionsDict keeps track of completed auth stages for each session.
// It shouldn't be passed by value because it contains a mutex. // It shouldn't be passed by value because it contains a mutex.
@ -198,8 +195,7 @@ func (d *sessionsDict) getDeviceToDelete(sessionID string) (string, bool) {
} }
var ( var (
sessions = newSessionsDict() sessions = newSessionsDict()
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`)
) )
// registerRequest represents the submitted registration request. // registerRequest represents the submitted registration request.
@ -262,10 +258,9 @@ func newUserInteractiveResponse(
// http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#post-matrix-client-unstable-register // http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#post-matrix-client-unstable-register
type registerResponse struct { type registerResponse struct {
UserID string `json:"user_id"` UserID string `json:"user_id"`
AccessToken string `json:"access_token,omitempty"` AccessToken string `json:"access_token,omitempty"`
HomeServer gomatrixserverlib.ServerName `json:"home_server"` DeviceID string `json:"device_id,omitempty"`
DeviceID string `json:"device_id,omitempty"`
} }
// recaptchaResponse represents the HTTP response from a Google Recaptcha server // recaptchaResponse represents the HTTP response from a Google Recaptcha server
@ -276,66 +271,28 @@ type recaptchaResponse struct {
ErrorCodes []int `json:"error-codes"` ErrorCodes []int `json:"error-codes"`
} }
// validateUsername returns an error response if the username is invalid var (
func validateUsername(localpart string, domain gomatrixserverlib.ServerName) *util.JSONResponse { ErrInvalidCaptcha = errors.New("invalid captcha response")
// https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 ErrMissingResponse = errors.New("captcha response is required")
if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength { ErrCaptchaDisabled = errors.New("captcha registration is disabled")
return &util.JSONResponse{ )
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(fmt.Sprintf("%q exceeds the maximum length of %d characters", id, maxUsernameLength)),
}
} else if !validUsernameRegex.MatchString(localpart) {
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"),
}
} else if localpart[0] == '_' { // Regex checks its not a zero length string
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Username cannot start with a '_'"),
}
}
return nil
}
// validateApplicationServiceUsername returns an error response if the username is invalid for an application service
func validateApplicationServiceUsername(localpart string, domain gomatrixserverlib.ServerName) *util.JSONResponse {
if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength {
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(fmt.Sprintf("%q exceeds the maximum length of %d characters", id, maxUsernameLength)),
}
} else if !validUsernameRegex.MatchString(localpart) {
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"),
}
}
return nil
}
// validateRecaptcha returns an error response if the captcha response is invalid // validateRecaptcha returns an error response if the captcha response is invalid
func validateRecaptcha( func validateRecaptcha(
cfg *config.ClientAPI, cfg *config.ClientAPI,
response string, response string,
clientip string, clientip string,
) *util.JSONResponse { ) error {
ip, _, _ := net.SplitHostPort(clientip) ip, _, _ := net.SplitHostPort(clientip)
if !cfg.RecaptchaEnabled { if !cfg.RecaptchaEnabled {
return &util.JSONResponse{ return ErrCaptchaDisabled
Code: http.StatusConflict,
JSON: jsonerror.Unknown("Captcha registration is disabled"),
}
} }
if response == "" { if response == "" {
return &util.JSONResponse{ return ErrMissingResponse
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("Captcha response is required"),
}
} }
// Make a POST request to Google's API to check the captcha response // Make a POST request to the captcha provider API to check the captcha response
resp, err := http.PostForm(cfg.RecaptchaSiteVerifyAPI, resp, err := http.PostForm(cfg.RecaptchaSiteVerifyAPI,
url.Values{ url.Values{
"secret": {cfg.RecaptchaPrivateKey}, "secret": {cfg.RecaptchaPrivateKey},
@ -345,10 +302,7 @@ func validateRecaptcha(
) )
if err != nil { if err != nil {
return &util.JSONResponse{ return err
Code: http.StatusInternalServerError,
JSON: jsonerror.BadJSON("Error in requesting validation of captcha response"),
}
} }
// Close the request once we're finishing reading from it // Close the request once we're finishing reading from it
@ -358,25 +312,16 @@ func validateRecaptcha(
var r recaptchaResponse var r recaptchaResponse
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return &util.JSONResponse{ return err
Code: http.StatusGatewayTimeout,
JSON: jsonerror.Unknown("Error in contacting captcha server" + err.Error()),
}
} }
err = json.Unmarshal(body, &r) err = json.Unmarshal(body, &r)
if err != nil { if err != nil {
return &util.JSONResponse{ return err
Code: http.StatusInternalServerError,
JSON: jsonerror.BadJSON("Error in unmarshaling captcha server's response: " + err.Error()),
}
} }
// Check that we received a "success" // Check that we received a "success"
if !r.Success { if !r.Success {
return &util.JSONResponse{ return ErrInvalidCaptcha
Code: http.StatusUnauthorized,
JSON: jsonerror.BadJSON("Invalid captcha response. Please try again."),
}
} }
return nil return nil
} }
@ -508,8 +453,8 @@ func validateApplicationService(
} }
// Check username application service is trying to register is valid // Check username application service is trying to register is valid
if err := validateApplicationServiceUsername(username, cfg.Matrix.ServerName); err != nil { if err := internal.ValidateApplicationServiceUsername(username, cfg.Matrix.ServerName); err != nil {
return "", err return "", internal.UsernameResponse(err)
} }
// No errors, registration valid // No errors, registration valid
@ -564,15 +509,12 @@ func Register(
if resErr := httputil.UnmarshalJSON(reqBody, &r); resErr != nil { if resErr := httputil.UnmarshalJSON(reqBody, &r); resErr != nil {
return *resErr return *resErr
} }
if l, d, err := cfg.Matrix.SplitLocalID('@', r.Username); err == nil {
r.Username, r.ServerName = l, d
}
if req.URL.Query().Get("kind") == "guest" { if req.URL.Query().Get("kind") == "guest" {
return handleGuestRegistration(req, r, cfg, userAPI) return handleGuestRegistration(req, r, cfg, userAPI)
} }
// Don't allow numeric usernames less than MAX_INT64. // Don't allow numeric usernames less than MAX_INT64.
if _, err := strconv.ParseInt(r.Username, 10, 64); err == nil { if _, err = strconv.ParseInt(r.Username, 10, 64); err == nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Numeric user IDs are reserved"), JSON: jsonerror.InvalidUsername("Numeric user IDs are reserved"),
@ -584,7 +526,7 @@ func Register(
ServerName: r.ServerName, ServerName: r.ServerName,
} }
nres := &userapi.QueryNumericLocalpartResponse{} nres := &userapi.QueryNumericLocalpartResponse{}
if err := userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil { if err = userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed") util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
@ -601,8 +543,8 @@ func Register(
case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil: case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil:
// Spec-compliant case (the access_token is specified and the login type // Spec-compliant case (the access_token is specified and the login type
// is correctly set, so it's an appservice registration) // is correctly set, so it's an appservice registration)
if resErr := validateApplicationServiceUsername(r.Username, r.ServerName); resErr != nil { if err = internal.ValidateApplicationServiceUsername(r.Username, r.ServerName); err != nil {
return *resErr return *internal.UsernameResponse(err)
} }
case accessTokenErr == nil: case accessTokenErr == nil:
// Non-spec-compliant case (the access_token is specified but the login // Non-spec-compliant case (the access_token is specified but the login
@ -614,12 +556,12 @@ func Register(
default: default:
// Spec-compliant case (neither the access_token nor the login type are // Spec-compliant case (neither the access_token nor the login type are
// specified, so it's a normal user registration) // specified, so it's a normal user registration)
if resErr := validateUsername(r.Username, r.ServerName); resErr != nil { if err = internal.ValidateUsername(r.Username, r.ServerName); err != nil {
return *resErr return *internal.UsernameResponse(err)
} }
} }
if resErr := internal.ValidatePassword(r.Password); resErr != nil { if err = internal.ValidatePassword(r.Password); err != nil {
return *resErr return *internal.PasswordResponse(err)
} }
logger := util.GetLogger(req.Context()) logger := util.GetLogger(req.Context())
@ -697,7 +639,6 @@ func handleGuestRegistration(
JSON: registerResponse{ JSON: registerResponse{
UserID: devRes.Device.UserID, UserID: devRes.Device.UserID,
AccessToken: devRes.Device.AccessToken, AccessToken: devRes.Device.AccessToken,
HomeServer: res.Account.ServerName,
DeviceID: devRes.Device.ID, DeviceID: devRes.Device.ID,
}, },
} }
@ -761,9 +702,18 @@ func handleRegistrationFlow(
switch r.Auth.Type { switch r.Auth.Type {
case authtypes.LoginTypeRecaptcha: case authtypes.LoginTypeRecaptcha:
// Check given captcha response // Check given captcha response
resErr := validateRecaptcha(cfg, r.Auth.Response, req.RemoteAddr) err := validateRecaptcha(cfg, r.Auth.Response, req.RemoteAddr)
if resErr != nil { switch err {
return *resErr case ErrCaptchaDisabled:
return util.JSONResponse{Code: http.StatusForbidden, JSON: jsonerror.Unknown(err.Error())}
case ErrMissingResponse:
return util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(err.Error())}
case ErrInvalidCaptcha:
return util.JSONResponse{Code: http.StatusUnauthorized, JSON: jsonerror.BadJSON(err.Error())}
case nil:
default:
util.GetLogger(req.Context()).WithError(err).Error("failed to validate recaptcha")
return util.JSONResponse{Code: http.StatusInternalServerError, JSON: jsonerror.InternalServerError()}
} }
// Add Recaptcha to the list of completed registration stages // Add Recaptcha to the list of completed registration stages
@ -924,8 +874,7 @@ func completeRegistration(
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: registerResponse{ JSON: registerResponse{
UserID: userutil.MakeUserID(username, accRes.Account.ServerName), UserID: userutil.MakeUserID(username, accRes.Account.ServerName),
HomeServer: accRes.Account.ServerName,
}, },
} }
} }
@ -958,7 +907,6 @@ func completeRegistration(
result := registerResponse{ result := registerResponse{
UserID: devRes.Device.UserID, UserID: devRes.Device.UserID,
AccessToken: devRes.Device.AccessToken, AccessToken: devRes.Device.AccessToken,
HomeServer: accRes.Account.ServerName,
DeviceID: devRes.Device.ID, DeviceID: devRes.Device.ID,
} }
sessions.addCompletedRegistration(sessionID, result) sessions.addCompletedRegistration(sessionID, result)
@ -1054,8 +1002,8 @@ func RegisterAvailable(
} }
} }
if err := validateUsername(username, domain); err != nil { if err := internal.ValidateUsername(username, domain); err != nil {
return *err return *internal.UsernameResponse(err)
} }
// Check if this username is reserved by an application service // Check if this username is reserved by an application service
@ -1117,11 +1065,11 @@ func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.Clien
// downcase capitals // downcase capitals
ssrr.User = strings.ToLower(ssrr.User) ssrr.User = strings.ToLower(ssrr.User)
if resErr := validateUsername(ssrr.User, cfg.Matrix.ServerName); resErr != nil { if err = internal.ValidateUsername(ssrr.User, cfg.Matrix.ServerName); err != nil {
return *resErr return *internal.UsernameResponse(err)
} }
if resErr := internal.ValidatePassword(ssrr.Password); resErr != nil { if err = internal.ValidatePassword(ssrr.Password); err != nil {
return *resErr return *internal.PasswordResponse(err)
} }
deviceID := "shared_secret_registration" deviceID := "shared_secret_registration"

View file

@ -15,12 +15,27 @@
package routing package routing
import ( import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"reflect"
"regexp" "regexp"
"strings"
"testing" "testing"
"time" "time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/dendrite/userapi"
"github.com/matrix-org/util"
) )
var ( var (
@ -264,3 +279,294 @@ func TestSessionCleanUp(t *testing.T) {
} }
}) })
} }
func Test_register(t *testing.T) {
testCases := []struct {
name string
kind string
password string
username string
loginType string
forceEmpty bool
registrationDisabled bool
guestsDisabled bool
enableRecaptcha bool
captchaBody string
wantResponse util.JSONResponse
}{
{
name: "disallow guests",
kind: "guest",
guestsDisabled: true,
wantResponse: util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden(`Guest registration is disabled on "test"`),
},
},
{
name: "allow guests",
kind: "guest",
},
{
name: "unknown login type",
loginType: "im.not.known",
wantResponse: util.JSONResponse{
Code: http.StatusNotImplemented,
JSON: jsonerror.Unknown("unknown/unimplemented auth type"),
},
},
{
name: "disabled registration",
registrationDisabled: true,
wantResponse: util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden(`Registration is disabled on "test"`),
},
},
{
name: "successful registration, numeric ID",
username: "",
password: "someRandomPassword",
forceEmpty: true,
},
{
name: "successful registration",
username: "success",
},
{
name: "failing registration - user already exists",
username: "success",
wantResponse: util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.UserInUse("Desired user ID is already taken."),
},
},
{
name: "successful registration uppercase username",
username: "LOWERCASED", // this is going to be lower-cased
},
{
name: "invalid username",
username: "#totalyNotValid",
wantResponse: *internal.UsernameResponse(internal.ErrUsernameInvalid),
},
{
name: "numeric username is forbidden",
username: "1337",
wantResponse: util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Numeric user IDs are reserved"),
},
},
{
name: "disabled recaptcha login",
loginType: authtypes.LoginTypeRecaptcha,
wantResponse: util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Unknown(ErrCaptchaDisabled.Error()),
},
},
{
name: "enabled recaptcha, no response defined",
enableRecaptcha: true,
loginType: authtypes.LoginTypeRecaptcha,
wantResponse: util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(ErrMissingResponse.Error()),
},
},
{
name: "invalid captcha response",
enableRecaptcha: true,
loginType: authtypes.LoginTypeRecaptcha,
captchaBody: `notvalid`,
wantResponse: util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: jsonerror.BadJSON(ErrInvalidCaptcha.Error()),
},
},
{
name: "valid captcha response",
enableRecaptcha: true,
loginType: authtypes.LoginTypeRecaptcha,
captchaBody: `success`,
},
{
name: "captcha invalid from remote",
enableRecaptcha: true,
loginType: authtypes.LoginTypeRecaptcha,
captchaBody: `i should fail for other reasons`,
wantResponse: util.JSONResponse{Code: http.StatusInternalServerError, JSON: jsonerror.InternalServerError()},
},
}
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
base, baseClose := testrig.CreateBaseDendrite(t, dbType)
defer baseClose()
rsAPI := roomserver.NewInternalAPI(base)
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI)
userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil)
keyAPI.SetUserAPI(userAPI)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.enableRecaptcha {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
t.Fatal(err)
}
response := r.Form.Get("response")
// Respond with valid JSON or no JSON at all to test happy/error cases
switch response {
case "success":
json.NewEncoder(w).Encode(recaptchaResponse{Success: true})
case "notvalid":
json.NewEncoder(w).Encode(recaptchaResponse{Success: false})
default:
}
}))
defer srv.Close()
base.Cfg.ClientAPI.RecaptchaSiteVerifyAPI = srv.URL
}
if err := base.Cfg.Derive(); err != nil {
t.Fatalf("failed to derive config: %s", err)
}
base.Cfg.ClientAPI.RecaptchaEnabled = tc.enableRecaptcha
base.Cfg.ClientAPI.RegistrationDisabled = tc.registrationDisabled
base.Cfg.ClientAPI.GuestsDisabled = tc.guestsDisabled
if tc.kind == "" {
tc.kind = "user"
}
if tc.password == "" && !tc.forceEmpty {
tc.password = "someRandomPassword"
}
if tc.username == "" && !tc.forceEmpty {
tc.username = "valid"
}
if tc.loginType == "" {
tc.loginType = "m.login.dummy"
}
reg := registerRequest{
Password: tc.password,
Username: tc.username,
}
body := &bytes.Buffer{}
err := json.NewEncoder(body).Encode(reg)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("/?kind=%s", tc.kind), body)
resp := Register(req, userAPI, &base.Cfg.ClientAPI)
t.Logf("Resp: %+v", resp)
// The first request should return a userInteractiveResponse
switch r := resp.JSON.(type) {
case userInteractiveResponse:
// Check that the flows are the ones we configured
if !reflect.DeepEqual(r.Flows, base.Cfg.Derived.Registration.Flows) {
t.Fatalf("unexpected registration flows: %+v, want %+v", r.Flows, base.Cfg.Derived.Registration.Flows)
}
case *jsonerror.MatrixError:
if !reflect.DeepEqual(tc.wantResponse, resp) {
t.Fatalf("(%s), unexpected response: %+v, want: %+v", tc.name, resp, tc.wantResponse)
}
return
case registerResponse:
// this should only be possible on guest user registration, never for normal users
if tc.kind != "guest" {
t.Fatalf("got register response on first request: %+v", r)
}
// assert we've got a UserID, AccessToken and DeviceID
if r.UserID == "" {
t.Fatalf("missing userID in response")
}
if r.AccessToken == "" {
t.Fatalf("missing accessToken in response")
}
if r.DeviceID == "" {
t.Fatalf("missing deviceID in response")
}
return
default:
t.Logf("Got response: %T", resp.JSON)
}
// If we reached this, we should have received a UIA response
uia, ok := resp.JSON.(userInteractiveResponse)
if !ok {
t.Fatalf("did not receive a userInteractiveResponse: %T", resp.JSON)
}
t.Logf("%+v", uia)
// Register the user
reg.Auth = authDict{
Type: authtypes.LoginType(tc.loginType),
Session: uia.Session,
}
if tc.captchaBody != "" {
reg.Auth.Response = tc.captchaBody
}
dummy := "dummy"
reg.DeviceID = &dummy
reg.InitialDisplayName = &dummy
reg.Type = authtypes.LoginType(tc.loginType)
err = json.NewEncoder(body).Encode(reg)
if err != nil {
t.Fatal(err)
}
req = httptest.NewRequest(http.MethodPost, "/", body)
resp = Register(req, userAPI, &base.Cfg.ClientAPI)
switch resp.JSON.(type) {
case *jsonerror.MatrixError:
if !reflect.DeepEqual(tc.wantResponse, resp) {
t.Fatalf("unexpected response: %+v, want: %+v", resp, tc.wantResponse)
}
return
case util.JSONResponse:
if !reflect.DeepEqual(tc.wantResponse, resp) {
t.Fatalf("unexpected response: %+v, want: %+v", resp, tc.wantResponse)
}
return
}
rr, ok := resp.JSON.(registerResponse)
if !ok {
t.Fatalf("expected a registerresponse, got %T", resp.JSON)
}
// validate the response
if tc.forceEmpty {
// when not supplying a username, one will be generated. Given this _SHOULD_ be
// the second user, set the username accordingly
reg.Username = "2"
}
wantUserID := strings.ToLower(fmt.Sprintf("@%s:%s", reg.Username, "test"))
if wantUserID != rr.UserID {
t.Fatalf("unexpected userID: %s, want %s", rr.UserID, wantUserID)
}
if rr.DeviceID != *reg.DeviceID {
t.Fatalf("unexpected deviceID: %s, want %s", rr.DeviceID, *reg.DeviceID)
}
if rr.AccessToken == "" {
t.Fatalf("missing accessToken in response")
}
})
}
})
}

View file

@ -639,9 +639,9 @@ func Setup(
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
v3mux.Handle("/auth/{authType}/fallback/web", v3mux.Handle("/auth/{authType}/fallback/web",
httputil.MakeHTMLAPI("auth_fallback", base.EnableMetrics, func(w http.ResponseWriter, req *http.Request) *util.JSONResponse { httputil.MakeHTMLAPI("auth_fallback", base.EnableMetrics, func(w http.ResponseWriter, req *http.Request) {
vars := mux.Vars(req) vars := mux.Vars(req)
return AuthFallback(w, req, vars["authType"], cfg) AuthFallback(w, req, vars["authType"], cfg)
}), }),
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)

View file

@ -25,10 +25,10 @@ import (
"io" "io"
"net/http" "net/http"
"os" "os"
"regexp"
"strings" "strings"
"time" "time"
"github.com/matrix-org/dendrite/internal"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -58,15 +58,14 @@ Arguments:
` `
var ( var (
username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')") username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')")
password = flag.String("password", "", "The password to associate with the account") password = flag.String("password", "", "The password to associate with the account")
pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)") pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)")
pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin")
isAdmin = flag.Bool("admin", false, "Create an admin account") isAdmin = flag.Bool("admin", false, "Create an admin account")
resetPassword = flag.Bool("reset-password", false, "Deprecated") resetPassword = flag.Bool("reset-password", false, "Deprecated")
serverURL = flag.String("url", "http://localhost:8008", "The URL to connect to.") serverURL = flag.String("url", "http://localhost:8008", "The URL to connect to.")
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`) timeout = flag.Duration("timeout", time.Second*30, "Timeout for the http client when connecting to the server")
timeout = flag.Duration("timeout", time.Second*30, "Timeout for the http client when connecting to the server")
) )
var cl = http.Client{ var cl = http.Client{
@ -95,20 +94,21 @@ func main() {
os.Exit(1) os.Exit(1)
} }
if !validUsernameRegex.MatchString(*username) { if err := internal.ValidateUsername(*username, cfg.Global.ServerName); err != nil {
logrus.Warn("Username can only contain characters a-z, 0-9, or '_-./='") logrus.WithError(err).Error("Specified username is invalid")
os.Exit(1) os.Exit(1)
} }
if len(fmt.Sprintf("@%s:%s", *username, cfg.Global.ServerName)) > 255 {
logrus.Fatalf("Username can not be longer than 255 characters: %s", fmt.Sprintf("@%s:%s", *username, cfg.Global.ServerName))
}
pass, err := getPassword(*password, *pwdFile, *pwdStdin, os.Stdin) pass, err := getPassword(*password, *pwdFile, *pwdStdin, os.Stdin)
if err != nil { if err != nil {
logrus.Fatalln(err) logrus.Fatalln(err)
} }
if err = internal.ValidatePassword(pass); err != nil {
logrus.WithError(err).Error("Specified password is invalid")
os.Exit(1)
}
cl.Timeout = *timeout cl.Timeout = *timeout
accessToken, err := sharedSecretRegister(cfg.ClientAPI.RegistrationSharedSecret, *serverURL, *username, pass, *isAdmin) accessToken, err := sharedSecretRegister(cfg.ClientAPI.RegistrationSharedSecret, *serverURL, *username, pass, *isAdmin)

View file

@ -157,11 +157,11 @@ func TestOutboundPeeking(t *testing.T) {
if len(outboundPeeks) != len(peekIDs) { if len(outboundPeeks) != len(peekIDs) {
t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(outboundPeeks)) t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(outboundPeeks))
} }
for i := range outboundPeeks { gotPeekIDs := make([]string, 0, len(outboundPeeks))
if outboundPeeks[i].PeekID != peekIDs[i] { for _, p := range outboundPeeks {
t.Fatalf("unexpected peek ID: %s, want %s", outboundPeeks[i].PeekID, peekIDs[i]) gotPeekIDs = append(gotPeekIDs, p.PeekID)
}
} }
assert.ElementsMatch(t, gotPeekIDs, peekIDs)
}) })
} }
@ -239,10 +239,10 @@ func TestInboundPeeking(t *testing.T) {
if len(inboundPeeks) != len(peekIDs) { if len(inboundPeeks) != len(peekIDs) {
t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(inboundPeeks)) t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(inboundPeeks))
} }
for i := range inboundPeeks { gotPeekIDs := make([]string, 0, len(inboundPeeks))
if inboundPeeks[i].PeekID != peekIDs[i] { for _, p := range inboundPeeks {
t.Fatalf("unexpected peek ID: %s, want %s", inboundPeeks[i].PeekID, peekIDs[i]) gotPeekIDs = append(gotPeekIDs, p.PeekID)
}
} }
assert.ElementsMatch(t, gotPeekIDs, peekIDs)
}) })
} }

2
go.mod
View file

@ -22,7 +22,7 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
github.com/matrix-org/gomatrixserverlib v0.0.0-20221129095800-8835f6db16b8 github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab
github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.15 github.com/mattn/go-sqlite3 v1.14.15

4
go.sum
View file

@ -348,8 +348,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20221129095800-8835f6db16b8 h1:jVvlCGs6OosCdvw9MkfiVnTVnIt7vKMHg/F6th9BtSo= github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab h1:ChaQdT2mpxMm3GRXNOZzLDQ/wOnlKZ8o60LmZGOjdj8=
github.com/matrix-org/gomatrixserverlib v0.0.0-20221129095800-8835f6db16b8/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4=
github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 h1:auIBCi7gfZuvztD0aPr1G/J5Ya5vWr79M/+TJqwD/JM= github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 h1:auIBCi7gfZuvztD0aPr1G/J5Ya5vWr79M/+TJqwD/JM=
github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=

View file

@ -198,17 +198,12 @@ func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse
// MakeHTMLAPI adds Span metrics to the HTML Handler function // MakeHTMLAPI adds Span metrics to the HTML Handler function
// This is used to serve HTML alongside JSON error messages // This is used to serve HTML alongside JSON error messages
func MakeHTMLAPI(metricsName string, enableMetrics bool, f func(http.ResponseWriter, *http.Request) *util.JSONResponse) http.Handler { func MakeHTMLAPI(metricsName string, enableMetrics bool, f func(http.ResponseWriter, *http.Request)) http.Handler {
withSpan := func(w http.ResponseWriter, req *http.Request) { withSpan := func(w http.ResponseWriter, req *http.Request) {
span := opentracing.StartSpan(metricsName) span := opentracing.StartSpan(metricsName)
defer span.Finish() defer span.Finish()
req = req.WithContext(opentracing.ContextWithSpan(req.Context(), span)) req = req.WithContext(opentracing.ContextWithSpan(req.Context(), span))
if err := f(w, req); err != nil { f(w, req)
h := util.MakeJSONAPI(util.NewJSONRequestHandler(func(req *http.Request) util.JSONResponse {
return *err
}))
h.ServeHTTP(w, req)
}
} }
if !enableMetrics { if !enableMetrics {

View file

@ -14,7 +14,7 @@ type Condition struct {
// Pattern indicates the value pattern that must match. Required // Pattern indicates the value pattern that must match. Required
// for EventMatchCondition. // for EventMatchCondition.
Pattern string `json:"pattern,omitempty"` Pattern *string `json:"pattern,omitempty"`
// Is indicates the condition that must be fulfilled. Required for // Is indicates the condition that must be fulfilled. Required for
// RoomMemberCountCondition. // RoomMemberCountCondition.

View file

@ -15,13 +15,7 @@ func mRuleContainsUserNameDefinition(localpart string) *Rule {
RuleID: MRuleContainsUserName, RuleID: MRuleContainsUserName,
Default: true, Default: true,
Enabled: true, Enabled: true,
Pattern: localpart, Pattern: &localpart,
Conditions: []*Condition{
{
Kind: EventMatchCondition,
Key: "content.body",
},
},
Actions: []*Action{ Actions: []*Action{
{Kind: NotifyAction}, {Kind: NotifyAction},
{ {
@ -32,7 +26,6 @@ func mRuleContainsUserNameDefinition(localpart string) *Rule {
{ {
Kind: SetTweakAction, Kind: SetTweakAction,
Tweak: HighlightTweak, Tweak: HighlightTweak,
Value: true,
}, },
}, },
} }

View file

@ -22,15 +22,15 @@ const (
MRuleTombstone = ".m.rule.tombstone" MRuleTombstone = ".m.rule.tombstone"
MRuleRoomNotif = ".m.rule.roomnotif" MRuleRoomNotif = ".m.rule.roomnotif"
MRuleReaction = ".m.rule.reaction" MRuleReaction = ".m.rule.reaction"
MRuleRoomACLs = ".m.rule.room.server_acl"
) )
var ( var (
mRuleMasterDefinition = Rule{ mRuleMasterDefinition = Rule{
RuleID: MRuleMaster, RuleID: MRuleMaster,
Default: true, Default: true,
Enabled: false, Enabled: false,
Conditions: []*Condition{}, Actions: []*Action{{Kind: DontNotifyAction}},
Actions: []*Action{{Kind: DontNotifyAction}},
} }
mRuleSuppressNoticesDefinition = Rule{ mRuleSuppressNoticesDefinition = Rule{
RuleID: MRuleSuppressNotices, RuleID: MRuleSuppressNotices,
@ -40,7 +40,7 @@ var (
{ {
Kind: EventMatchCondition, Kind: EventMatchCondition,
Key: "content.msgtype", Key: "content.msgtype",
Pattern: "m.notice", Pattern: pointer("m.notice"),
}, },
}, },
Actions: []*Action{{Kind: DontNotifyAction}}, Actions: []*Action{{Kind: DontNotifyAction}},
@ -53,7 +53,7 @@ var (
{ {
Kind: EventMatchCondition, Kind: EventMatchCondition,
Key: "type", Key: "type",
Pattern: "m.room.member", Pattern: pointer("m.room.member"),
}, },
}, },
Actions: []*Action{{Kind: DontNotifyAction}}, Actions: []*Action{{Kind: DontNotifyAction}},
@ -73,7 +73,6 @@ var (
{ {
Kind: SetTweakAction, Kind: SetTweakAction,
Tweak: HighlightTweak, Tweak: HighlightTweak,
Value: true,
}, },
}, },
} }
@ -85,12 +84,12 @@ var (
{ {
Kind: EventMatchCondition, Kind: EventMatchCondition,
Key: "type", Key: "type",
Pattern: "m.room.tombstone", Pattern: pointer("m.room.tombstone"),
}, },
{ {
Kind: EventMatchCondition, Kind: EventMatchCondition,
Key: "state_key", Key: "state_key",
Pattern: "", Pattern: pointer(""),
}, },
}, },
Actions: []*Action{ Actions: []*Action{
@ -98,10 +97,27 @@ var (
{ {
Kind: SetTweakAction, Kind: SetTweakAction,
Tweak: HighlightTweak, Tweak: HighlightTweak,
Value: true,
}, },
}, },
} }
mRuleACLsDefinition = Rule{
RuleID: MRuleRoomACLs,
Default: true,
Enabled: true,
Conditions: []*Condition{
{
Kind: EventMatchCondition,
Key: "type",
Pattern: pointer("m.room.server_acl"),
},
{
Kind: EventMatchCondition,
Key: "state_key",
Pattern: pointer(""),
},
},
Actions: []*Action{},
}
mRuleRoomNotifDefinition = Rule{ mRuleRoomNotifDefinition = Rule{
RuleID: MRuleRoomNotif, RuleID: MRuleRoomNotif,
Default: true, Default: true,
@ -110,7 +126,7 @@ var (
{ {
Kind: EventMatchCondition, Kind: EventMatchCondition,
Key: "content.body", Key: "content.body",
Pattern: "@room", Pattern: pointer("@room"),
}, },
{ {
Kind: SenderNotificationPermissionCondition, Kind: SenderNotificationPermissionCondition,
@ -122,7 +138,6 @@ var (
{ {
Kind: SetTweakAction, Kind: SetTweakAction,
Tweak: HighlightTweak, Tweak: HighlightTweak,
Value: true,
}, },
}, },
} }
@ -134,7 +149,7 @@ var (
{ {
Kind: EventMatchCondition, Kind: EventMatchCondition,
Key: "type", Key: "type",
Pattern: "m.reaction", Pattern: pointer("m.reaction"),
}, },
}, },
Actions: []*Action{ Actions: []*Action{
@ -152,17 +167,17 @@ func mRuleInviteForMeDefinition(userID string) *Rule {
{ {
Kind: EventMatchCondition, Kind: EventMatchCondition,
Key: "type", Key: "type",
Pattern: "m.room.member", Pattern: pointer("m.room.member"),
}, },
{ {
Kind: EventMatchCondition, Kind: EventMatchCondition,
Key: "content.membership", Key: "content.membership",
Pattern: "invite", Pattern: pointer("invite"),
}, },
{ {
Kind: EventMatchCondition, Kind: EventMatchCondition,
Key: "state_key", Key: "state_key",
Pattern: userID, Pattern: pointer(userID),
}, },
}, },
Actions: []*Action{ Actions: []*Action{
@ -172,11 +187,6 @@ func mRuleInviteForMeDefinition(userID string) *Rule {
Tweak: SoundTweak, Tweak: SoundTweak,
Value: "default", Value: "default",
}, },
{
Kind: SetTweakAction,
Tweak: HighlightTweak,
Value: false,
},
}, },
} }
} }

View file

@ -0,0 +1,111 @@
package pushrules
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
)
// Tests that the pre-defined rules as of
// https://spec.matrix.org/v1.4/client-server-api/#predefined-rules
// are correct
func TestDefaultRules(t *testing.T) {
type testCase struct {
name string
inputBytes []byte
want Rule
}
testCases := []testCase{
// Default override rules
{
name: ".m.rule.master",
inputBytes: []byte(`{"rule_id":".m.rule.master","default":true,"enabled":false,"actions":["dont_notify"]}`),
want: mRuleMasterDefinition,
},
{
name: ".m.rule.suppress_notices",
inputBytes: []byte(`{"rule_id":".m.rule.suppress_notices","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"content.msgtype","pattern":"m.notice"}],"actions":["dont_notify"]}`),
want: mRuleSuppressNoticesDefinition,
},
{
name: ".m.rule.invite_for_me",
inputBytes: []byte(`{"rule_id":".m.rule.invite_for_me","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.member"},{"kind":"event_match","key":"content.membership","pattern":"invite"},{"kind":"event_match","key":"state_key","pattern":"@test:localhost"}],"actions":["notify",{"set_tweak":"sound","value":"default"}]}`),
want: *mRuleInviteForMeDefinition("@test:localhost"),
},
{
name: ".m.rule.member_event",
inputBytes: []byte(`{"rule_id":".m.rule.member_event","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.member"}],"actions":["dont_notify"]}`),
want: mRuleMemberEventDefinition,
},
{
name: ".m.rule.contains_display_name",
inputBytes: []byte(`{"rule_id":".m.rule.contains_display_name","default":true,"enabled":true,"conditions":[{"kind":"contains_display_name"}],"actions":["notify",{"set_tweak":"sound","value":"default"},{"set_tweak":"highlight"}]}`),
want: mRuleContainsDisplayNameDefinition,
},
{
name: ".m.rule.tombstone",
inputBytes: []byte(`{"rule_id":".m.rule.tombstone","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.tombstone"},{"kind":"event_match","key":"state_key","pattern":""}],"actions":["notify",{"set_tweak":"highlight"}]}`),
want: mRuleTombstoneDefinition,
},
{
name: ".m.rule.room.server_acl",
inputBytes: []byte(`{"rule_id":".m.rule.room.server_acl","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.server_acl"},{"kind":"event_match","key":"state_key","pattern":""}],"actions":[]}`),
want: mRuleACLsDefinition,
},
{
name: ".m.rule.roomnotif",
inputBytes: []byte(`{"rule_id":".m.rule.roomnotif","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"content.body","pattern":"@room"},{"kind":"sender_notification_permission","key":"room"}],"actions":["notify",{"set_tweak":"highlight"}]}`),
want: mRuleRoomNotifDefinition,
},
// Default content rules
{
name: ".m.rule.contains_user_name",
inputBytes: []byte(`{"rule_id":".m.rule.contains_user_name","default":true,"enabled":true,"actions":["notify",{"set_tweak":"sound","value":"default"},{"set_tweak":"highlight"}],"pattern":"myLocalUser"}`),
want: *mRuleContainsUserNameDefinition("myLocalUser"),
},
// default underride rules
{
name: ".m.rule.call",
inputBytes: []byte(`{"rule_id":".m.rule.call","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.call.invite"}],"actions":["notify",{"set_tweak":"sound","value":"ring"}]}`),
want: mRuleCallDefinition,
},
{
name: ".m.rule.encrypted_room_one_to_one",
inputBytes: []byte(`{"rule_id":".m.rule.encrypted_room_one_to_one","default":true,"enabled":true,"conditions":[{"kind":"room_member_count","is":"2"},{"kind":"event_match","key":"type","pattern":"m.room.encrypted"}],"actions":["notify",{"set_tweak":"sound","value":"default"}]}`),
want: mRuleEncryptedRoomOneToOneDefinition,
},
{
name: ".m.rule.room_one_to_one",
inputBytes: []byte(`{"rule_id":".m.rule.room_one_to_one","default":true,"enabled":true,"conditions":[{"kind":"room_member_count","is":"2"},{"kind":"event_match","key":"type","pattern":"m.room.message"}],"actions":["notify",{"set_tweak":"sound","value":"default"}]}`),
want: mRuleRoomOneToOneDefinition,
},
{
name: ".m.rule.message",
inputBytes: []byte(`{"rule_id":".m.rule.message","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.message"}],"actions":["notify"]}`),
want: mRuleMessageDefinition,
},
{
name: ".m.rule.encrypted",
inputBytes: []byte(`{"rule_id":".m.rule.encrypted","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.encrypted"}],"actions":["notify"]}`),
want: mRuleEncryptedDefinition,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
r := Rule{}
// unmarshal predefined push rules
err := json.Unmarshal(tc.inputBytes, &r)
assert.NoError(t, err)
assert.Equal(t, tc.want, r)
// and reverse it to check we get the expected result
got, err := json.Marshal(r)
assert.NoError(t, err)
assert.Equal(t, string(got), string(tc.inputBytes))
})
}
}

View file

@ -25,7 +25,7 @@ var (
{ {
Kind: EventMatchCondition, Kind: EventMatchCondition,
Key: "type", Key: "type",
Pattern: "m.call.invite", Pattern: pointer("m.call.invite"),
}, },
}, },
Actions: []*Action{ Actions: []*Action{
@ -35,11 +35,6 @@ var (
Tweak: SoundTweak, Tweak: SoundTweak,
Value: "ring", Value: "ring",
}, },
{
Kind: SetTweakAction,
Tweak: HighlightTweak,
Value: false,
},
}, },
} }
mRuleEncryptedRoomOneToOneDefinition = Rule{ mRuleEncryptedRoomOneToOneDefinition = Rule{
@ -54,7 +49,7 @@ var (
{ {
Kind: EventMatchCondition, Kind: EventMatchCondition,
Key: "type", Key: "type",
Pattern: "m.room.encrypted", Pattern: pointer("m.room.encrypted"),
}, },
}, },
Actions: []*Action{ Actions: []*Action{
@ -64,11 +59,6 @@ var (
Tweak: SoundTweak, Tweak: SoundTweak,
Value: "default", Value: "default",
}, },
{
Kind: SetTweakAction,
Tweak: HighlightTweak,
Value: false,
},
}, },
} }
mRuleRoomOneToOneDefinition = Rule{ mRuleRoomOneToOneDefinition = Rule{
@ -83,20 +73,15 @@ var (
{ {
Kind: EventMatchCondition, Kind: EventMatchCondition,
Key: "type", Key: "type",
Pattern: "m.room.message", Pattern: pointer("m.room.message"),
}, },
}, },
Actions: []*Action{ Actions: []*Action{
{Kind: NotifyAction}, {Kind: NotifyAction},
{ {
Kind: SetTweakAction, Kind: SetTweakAction,
Tweak: HighlightTweak, Tweak: SoundTweak,
Value: false, Value: "default",
},
{
Kind: SetTweakAction,
Tweak: HighlightTweak,
Value: false,
}, },
}, },
} }
@ -108,16 +93,11 @@ var (
{ {
Kind: EventMatchCondition, Kind: EventMatchCondition,
Key: "type", Key: "type",
Pattern: "m.room.message", Pattern: pointer("m.room.message"),
}, },
}, },
Actions: []*Action{ Actions: []*Action{
{Kind: NotifyAction}, {Kind: NotifyAction},
{
Kind: SetTweakAction,
Tweak: HighlightTweak,
Value: false,
},
}, },
} }
mRuleEncryptedDefinition = Rule{ mRuleEncryptedDefinition = Rule{
@ -128,16 +108,11 @@ var (
{ {
Kind: EventMatchCondition, Kind: EventMatchCondition,
Key: "type", Key: "type",
Pattern: "m.room.encrypted", Pattern: pointer("m.room.encrypted"),
}, },
}, },
Actions: []*Action{ Actions: []*Action{
{Kind: NotifyAction}, {Kind: NotifyAction},
{
Kind: SetTweakAction,
Tweak: HighlightTweak,
Value: false,
},
}, },
} }
) )

View file

@ -104,7 +104,10 @@ func ruleMatches(rule *Rule, kind Kind, event *gomatrixserverlib.Event, ec Evalu
case ContentKind: case ContentKind:
// TODO: "These configure behaviour for (unencrypted) messages // TODO: "These configure behaviour for (unencrypted) messages
// that match certain patterns." - Does that mean "content.body"? // that match certain patterns." - Does that mean "content.body"?
return patternMatches("content.body", rule.Pattern, event) if rule.Pattern == nil {
return false, nil
}
return patternMatches("content.body", *rule.Pattern, event)
case RoomKind: case RoomKind:
return rule.RuleID == event.RoomID(), nil return rule.RuleID == event.RoomID(), nil
@ -120,7 +123,10 @@ func ruleMatches(rule *Rule, kind Kind, event *gomatrixserverlib.Event, ec Evalu
func conditionMatches(cond *Condition, event *gomatrixserverlib.Event, ec EvaluationContext) (bool, error) { func conditionMatches(cond *Condition, event *gomatrixserverlib.Event, ec EvaluationContext) (bool, error) {
switch cond.Kind { switch cond.Kind {
case EventMatchCondition: case EventMatchCondition:
return patternMatches(cond.Key, cond.Pattern, event) if cond.Pattern == nil {
return false, fmt.Errorf("missing condition pattern")
}
return patternMatches(cond.Key, *cond.Pattern, event)
case ContainsDisplayNameCondition: case ContainsDisplayNameCondition:
return patternMatches("content.body", ec.UserDisplayName(), event) return patternMatches("content.body", ec.UserDisplayName(), event)

View file

@ -79,8 +79,8 @@ func TestRuleMatches(t *testing.T) {
{"underrideConditionMatch", UnderrideKind, Rule{Enabled: true}, `{}`, true}, {"underrideConditionMatch", UnderrideKind, Rule{Enabled: true}, `{}`, true},
{"underrideConditionNoMatch", UnderrideKind, Rule{Enabled: true, Conditions: []*Condition{{}}}, `{}`, false}, {"underrideConditionNoMatch", UnderrideKind, Rule{Enabled: true, Conditions: []*Condition{{}}}, `{}`, false},
{"contentMatch", ContentKind, Rule{Enabled: true, Pattern: "b"}, `{"content":{"body":"abc"}}`, true}, {"contentMatch", ContentKind, Rule{Enabled: true, Pattern: pointer("b")}, `{"content":{"body":"abc"}}`, true},
{"contentNoMatch", ContentKind, Rule{Enabled: true, Pattern: "d"}, `{"content":{"body":"abc"}}`, false}, {"contentNoMatch", ContentKind, Rule{Enabled: true, Pattern: pointer("d")}, `{"content":{"body":"abc"}}`, false},
{"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!room@example.com"}`, true}, {"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!room@example.com"}`, true},
{"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!otherroom@example.com"}`, false}, {"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!otherroom@example.com"}`, false},
@ -106,41 +106,44 @@ func TestConditionMatches(t *testing.T) {
Name string Name string
Cond Condition Cond Condition
EventJSON string EventJSON string
Want bool WantMatch bool
WantErr bool
}{ }{
{"empty", Condition{}, `{}`, false}, {Name: "empty", Cond: Condition{}, EventJSON: `{}`, WantMatch: false, WantErr: false},
{"empty", Condition{Kind: "unknownstring"}, `{}`, false}, {Name: "empty", Cond: Condition{Kind: "unknownstring"}, EventJSON: `{}`, WantMatch: false, WantErr: false},
// Neither of these should match because `content` is not a full string match, // Neither of these should match because `content` is not a full string match,
// and `content.body` is not a string value. // and `content.body` is not a string value.
{"eventMatch", Condition{Kind: EventMatchCondition, Key: "content"}, `{"content":{}}`, false}, {Name: "eventMatch", Cond: Condition{Kind: EventMatchCondition, Key: "content", Pattern: pointer("")}, EventJSON: `{"content":{}}`, WantMatch: false, WantErr: false},
{"eventBodyMatch", Condition{Kind: EventMatchCondition, Key: "content.body", Is: "3"}, `{"content":{"body": 3}}`, false}, {Name: "eventBodyMatch", Cond: Condition{Kind: EventMatchCondition, Key: "content.body", Is: "3", Pattern: pointer("")}, EventJSON: `{"content":{"body": "3"}}`, WantMatch: false, WantErr: false},
{Name: "eventBodyMatch matches", Cond: Condition{Kind: EventMatchCondition, Key: "content.body", Pattern: pointer("world")}, EventJSON: `{"content":{"body": "hello world!"}}`, WantMatch: true, WantErr: false},
{Name: "EventMatch missing pattern", Cond: Condition{Kind: EventMatchCondition, Key: "content.body"}, EventJSON: `{"content":{"body": "hello world!"}}`, WantMatch: false, WantErr: true},
{"displayNameNoMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"something without displayname"}}`, false}, {Name: "displayNameNoMatch", Cond: Condition{Kind: ContainsDisplayNameCondition}, EventJSON: `{"content":{"body":"something without displayname"}}`, WantMatch: false, WantErr: false},
{"displayNameMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"hello Dear User, how are you?"}}`, true}, {Name: "displayNameMatch", Cond: Condition{Kind: ContainsDisplayNameCondition}, EventJSON: `{"content":{"body":"hello Dear User, how are you?"}}`, WantMatch: true, WantErr: false},
{"roomMemberCountLessNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "<2"}, `{}`, false}, {Name: "roomMemberCountLessNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "<2"}, EventJSON: `{}`, WantMatch: false, WantErr: false},
{"roomMemberCountLessMatch", Condition{Kind: RoomMemberCountCondition, Is: "<3"}, `{}`, true}, {Name: "roomMemberCountLessMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "<3"}, EventJSON: `{}`, WantMatch: true, WantErr: false},
{"roomMemberCountLessEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "<=1"}, `{}`, false}, {Name: "roomMemberCountLessEqualNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "<=1"}, EventJSON: `{}`, WantMatch: false, WantErr: false},
{"roomMemberCountLessEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: "<=2"}, `{}`, true}, {Name: "roomMemberCountLessEqualMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "<=2"}, EventJSON: `{}`, WantMatch: true, WantErr: false},
{"roomMemberCountEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "==1"}, `{}`, false}, {Name: "roomMemberCountEqualNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "==1"}, EventJSON: `{}`, WantMatch: false, WantErr: false},
{"roomMemberCountEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: "==2"}, `{}`, true}, {Name: "roomMemberCountEqualMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "==2"}, EventJSON: `{}`, WantMatch: true, WantErr: false},
{"roomMemberCountGreaterEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: ">=3"}, `{}`, false}, {Name: "roomMemberCountGreaterEqualNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: ">=3"}, EventJSON: `{}`, WantMatch: false, WantErr: false},
{"roomMemberCountGreaterEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: ">=2"}, `{}`, true}, {Name: "roomMemberCountGreaterEqualMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: ">=2"}, EventJSON: `{}`, WantMatch: true, WantErr: false},
{"roomMemberCountGreaterNoMatch", Condition{Kind: RoomMemberCountCondition, Is: ">2"}, `{}`, false}, {Name: "roomMemberCountGreaterNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: ">2"}, EventJSON: `{}`, WantMatch: false, WantErr: false},
{"roomMemberCountGreaterMatch", Condition{Kind: RoomMemberCountCondition, Is: ">1"}, `{}`, true}, {Name: "roomMemberCountGreaterMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: ">1"}, EventJSON: `{}`, WantMatch: true, WantErr: false},
{"senderNotificationPermissionMatch", Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, `{"sender":"@poweruser:example.com"}`, true}, {Name: "senderNotificationPermissionMatch", Cond: Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, EventJSON: `{"sender":"@poweruser:example.com"}`, WantMatch: true, WantErr: false},
{"senderNotificationPermissionNoMatch", Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, `{"sender":"@nobody:example.com"}`, false}, {Name: "senderNotificationPermissionNoMatch", Cond: Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, EventJSON: `{"sender":"@nobody:example.com"}`, WantMatch: false, WantErr: false},
} }
for _, tst := range tsts { for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) { t.Run(tst.Name, func(t *testing.T) {
got, err := conditionMatches(&tst.Cond, mustEventFromJSON(t, tst.EventJSON), &fakeEvaluationContext{2}) got, err := conditionMatches(&tst.Cond, mustEventFromJSON(t, tst.EventJSON), &fakeEvaluationContext{2})
if err != nil { if err != nil && !tst.WantErr {
t.Fatalf("conditionMatches failed: %v", err) t.Fatalf("conditionMatches failed: %v", err)
} }
if got != tst.Want { if got != tst.WantMatch {
t.Errorf("conditionMatches: got %v, want %v on %s", got, tst.Want, tst.Name) t.Errorf("conditionMatches: got %v, want %v on %s", got, tst.WantMatch, tst.Name)
} }
}) })
} }

View file

@ -36,18 +36,18 @@ type Rule struct {
// around. Required. // around. Required.
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
// Conditions provide the rule's conditions for OverrideKind and
// UnderrideKind. Not allowed for other kinds.
Conditions []*Condition `json:"conditions,omitempty"`
// Actions describe the desired outcome, should the rule // Actions describe the desired outcome, should the rule
// match. Required. // match. Required.
Actions []*Action `json:"actions"` Actions []*Action `json:"actions"`
// Conditions provide the rule's conditions for OverrideKind and
// UnderrideKind. Not allowed for other kinds.
Conditions []*Condition `json:"conditions"`
// Pattern is the body pattern to match for ContentKind. Required // Pattern is the body pattern to match for ContentKind. Required
// for that kind. The interpretation is the same as that of // for that kind. The interpretation is the same as that of
// Condition.Pattern. // Condition.Pattern.
Pattern string `json:"pattern"` Pattern *string `json:"pattern,omitempty"`
} }
// Scope only has one valid value. See also AccountRuleSets. // Scope only has one valid value. See also AccountRuleSets.

View file

@ -128,3 +128,7 @@ func parseRoomMemberCountCondition(s string) (func(int) bool, error) {
b = int(v) b = int(v)
return cmp, nil return cmp, nil
} }
func pointer[t any](s t) *t {
return &s
}

View file

@ -34,7 +34,10 @@ func ValidateRule(kind Kind, rule *Rule) []error {
} }
case ContentKind: case ContentKind:
if rule.Pattern == "" { if rule.Pattern == nil {
errs = append(errs, fmt.Errorf("missing content rule pattern"))
}
if rule.Pattern != nil && *rule.Pattern == "" {
errs = append(errs, fmt.Errorf("missing content rule pattern")) errs = append(errs, fmt.Errorf("missing content rule pattern"))
} }

View file

@ -12,15 +12,16 @@ func TestValidateRuleNegatives(t *testing.T) {
Rule Rule Rule Rule
WantErrString string WantErrString string
}{ }{
{"emptyRuleID", OverrideKind, Rule{}, "invalid rule ID"}, {Name: "emptyRuleID", Kind: OverrideKind, Rule: Rule{}, WantErrString: "invalid rule ID"},
{"invalidKind", Kind("something else"), Rule{}, "invalid rule kind"}, {Name: "invalidKind", Kind: Kind("something else"), Rule: Rule{}, WantErrString: "invalid rule kind"},
{"ruleIDBackslash", OverrideKind, Rule{RuleID: "#foo\\:example.com"}, "invalid rule ID"}, {Name: "ruleIDBackslash", Kind: OverrideKind, Rule: Rule{RuleID: "#foo\\:example.com"}, WantErrString: "invalid rule ID"},
{"noActions", OverrideKind, Rule{}, "missing actions"}, {Name: "noActions", Kind: OverrideKind, Rule: Rule{}, WantErrString: "missing actions"},
{"invalidAction", OverrideKind, Rule{Actions: []*Action{{}}}, "invalid rule action kind"}, {Name: "invalidAction", Kind: OverrideKind, Rule: Rule{Actions: []*Action{{}}}, WantErrString: "invalid rule action kind"},
{"invalidCondition", OverrideKind, Rule{Conditions: []*Condition{{}}}, "invalid rule condition kind"}, {Name: "invalidCondition", Kind: OverrideKind, Rule: Rule{Conditions: []*Condition{{}}}, WantErrString: "invalid rule condition kind"},
{"overrideNoCondition", OverrideKind, Rule{}, "missing rule conditions"}, {Name: "overrideNoCondition", Kind: OverrideKind, Rule: Rule{}, WantErrString: "missing rule conditions"},
{"underrideNoCondition", UnderrideKind, Rule{}, "missing rule conditions"}, {Name: "underrideNoCondition", Kind: UnderrideKind, Rule: Rule{}, WantErrString: "missing rule conditions"},
{"contentNoPattern", ContentKind, Rule{}, "missing content rule pattern"}, {Name: "contentNoPattern", Kind: ContentKind, Rule: Rule{}, WantErrString: "missing content rule pattern"},
{Name: "contentEmptyPattern", Kind: ContentKind, Rule: Rule{Pattern: pointer("")}, WantErrString: "missing content rule pattern"},
} }
for _, tst := range tsts { for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) { t.Run(tst.Name, func(t *testing.T) {

View file

@ -15,30 +15,96 @@
package internal package internal
import ( import (
"errors"
"fmt" "fmt"
"net/http" "net/http"
"regexp"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
const minPasswordLength = 8 // http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based const (
maxUsernameLength = 254 // http://matrix.org/speculator/spec/HEAD/intro.html#user-identifiers TODO account for domain
const maxPasswordLength = 512 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 minPasswordLength = 8 // http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based
maxPasswordLength = 512 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
)
// ValidatePassword returns an error response if the password is invalid var (
func ValidatePassword(password string) *util.JSONResponse { ErrPasswordTooLong = fmt.Errorf("password too long: max %d characters", maxPasswordLength)
ErrPasswordWeak = fmt.Errorf("password too weak: min %d characters", minPasswordLength)
ErrUsernameTooLong = fmt.Errorf("username exceeds the maximum length of %d characters", maxUsernameLength)
ErrUsernameInvalid = errors.New("username can only contain characters a-z, 0-9, or '_-./='")
ErrUsernameUnderscore = errors.New("username cannot start with a '_'")
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`)
)
// ValidatePassword returns an error if the password is invalid
func ValidatePassword(password string) error {
// https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
if len(password) > maxPasswordLength { if len(password) > maxPasswordLength {
return &util.JSONResponse{ return ErrPasswordTooLong
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(fmt.Sprintf("password too long: max %d characters", maxPasswordLength)),
}
} else if len(password) > 0 && len(password) < minPasswordLength { } else if len(password) > 0 && len(password) < minPasswordLength {
return ErrPasswordWeak
}
return nil
}
// PasswordResponse returns a util.JSONResponse for a given error, if any.
func PasswordResponse(err error) *util.JSONResponse {
switch err {
case ErrPasswordWeak:
return &util.JSONResponse{ return &util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.WeakPassword(fmt.Sprintf("password too weak: min %d chars", minPasswordLength)), JSON: jsonerror.WeakPassword(ErrPasswordWeak.Error()),
}
case ErrPasswordTooLong:
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(ErrPasswordTooLong.Error()),
} }
} }
return nil return nil
} }
// ValidateUsername returns an error if the username is invalid
func ValidateUsername(localpart string, domain gomatrixserverlib.ServerName) error {
// https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength {
return ErrUsernameTooLong
} else if !validUsernameRegex.MatchString(localpart) {
return ErrUsernameInvalid
} else if localpart[0] == '_' { // Regex checks its not a zero length string
return ErrUsernameUnderscore
}
return nil
}
// UsernameResponse returns a util.JSONResponse for the given error, if any.
func UsernameResponse(err error) *util.JSONResponse {
switch err {
case ErrUsernameTooLong:
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(err.Error()),
}
case ErrUsernameInvalid, ErrUsernameUnderscore:
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername(err.Error()),
}
}
return nil
}
// ValidateApplicationServiceUsername returns an error if the username is invalid for an application service
func ValidateApplicationServiceUsername(localpart string, domain gomatrixserverlib.ServerName) error {
if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength {
return ErrUsernameTooLong
} else if !validUsernameRegex.MatchString(localpart) {
return ErrUsernameInvalid
}
return nil
}

170
internal/validate_test.go Normal file
View file

@ -0,0 +1,170 @@
package internal
import (
"net/http"
"reflect"
"strings"
"testing"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
func Test_validatePassword(t *testing.T) {
tests := []struct {
name string
password string
wantError error
wantJSON *util.JSONResponse
}{
{
name: "password too short",
password: "shortpw",
wantError: ErrPasswordWeak,
wantJSON: &util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.WeakPassword(ErrPasswordWeak.Error())},
},
{
name: "password too long",
password: strings.Repeat("a", maxPasswordLength+1),
wantError: ErrPasswordTooLong,
wantJSON: &util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(ErrPasswordTooLong.Error())},
},
{
name: "password OK",
password: util.RandomString(10),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotErr := ValidatePassword(tt.password)
if !reflect.DeepEqual(gotErr, tt.wantError) {
t.Errorf("validatePassword() = %v, wantJSON %v", gotErr, tt.wantError)
}
if got := PasswordResponse(gotErr); !reflect.DeepEqual(got, tt.wantJSON) {
t.Errorf("validatePassword() = %v, wantJSON %v", got, tt.wantJSON)
}
})
}
}
func Test_validateUsername(t *testing.T) {
tooLongUsername := strings.Repeat("a", maxUsernameLength)
tests := []struct {
name string
localpart string
domain gomatrixserverlib.ServerName
wantErr error
wantJSON *util.JSONResponse
}{
{
name: "empty username",
localpart: "",
domain: "localhost",
wantErr: ErrUsernameInvalid,
wantJSON: &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()),
},
},
{
name: "invalid username",
localpart: "INVALIDUSERNAME",
domain: "localhost",
wantErr: ErrUsernameInvalid,
wantJSON: &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()),
},
},
{
name: "username too long",
localpart: tooLongUsername,
domain: "localhost",
wantErr: ErrUsernameTooLong,
wantJSON: &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(ErrUsernameTooLong.Error()),
},
},
{
name: "localpart starting with an underscore",
localpart: "_notvalid",
domain: "localhost",
wantErr: ErrUsernameUnderscore,
wantJSON: &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername(ErrUsernameUnderscore.Error()),
},
},
{
name: "valid username",
localpart: "valid",
domain: "localhost",
},
{
name: "complex username",
localpart: "f00_bar-baz.=40/",
domain: "localhost",
},
{
name: "rejects emoji username 💥",
localpart: "💥",
domain: "localhost",
wantErr: ErrUsernameInvalid,
wantJSON: &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()),
},
},
{
name: "special characters are allowed",
localpart: "/dev/null",
domain: "localhost",
},
{
name: "special characters are allowed 2",
localpart: "i_am_allowed=1",
domain: "localhost",
},
{
name: "not all special characters are allowed",
localpart: "notallowed#", // contains #
domain: "localhost",
wantErr: ErrUsernameInvalid,
wantJSON: &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()),
},
},
{
name: "username containing numbers",
localpart: "hello1337",
domain: "localhost",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotErr := ValidateUsername(tt.localpart, tt.domain)
if !reflect.DeepEqual(gotErr, tt.wantErr) {
t.Errorf("ValidateUsername() = %v, wantErr %v", gotErr, tt.wantErr)
}
if gotJSON := UsernameResponse(gotErr); !reflect.DeepEqual(gotJSON, tt.wantJSON) {
t.Errorf("UsernameResponse() = %v, wantJSON %v", gotJSON, tt.wantJSON)
}
// Application services are allowed usernames starting with an underscore
if tt.wantErr == ErrUsernameUnderscore {
return
}
gotErr = ValidateApplicationServiceUsername(tt.localpart, tt.domain)
if !reflect.DeepEqual(gotErr, tt.wantErr) {
t.Errorf("ValidateUsername() = %v, wantErr %v", gotErr, tt.wantErr)
}
if gotJSON := UsernameResponse(gotErr); !reflect.DeepEqual(gotJSON, tt.wantJSON) {
t.Errorf("UsernameResponse() = %v, wantJSON %v", gotJSON, tt.wantJSON)
}
})
}
}

View file

@ -78,6 +78,7 @@ const (
type PerformJoinRequest struct { type PerformJoinRequest struct {
RoomIDOrAlias string `json:"room_id_or_alias"` RoomIDOrAlias string `json:"room_id_or_alias"`
UserID string `json:"user_id"` UserID string `json:"user_id"`
IsGuest bool `json:"is_guest"`
Content map[string]interface{} `json:"content"` Content map[string]interface{} `json:"content"`
ServerNames []gomatrixserverlib.ServerName `json:"server_names"` ServerNames []gomatrixserverlib.ServerName `json:"server_names"`
Unsigned map[string]interface{} `json:"unsigned"` Unsigned map[string]interface{} `json:"unsigned"`

View file

@ -4,6 +4,10 @@ import (
"context" "context"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
asAPI "github.com/matrix-org/dendrite/appservice/api" asAPI "github.com/matrix-org/dendrite/appservice/api"
fsAPI "github.com/matrix-org/dendrite/federationapi/api" fsAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
@ -19,9 +23,6 @@ import (
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
) )
// RoomserverInternalAPI is an implementation of api.RoomserverInternalAPI // RoomserverInternalAPI is an implementation of api.RoomserverInternalAPI
@ -104,6 +105,11 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
r.fsAPI = fsAPI r.fsAPI = fsAPI
r.KeyRing = keyRing r.KeyRing = keyRing
identity, err := r.Cfg.Matrix.SigningIdentityFor(r.ServerName)
if err != nil {
logrus.Panic(err)
}
r.Inputer = &input.Inputer{ r.Inputer = &input.Inputer{
Cfg: &r.Base.Cfg.RoomServer, Cfg: &r.Base.Cfg.RoomServer,
Base: r.Base, Base: r.Base,
@ -114,7 +120,8 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
JetStream: r.JetStream, JetStream: r.JetStream,
NATSClient: r.NATSClient, NATSClient: r.NATSClient,
Durable: nats.Durable(r.Durable), Durable: nats.Durable(r.Durable),
ServerName: r.Cfg.Matrix.ServerName, ServerName: r.ServerName,
SigningIdentity: identity,
FSAPI: fsAPI, FSAPI: fsAPI,
KeyRing: keyRing, KeyRing: keyRing,
ACLs: r.ServerACLs, ACLs: r.ServerACLs,
@ -135,7 +142,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
Queryer: r.Queryer, Queryer: r.Queryer,
} }
r.Peeker = &perform.Peeker{ r.Peeker = &perform.Peeker{
ServerName: r.Cfg.Matrix.ServerName, ServerName: r.ServerName,
Cfg: r.Cfg, Cfg: r.Cfg,
DB: r.DB, DB: r.DB,
FSAPI: r.fsAPI, FSAPI: r.fsAPI,
@ -146,7 +153,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
Inputer: r.Inputer, Inputer: r.Inputer,
} }
r.Unpeeker = &perform.Unpeeker{ r.Unpeeker = &perform.Unpeeker{
ServerName: r.Cfg.Matrix.ServerName, ServerName: r.ServerName,
Cfg: r.Cfg, Cfg: r.Cfg,
DB: r.DB, DB: r.DB,
FSAPI: r.fsAPI, FSAPI: r.fsAPI,
@ -193,6 +200,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.RoomserverUserAPI) { func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.RoomserverUserAPI) {
r.Leaver.UserAPI = userAPI r.Leaver.UserAPI = userAPI
r.Inputer.UserAPI = userAPI
} }
func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) { func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) {

View file

@ -23,6 +23,8 @@ import (
"sync" "sync"
"time" "time"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/Arceliar/phony" "github.com/Arceliar/phony"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -79,6 +81,7 @@ type Inputer struct {
JetStream nats.JetStreamContext JetStream nats.JetStreamContext
Durable nats.SubOpt Durable nats.SubOpt
ServerName gomatrixserverlib.ServerName ServerName gomatrixserverlib.ServerName
SigningIdentity *gomatrixserverlib.SigningIdentity
FSAPI fedapi.RoomserverFederationAPI FSAPI fedapi.RoomserverFederationAPI
KeyRing gomatrixserverlib.JSONVerifier KeyRing gomatrixserverlib.JSONVerifier
ACLs *acls.ServerACLs ACLs *acls.ServerACLs
@ -87,6 +90,7 @@ type Inputer struct {
workers sync.Map // room ID -> *worker workers sync.Map // room ID -> *worker
Queryer *query.Queryer Queryer *query.Queryer
UserAPI userapi.RoomserverUserAPI
} }
// If a room consumer is inactive for a while then we will allow NATS // If a room consumer is inactive for a while then we will allow NATS

View file

@ -19,6 +19,7 @@ package input
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"time" "time"
@ -31,6 +32,8 @@ import (
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
userAPI "github.com/matrix-org/dendrite/userapi/api"
fedapi "github.com/matrix-org/dendrite/federationapi/api" fedapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
@ -440,6 +443,13 @@ func (r *Inputer) processRoomEvent(
} }
} }
// If guest_access changed and is not can_join, kick all guest users.
if event.Type() == gomatrixserverlib.MRoomGuestAccess && gjson.GetBytes(event.Content(), "guest_access").Str != "can_join" {
if err = r.kickGuests(ctx, event, roomInfo); err != nil {
logrus.WithError(err).Error("failed to kick guest users on m.room.guest_access revocation")
}
}
// Everything was OK — the latest events updater didn't error and // Everything was OK — the latest events updater didn't error and
// we've sent output events. Finally, generate a hook call. // we've sent output events. Finally, generate a hook call.
hooks.Run(hooks.KindNewEventPersisted, headered) hooks.Run(hooks.KindNewEventPersisted, headered)
@ -729,3 +739,98 @@ func (r *Inputer) calculateAndSetState(
succeeded = true succeeded = true
return nil return nil
} }
// kickGuests kicks guests users from m.room.guest_access rooms, if guest access is now prohibited.
func (r *Inputer) kickGuests(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo) error {
membershipNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true)
if err != nil {
return err
}
memberEvents, err := r.DB.Events(ctx, membershipNIDs)
if err != nil {
return err
}
inputEvents := make([]api.InputRoomEvent, 0, len(memberEvents))
latestReq := &api.QueryLatestEventsAndStateRequest{
RoomID: event.RoomID(),
}
latestRes := &api.QueryLatestEventsAndStateResponse{}
if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil {
return err
}
prevEvents := latestRes.LatestEvents
for _, memberEvent := range memberEvents {
if memberEvent.StateKey() == nil {
continue
}
localpart, senderDomain, err := gomatrixserverlib.SplitID('@', *memberEvent.StateKey())
if err != nil {
continue
}
accountRes := &userAPI.QueryAccountByLocalpartResponse{}
if err = r.UserAPI.QueryAccountByLocalpart(ctx, &userAPI.QueryAccountByLocalpartRequest{
Localpart: localpart,
ServerName: senderDomain,
}, accountRes); err != nil {
return err
}
if accountRes.Account == nil {
continue
}
if accountRes.Account.AccountType != userAPI.AccountTypeGuest {
continue
}
var memberContent gomatrixserverlib.MemberContent
if err = json.Unmarshal(memberEvent.Content(), &memberContent); err != nil {
return err
}
memberContent.Membership = gomatrixserverlib.Leave
stateKey := *memberEvent.StateKey()
fledglingEvent := &gomatrixserverlib.EventBuilder{
RoomID: event.RoomID(),
Type: gomatrixserverlib.MRoomMember,
StateKey: &stateKey,
Sender: stateKey,
PrevEvents: prevEvents,
}
if fledglingEvent.Content, err = json.Marshal(memberContent); err != nil {
return err
}
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(fledglingEvent)
if err != nil {
return err
}
event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, r.SigningIdentity, time.Now(), &eventsNeeded, latestRes)
if err != nil {
return err
}
inputEvents = append(inputEvents, api.InputRoomEvent{
Kind: api.KindNew,
Event: event,
Origin: senderDomain,
SendAsServer: string(senderDomain),
})
prevEvents = []gomatrixserverlib.EventReference{
event.EventReference(),
}
}
inputReq := &api.InputRoomEventsRequest{
InputRoomEvents: inputEvents,
Asynchronous: true, // Needs to be async, as we otherwise create a deadlock
}
inputRes := &api.InputRoomEventsResponse{}
return r.InputRoomEvents(ctx, inputReq, inputRes)
}

View file

@ -122,11 +122,14 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
ctx, req.VirtualHost, requester, ctx, req.VirtualHost, requester,
r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100,
) )
if err != nil { // Only return an error if we really couldn't get any events.
if err != nil && len(events) == 0 {
logrus.WithError(err).Errorf("gomatrixserverlib.RequestBackfill failed") logrus.WithError(err).Errorf("gomatrixserverlib.RequestBackfill failed")
return err return err
} }
logrus.WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events)) // If we got an error but still got events, that's fine, because a server might have returned a 404 (or something)
// but other servers could provide the missing event.
logrus.WithError(err).WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events))
// persist these new events - auth checks have already been done // persist these new events - auth checks have already been done
roomNID, backfilledEventMap := persistEvents(ctx, r.DB, events) roomNID, backfilledEventMap := persistEvents(ctx, r.DB, events)
@ -319,6 +322,7 @@ FederationHit:
FedClient: b.fsAPI, FedClient: b.fsAPI,
RememberAuthEvents: false, RememberAuthEvents: false,
Server: srv, Server: srv,
Origin: b.virtualHost,
} }
res, err := c.StateIDsBeforeEvent(ctx, targetEvent) res, err := c.StateIDsBeforeEvent(ctx, targetEvent)
if err != nil { if err != nil {
@ -394,6 +398,7 @@ func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatr
FedClient: b.fsAPI, FedClient: b.fsAPI,
RememberAuthEvents: false, RememberAuthEvents: false,
Server: srv, Server: srv,
Origin: b.virtualHost,
} }
result, err := c.StateBeforeEvent(ctx, roomVer, event, eventIDs) result, err := c.StateBeforeEvent(ctx, roomVer, event, eventIDs)
if err != nil { if err != nil {

View file

@ -16,6 +16,7 @@ package perform
import ( import (
"context" "context"
"database/sql"
"errors" "errors"
"fmt" "fmt"
"strings" "strings"
@ -270,6 +271,28 @@ func (r *Joiner) performJoinRoomByID(
} }
} }
// If a guest is trying to join a room, check that the room has a m.room.guest_access event
if req.IsGuest {
var guestAccessEvent *gomatrixserverlib.HeaderedEvent
guestAccess := "forbidden"
guestAccessEvent, err = r.DB.GetStateEvent(ctx, req.RoomIDOrAlias, gomatrixserverlib.MRoomGuestAccess, "")
if (err != nil && !errors.Is(err, sql.ErrNoRows)) || guestAccessEvent == nil {
logrus.WithError(err).Warn("unable to get m.room.guest_access event, defaulting to 'forbidden'")
}
if guestAccessEvent != nil {
guestAccess = gjson.GetBytes(guestAccessEvent.Content(), "guest_access").String()
}
// Servers MUST only allow guest users to join rooms if the m.room.guest_access state event
// is present on the room and has the guest_access value can_join.
if guestAccess != "can_join" {
return "", "", &rsAPI.PerformError{
Code: rsAPI.PerformErrorNotAllowed,
Msg: "Guest access is forbidden",
}
}
}
// If we should do a forced federated join then do that. // If we should do a forced federated join then do that.
var joinedVia gomatrixserverlib.ServerName var joinedVia gomatrixserverlib.ServerName
if forceFederatedJoin { if forceFederatedJoin {

View file

@ -3,18 +3,23 @@ package roomserver_test
import ( import (
"context" "context"
"net/http" "net/http"
"reflect"
"testing" "testing"
"time" "time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/userapi"
userAPI "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/inthttp" "github.com/matrix-org/dendrite/roomserver/inthttp"
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/test/testrig"
) )
@ -29,7 +34,28 @@ func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, s
return base, db, close return base, db, close
} }
func Test_SharedUsers(t *testing.T) { func TestUsers(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
base, close := testrig.CreateBaseDendrite(t, dbType)
defer close()
rsAPI := roomserver.NewInternalAPI(base)
// SetFederationAPI starts the room event input consumer
rsAPI.SetFederationAPI(nil, nil)
t.Run("shared users", func(t *testing.T) {
testSharedUsers(t, rsAPI)
})
t.Run("kick users", func(t *testing.T) {
usrAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, nil, rsAPI, nil)
rsAPI.SetUserAPI(usrAPI)
testKickUsers(t, rsAPI, usrAPI)
})
})
}
func testSharedUsers(t *testing.T, rsAPI api.RoomserverInternalAPI) {
alice := test.NewUser(t) alice := test.NewUser(t)
bob := test.NewUser(t) bob := test.NewUser(t)
room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat)) room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat))
@ -43,36 +69,93 @@ func Test_SharedUsers(t *testing.T) {
}, test.WithStateKey(bob.ID)) }, test.WithStateKey(bob.ID))
ctx := context.Background() ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
base, _, close := mustCreateDatabase(t, dbType)
defer close()
rsAPI := roomserver.NewInternalAPI(base) // Create the room
// SetFederationAPI starts the room event input consumer if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
rsAPI.SetFederationAPI(nil, nil) t.Errorf("failed to send events: %v", err)
// Create the room }
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
t.Fatalf("failed to send events: %v", err) // Query the shared users for Alice, there should only be Bob.
// This is used by the SyncAPI keychange consumer.
res := &api.QuerySharedUsersResponse{}
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID}, res); err != nil {
t.Errorf("unable to query known users: %v", err)
}
if _, ok := res.UserIDsToCount[bob.ID]; !ok {
t.Errorf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
}
// Also verify that we get the expected result when specifying OtherUserIDs.
// This is used by the SyncAPI when getting device list changes.
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID, OtherUserIDs: []string{bob.ID}}, res); err != nil {
t.Errorf("unable to query known users: %v", err)
}
if _, ok := res.UserIDsToCount[bob.ID]; !ok {
t.Errorf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
}
}
func testKickUsers(t *testing.T, rsAPI api.RoomserverInternalAPI, usrAPI userAPI.UserInternalAPI) {
// Create users and room; Bob is going to be the guest and kicked on revocation of guest access
alice := test.NewUser(t, test.WithAccountType(userAPI.AccountTypeUser))
bob := test.NewUser(t, test.WithAccountType(userAPI.AccountTypeGuest))
room := test.NewRoom(t, alice, test.RoomPreset(test.PresetPublicChat), test.GuestsCanJoin(true))
// Join with the guest user
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "join",
}, test.WithStateKey(bob.ID))
ctx := context.Background()
// Create the users in the userapi, so the RSAPI can query the account type later
for _, u := range []*test.User{alice, bob} {
localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID)
userRes := &userAPI.PerformAccountCreationResponse{}
if err := usrAPI.PerformAccountCreation(ctx, &userAPI.PerformAccountCreationRequest{
AccountType: u.AccountType,
Localpart: localpart,
ServerName: serverName,
Password: "someRandomPassword",
}, userRes); err != nil {
t.Errorf("failed to create account: %s", err)
}
}
// Create the room in the database
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
t.Errorf("failed to send events: %v", err)
}
// Get the membership events BEFORE revoking guest access
membershipRes := &api.QueryMembershipsForRoomResponse{}
if err := rsAPI.QueryMembershipsForRoom(ctx, &api.QueryMembershipsForRoomRequest{LocalOnly: true, JoinedOnly: true, RoomID: room.ID}, membershipRes); err != nil {
t.Errorf("failed to query membership for room: %s", err)
}
// revoke guest access
revokeEvent := room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomGuestAccess, map[string]string{"guest_access": "forbidden"}, test.WithStateKey(""))
if err := api.SendEvents(ctx, rsAPI, api.KindNew, []*gomatrixserverlib.HeaderedEvent{revokeEvent}, "test", "test", "test", nil, false); err != nil {
t.Errorf("failed to send events: %v", err)
}
// TODO: Even though we are sending the events sync, the "kickUsers" function is sending the events async, so we need
// to loop and wait for the events to be processed by the roomserver.
for i := 0; i <= 20; i++ {
// Get the membership events AFTER revoking guest access
membershipRes2 := &api.QueryMembershipsForRoomResponse{}
if err := rsAPI.QueryMembershipsForRoom(ctx, &api.QueryMembershipsForRoomRequest{LocalOnly: true, JoinedOnly: true, RoomID: room.ID}, membershipRes2); err != nil {
t.Errorf("failed to query membership for room: %s", err)
} }
// Query the shared users for Alice, there should only be Bob. // The membership events should NOT match, as Bob (guest user) should now be kicked from the room
// This is used by the SyncAPI keychange consumer. if !reflect.DeepEqual(membershipRes, membershipRes2) {
res := &api.QuerySharedUsersResponse{} return
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID}, res); err != nil {
t.Fatalf("unable to query known users: %v", err)
} }
if _, ok := res.UserIDsToCount[bob.ID]; !ok { time.Sleep(time.Millisecond * 10)
t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount) }
}
// Also verify that we get the expected result when specifying OtherUserIDs. t.Errorf("memberships didn't change in time")
// This is used by the SyncAPI when getting device list changes.
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID, OtherUserIDs: []string{bob.ID}}, res); err != nil {
t.Fatalf("unable to query known users: %v", err)
}
if _, ok := res.UserIDsToCount[bob.ID]; !ok {
t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
}
})
} }
func Test_QueryLeftUsers(t *testing.T) { func Test_QueryLeftUsers(t *testing.T) {

View file

@ -29,7 +29,7 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
yaml "gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
jaegerconfig "github.com/uber/jaeger-client-go/config" jaegerconfig "github.com/uber/jaeger-client-go/config"
jaegermetrics "github.com/uber/jaeger-lib/metrics" jaegermetrics "github.com/uber/jaeger-lib/metrics"
@ -314,11 +314,13 @@ func (config *Dendrite) Derive() error {
if config.ClientAPI.RecaptchaEnabled { if config.ClientAPI.RecaptchaEnabled {
config.Derived.Registration.Params[authtypes.LoginTypeRecaptcha] = map[string]string{"public_key": config.ClientAPI.RecaptchaPublicKey} config.Derived.Registration.Params[authtypes.LoginTypeRecaptcha] = map[string]string{"public_key": config.ClientAPI.RecaptchaPublicKey}
config.Derived.Registration.Flows = append(config.Derived.Registration.Flows, config.Derived.Registration.Flows = []authtypes.Flow{
authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeRecaptcha}}) {Stages: []authtypes.LoginType{authtypes.LoginTypeRecaptcha}},
}
} else { } else {
config.Derived.Registration.Flows = append(config.Derived.Registration.Flows, config.Derived.Registration.Flows = []authtypes.Flow{
authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeDummy}}) {Stages: []authtypes.LoginType{authtypes.LoginTypeDummy}},
}
} }
// Load application service configuration files // Load application service configuration files

View file

@ -78,9 +78,6 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
c.TURN.Verify(configErrs) c.TURN.Verify(configErrs)
c.RateLimiting.Verify(configErrs) c.RateLimiting.Verify(configErrs)
if c.RecaptchaEnabled { if c.RecaptchaEnabled {
checkNotEmpty(configErrs, "client_api.recaptcha_public_key", c.RecaptchaPublicKey)
checkNotEmpty(configErrs, "client_api.recaptcha_private_key", c.RecaptchaPrivateKey)
checkNotEmpty(configErrs, "client_api.recaptcha_siteverify_api", c.RecaptchaSiteVerifyAPI)
if c.RecaptchaSiteVerifyAPI == "" { if c.RecaptchaSiteVerifyAPI == "" {
c.RecaptchaSiteVerifyAPI = "https://www.google.com/recaptcha/api/siteverify" c.RecaptchaSiteVerifyAPI = "https://www.google.com/recaptcha/api/siteverify"
} }
@ -93,6 +90,10 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
if c.RecaptchaSitekeyClass == "" { if c.RecaptchaSitekeyClass == "" {
c.RecaptchaSitekeyClass = "g-recaptcha-response" c.RecaptchaSitekeyClass = "g-recaptcha-response"
} }
checkNotEmpty(configErrs, "client_api.recaptcha_public_key", c.RecaptchaPublicKey)
checkNotEmpty(configErrs, "client_api.recaptcha_private_key", c.RecaptchaPrivateKey)
checkNotEmpty(configErrs, "client_api.recaptcha_siteverify_api", c.RecaptchaSiteVerifyAPI)
checkNotEmpty(configErrs, "client_api.recaptcha_sitekey_class", c.RecaptchaSitekeyClass)
} }
// Ensure there is any spam counter measure when enabling registration // Ensure there is any spam counter measure when enabling registration
if !c.RegistrationDisabled && !c.OpenRegistrationWithoutVerificationEnabled { if !c.RegistrationDisabled && !c.OpenRegistrationWithoutVerificationEnabled {

View file

@ -174,7 +174,7 @@ func (c *Global) SigningIdentityFor(serverName gomatrixserverlib.ServerName) (*g
return id, nil return id, nil
} }
} }
return nil, fmt.Errorf("no signing identity %q", serverName) return nil, fmt.Errorf("no signing identity for %q", serverName)
} }
func (c *Global) SigningIdentities() []*gomatrixserverlib.SigningIdentity { func (c *Global) SigningIdentities() []*gomatrixserverlib.SigningIdentity {

View file

@ -16,8 +16,10 @@ package config
import ( import (
"fmt" "fmt"
"reflect"
"testing" "testing"
"github.com/matrix-org/gomatrixserverlib"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
@ -290,3 +292,55 @@ func TestUnmarshalDataUnit(t *testing.T) {
} }
} }
} }
func Test_SigningIdentityFor(t *testing.T) {
tests := []struct {
name string
virtualHosts []*VirtualHost
serverName gomatrixserverlib.ServerName
want *gomatrixserverlib.SigningIdentity
wantErr bool
}{
{
name: "no virtual hosts defined",
wantErr: true,
},
{
name: "no identity found",
serverName: gomatrixserverlib.ServerName("doesnotexist"),
wantErr: true,
},
{
name: "found identity",
serverName: gomatrixserverlib.ServerName("main"),
want: &gomatrixserverlib.SigningIdentity{ServerName: "main"},
},
{
name: "identity found on virtual hosts",
serverName: gomatrixserverlib.ServerName("vh2"),
virtualHosts: []*VirtualHost{
{SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh1"}},
{SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh2"}},
},
want: &gomatrixserverlib.SigningIdentity{ServerName: "vh2"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Global{
VirtualHosts: tt.virtualHosts,
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: "main",
},
}
got, err := c.SigningIdentityFor(tt.serverName)
if (err != nil) != tt.wantErr {
t.Errorf("SigningIdentityFor() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("SigningIdentityFor() got = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -48,4 +48,7 @@ If a device list update goes missing, the server resyncs on the next one
Leaves are present in non-gapped incremental syncs Leaves are present in non-gapped incremental syncs
# Below test was passing for the wrong reason, failing correctly since #2858 # Below test was passing for the wrong reason, failing correctly since #2858
New federated private chats get full presence information (SYN-115) New federated private chats get full presence information (SYN-115)
# We don't have any state to calculate m.room.guest_access when accepting invites
Guest users can accept invites to private rooms over federation

View file

@ -763,4 +763,7 @@ AS and main public room lists are separate
local user has tags copied to the new room local user has tags copied to the new room
remote user has tags copied to the new room remote user has tags copied to the new room
/upgrade moves remote aliases to the new room /upgrade moves remote aliases to the new room
Local and remote users' homeservers remove a room from their public directory on upgrade Local and remote users' homeservers remove a room from their public directory on upgrade
Guest users denied access over federation if guest access prohibited
Guest users are kicked from guest_access rooms on revocation of guest_access
Guest users are kicked from guest_access rooms on revocation of guest_access over federation

View file

@ -38,11 +38,12 @@ var (
) )
type Room struct { type Room struct {
ID string ID string
Version gomatrixserverlib.RoomVersion Version gomatrixserverlib.RoomVersion
preset Preset preset Preset
visibility gomatrixserverlib.HistoryVisibility guestCanJoin bool
creator *User visibility gomatrixserverlib.HistoryVisibility
creator *User
authEvents gomatrixserverlib.AuthEvents authEvents gomatrixserverlib.AuthEvents
currentState map[string]*gomatrixserverlib.HeaderedEvent currentState map[string]*gomatrixserverlib.HeaderedEvent
@ -120,6 +121,11 @@ func (r *Room) insertCreateEvents(t *testing.T) {
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomPowerLevels, plContent, WithStateKey("")) r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomPowerLevels, plContent, WithStateKey(""))
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomJoinRules, joinRule, WithStateKey("")) r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomJoinRules, joinRule, WithStateKey(""))
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomHistoryVisibility, hisVis, WithStateKey("")) r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomHistoryVisibility, hisVis, WithStateKey(""))
if r.guestCanJoin {
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomGuestAccess, map[string]string{
"guest_access": "can_join",
}, WithStateKey(""))
}
} }
// Create an event in this room but do not insert it. Does not modify the room in any way (depth, fwd extremities, etc) so is thread-safe. // Create an event in this room but do not insert it. Does not modify the room in any way (depth, fwd extremities, etc) so is thread-safe.
@ -268,3 +274,9 @@ func RoomVersion(ver gomatrixserverlib.RoomVersion) roomModifier {
r.Version = ver r.Version = ver
} }
} }
func GuestsCanJoin(canJoin bool) roomModifier {
return func(t *testing.T, r *Room) {
r.guestCanJoin = canJoin
}
}

View file

@ -50,6 +50,7 @@ type KeyserverUserAPI interface {
type RoomserverUserAPI interface { type RoomserverUserAPI interface {
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error)
} }
// api functions required by the media api // api functions required by the media api
@ -671,3 +672,12 @@ type PerformSaveThreePIDAssociationRequest struct {
ServerName gomatrixserverlib.ServerName ServerName gomatrixserverlib.ServerName
Medium string Medium string
} }
type QueryAccountByLocalpartRequest struct {
Localpart string
ServerName gomatrixserverlib.ServerName
}
type QueryAccountByLocalpartResponse struct {
Account *Account
}

View file

@ -204,6 +204,12 @@ func (t *UserInternalAPITrace) PerformSaveThreePIDAssociation(ctx context.Contex
return err return err
} }
func (t *UserInternalAPITrace) QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) error {
err := t.Impl.QueryAccountByLocalpart(ctx, req, res)
util.GetLogger(ctx).Infof("QueryAccountByLocalpart req=%+v res=%+v", js(req), js(res))
return err
}
func js(thing interface{}) string { func js(thing interface{}) string {
b, err := json.Marshal(thing) b, err := json.Marshal(thing)
if err != nil { if err != nil {

View file

@ -81,11 +81,6 @@ func Test_evaluatePushRules(t *testing.T) {
wantAction: pushrules.NotifyAction, wantAction: pushrules.NotifyAction,
wantActions: []*pushrules.Action{ wantActions: []*pushrules.Action{
{Kind: pushrules.NotifyAction}, {Kind: pushrules.NotifyAction},
{
Kind: pushrules.SetTweakAction,
Tweak: pushrules.HighlightTweak,
Value: false,
},
}, },
}, },
{ {
@ -103,7 +98,6 @@ func Test_evaluatePushRules(t *testing.T) {
{ {
Kind: pushrules.SetTweakAction, Kind: pushrules.SetTweakAction,
Tweak: pushrules.HighlightTweak, Tweak: pushrules.HighlightTweak,
Value: true,
}, },
}, },
}, },

View file

@ -548,6 +548,11 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
return nil return nil
} }
func (a *UserInternalAPI) QueryAccountByLocalpart(ctx context.Context, req *api.QueryAccountByLocalpartRequest, res *api.QueryAccountByLocalpartResponse) (err error) {
res.Account, err = a.DB.GetAccountByLocalpart(ctx, req.Localpart, req.ServerName)
return
}
// Return the appservice 'device' or nil if the token is not an appservice. Returns an error if there was a problem // Return the appservice 'device' or nil if the token is not an appservice. Returns an error if there was a problem
// creating a 'device'. // creating a 'device'.
func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appServiceUserID string) (*api.Device, error) { func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appServiceUserID string) (*api.Device, error) {

View file

@ -60,6 +60,7 @@ const (
QueryAccountByPasswordPath = "/userapi/queryAccountByPassword" QueryAccountByPasswordPath = "/userapi/queryAccountByPassword"
QueryLocalpartForThreePIDPath = "/userapi/queryLocalpartForThreePID" QueryLocalpartForThreePIDPath = "/userapi/queryLocalpartForThreePID"
QueryThreePIDsForLocalpartPath = "/userapi/queryThreePIDsForLocalpart" QueryThreePIDsForLocalpartPath = "/userapi/queryThreePIDsForLocalpart"
QueryAccountByLocalpartPath = "/userapi/queryAccountType"
) )
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. // NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
@ -440,3 +441,14 @@ func (h *httpUserInternalAPI) PerformSaveThreePIDAssociation(
h.httpClient, ctx, request, response, h.httpClient, ctx, request, response,
) )
} }
func (h *httpUserInternalAPI) QueryAccountByLocalpart(
ctx context.Context,
req *api.QueryAccountByLocalpartRequest,
res *api.QueryAccountByLocalpartResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryAccountByLocalpart", h.apiURL+QueryAccountByLocalpartPath,
h.httpClient, ctx, req, res,
)
}

View file

@ -189,4 +189,9 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI, enableMetrics
PerformSaveThreePIDAssociationPath, PerformSaveThreePIDAssociationPath,
httputil.MakeInternalRPCAPI("UserAPIPerformSaveThreePIDAssociation", enableMetrics, s.PerformSaveThreePIDAssociation), httputil.MakeInternalRPCAPI("UserAPIPerformSaveThreePIDAssociation", enableMetrics, s.PerformSaveThreePIDAssociation),
) )
internalAPIMux.Handle(
QueryAccountByLocalpartPath,
httputil.MakeInternalRPCAPI("AccountByLocalpart", enableMetrics, s.QueryAccountByLocalpart),
)
} }

View file

@ -307,3 +307,64 @@ func TestLoginToken(t *testing.T) {
}) })
}) })
} }
func TestQueryAccountByLocalpart(t *testing.T) {
alice := test.NewUser(t)
localpart, userServername, _ := gomatrixserverlib.SplitID('@', alice.ID)
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
intAPI, db, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
defer close()
createdAcc, err := db.CreateAccount(ctx, localpart, userServername, "", "", alice.AccountType)
if err != nil {
t.Error(err)
}
testCases := func(t *testing.T, internalAPI api.UserInternalAPI) {
// Query existing account
queryAccResp := &api.QueryAccountByLocalpartResponse{}
if err = internalAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{
Localpart: localpart,
ServerName: userServername,
}, queryAccResp); err != nil {
t.Error(err)
}
if !reflect.DeepEqual(createdAcc, queryAccResp.Account) {
t.Fatalf("created and queried accounts don't match:\n%+v vs.\n%+v", createdAcc, queryAccResp.Account)
}
// Query non-existent account, this should result in an error
err = internalAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{
Localpart: "doesnotexist",
ServerName: userServername,
}, queryAccResp)
if err == nil {
t.Fatalf("expected an error, but got none: %+v", queryAccResp)
}
}
t.Run("Monolith", func(t *testing.T) {
testCases(t, intAPI)
// also test tracing
testCases(t, &api.UserInternalAPITrace{Impl: intAPI})
})
t.Run("HTTP API", func(t *testing.T) {
router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter()
userapi.AddInternalRoutes(router, intAPI, false)
apiURL, cancel := test.ListenAndServe(t, router, false)
defer cancel()
userHTTPApi, err := inthttp.NewUserAPIClient(apiURL, &http.Client{Timeout: time.Second * 5})
if err != nil {
t.Fatalf("failed to create HTTP client: %s", err)
}
testCases(t, userHTTPApi)
})
})
}