mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-20 20:43:09 -06:00
Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/helm
This commit is contained in:
commit
c6994efe70
3
.github/workflows/dendrite.yml
vendored
3
.github/workflows/dendrite.yml
vendored
|
|
@ -331,8 +331,7 @@ jobs:
|
||||||
postgres: postgres
|
postgres: postgres
|
||||||
api: full-http
|
api: full-http
|
||||||
container:
|
container:
|
||||||
# Temporary for debugging to see if this image is working better.
|
image: matrixdotorg/sytest-dendrite
|
||||||
image: matrixdotorg/sytest-dendrite@sha256:434ad464a9f4ed3f8c3cc47200275b6ccb5c5031a8063daf4acea62be5a23c73
|
|
||||||
volumes:
|
volumes:
|
||||||
- ${{ github.workspace }}:/src
|
- ${{ github.workspace }}:/src
|
||||||
- /root/.cache/go-build:/github/home/.cache/go-build
|
- /root/.cache/go-build:/github/home/.cache/go-build
|
||||||
|
|
|
||||||
27
Dockerfile
27
Dockerfile
|
|
@ -63,30 +63,3 @@ WORKDIR /etc/dendrite
|
||||||
ENTRYPOINT ["/usr/bin/dendrite-monolith-server"]
|
ENTRYPOINT ["/usr/bin/dendrite-monolith-server"]
|
||||||
EXPOSE 8008 8448
|
EXPOSE 8008 8448
|
||||||
|
|
||||||
#
|
|
||||||
# Builds the Complement image, used for integration tests
|
|
||||||
#
|
|
||||||
FROM base AS complement
|
|
||||||
LABEL org.opencontainers.image.title="Dendrite (Complement)"
|
|
||||||
RUN apk add --no-cache sqlite openssl ca-certificates
|
|
||||||
|
|
||||||
COPY --from=build /out/generate-config /usr/bin/generate-config
|
|
||||||
COPY --from=build /out/generate-keys /usr/bin/generate-keys
|
|
||||||
COPY --from=build /out/dendrite-monolith-server /usr/bin/dendrite-monolith-server
|
|
||||||
|
|
||||||
WORKDIR /dendrite
|
|
||||||
RUN /usr/bin/generate-keys --private-key matrix_key.pem && \
|
|
||||||
mkdir /ca && \
|
|
||||||
openssl genrsa -out /ca/ca.key 2048 && \
|
|
||||||
openssl req -new -x509 -key /ca/ca.key -days 3650 -subj "/C=GB/ST=London/O=matrix.org/CN=Complement CA" -out /ca/ca.crt
|
|
||||||
|
|
||||||
ENV SERVER_NAME=localhost
|
|
||||||
ENV API=0
|
|
||||||
EXPOSE 8008 8448
|
|
||||||
|
|
||||||
# At runtime, generate TLS cert based on the CA now mounted at /ca
|
|
||||||
# At runtime, replace the SERVER_NAME with what we are told
|
|
||||||
CMD /usr/bin/generate-keys --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /ca/ca.crt --tls-authority-key /ca/ca.key && \
|
|
||||||
/usr/bin/generate-config -server $SERVER_NAME --ci > dendrite.yaml && \
|
|
||||||
cp /ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \
|
|
||||||
/usr/bin/dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0}
|
|
||||||
|
|
|
||||||
|
|
@ -16,13 +16,16 @@ RUN --mount=target=. \
|
||||||
--mount=type=cache,target=/root/.cache/go-build \
|
--mount=type=cache,target=/root/.cache/go-build \
|
||||||
CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-config && \
|
CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-config && \
|
||||||
CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-keys && \
|
CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-keys && \
|
||||||
CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server
|
CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server && \
|
||||||
|
CGO_ENABLED=${CGO} go test -c -cover -covermode=atomic -o /dendrite/dendrite-monolith-server-cover -coverpkg "github.com/matrix-org/..." ./cmd/dendrite-monolith-server && \
|
||||||
|
cp build/scripts/complement-cmd.sh /complement-cmd.sh
|
||||||
|
|
||||||
WORKDIR /dendrite
|
WORKDIR /dendrite
|
||||||
RUN ./generate-keys --private-key matrix_key.pem
|
RUN ./generate-keys --private-key matrix_key.pem
|
||||||
|
|
||||||
ENV SERVER_NAME=localhost
|
ENV SERVER_NAME=localhost
|
||||||
ENV API=0
|
ENV API=0
|
||||||
|
ENV COVER=0
|
||||||
EXPOSE 8008 8448
|
EXPOSE 8008 8448
|
||||||
|
|
||||||
# At runtime, generate TLS cert based on the CA now mounted at /ca
|
# At runtime, generate TLS cert based on the CA now mounted at /ca
|
||||||
|
|
@ -30,4 +33,4 @@ EXPOSE 8008 8448
|
||||||
CMD ./generate-keys -keysize 1024 --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key && \
|
CMD ./generate-keys -keysize 1024 --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key && \
|
||||||
./generate-config -server $SERVER_NAME --ci > dendrite.yaml && \
|
./generate-config -server $SERVER_NAME --ci > dendrite.yaml && \
|
||||||
cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \
|
cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \
|
||||||
exec ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0}
|
exec /complement-cmd.sh
|
||||||
|
|
|
||||||
|
|
@ -12,18 +12,20 @@ FROM golang:1.18-stretch
|
||||||
RUN apt-get update && apt-get install -y sqlite3
|
RUN apt-get update && apt-get install -y sqlite3
|
||||||
|
|
||||||
ENV SERVER_NAME=localhost
|
ENV SERVER_NAME=localhost
|
||||||
|
ENV COVER=0
|
||||||
EXPOSE 8008 8448
|
EXPOSE 8008 8448
|
||||||
|
|
||||||
WORKDIR /runtime
|
WORKDIR /runtime
|
||||||
# This script compiles Dendrite for us.
|
# This script compiles Dendrite for us.
|
||||||
RUN echo '\
|
RUN echo '\
|
||||||
#!/bin/bash -eux \n\
|
#!/bin/bash -eux \n\
|
||||||
if test -f "/runtime/dendrite-monolith-server"; then \n\
|
if test -f "/runtime/dendrite-monolith-server" && test -f "/runtime/dendrite-monolith-server-cover"; then \n\
|
||||||
echo "Skipping compilation; binaries exist" \n\
|
echo "Skipping compilation; binaries exist" \n\
|
||||||
exit 0 \n\
|
exit 0 \n\
|
||||||
fi \n\
|
fi \n\
|
||||||
cd /dendrite \n\
|
cd /dendrite \n\
|
||||||
go build -v -o /runtime /dendrite/cmd/dendrite-monolith-server \n\
|
go build -v -o /runtime /dendrite/cmd/dendrite-monolith-server \n\
|
||||||
|
go test -c -cover -covermode=atomic -o /runtime/dendrite-monolith-server-cover -coverpkg "github.com/matrix-org/..." /dendrite/cmd/dendrite-monolith-server \n\
|
||||||
' > compile.sh && chmod +x compile.sh
|
' > compile.sh && chmod +x compile.sh
|
||||||
|
|
||||||
# This script runs Dendrite for us. Must be run in the /runtime directory.
|
# This script runs Dendrite for us. Must be run in the /runtime directory.
|
||||||
|
|
@ -33,6 +35,7 @@ RUN echo '\
|
||||||
./generate-keys -keysize 1024 --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key \n\
|
./generate-keys -keysize 1024 --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key \n\
|
||||||
./generate-config -server $SERVER_NAME --ci > dendrite.yaml \n\
|
./generate-config -server $SERVER_NAME --ci > dendrite.yaml \n\
|
||||||
cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates \n\
|
cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates \n\
|
||||||
|
[ ${COVER} -eq 1 ] && exec ./dendrite-monolith-server-cover --test.coverprofile=integrationcover.log --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml \n\
|
||||||
exec ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml \n\
|
exec ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml \n\
|
||||||
' > run.sh && chmod +x run.sh
|
' > run.sh && chmod +x run.sh
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -34,13 +34,16 @@ RUN --mount=target=. \
|
||||||
--mount=type=cache,target=/root/.cache/go-build \
|
--mount=type=cache,target=/root/.cache/go-build \
|
||||||
CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-config && \
|
CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-config && \
|
||||||
CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-keys && \
|
CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-keys && \
|
||||||
CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server
|
CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server && \
|
||||||
|
CGO_ENABLED=${CGO} go test -c -cover -covermode=atomic -o /dendrite/dendrite-monolith-server-cover -coverpkg "github.com/matrix-org/..." ./cmd/dendrite-monolith-server && \
|
||||||
|
cp build/scripts/complement-cmd.sh /complement-cmd.sh
|
||||||
|
|
||||||
WORKDIR /dendrite
|
WORKDIR /dendrite
|
||||||
RUN ./generate-keys --private-key matrix_key.pem
|
RUN ./generate-keys --private-key matrix_key.pem
|
||||||
|
|
||||||
ENV SERVER_NAME=localhost
|
ENV SERVER_NAME=localhost
|
||||||
ENV API=0
|
ENV API=0
|
||||||
|
ENV COVER=0
|
||||||
EXPOSE 8008 8448
|
EXPOSE 8008 8448
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -51,4 +54,4 @@ CMD /build/run_postgres.sh && ./generate-keys --keysize 1024 --server $SERVER_NA
|
||||||
# Bump max_open_conns up here in the global database config
|
# Bump max_open_conns up here in the global database config
|
||||||
sed -i 's/max_open_conns:.*$/max_open_conns: 1990/g' dendrite.yaml && \
|
sed -i 's/max_open_conns:.*$/max_open_conns: 1990/g' dendrite.yaml && \
|
||||||
cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \
|
cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \
|
||||||
exec ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0}
|
exec /complement-cmd.sh
|
||||||
22
build/scripts/complement-cmd.sh
Executable file
22
build/scripts/complement-cmd.sh
Executable file
|
|
@ -0,0 +1,22 @@
|
||||||
|
#!/bin/bash -e
|
||||||
|
|
||||||
|
# This script is intended to be used inside a docker container for Complement
|
||||||
|
|
||||||
|
if [[ "${COVER}" -eq 1 ]]; then
|
||||||
|
echo "Running with coverage"
|
||||||
|
exec /dendrite/dendrite-monolith-server-cover \
|
||||||
|
--really-enable-open-registration \
|
||||||
|
--tls-cert server.crt \
|
||||||
|
--tls-key server.key \
|
||||||
|
--config dendrite.yaml \
|
||||||
|
-api=${API:-0} \
|
||||||
|
--test.coverprofile=integrationcover.log
|
||||||
|
else
|
||||||
|
echo "Not running with coverage"
|
||||||
|
exec /dendrite/dendrite-monolith-server \
|
||||||
|
--really-enable-open-registration \
|
||||||
|
--tls-cert server.crt \
|
||||||
|
--tls-key server.key \
|
||||||
|
--config dendrite.yaml \
|
||||||
|
-api=${API:-0}
|
||||||
|
fi
|
||||||
|
|
@ -15,6 +15,8 @@
|
||||||
package clientapi
|
package clientapi
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
|
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
|
||||||
"github.com/matrix-org/dendrite/clientapi/api"
|
"github.com/matrix-org/dendrite/clientapi/api"
|
||||||
"github.com/matrix-org/dendrite/clientapi/producers"
|
"github.com/matrix-org/dendrite/clientapi/producers"
|
||||||
|
|
@ -26,7 +28,6 @@ import (
|
||||||
"github.com/matrix-org/dendrite/setup/base"
|
"github.com/matrix-org/dendrite/setup/base"
|
||||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// AddPublicRoutes sets up and registers HTTP handlers for the ClientAPI component.
|
// AddPublicRoutes sets up and registers HTTP handlers for the ClientAPI component.
|
||||||
|
|
|
||||||
|
|
@ -137,7 +137,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
|
||||||
request := struct {
|
request := struct {
|
||||||
Password string `json:"password"`
|
Password string `json:"password"`
|
||||||
}{}
|
}{}
|
||||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
if err = json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusBadRequest,
|
Code: http.StatusBadRequest,
|
||||||
JSON: jsonerror.Unknown("Failed to decode request body: " + err.Error()),
|
JSON: jsonerror.Unknown("Failed to decode request body: " + err.Error()),
|
||||||
|
|
@ -150,8 +150,8 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if resErr := internal.ValidatePassword(request.Password); resErr != nil {
|
if err = internal.ValidatePassword(request.Password); err != nil {
|
||||||
return *resErr
|
return *internal.PasswordResponse(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
updateReq := &userapi.PerformPasswordUpdateRequest{
|
updateReq := &userapi.PerformPasswordUpdateRequest{
|
||||||
|
|
|
||||||
|
|
@ -15,11 +15,11 @@
|
||||||
package routing
|
package routing
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"html/template"
|
"html/template"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
)
|
)
|
||||||
|
|
@ -101,14 +101,28 @@ func serveTemplate(w http.ResponseWriter, templateHTML string, data map[string]s
|
||||||
func AuthFallback(
|
func AuthFallback(
|
||||||
w http.ResponseWriter, req *http.Request, authType string,
|
w http.ResponseWriter, req *http.Request, authType string,
|
||||||
cfg *config.ClientAPI,
|
cfg *config.ClientAPI,
|
||||||
) *util.JSONResponse {
|
) {
|
||||||
sessionID := req.URL.Query().Get("session")
|
// We currently only support "m.login.recaptcha", so fail early if that's not requested
|
||||||
|
if authType == authtypes.LoginTypeRecaptcha {
|
||||||
|
if !cfg.RecaptchaEnabled {
|
||||||
|
writeHTTPMessage(w, req,
|
||||||
|
"Recaptcha login is disabled on this Homeserver",
|
||||||
|
http.StatusBadRequest,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
writeHTTPMessage(w, req, fmt.Sprintf("Unknown authtype %q", authType), http.StatusNotImplemented)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionID := req.URL.Query().Get("session")
|
||||||
if sessionID == "" {
|
if sessionID == "" {
|
||||||
return writeHTTPMessage(w, req,
|
writeHTTPMessage(w, req,
|
||||||
"Session ID not provided",
|
"Session ID not provided",
|
||||||
http.StatusBadRequest,
|
http.StatusBadRequest,
|
||||||
)
|
)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
serveRecaptcha := func() {
|
serveRecaptcha := func() {
|
||||||
|
|
@ -130,70 +144,44 @@ func AuthFallback(
|
||||||
|
|
||||||
if req.Method == http.MethodGet {
|
if req.Method == http.MethodGet {
|
||||||
// Handle Recaptcha
|
// Handle Recaptcha
|
||||||
if authType == authtypes.LoginTypeRecaptcha {
|
serveRecaptcha()
|
||||||
if err := checkRecaptchaEnabled(cfg, w, req); err != nil {
|
return
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
serveRecaptcha()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return &util.JSONResponse{
|
|
||||||
Code: http.StatusNotFound,
|
|
||||||
JSON: jsonerror.NotFound("Unknown auth stage type"),
|
|
||||||
}
|
|
||||||
} else if req.Method == http.MethodPost {
|
} else if req.Method == http.MethodPost {
|
||||||
// Handle Recaptcha
|
// Handle Recaptcha
|
||||||
if authType == authtypes.LoginTypeRecaptcha {
|
clientIP := req.RemoteAddr
|
||||||
if err := checkRecaptchaEnabled(cfg, w, req); err != nil {
|
err := req.ParseForm()
|
||||||
return err
|
if err != nil {
|
||||||
}
|
util.GetLogger(req.Context()).WithError(err).Error("req.ParseForm failed")
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
clientIP := req.RemoteAddr
|
serveRecaptcha()
|
||||||
err := req.ParseForm()
|
return
|
||||||
if err != nil {
|
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("req.ParseForm failed")
|
|
||||||
res := jsonerror.InternalServerError()
|
|
||||||
return &res
|
|
||||||
}
|
|
||||||
|
|
||||||
response := req.Form.Get(cfg.RecaptchaFormField)
|
|
||||||
if err := validateRecaptcha(cfg, response, clientIP); err != nil {
|
|
||||||
util.GetLogger(req.Context()).Error(err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Success. Add recaptcha as a completed login flow
|
|
||||||
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha)
|
|
||||||
|
|
||||||
serveSuccess()
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &util.JSONResponse{
|
response := req.Form.Get(cfg.RecaptchaFormField)
|
||||||
Code: http.StatusNotFound,
|
err = validateRecaptcha(cfg, response, clientIP)
|
||||||
JSON: jsonerror.NotFound("Unknown auth stage type"),
|
switch err {
|
||||||
|
case ErrMissingResponse:
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
serveRecaptcha() // serve the initial page again, instead of nothing
|
||||||
|
return
|
||||||
|
case ErrInvalidCaptcha:
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
serveRecaptcha()
|
||||||
|
return
|
||||||
|
case nil:
|
||||||
|
default: // something else failed
|
||||||
|
util.GetLogger(req.Context()).WithError(err).Error("failed to validate recaptcha")
|
||||||
|
serveRecaptcha()
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return &util.JSONResponse{
|
|
||||||
Code: http.StatusMethodNotAllowed,
|
|
||||||
JSON: jsonerror.NotFound("Bad method"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkRecaptchaEnabled creates an error response if recaptcha is not usable on homeserver.
|
// Success. Add recaptcha as a completed login flow
|
||||||
func checkRecaptchaEnabled(
|
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha)
|
||||||
cfg *config.ClientAPI,
|
|
||||||
w http.ResponseWriter,
|
serveSuccess()
|
||||||
req *http.Request,
|
return
|
||||||
) *util.JSONResponse {
|
|
||||||
if !cfg.RecaptchaEnabled {
|
|
||||||
return writeHTTPMessage(w, req,
|
|
||||||
"Recaptcha login is disabled on this Homeserver",
|
|
||||||
http.StatusBadRequest,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
return nil
|
writeHTTPMessage(w, req, "Bad method", http.StatusMethodNotAllowed)
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeHTTPMessage writes the given header and message to the HTTP response writer.
|
// writeHTTPMessage writes the given header and message to the HTTP response writer.
|
||||||
|
|
@ -201,13 +189,10 @@ func checkRecaptchaEnabled(
|
||||||
func writeHTTPMessage(
|
func writeHTTPMessage(
|
||||||
w http.ResponseWriter, req *http.Request,
|
w http.ResponseWriter, req *http.Request,
|
||||||
message string, header int,
|
message string, header int,
|
||||||
) *util.JSONResponse {
|
) {
|
||||||
w.WriteHeader(header)
|
w.WriteHeader(header)
|
||||||
_, err := w.Write([]byte(message))
|
_, err := w.Write([]byte(message))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("w.Write failed")
|
util.GetLogger(req.Context()).WithError(err).Error("w.Write failed")
|
||||||
res := jsonerror.InternalServerError()
|
|
||||||
return &res
|
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
149
clientapi/routing/auth_fallback_test.go
Normal file
149
clientapi/routing/auth_fallback_test.go
Normal file
|
|
@ -0,0 +1,149 @@
|
||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/test/testrig"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_AuthFallback(t *testing.T) {
|
||||||
|
base, _, _ := testrig.Base(nil)
|
||||||
|
defer base.Close()
|
||||||
|
|
||||||
|
for _, useHCaptcha := range []bool{false, true} {
|
||||||
|
for _, recaptchaEnabled := range []bool{false, true} {
|
||||||
|
for _, wantErr := range []bool{false, true} {
|
||||||
|
t.Run(fmt.Sprintf("useHCaptcha(%v) - recaptchaEnabled(%v) - wantErr(%v)", useHCaptcha, recaptchaEnabled, wantErr), func(t *testing.T) {
|
||||||
|
// Set the defaults for each test
|
||||||
|
base.Cfg.ClientAPI.Defaults(config.DefaultOpts{Generate: true, Monolithic: true})
|
||||||
|
base.Cfg.ClientAPI.RecaptchaEnabled = recaptchaEnabled
|
||||||
|
base.Cfg.ClientAPI.RecaptchaPublicKey = "pub"
|
||||||
|
base.Cfg.ClientAPI.RecaptchaPrivateKey = "priv"
|
||||||
|
if useHCaptcha {
|
||||||
|
base.Cfg.ClientAPI.RecaptchaSiteVerifyAPI = "https://hcaptcha.com/siteverify"
|
||||||
|
base.Cfg.ClientAPI.RecaptchaApiJsUrl = "https://js.hcaptcha.com/1/api.js"
|
||||||
|
base.Cfg.ClientAPI.RecaptchaFormField = "h-captcha-response"
|
||||||
|
base.Cfg.ClientAPI.RecaptchaSitekeyClass = "h-captcha"
|
||||||
|
}
|
||||||
|
cfgErrs := &config.ConfigErrors{}
|
||||||
|
base.Cfg.ClientAPI.Verify(cfgErrs, true)
|
||||||
|
if len(*cfgErrs) > 0 {
|
||||||
|
t.Fatalf("(hCaptcha=%v) unexpected config errors: %s", useHCaptcha, cfgErrs.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/?session=1337", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI)
|
||||||
|
if !recaptchaEnabled {
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("unexpected response code: %d, want %d", rec.Code, http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
if rec.Body.String() != "Recaptcha login is disabled on this Homeserver" {
|
||||||
|
t.Fatalf("unexpected response body: %s", rec.Body.String())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if !strings.Contains(rec.Body.String(), base.Cfg.ClientAPI.RecaptchaSitekeyClass) {
|
||||||
|
t.Fatalf("body does not contain %s: %s", base.Cfg.ClientAPI.RecaptchaSitekeyClass, rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if wantErr {
|
||||||
|
_, _ = w.Write([]byte(`{"success":false}`))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, _ = w.Write([]byte(`{"success":true}`))
|
||||||
|
}))
|
||||||
|
defer srv.Close() // nolint: errcheck
|
||||||
|
|
||||||
|
base.Cfg.ClientAPI.RecaptchaSiteVerifyAPI = srv.URL
|
||||||
|
|
||||||
|
// check the result after sending the captcha
|
||||||
|
req = httptest.NewRequest(http.MethodPost, "/?session=1337", nil)
|
||||||
|
req.Form = url.Values{}
|
||||||
|
req.Form.Add(base.Cfg.ClientAPI.RecaptchaFormField, "someRandomValue")
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI)
|
||||||
|
if recaptchaEnabled {
|
||||||
|
if !wantErr {
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Fatalf("unexpected response code: %d, want %d", rec.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
if rec.Body.String() != successTemplate {
|
||||||
|
t.Fatalf("unexpected response: %s, want %s", rec.Body.String(), successTemplate)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if rec.Code != http.StatusUnauthorized {
|
||||||
|
t.Fatalf("unexpected response code: %d, want %d", rec.Code, http.StatusUnauthorized)
|
||||||
|
}
|
||||||
|
wantString := "Authentication"
|
||||||
|
if !strings.Contains(rec.Body.String(), wantString) {
|
||||||
|
t.Fatalf("expected response to contain '%s', but didn't: %s", wantString, rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("unexpected response code: %d, want %d", rec.Code, http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
if rec.Body.String() != "Recaptcha login is disabled on this Homeserver" {
|
||||||
|
t.Fatalf("unexpected response: %s, want %s", rec.Body.String(), "successTemplate")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("unknown fallbacks are handled correctly", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/?session=1337", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
AuthFallback(rec, req, "DoesNotExist", &base.Cfg.ClientAPI)
|
||||||
|
if rec.Code != http.StatusNotImplemented {
|
||||||
|
t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusNotImplemented)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unknown methods are handled correctly", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodDelete, "/?session=1337", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI)
|
||||||
|
if rec.Code != http.StatusMethodNotAllowed {
|
||||||
|
t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusMethodNotAllowed)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing session parameter is handled correctly", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI)
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing session parameter is handled correctly", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI)
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing 'response' is handled correctly", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/?session=1337", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI)
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
@ -37,6 +37,7 @@ func JoinRoomByIDOrAlias(
|
||||||
joinReq := roomserverAPI.PerformJoinRequest{
|
joinReq := roomserverAPI.PerformJoinRequest{
|
||||||
RoomIDOrAlias: roomIDOrAlias,
|
RoomIDOrAlias: roomIDOrAlias,
|
||||||
UserID: device.UserID,
|
UserID: device.UserID,
|
||||||
|
IsGuest: device.AccountType == api.AccountTypeGuest,
|
||||||
Content: map[string]interface{}{},
|
Content: map[string]interface{}{},
|
||||||
}
|
}
|
||||||
joinRes := roomserverAPI.PerformJoinResponse{}
|
joinRes := roomserverAPI.PerformJoinResponse{}
|
||||||
|
|
@ -84,7 +85,14 @@ func JoinRoomByIDOrAlias(
|
||||||
if err := rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes); err != nil {
|
if err := rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes); err != nil {
|
||||||
done <- jsonerror.InternalAPIError(req.Context(), err)
|
done <- jsonerror.InternalAPIError(req.Context(), err)
|
||||||
} else if joinRes.Error != nil {
|
} else if joinRes.Error != nil {
|
||||||
done <- joinRes.Error.JSONResponse()
|
if joinRes.Error.Code == roomserverAPI.PerformErrorNotAllowed && device.AccountType == api.AccountTypeGuest {
|
||||||
|
done <- util.JSONResponse{
|
||||||
|
Code: http.StatusForbidden,
|
||||||
|
JSON: jsonerror.GuestAccessForbidden(joinRes.Error.Msg),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
done <- joinRes.Error.JSONResponse()
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
done <- util.JSONResponse{
|
done <- util.JSONResponse{
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
|
|
|
||||||
158
clientapi/routing/joinroom_test.go
Normal file
158
clientapi/routing/joinroom_test.go
Normal file
|
|
@ -0,0 +1,158 @@
|
||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/appservice"
|
||||||
|
"github.com/matrix-org/dendrite/keyserver"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver"
|
||||||
|
"github.com/matrix-org/dendrite/test"
|
||||||
|
"github.com/matrix-org/dendrite/test/testrig"
|
||||||
|
"github.com/matrix-org/dendrite/userapi"
|
||||||
|
uapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestJoinRoomByIDOrAlias(t *testing.T) {
|
||||||
|
alice := test.NewUser(t)
|
||||||
|
bob := test.NewUser(t)
|
||||||
|
charlie := test.NewUser(t, test.WithAccountType(uapi.AccountTypeGuest))
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
base, baseClose := testrig.CreateBaseDendrite(t, dbType)
|
||||||
|
defer baseClose()
|
||||||
|
|
||||||
|
rsAPI := roomserver.NewInternalAPI(base)
|
||||||
|
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI)
|
||||||
|
userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil)
|
||||||
|
asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI)
|
||||||
|
rsAPI.SetFederationAPI(nil, nil) // creates the rs.Inputer etc
|
||||||
|
|
||||||
|
// Create the users in the userapi
|
||||||
|
for _, u := range []*test.User{alice, bob, charlie} {
|
||||||
|
localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID)
|
||||||
|
userRes := &uapi.PerformAccountCreationResponse{}
|
||||||
|
if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
|
||||||
|
AccountType: u.AccountType,
|
||||||
|
Localpart: localpart,
|
||||||
|
ServerName: serverName,
|
||||||
|
Password: "someRandomPassword",
|
||||||
|
}, userRes); err != nil {
|
||||||
|
t.Errorf("failed to create account: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
aliceDev := &uapi.Device{UserID: alice.ID}
|
||||||
|
bobDev := &uapi.Device{UserID: bob.ID}
|
||||||
|
charlieDev := &uapi.Device{UserID: charlie.ID, AccountType: uapi.AccountTypeGuest}
|
||||||
|
|
||||||
|
// create a room with disabled guest access and invite Bob
|
||||||
|
resp := createRoom(ctx, createRoomRequest{
|
||||||
|
Name: "testing",
|
||||||
|
IsDirect: true,
|
||||||
|
Topic: "testing",
|
||||||
|
Visibility: "public",
|
||||||
|
Preset: presetPublicChat,
|
||||||
|
RoomAliasName: "alias",
|
||||||
|
Invite: []string{bob.ID},
|
||||||
|
GuestCanJoin: false,
|
||||||
|
}, aliceDev, &base.Cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now())
|
||||||
|
crResp, ok := resp.JSON.(createRoomResponse)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("response is not a createRoomResponse: %+v", resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a room with guest access enabled and invite Charlie
|
||||||
|
resp = createRoom(ctx, createRoomRequest{
|
||||||
|
Name: "testing",
|
||||||
|
IsDirect: true,
|
||||||
|
Topic: "testing",
|
||||||
|
Visibility: "public",
|
||||||
|
Preset: presetPublicChat,
|
||||||
|
Invite: []string{charlie.ID},
|
||||||
|
GuestCanJoin: true,
|
||||||
|
}, aliceDev, &base.Cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now())
|
||||||
|
crRespWithGuestAccess, ok := resp.JSON.(createRoomResponse)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("response is not a createRoomResponse: %+v", resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dummy request
|
||||||
|
body := &bytes.Buffer{}
|
||||||
|
req, err := http.NewRequest(http.MethodPost, "/?server_name=test", body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
device *uapi.Device
|
||||||
|
roomID string
|
||||||
|
wantHTTP200 bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "User can join successfully by alias",
|
||||||
|
device: bobDev,
|
||||||
|
roomID: crResp.RoomAlias,
|
||||||
|
wantHTTP200: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "User can join successfully by roomID",
|
||||||
|
device: bobDev,
|
||||||
|
roomID: crResp.RoomID,
|
||||||
|
wantHTTP200: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "join is forbidden if user is guest",
|
||||||
|
device: charlieDev,
|
||||||
|
roomID: crResp.RoomID,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "room does not exist",
|
||||||
|
device: aliceDev,
|
||||||
|
roomID: "!doesnotexist:test",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "user from different server",
|
||||||
|
device: &uapi.Device{UserID: "@wrong:server"},
|
||||||
|
roomID: crResp.RoomAlias,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "user doesn't exist locally",
|
||||||
|
device: &uapi.Device{UserID: "@doesnotexist:test"},
|
||||||
|
roomID: crResp.RoomAlias,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid room ID",
|
||||||
|
device: aliceDev,
|
||||||
|
roomID: "invalidRoomID",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "roomAlias does not exist",
|
||||||
|
device: aliceDev,
|
||||||
|
roomID: "#doesnotexist:test",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "room with guest_access event",
|
||||||
|
device: charlieDev,
|
||||||
|
roomID: crRespWithGuestAccess.RoomID,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
joinResp := JoinRoomByIDOrAlias(req, tc.device, rsAPI, userAPI, tc.roomID)
|
||||||
|
if tc.wantHTTP200 && !joinResp.Is2xx() {
|
||||||
|
t.Fatalf("expected join room to succeed, but didn't: %+v", joinResp)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
@ -23,15 +23,13 @@ import (
|
||||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
type loginResponse struct {
|
type loginResponse struct {
|
||||||
UserID string `json:"user_id"`
|
UserID string `json:"user_id"`
|
||||||
AccessToken string `json:"access_token"`
|
AccessToken string `json:"access_token"`
|
||||||
HomeServer gomatrixserverlib.ServerName `json:"home_server"`
|
DeviceID string `json:"device_id"`
|
||||||
DeviceID string `json:"device_id"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type flows struct {
|
type flows struct {
|
||||||
|
|
@ -116,7 +114,6 @@ func completeAuth(
|
||||||
JSON: loginResponse{
|
JSON: loginResponse{
|
||||||
UserID: performRes.Device.UserID,
|
UserID: performRes.Device.UserID,
|
||||||
AccessToken: performRes.Device.AccessToken,
|
AccessToken: performRes.Device.AccessToken,
|
||||||
HomeServer: serverName,
|
|
||||||
DeviceID: performRes.Device.ID,
|
DeviceID: performRes.Device.ID,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -82,8 +82,8 @@ func Password(
|
||||||
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
|
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
|
||||||
|
|
||||||
// Check the new password strength.
|
// Check the new password strength.
|
||||||
if resErr = internal.ValidatePassword(r.NewPassword); resErr != nil {
|
if err := internal.ValidatePassword(r.NewPassword); err != nil {
|
||||||
return *resErr
|
return *internal.PasswordResponse(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the local part.
|
// Get the local part.
|
||||||
|
|
|
||||||
|
|
@ -18,12 +18,12 @@ package routing
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"regexp"
|
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
@ -60,10 +60,7 @@ var (
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const sessionIDLength = 24
|
||||||
maxUsernameLength = 254 // http://matrix.org/speculator/spec/HEAD/intro.html#user-identifiers TODO account for domain
|
|
||||||
sessionIDLength = 24
|
|
||||||
)
|
|
||||||
|
|
||||||
// sessionsDict keeps track of completed auth stages for each session.
|
// sessionsDict keeps track of completed auth stages for each session.
|
||||||
// It shouldn't be passed by value because it contains a mutex.
|
// It shouldn't be passed by value because it contains a mutex.
|
||||||
|
|
@ -198,8 +195,7 @@ func (d *sessionsDict) getDeviceToDelete(sessionID string) (string, bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
sessions = newSessionsDict()
|
sessions = newSessionsDict()
|
||||||
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// registerRequest represents the submitted registration request.
|
// registerRequest represents the submitted registration request.
|
||||||
|
|
@ -262,10 +258,9 @@ func newUserInteractiveResponse(
|
||||||
|
|
||||||
// http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#post-matrix-client-unstable-register
|
// http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#post-matrix-client-unstable-register
|
||||||
type registerResponse struct {
|
type registerResponse struct {
|
||||||
UserID string `json:"user_id"`
|
UserID string `json:"user_id"`
|
||||||
AccessToken string `json:"access_token,omitempty"`
|
AccessToken string `json:"access_token,omitempty"`
|
||||||
HomeServer gomatrixserverlib.ServerName `json:"home_server"`
|
DeviceID string `json:"device_id,omitempty"`
|
||||||
DeviceID string `json:"device_id,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// recaptchaResponse represents the HTTP response from a Google Recaptcha server
|
// recaptchaResponse represents the HTTP response from a Google Recaptcha server
|
||||||
|
|
@ -276,66 +271,28 @@ type recaptchaResponse struct {
|
||||||
ErrorCodes []int `json:"error-codes"`
|
ErrorCodes []int `json:"error-codes"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateUsername returns an error response if the username is invalid
|
var (
|
||||||
func validateUsername(localpart string, domain gomatrixserverlib.ServerName) *util.JSONResponse {
|
ErrInvalidCaptcha = errors.New("invalid captcha response")
|
||||||
// https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
|
ErrMissingResponse = errors.New("captcha response is required")
|
||||||
if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength {
|
ErrCaptchaDisabled = errors.New("captcha registration is disabled")
|
||||||
return &util.JSONResponse{
|
)
|
||||||
Code: http.StatusBadRequest,
|
|
||||||
JSON: jsonerror.BadJSON(fmt.Sprintf("%q exceeds the maximum length of %d characters", id, maxUsernameLength)),
|
|
||||||
}
|
|
||||||
} else if !validUsernameRegex.MatchString(localpart) {
|
|
||||||
return &util.JSONResponse{
|
|
||||||
Code: http.StatusBadRequest,
|
|
||||||
JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"),
|
|
||||||
}
|
|
||||||
} else if localpart[0] == '_' { // Regex checks its not a zero length string
|
|
||||||
return &util.JSONResponse{
|
|
||||||
Code: http.StatusBadRequest,
|
|
||||||
JSON: jsonerror.InvalidUsername("Username cannot start with a '_'"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateApplicationServiceUsername returns an error response if the username is invalid for an application service
|
|
||||||
func validateApplicationServiceUsername(localpart string, domain gomatrixserverlib.ServerName) *util.JSONResponse {
|
|
||||||
if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength {
|
|
||||||
return &util.JSONResponse{
|
|
||||||
Code: http.StatusBadRequest,
|
|
||||||
JSON: jsonerror.BadJSON(fmt.Sprintf("%q exceeds the maximum length of %d characters", id, maxUsernameLength)),
|
|
||||||
}
|
|
||||||
} else if !validUsernameRegex.MatchString(localpart) {
|
|
||||||
return &util.JSONResponse{
|
|
||||||
Code: http.StatusBadRequest,
|
|
||||||
JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateRecaptcha returns an error response if the captcha response is invalid
|
// validateRecaptcha returns an error response if the captcha response is invalid
|
||||||
func validateRecaptcha(
|
func validateRecaptcha(
|
||||||
cfg *config.ClientAPI,
|
cfg *config.ClientAPI,
|
||||||
response string,
|
response string,
|
||||||
clientip string,
|
clientip string,
|
||||||
) *util.JSONResponse {
|
) error {
|
||||||
ip, _, _ := net.SplitHostPort(clientip)
|
ip, _, _ := net.SplitHostPort(clientip)
|
||||||
if !cfg.RecaptchaEnabled {
|
if !cfg.RecaptchaEnabled {
|
||||||
return &util.JSONResponse{
|
return ErrCaptchaDisabled
|
||||||
Code: http.StatusConflict,
|
|
||||||
JSON: jsonerror.Unknown("Captcha registration is disabled"),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if response == "" {
|
if response == "" {
|
||||||
return &util.JSONResponse{
|
return ErrMissingResponse
|
||||||
Code: http.StatusBadRequest,
|
|
||||||
JSON: jsonerror.BadJSON("Captcha response is required"),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make a POST request to Google's API to check the captcha response
|
// Make a POST request to the captcha provider API to check the captcha response
|
||||||
resp, err := http.PostForm(cfg.RecaptchaSiteVerifyAPI,
|
resp, err := http.PostForm(cfg.RecaptchaSiteVerifyAPI,
|
||||||
url.Values{
|
url.Values{
|
||||||
"secret": {cfg.RecaptchaPrivateKey},
|
"secret": {cfg.RecaptchaPrivateKey},
|
||||||
|
|
@ -345,10 +302,7 @@ func validateRecaptcha(
|
||||||
)
|
)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &util.JSONResponse{
|
return err
|
||||||
Code: http.StatusInternalServerError,
|
|
||||||
JSON: jsonerror.BadJSON("Error in requesting validation of captcha response"),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close the request once we're finishing reading from it
|
// Close the request once we're finishing reading from it
|
||||||
|
|
@ -358,25 +312,16 @@ func validateRecaptcha(
|
||||||
var r recaptchaResponse
|
var r recaptchaResponse
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &util.JSONResponse{
|
return err
|
||||||
Code: http.StatusGatewayTimeout,
|
|
||||||
JSON: jsonerror.Unknown("Error in contacting captcha server" + err.Error()),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
err = json.Unmarshal(body, &r)
|
err = json.Unmarshal(body, &r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &util.JSONResponse{
|
return err
|
||||||
Code: http.StatusInternalServerError,
|
|
||||||
JSON: jsonerror.BadJSON("Error in unmarshaling captcha server's response: " + err.Error()),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check that we received a "success"
|
// Check that we received a "success"
|
||||||
if !r.Success {
|
if !r.Success {
|
||||||
return &util.JSONResponse{
|
return ErrInvalidCaptcha
|
||||||
Code: http.StatusUnauthorized,
|
|
||||||
JSON: jsonerror.BadJSON("Invalid captcha response. Please try again."),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
@ -508,8 +453,8 @@ func validateApplicationService(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check username application service is trying to register is valid
|
// Check username application service is trying to register is valid
|
||||||
if err := validateApplicationServiceUsername(username, cfg.Matrix.ServerName); err != nil {
|
if err := internal.ValidateApplicationServiceUsername(username, cfg.Matrix.ServerName); err != nil {
|
||||||
return "", err
|
return "", internal.UsernameResponse(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// No errors, registration valid
|
// No errors, registration valid
|
||||||
|
|
@ -564,15 +509,12 @@ func Register(
|
||||||
if resErr := httputil.UnmarshalJSON(reqBody, &r); resErr != nil {
|
if resErr := httputil.UnmarshalJSON(reqBody, &r); resErr != nil {
|
||||||
return *resErr
|
return *resErr
|
||||||
}
|
}
|
||||||
if l, d, err := cfg.Matrix.SplitLocalID('@', r.Username); err == nil {
|
|
||||||
r.Username, r.ServerName = l, d
|
|
||||||
}
|
|
||||||
if req.URL.Query().Get("kind") == "guest" {
|
if req.URL.Query().Get("kind") == "guest" {
|
||||||
return handleGuestRegistration(req, r, cfg, userAPI)
|
return handleGuestRegistration(req, r, cfg, userAPI)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Don't allow numeric usernames less than MAX_INT64.
|
// Don't allow numeric usernames less than MAX_INT64.
|
||||||
if _, err := strconv.ParseInt(r.Username, 10, 64); err == nil {
|
if _, err = strconv.ParseInt(r.Username, 10, 64); err == nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusBadRequest,
|
Code: http.StatusBadRequest,
|
||||||
JSON: jsonerror.InvalidUsername("Numeric user IDs are reserved"),
|
JSON: jsonerror.InvalidUsername("Numeric user IDs are reserved"),
|
||||||
|
|
@ -584,7 +526,7 @@ func Register(
|
||||||
ServerName: r.ServerName,
|
ServerName: r.ServerName,
|
||||||
}
|
}
|
||||||
nres := &userapi.QueryNumericLocalpartResponse{}
|
nres := &userapi.QueryNumericLocalpartResponse{}
|
||||||
if err := userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil {
|
if err = userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed")
|
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
@ -601,8 +543,8 @@ func Register(
|
||||||
case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil:
|
case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil:
|
||||||
// Spec-compliant case (the access_token is specified and the login type
|
// Spec-compliant case (the access_token is specified and the login type
|
||||||
// is correctly set, so it's an appservice registration)
|
// is correctly set, so it's an appservice registration)
|
||||||
if resErr := validateApplicationServiceUsername(r.Username, r.ServerName); resErr != nil {
|
if err = internal.ValidateApplicationServiceUsername(r.Username, r.ServerName); err != nil {
|
||||||
return *resErr
|
return *internal.UsernameResponse(err)
|
||||||
}
|
}
|
||||||
case accessTokenErr == nil:
|
case accessTokenErr == nil:
|
||||||
// Non-spec-compliant case (the access_token is specified but the login
|
// Non-spec-compliant case (the access_token is specified but the login
|
||||||
|
|
@ -614,12 +556,12 @@ func Register(
|
||||||
default:
|
default:
|
||||||
// Spec-compliant case (neither the access_token nor the login type are
|
// Spec-compliant case (neither the access_token nor the login type are
|
||||||
// specified, so it's a normal user registration)
|
// specified, so it's a normal user registration)
|
||||||
if resErr := validateUsername(r.Username, r.ServerName); resErr != nil {
|
if err = internal.ValidateUsername(r.Username, r.ServerName); err != nil {
|
||||||
return *resErr
|
return *internal.UsernameResponse(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if resErr := internal.ValidatePassword(r.Password); resErr != nil {
|
if err = internal.ValidatePassword(r.Password); err != nil {
|
||||||
return *resErr
|
return *internal.PasswordResponse(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
logger := util.GetLogger(req.Context())
|
logger := util.GetLogger(req.Context())
|
||||||
|
|
@ -697,7 +639,6 @@ func handleGuestRegistration(
|
||||||
JSON: registerResponse{
|
JSON: registerResponse{
|
||||||
UserID: devRes.Device.UserID,
|
UserID: devRes.Device.UserID,
|
||||||
AccessToken: devRes.Device.AccessToken,
|
AccessToken: devRes.Device.AccessToken,
|
||||||
HomeServer: res.Account.ServerName,
|
|
||||||
DeviceID: devRes.Device.ID,
|
DeviceID: devRes.Device.ID,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
@ -761,9 +702,18 @@ func handleRegistrationFlow(
|
||||||
switch r.Auth.Type {
|
switch r.Auth.Type {
|
||||||
case authtypes.LoginTypeRecaptcha:
|
case authtypes.LoginTypeRecaptcha:
|
||||||
// Check given captcha response
|
// Check given captcha response
|
||||||
resErr := validateRecaptcha(cfg, r.Auth.Response, req.RemoteAddr)
|
err := validateRecaptcha(cfg, r.Auth.Response, req.RemoteAddr)
|
||||||
if resErr != nil {
|
switch err {
|
||||||
return *resErr
|
case ErrCaptchaDisabled:
|
||||||
|
return util.JSONResponse{Code: http.StatusForbidden, JSON: jsonerror.Unknown(err.Error())}
|
||||||
|
case ErrMissingResponse:
|
||||||
|
return util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(err.Error())}
|
||||||
|
case ErrInvalidCaptcha:
|
||||||
|
return util.JSONResponse{Code: http.StatusUnauthorized, JSON: jsonerror.BadJSON(err.Error())}
|
||||||
|
case nil:
|
||||||
|
default:
|
||||||
|
util.GetLogger(req.Context()).WithError(err).Error("failed to validate recaptcha")
|
||||||
|
return util.JSONResponse{Code: http.StatusInternalServerError, JSON: jsonerror.InternalServerError()}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add Recaptcha to the list of completed registration stages
|
// Add Recaptcha to the list of completed registration stages
|
||||||
|
|
@ -924,8 +874,7 @@ func completeRegistration(
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
JSON: registerResponse{
|
JSON: registerResponse{
|
||||||
UserID: userutil.MakeUserID(username, accRes.Account.ServerName),
|
UserID: userutil.MakeUserID(username, accRes.Account.ServerName),
|
||||||
HomeServer: accRes.Account.ServerName,
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -958,7 +907,6 @@ func completeRegistration(
|
||||||
result := registerResponse{
|
result := registerResponse{
|
||||||
UserID: devRes.Device.UserID,
|
UserID: devRes.Device.UserID,
|
||||||
AccessToken: devRes.Device.AccessToken,
|
AccessToken: devRes.Device.AccessToken,
|
||||||
HomeServer: accRes.Account.ServerName,
|
|
||||||
DeviceID: devRes.Device.ID,
|
DeviceID: devRes.Device.ID,
|
||||||
}
|
}
|
||||||
sessions.addCompletedRegistration(sessionID, result)
|
sessions.addCompletedRegistration(sessionID, result)
|
||||||
|
|
@ -1054,8 +1002,8 @@ func RegisterAvailable(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := validateUsername(username, domain); err != nil {
|
if err := internal.ValidateUsername(username, domain); err != nil {
|
||||||
return *err
|
return *internal.UsernameResponse(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if this username is reserved by an application service
|
// Check if this username is reserved by an application service
|
||||||
|
|
@ -1117,11 +1065,11 @@ func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.Clien
|
||||||
// downcase capitals
|
// downcase capitals
|
||||||
ssrr.User = strings.ToLower(ssrr.User)
|
ssrr.User = strings.ToLower(ssrr.User)
|
||||||
|
|
||||||
if resErr := validateUsername(ssrr.User, cfg.Matrix.ServerName); resErr != nil {
|
if err = internal.ValidateUsername(ssrr.User, cfg.Matrix.ServerName); err != nil {
|
||||||
return *resErr
|
return *internal.UsernameResponse(err)
|
||||||
}
|
}
|
||||||
if resErr := internal.ValidatePassword(ssrr.Password); resErr != nil {
|
if err = internal.ValidatePassword(ssrr.Password); err != nil {
|
||||||
return *resErr
|
return *internal.PasswordResponse(err)
|
||||||
}
|
}
|
||||||
deviceID := "shared_secret_registration"
|
deviceID := "shared_secret_registration"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,12 +15,27 @@
|
||||||
package routing
|
package routing
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/keyserver"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/test"
|
||||||
|
"github.com/matrix-org/dendrite/test/testrig"
|
||||||
|
"github.com/matrix-org/dendrite/userapi"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
@ -264,3 +279,294 @@ func TestSessionCleanUp(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_register(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
kind string
|
||||||
|
password string
|
||||||
|
username string
|
||||||
|
loginType string
|
||||||
|
forceEmpty bool
|
||||||
|
registrationDisabled bool
|
||||||
|
guestsDisabled bool
|
||||||
|
enableRecaptcha bool
|
||||||
|
captchaBody string
|
||||||
|
wantResponse util.JSONResponse
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "disallow guests",
|
||||||
|
kind: "guest",
|
||||||
|
guestsDisabled: true,
|
||||||
|
wantResponse: util.JSONResponse{
|
||||||
|
Code: http.StatusForbidden,
|
||||||
|
JSON: jsonerror.Forbidden(`Guest registration is disabled on "test"`),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allow guests",
|
||||||
|
kind: "guest",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown login type",
|
||||||
|
loginType: "im.not.known",
|
||||||
|
wantResponse: util.JSONResponse{
|
||||||
|
Code: http.StatusNotImplemented,
|
||||||
|
JSON: jsonerror.Unknown("unknown/unimplemented auth type"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "disabled registration",
|
||||||
|
registrationDisabled: true,
|
||||||
|
wantResponse: util.JSONResponse{
|
||||||
|
Code: http.StatusForbidden,
|
||||||
|
JSON: jsonerror.Forbidden(`Registration is disabled on "test"`),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "successful registration, numeric ID",
|
||||||
|
username: "",
|
||||||
|
password: "someRandomPassword",
|
||||||
|
forceEmpty: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "successful registration",
|
||||||
|
username: "success",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "failing registration - user already exists",
|
||||||
|
username: "success",
|
||||||
|
wantResponse: util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.UserInUse("Desired user ID is already taken."),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "successful registration uppercase username",
|
||||||
|
username: "LOWERCASED", // this is going to be lower-cased
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid username",
|
||||||
|
username: "#totalyNotValid",
|
||||||
|
wantResponse: *internal.UsernameResponse(internal.ErrUsernameInvalid),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "numeric username is forbidden",
|
||||||
|
username: "1337",
|
||||||
|
wantResponse: util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.InvalidUsername("Numeric user IDs are reserved"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "disabled recaptcha login",
|
||||||
|
loginType: authtypes.LoginTypeRecaptcha,
|
||||||
|
wantResponse: util.JSONResponse{
|
||||||
|
Code: http.StatusForbidden,
|
||||||
|
JSON: jsonerror.Unknown(ErrCaptchaDisabled.Error()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "enabled recaptcha, no response defined",
|
||||||
|
enableRecaptcha: true,
|
||||||
|
loginType: authtypes.LoginTypeRecaptcha,
|
||||||
|
wantResponse: util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.BadJSON(ErrMissingResponse.Error()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid captcha response",
|
||||||
|
enableRecaptcha: true,
|
||||||
|
loginType: authtypes.LoginTypeRecaptcha,
|
||||||
|
captchaBody: `notvalid`,
|
||||||
|
wantResponse: util.JSONResponse{
|
||||||
|
Code: http.StatusUnauthorized,
|
||||||
|
JSON: jsonerror.BadJSON(ErrInvalidCaptcha.Error()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid captcha response",
|
||||||
|
enableRecaptcha: true,
|
||||||
|
loginType: authtypes.LoginTypeRecaptcha,
|
||||||
|
captchaBody: `success`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "captcha invalid from remote",
|
||||||
|
enableRecaptcha: true,
|
||||||
|
loginType: authtypes.LoginTypeRecaptcha,
|
||||||
|
captchaBody: `i should fail for other reasons`,
|
||||||
|
wantResponse: util.JSONResponse{Code: http.StatusInternalServerError, JSON: jsonerror.InternalServerError()},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
base, baseClose := testrig.CreateBaseDendrite(t, dbType)
|
||||||
|
defer baseClose()
|
||||||
|
|
||||||
|
rsAPI := roomserver.NewInternalAPI(base)
|
||||||
|
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI)
|
||||||
|
userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil)
|
||||||
|
keyAPI.SetUserAPI(userAPI)
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
if tc.enableRecaptcha {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
response := r.Form.Get("response")
|
||||||
|
|
||||||
|
// Respond with valid JSON or no JSON at all to test happy/error cases
|
||||||
|
switch response {
|
||||||
|
case "success":
|
||||||
|
json.NewEncoder(w).Encode(recaptchaResponse{Success: true})
|
||||||
|
case "notvalid":
|
||||||
|
json.NewEncoder(w).Encode(recaptchaResponse{Success: false})
|
||||||
|
default:
|
||||||
|
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
base.Cfg.ClientAPI.RecaptchaSiteVerifyAPI = srv.URL
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := base.Cfg.Derive(); err != nil {
|
||||||
|
t.Fatalf("failed to derive config: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
base.Cfg.ClientAPI.RecaptchaEnabled = tc.enableRecaptcha
|
||||||
|
base.Cfg.ClientAPI.RegistrationDisabled = tc.registrationDisabled
|
||||||
|
base.Cfg.ClientAPI.GuestsDisabled = tc.guestsDisabled
|
||||||
|
|
||||||
|
if tc.kind == "" {
|
||||||
|
tc.kind = "user"
|
||||||
|
}
|
||||||
|
if tc.password == "" && !tc.forceEmpty {
|
||||||
|
tc.password = "someRandomPassword"
|
||||||
|
}
|
||||||
|
if tc.username == "" && !tc.forceEmpty {
|
||||||
|
tc.username = "valid"
|
||||||
|
}
|
||||||
|
if tc.loginType == "" {
|
||||||
|
tc.loginType = "m.login.dummy"
|
||||||
|
}
|
||||||
|
|
||||||
|
reg := registerRequest{
|
||||||
|
Password: tc.password,
|
||||||
|
Username: tc.username,
|
||||||
|
}
|
||||||
|
|
||||||
|
body := &bytes.Buffer{}
|
||||||
|
err := json.NewEncoder(body).Encode(reg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("/?kind=%s", tc.kind), body)
|
||||||
|
|
||||||
|
resp := Register(req, userAPI, &base.Cfg.ClientAPI)
|
||||||
|
t.Logf("Resp: %+v", resp)
|
||||||
|
|
||||||
|
// The first request should return a userInteractiveResponse
|
||||||
|
switch r := resp.JSON.(type) {
|
||||||
|
case userInteractiveResponse:
|
||||||
|
// Check that the flows are the ones we configured
|
||||||
|
if !reflect.DeepEqual(r.Flows, base.Cfg.Derived.Registration.Flows) {
|
||||||
|
t.Fatalf("unexpected registration flows: %+v, want %+v", r.Flows, base.Cfg.Derived.Registration.Flows)
|
||||||
|
}
|
||||||
|
case *jsonerror.MatrixError:
|
||||||
|
if !reflect.DeepEqual(tc.wantResponse, resp) {
|
||||||
|
t.Fatalf("(%s), unexpected response: %+v, want: %+v", tc.name, resp, tc.wantResponse)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case registerResponse:
|
||||||
|
// this should only be possible on guest user registration, never for normal users
|
||||||
|
if tc.kind != "guest" {
|
||||||
|
t.Fatalf("got register response on first request: %+v", r)
|
||||||
|
}
|
||||||
|
// assert we've got a UserID, AccessToken and DeviceID
|
||||||
|
if r.UserID == "" {
|
||||||
|
t.Fatalf("missing userID in response")
|
||||||
|
}
|
||||||
|
if r.AccessToken == "" {
|
||||||
|
t.Fatalf("missing accessToken in response")
|
||||||
|
}
|
||||||
|
if r.DeviceID == "" {
|
||||||
|
t.Fatalf("missing deviceID in response")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
t.Logf("Got response: %T", resp.JSON)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we reached this, we should have received a UIA response
|
||||||
|
uia, ok := resp.JSON.(userInteractiveResponse)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("did not receive a userInteractiveResponse: %T", resp.JSON)
|
||||||
|
}
|
||||||
|
t.Logf("%+v", uia)
|
||||||
|
|
||||||
|
// Register the user
|
||||||
|
reg.Auth = authDict{
|
||||||
|
Type: authtypes.LoginType(tc.loginType),
|
||||||
|
Session: uia.Session,
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.captchaBody != "" {
|
||||||
|
reg.Auth.Response = tc.captchaBody
|
||||||
|
}
|
||||||
|
|
||||||
|
dummy := "dummy"
|
||||||
|
reg.DeviceID = &dummy
|
||||||
|
reg.InitialDisplayName = &dummy
|
||||||
|
reg.Type = authtypes.LoginType(tc.loginType)
|
||||||
|
|
||||||
|
err = json.NewEncoder(body).Encode(reg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req = httptest.NewRequest(http.MethodPost, "/", body)
|
||||||
|
|
||||||
|
resp = Register(req, userAPI, &base.Cfg.ClientAPI)
|
||||||
|
|
||||||
|
switch resp.JSON.(type) {
|
||||||
|
case *jsonerror.MatrixError:
|
||||||
|
if !reflect.DeepEqual(tc.wantResponse, resp) {
|
||||||
|
t.Fatalf("unexpected response: %+v, want: %+v", resp, tc.wantResponse)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case util.JSONResponse:
|
||||||
|
if !reflect.DeepEqual(tc.wantResponse, resp) {
|
||||||
|
t.Fatalf("unexpected response: %+v, want: %+v", resp, tc.wantResponse)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rr, ok := resp.JSON.(registerResponse)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected a registerresponse, got %T", resp.JSON)
|
||||||
|
}
|
||||||
|
|
||||||
|
// validate the response
|
||||||
|
if tc.forceEmpty {
|
||||||
|
// when not supplying a username, one will be generated. Given this _SHOULD_ be
|
||||||
|
// the second user, set the username accordingly
|
||||||
|
reg.Username = "2"
|
||||||
|
}
|
||||||
|
wantUserID := strings.ToLower(fmt.Sprintf("@%s:%s", reg.Username, "test"))
|
||||||
|
if wantUserID != rr.UserID {
|
||||||
|
t.Fatalf("unexpected userID: %s, want %s", rr.UserID, wantUserID)
|
||||||
|
}
|
||||||
|
if rr.DeviceID != *reg.DeviceID {
|
||||||
|
t.Fatalf("unexpected deviceID: %s, want %s", rr.DeviceID, *reg.DeviceID)
|
||||||
|
}
|
||||||
|
if rr.AccessToken == "" {
|
||||||
|
t.Fatalf("missing accessToken in response")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -639,9 +639,9 @@ func Setup(
|
||||||
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
v3mux.Handle("/auth/{authType}/fallback/web",
|
v3mux.Handle("/auth/{authType}/fallback/web",
|
||||||
httputil.MakeHTMLAPI("auth_fallback", base.EnableMetrics, func(w http.ResponseWriter, req *http.Request) *util.JSONResponse {
|
httputil.MakeHTMLAPI("auth_fallback", base.EnableMetrics, func(w http.ResponseWriter, req *http.Request) {
|
||||||
vars := mux.Vars(req)
|
vars := mux.Vars(req)
|
||||||
return AuthFallback(w, req, vars["authType"], cfg)
|
AuthFallback(w, req, vars["authType"], cfg)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,10 +25,10 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"regexp"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
@ -58,15 +58,14 @@ Arguments:
|
||||||
`
|
`
|
||||||
|
|
||||||
var (
|
var (
|
||||||
username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')")
|
username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')")
|
||||||
password = flag.String("password", "", "The password to associate with the account")
|
password = flag.String("password", "", "The password to associate with the account")
|
||||||
pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)")
|
pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)")
|
||||||
pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin")
|
pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin")
|
||||||
isAdmin = flag.Bool("admin", false, "Create an admin account")
|
isAdmin = flag.Bool("admin", false, "Create an admin account")
|
||||||
resetPassword = flag.Bool("reset-password", false, "Deprecated")
|
resetPassword = flag.Bool("reset-password", false, "Deprecated")
|
||||||
serverURL = flag.String("url", "http://localhost:8008", "The URL to connect to.")
|
serverURL = flag.String("url", "http://localhost:8008", "The URL to connect to.")
|
||||||
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`)
|
timeout = flag.Duration("timeout", time.Second*30, "Timeout for the http client when connecting to the server")
|
||||||
timeout = flag.Duration("timeout", time.Second*30, "Timeout for the http client when connecting to the server")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var cl = http.Client{
|
var cl = http.Client{
|
||||||
|
|
@ -95,20 +94,21 @@ func main() {
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !validUsernameRegex.MatchString(*username) {
|
if err := internal.ValidateUsername(*username, cfg.Global.ServerName); err != nil {
|
||||||
logrus.Warn("Username can only contain characters a-z, 0-9, or '_-./='")
|
logrus.WithError(err).Error("Specified username is invalid")
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(fmt.Sprintf("@%s:%s", *username, cfg.Global.ServerName)) > 255 {
|
|
||||||
logrus.Fatalf("Username can not be longer than 255 characters: %s", fmt.Sprintf("@%s:%s", *username, cfg.Global.ServerName))
|
|
||||||
}
|
|
||||||
|
|
||||||
pass, err := getPassword(*password, *pwdFile, *pwdStdin, os.Stdin)
|
pass, err := getPassword(*password, *pwdFile, *pwdStdin, os.Stdin)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Fatalln(err)
|
logrus.Fatalln(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err = internal.ValidatePassword(pass); err != nil {
|
||||||
|
logrus.WithError(err).Error("Specified password is invalid")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
cl.Timeout = *timeout
|
cl.Timeout = *timeout
|
||||||
|
|
||||||
accessToken, err := sharedSecretRegister(cfg.ClientAPI.RegistrationSharedSecret, *serverURL, *username, pass, *isAdmin)
|
accessToken, err := sharedSecretRegister(cfg.ClientAPI.RegistrationSharedSecret, *serverURL, *username, pass, *isAdmin)
|
||||||
|
|
|
||||||
|
|
@ -157,11 +157,11 @@ func TestOutboundPeeking(t *testing.T) {
|
||||||
if len(outboundPeeks) != len(peekIDs) {
|
if len(outboundPeeks) != len(peekIDs) {
|
||||||
t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(outboundPeeks))
|
t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(outboundPeeks))
|
||||||
}
|
}
|
||||||
for i := range outboundPeeks {
|
gotPeekIDs := make([]string, 0, len(outboundPeeks))
|
||||||
if outboundPeeks[i].PeekID != peekIDs[i] {
|
for _, p := range outboundPeeks {
|
||||||
t.Fatalf("unexpected peek ID: %s, want %s", outboundPeeks[i].PeekID, peekIDs[i])
|
gotPeekIDs = append(gotPeekIDs, p.PeekID)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
assert.ElementsMatch(t, gotPeekIDs, peekIDs)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -239,10 +239,10 @@ func TestInboundPeeking(t *testing.T) {
|
||||||
if len(inboundPeeks) != len(peekIDs) {
|
if len(inboundPeeks) != len(peekIDs) {
|
||||||
t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(inboundPeeks))
|
t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(inboundPeeks))
|
||||||
}
|
}
|
||||||
for i := range inboundPeeks {
|
gotPeekIDs := make([]string, 0, len(inboundPeeks))
|
||||||
if inboundPeeks[i].PeekID != peekIDs[i] {
|
for _, p := range inboundPeeks {
|
||||||
t.Fatalf("unexpected peek ID: %s, want %s", inboundPeeks[i].PeekID, peekIDs[i])
|
gotPeekIDs = append(gotPeekIDs, p.PeekID)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
assert.ElementsMatch(t, gotPeekIDs, peekIDs)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
||||||
2
go.mod
2
go.mod
|
|
@ -22,7 +22,7 @@ require (
|
||||||
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
|
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
|
||||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
|
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
|
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20221129095800-8835f6db16b8
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab
|
||||||
github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847
|
github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847
|
||||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
|
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
|
||||||
github.com/mattn/go-sqlite3 v1.14.15
|
github.com/mattn/go-sqlite3 v1.14.15
|
||||||
|
|
|
||||||
4
go.sum
4
go.sum
|
|
@ -348,8 +348,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
|
||||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
|
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
|
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
|
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20221129095800-8835f6db16b8 h1:jVvlCGs6OosCdvw9MkfiVnTVnIt7vKMHg/F6th9BtSo=
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab h1:ChaQdT2mpxMm3GRXNOZzLDQ/wOnlKZ8o60LmZGOjdj8=
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20221129095800-8835f6db16b8/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4=
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4=
|
||||||
github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 h1:auIBCi7gfZuvztD0aPr1G/J5Ya5vWr79M/+TJqwD/JM=
|
github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 h1:auIBCi7gfZuvztD0aPr1G/J5Ya5vWr79M/+TJqwD/JM=
|
||||||
github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc=
|
github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc=
|
||||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=
|
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=
|
||||||
|
|
|
||||||
|
|
@ -198,17 +198,12 @@ func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse
|
||||||
|
|
||||||
// MakeHTMLAPI adds Span metrics to the HTML Handler function
|
// MakeHTMLAPI adds Span metrics to the HTML Handler function
|
||||||
// This is used to serve HTML alongside JSON error messages
|
// This is used to serve HTML alongside JSON error messages
|
||||||
func MakeHTMLAPI(metricsName string, enableMetrics bool, f func(http.ResponseWriter, *http.Request) *util.JSONResponse) http.Handler {
|
func MakeHTMLAPI(metricsName string, enableMetrics bool, f func(http.ResponseWriter, *http.Request)) http.Handler {
|
||||||
withSpan := func(w http.ResponseWriter, req *http.Request) {
|
withSpan := func(w http.ResponseWriter, req *http.Request) {
|
||||||
span := opentracing.StartSpan(metricsName)
|
span := opentracing.StartSpan(metricsName)
|
||||||
defer span.Finish()
|
defer span.Finish()
|
||||||
req = req.WithContext(opentracing.ContextWithSpan(req.Context(), span))
|
req = req.WithContext(opentracing.ContextWithSpan(req.Context(), span))
|
||||||
if err := f(w, req); err != nil {
|
f(w, req)
|
||||||
h := util.MakeJSONAPI(util.NewJSONRequestHandler(func(req *http.Request) util.JSONResponse {
|
|
||||||
return *err
|
|
||||||
}))
|
|
||||||
h.ServeHTTP(w, req)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !enableMetrics {
|
if !enableMetrics {
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ type Condition struct {
|
||||||
|
|
||||||
// Pattern indicates the value pattern that must match. Required
|
// Pattern indicates the value pattern that must match. Required
|
||||||
// for EventMatchCondition.
|
// for EventMatchCondition.
|
||||||
Pattern string `json:"pattern,omitempty"`
|
Pattern *string `json:"pattern,omitempty"`
|
||||||
|
|
||||||
// Is indicates the condition that must be fulfilled. Required for
|
// Is indicates the condition that must be fulfilled. Required for
|
||||||
// RoomMemberCountCondition.
|
// RoomMemberCountCondition.
|
||||||
|
|
|
||||||
|
|
@ -15,13 +15,7 @@ func mRuleContainsUserNameDefinition(localpart string) *Rule {
|
||||||
RuleID: MRuleContainsUserName,
|
RuleID: MRuleContainsUserName,
|
||||||
Default: true,
|
Default: true,
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Pattern: localpart,
|
Pattern: &localpart,
|
||||||
Conditions: []*Condition{
|
|
||||||
{
|
|
||||||
Kind: EventMatchCondition,
|
|
||||||
Key: "content.body",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Actions: []*Action{
|
Actions: []*Action{
|
||||||
{Kind: NotifyAction},
|
{Kind: NotifyAction},
|
||||||
{
|
{
|
||||||
|
|
@ -32,7 +26,6 @@ func mRuleContainsUserNameDefinition(localpart string) *Rule {
|
||||||
{
|
{
|
||||||
Kind: SetTweakAction,
|
Kind: SetTweakAction,
|
||||||
Tweak: HighlightTweak,
|
Tweak: HighlightTweak,
|
||||||
Value: true,
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -22,15 +22,15 @@ const (
|
||||||
MRuleTombstone = ".m.rule.tombstone"
|
MRuleTombstone = ".m.rule.tombstone"
|
||||||
MRuleRoomNotif = ".m.rule.roomnotif"
|
MRuleRoomNotif = ".m.rule.roomnotif"
|
||||||
MRuleReaction = ".m.rule.reaction"
|
MRuleReaction = ".m.rule.reaction"
|
||||||
|
MRuleRoomACLs = ".m.rule.room.server_acl"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
mRuleMasterDefinition = Rule{
|
mRuleMasterDefinition = Rule{
|
||||||
RuleID: MRuleMaster,
|
RuleID: MRuleMaster,
|
||||||
Default: true,
|
Default: true,
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
Conditions: []*Condition{},
|
Actions: []*Action{{Kind: DontNotifyAction}},
|
||||||
Actions: []*Action{{Kind: DontNotifyAction}},
|
|
||||||
}
|
}
|
||||||
mRuleSuppressNoticesDefinition = Rule{
|
mRuleSuppressNoticesDefinition = Rule{
|
||||||
RuleID: MRuleSuppressNotices,
|
RuleID: MRuleSuppressNotices,
|
||||||
|
|
@ -40,7 +40,7 @@ var (
|
||||||
{
|
{
|
||||||
Kind: EventMatchCondition,
|
Kind: EventMatchCondition,
|
||||||
Key: "content.msgtype",
|
Key: "content.msgtype",
|
||||||
Pattern: "m.notice",
|
Pattern: pointer("m.notice"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Actions: []*Action{{Kind: DontNotifyAction}},
|
Actions: []*Action{{Kind: DontNotifyAction}},
|
||||||
|
|
@ -53,7 +53,7 @@ var (
|
||||||
{
|
{
|
||||||
Kind: EventMatchCondition,
|
Kind: EventMatchCondition,
|
||||||
Key: "type",
|
Key: "type",
|
||||||
Pattern: "m.room.member",
|
Pattern: pointer("m.room.member"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Actions: []*Action{{Kind: DontNotifyAction}},
|
Actions: []*Action{{Kind: DontNotifyAction}},
|
||||||
|
|
@ -73,7 +73,6 @@ var (
|
||||||
{
|
{
|
||||||
Kind: SetTweakAction,
|
Kind: SetTweakAction,
|
||||||
Tweak: HighlightTweak,
|
Tweak: HighlightTweak,
|
||||||
Value: true,
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
@ -85,12 +84,12 @@ var (
|
||||||
{
|
{
|
||||||
Kind: EventMatchCondition,
|
Kind: EventMatchCondition,
|
||||||
Key: "type",
|
Key: "type",
|
||||||
Pattern: "m.room.tombstone",
|
Pattern: pointer("m.room.tombstone"),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Kind: EventMatchCondition,
|
Kind: EventMatchCondition,
|
||||||
Key: "state_key",
|
Key: "state_key",
|
||||||
Pattern: "",
|
Pattern: pointer(""),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Actions: []*Action{
|
Actions: []*Action{
|
||||||
|
|
@ -98,10 +97,27 @@ var (
|
||||||
{
|
{
|
||||||
Kind: SetTweakAction,
|
Kind: SetTweakAction,
|
||||||
Tweak: HighlightTweak,
|
Tweak: HighlightTweak,
|
||||||
Value: true,
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
mRuleACLsDefinition = Rule{
|
||||||
|
RuleID: MRuleRoomACLs,
|
||||||
|
Default: true,
|
||||||
|
Enabled: true,
|
||||||
|
Conditions: []*Condition{
|
||||||
|
{
|
||||||
|
Kind: EventMatchCondition,
|
||||||
|
Key: "type",
|
||||||
|
Pattern: pointer("m.room.server_acl"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Kind: EventMatchCondition,
|
||||||
|
Key: "state_key",
|
||||||
|
Pattern: pointer(""),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Actions: []*Action{},
|
||||||
|
}
|
||||||
mRuleRoomNotifDefinition = Rule{
|
mRuleRoomNotifDefinition = Rule{
|
||||||
RuleID: MRuleRoomNotif,
|
RuleID: MRuleRoomNotif,
|
||||||
Default: true,
|
Default: true,
|
||||||
|
|
@ -110,7 +126,7 @@ var (
|
||||||
{
|
{
|
||||||
Kind: EventMatchCondition,
|
Kind: EventMatchCondition,
|
||||||
Key: "content.body",
|
Key: "content.body",
|
||||||
Pattern: "@room",
|
Pattern: pointer("@room"),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Kind: SenderNotificationPermissionCondition,
|
Kind: SenderNotificationPermissionCondition,
|
||||||
|
|
@ -122,7 +138,6 @@ var (
|
||||||
{
|
{
|
||||||
Kind: SetTweakAction,
|
Kind: SetTweakAction,
|
||||||
Tweak: HighlightTweak,
|
Tweak: HighlightTweak,
|
||||||
Value: true,
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
@ -134,7 +149,7 @@ var (
|
||||||
{
|
{
|
||||||
Kind: EventMatchCondition,
|
Kind: EventMatchCondition,
|
||||||
Key: "type",
|
Key: "type",
|
||||||
Pattern: "m.reaction",
|
Pattern: pointer("m.reaction"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Actions: []*Action{
|
Actions: []*Action{
|
||||||
|
|
@ -152,17 +167,17 @@ func mRuleInviteForMeDefinition(userID string) *Rule {
|
||||||
{
|
{
|
||||||
Kind: EventMatchCondition,
|
Kind: EventMatchCondition,
|
||||||
Key: "type",
|
Key: "type",
|
||||||
Pattern: "m.room.member",
|
Pattern: pointer("m.room.member"),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Kind: EventMatchCondition,
|
Kind: EventMatchCondition,
|
||||||
Key: "content.membership",
|
Key: "content.membership",
|
||||||
Pattern: "invite",
|
Pattern: pointer("invite"),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Kind: EventMatchCondition,
|
Kind: EventMatchCondition,
|
||||||
Key: "state_key",
|
Key: "state_key",
|
||||||
Pattern: userID,
|
Pattern: pointer(userID),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Actions: []*Action{
|
Actions: []*Action{
|
||||||
|
|
@ -172,11 +187,6 @@ func mRuleInviteForMeDefinition(userID string) *Rule {
|
||||||
Tweak: SoundTweak,
|
Tweak: SoundTweak,
|
||||||
Value: "default",
|
Value: "default",
|
||||||
},
|
},
|
||||||
{
|
|
||||||
Kind: SetTweakAction,
|
|
||||||
Tweak: HighlightTweak,
|
|
||||||
Value: false,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
111
internal/pushrules/default_pushrules_test.go
Normal file
111
internal/pushrules/default_pushrules_test.go
Normal file
|
|
@ -0,0 +1,111 @@
|
||||||
|
package pushrules
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Tests that the pre-defined rules as of
|
||||||
|
// https://spec.matrix.org/v1.4/client-server-api/#predefined-rules
|
||||||
|
// are correct
|
||||||
|
func TestDefaultRules(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
name string
|
||||||
|
inputBytes []byte
|
||||||
|
want Rule
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []testCase{
|
||||||
|
// Default override rules
|
||||||
|
{
|
||||||
|
name: ".m.rule.master",
|
||||||
|
inputBytes: []byte(`{"rule_id":".m.rule.master","default":true,"enabled":false,"actions":["dont_notify"]}`),
|
||||||
|
want: mRuleMasterDefinition,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: ".m.rule.suppress_notices",
|
||||||
|
inputBytes: []byte(`{"rule_id":".m.rule.suppress_notices","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"content.msgtype","pattern":"m.notice"}],"actions":["dont_notify"]}`),
|
||||||
|
want: mRuleSuppressNoticesDefinition,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: ".m.rule.invite_for_me",
|
||||||
|
inputBytes: []byte(`{"rule_id":".m.rule.invite_for_me","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.member"},{"kind":"event_match","key":"content.membership","pattern":"invite"},{"kind":"event_match","key":"state_key","pattern":"@test:localhost"}],"actions":["notify",{"set_tweak":"sound","value":"default"}]}`),
|
||||||
|
want: *mRuleInviteForMeDefinition("@test:localhost"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: ".m.rule.member_event",
|
||||||
|
inputBytes: []byte(`{"rule_id":".m.rule.member_event","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.member"}],"actions":["dont_notify"]}`),
|
||||||
|
want: mRuleMemberEventDefinition,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: ".m.rule.contains_display_name",
|
||||||
|
inputBytes: []byte(`{"rule_id":".m.rule.contains_display_name","default":true,"enabled":true,"conditions":[{"kind":"contains_display_name"}],"actions":["notify",{"set_tweak":"sound","value":"default"},{"set_tweak":"highlight"}]}`),
|
||||||
|
want: mRuleContainsDisplayNameDefinition,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: ".m.rule.tombstone",
|
||||||
|
inputBytes: []byte(`{"rule_id":".m.rule.tombstone","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.tombstone"},{"kind":"event_match","key":"state_key","pattern":""}],"actions":["notify",{"set_tweak":"highlight"}]}`),
|
||||||
|
want: mRuleTombstoneDefinition,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: ".m.rule.room.server_acl",
|
||||||
|
inputBytes: []byte(`{"rule_id":".m.rule.room.server_acl","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.server_acl"},{"kind":"event_match","key":"state_key","pattern":""}],"actions":[]}`),
|
||||||
|
want: mRuleACLsDefinition,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: ".m.rule.roomnotif",
|
||||||
|
inputBytes: []byte(`{"rule_id":".m.rule.roomnotif","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"content.body","pattern":"@room"},{"kind":"sender_notification_permission","key":"room"}],"actions":["notify",{"set_tweak":"highlight"}]}`),
|
||||||
|
want: mRuleRoomNotifDefinition,
|
||||||
|
},
|
||||||
|
// Default content rules
|
||||||
|
{
|
||||||
|
name: ".m.rule.contains_user_name",
|
||||||
|
inputBytes: []byte(`{"rule_id":".m.rule.contains_user_name","default":true,"enabled":true,"actions":["notify",{"set_tweak":"sound","value":"default"},{"set_tweak":"highlight"}],"pattern":"myLocalUser"}`),
|
||||||
|
want: *mRuleContainsUserNameDefinition("myLocalUser"),
|
||||||
|
},
|
||||||
|
// default underride rules
|
||||||
|
{
|
||||||
|
name: ".m.rule.call",
|
||||||
|
inputBytes: []byte(`{"rule_id":".m.rule.call","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.call.invite"}],"actions":["notify",{"set_tweak":"sound","value":"ring"}]}`),
|
||||||
|
want: mRuleCallDefinition,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: ".m.rule.encrypted_room_one_to_one",
|
||||||
|
inputBytes: []byte(`{"rule_id":".m.rule.encrypted_room_one_to_one","default":true,"enabled":true,"conditions":[{"kind":"room_member_count","is":"2"},{"kind":"event_match","key":"type","pattern":"m.room.encrypted"}],"actions":["notify",{"set_tweak":"sound","value":"default"}]}`),
|
||||||
|
want: mRuleEncryptedRoomOneToOneDefinition,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: ".m.rule.room_one_to_one",
|
||||||
|
inputBytes: []byte(`{"rule_id":".m.rule.room_one_to_one","default":true,"enabled":true,"conditions":[{"kind":"room_member_count","is":"2"},{"kind":"event_match","key":"type","pattern":"m.room.message"}],"actions":["notify",{"set_tweak":"sound","value":"default"}]}`),
|
||||||
|
want: mRuleRoomOneToOneDefinition,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: ".m.rule.message",
|
||||||
|
inputBytes: []byte(`{"rule_id":".m.rule.message","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.message"}],"actions":["notify"]}`),
|
||||||
|
want: mRuleMessageDefinition,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: ".m.rule.encrypted",
|
||||||
|
inputBytes: []byte(`{"rule_id":".m.rule.encrypted","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.encrypted"}],"actions":["notify"]}`),
|
||||||
|
want: mRuleEncryptedDefinition,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
r := Rule{}
|
||||||
|
// unmarshal predefined push rules
|
||||||
|
err := json.Unmarshal(tc.inputBytes, &r)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, tc.want, r)
|
||||||
|
|
||||||
|
// and reverse it to check we get the expected result
|
||||||
|
got, err := json.Marshal(r)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, string(got), string(tc.inputBytes))
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -25,7 +25,7 @@ var (
|
||||||
{
|
{
|
||||||
Kind: EventMatchCondition,
|
Kind: EventMatchCondition,
|
||||||
Key: "type",
|
Key: "type",
|
||||||
Pattern: "m.call.invite",
|
Pattern: pointer("m.call.invite"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Actions: []*Action{
|
Actions: []*Action{
|
||||||
|
|
@ -35,11 +35,6 @@ var (
|
||||||
Tweak: SoundTweak,
|
Tweak: SoundTweak,
|
||||||
Value: "ring",
|
Value: "ring",
|
||||||
},
|
},
|
||||||
{
|
|
||||||
Kind: SetTweakAction,
|
|
||||||
Tweak: HighlightTweak,
|
|
||||||
Value: false,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
mRuleEncryptedRoomOneToOneDefinition = Rule{
|
mRuleEncryptedRoomOneToOneDefinition = Rule{
|
||||||
|
|
@ -54,7 +49,7 @@ var (
|
||||||
{
|
{
|
||||||
Kind: EventMatchCondition,
|
Kind: EventMatchCondition,
|
||||||
Key: "type",
|
Key: "type",
|
||||||
Pattern: "m.room.encrypted",
|
Pattern: pointer("m.room.encrypted"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Actions: []*Action{
|
Actions: []*Action{
|
||||||
|
|
@ -64,11 +59,6 @@ var (
|
||||||
Tweak: SoundTweak,
|
Tweak: SoundTweak,
|
||||||
Value: "default",
|
Value: "default",
|
||||||
},
|
},
|
||||||
{
|
|
||||||
Kind: SetTweakAction,
|
|
||||||
Tweak: HighlightTweak,
|
|
||||||
Value: false,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
mRuleRoomOneToOneDefinition = Rule{
|
mRuleRoomOneToOneDefinition = Rule{
|
||||||
|
|
@ -83,20 +73,15 @@ var (
|
||||||
{
|
{
|
||||||
Kind: EventMatchCondition,
|
Kind: EventMatchCondition,
|
||||||
Key: "type",
|
Key: "type",
|
||||||
Pattern: "m.room.message",
|
Pattern: pointer("m.room.message"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Actions: []*Action{
|
Actions: []*Action{
|
||||||
{Kind: NotifyAction},
|
{Kind: NotifyAction},
|
||||||
{
|
{
|
||||||
Kind: SetTweakAction,
|
Kind: SetTweakAction,
|
||||||
Tweak: HighlightTweak,
|
Tweak: SoundTweak,
|
||||||
Value: false,
|
Value: "default",
|
||||||
},
|
|
||||||
{
|
|
||||||
Kind: SetTweakAction,
|
|
||||||
Tweak: HighlightTweak,
|
|
||||||
Value: false,
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
@ -108,16 +93,11 @@ var (
|
||||||
{
|
{
|
||||||
Kind: EventMatchCondition,
|
Kind: EventMatchCondition,
|
||||||
Key: "type",
|
Key: "type",
|
||||||
Pattern: "m.room.message",
|
Pattern: pointer("m.room.message"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Actions: []*Action{
|
Actions: []*Action{
|
||||||
{Kind: NotifyAction},
|
{Kind: NotifyAction},
|
||||||
{
|
|
||||||
Kind: SetTweakAction,
|
|
||||||
Tweak: HighlightTweak,
|
|
||||||
Value: false,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
mRuleEncryptedDefinition = Rule{
|
mRuleEncryptedDefinition = Rule{
|
||||||
|
|
@ -128,16 +108,11 @@ var (
|
||||||
{
|
{
|
||||||
Kind: EventMatchCondition,
|
Kind: EventMatchCondition,
|
||||||
Key: "type",
|
Key: "type",
|
||||||
Pattern: "m.room.encrypted",
|
Pattern: pointer("m.room.encrypted"),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Actions: []*Action{
|
Actions: []*Action{
|
||||||
{Kind: NotifyAction},
|
{Kind: NotifyAction},
|
||||||
{
|
|
||||||
Kind: SetTweakAction,
|
|
||||||
Tweak: HighlightTweak,
|
|
||||||
Value: false,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -104,7 +104,10 @@ func ruleMatches(rule *Rule, kind Kind, event *gomatrixserverlib.Event, ec Evalu
|
||||||
case ContentKind:
|
case ContentKind:
|
||||||
// TODO: "These configure behaviour for (unencrypted) messages
|
// TODO: "These configure behaviour for (unencrypted) messages
|
||||||
// that match certain patterns." - Does that mean "content.body"?
|
// that match certain patterns." - Does that mean "content.body"?
|
||||||
return patternMatches("content.body", rule.Pattern, event)
|
if rule.Pattern == nil {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return patternMatches("content.body", *rule.Pattern, event)
|
||||||
|
|
||||||
case RoomKind:
|
case RoomKind:
|
||||||
return rule.RuleID == event.RoomID(), nil
|
return rule.RuleID == event.RoomID(), nil
|
||||||
|
|
@ -120,7 +123,10 @@ func ruleMatches(rule *Rule, kind Kind, event *gomatrixserverlib.Event, ec Evalu
|
||||||
func conditionMatches(cond *Condition, event *gomatrixserverlib.Event, ec EvaluationContext) (bool, error) {
|
func conditionMatches(cond *Condition, event *gomatrixserverlib.Event, ec EvaluationContext) (bool, error) {
|
||||||
switch cond.Kind {
|
switch cond.Kind {
|
||||||
case EventMatchCondition:
|
case EventMatchCondition:
|
||||||
return patternMatches(cond.Key, cond.Pattern, event)
|
if cond.Pattern == nil {
|
||||||
|
return false, fmt.Errorf("missing condition pattern")
|
||||||
|
}
|
||||||
|
return patternMatches(cond.Key, *cond.Pattern, event)
|
||||||
|
|
||||||
case ContainsDisplayNameCondition:
|
case ContainsDisplayNameCondition:
|
||||||
return patternMatches("content.body", ec.UserDisplayName(), event)
|
return patternMatches("content.body", ec.UserDisplayName(), event)
|
||||||
|
|
|
||||||
|
|
@ -79,8 +79,8 @@ func TestRuleMatches(t *testing.T) {
|
||||||
{"underrideConditionMatch", UnderrideKind, Rule{Enabled: true}, `{}`, true},
|
{"underrideConditionMatch", UnderrideKind, Rule{Enabled: true}, `{}`, true},
|
||||||
{"underrideConditionNoMatch", UnderrideKind, Rule{Enabled: true, Conditions: []*Condition{{}}}, `{}`, false},
|
{"underrideConditionNoMatch", UnderrideKind, Rule{Enabled: true, Conditions: []*Condition{{}}}, `{}`, false},
|
||||||
|
|
||||||
{"contentMatch", ContentKind, Rule{Enabled: true, Pattern: "b"}, `{"content":{"body":"abc"}}`, true},
|
{"contentMatch", ContentKind, Rule{Enabled: true, Pattern: pointer("b")}, `{"content":{"body":"abc"}}`, true},
|
||||||
{"contentNoMatch", ContentKind, Rule{Enabled: true, Pattern: "d"}, `{"content":{"body":"abc"}}`, false},
|
{"contentNoMatch", ContentKind, Rule{Enabled: true, Pattern: pointer("d")}, `{"content":{"body":"abc"}}`, false},
|
||||||
|
|
||||||
{"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!room@example.com"}`, true},
|
{"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!room@example.com"}`, true},
|
||||||
{"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!otherroom@example.com"}`, false},
|
{"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!otherroom@example.com"}`, false},
|
||||||
|
|
@ -106,41 +106,44 @@ func TestConditionMatches(t *testing.T) {
|
||||||
Name string
|
Name string
|
||||||
Cond Condition
|
Cond Condition
|
||||||
EventJSON string
|
EventJSON string
|
||||||
Want bool
|
WantMatch bool
|
||||||
|
WantErr bool
|
||||||
}{
|
}{
|
||||||
{"empty", Condition{}, `{}`, false},
|
{Name: "empty", Cond: Condition{}, EventJSON: `{}`, WantMatch: false, WantErr: false},
|
||||||
{"empty", Condition{Kind: "unknownstring"}, `{}`, false},
|
{Name: "empty", Cond: Condition{Kind: "unknownstring"}, EventJSON: `{}`, WantMatch: false, WantErr: false},
|
||||||
|
|
||||||
// Neither of these should match because `content` is not a full string match,
|
// Neither of these should match because `content` is not a full string match,
|
||||||
// and `content.body` is not a string value.
|
// and `content.body` is not a string value.
|
||||||
{"eventMatch", Condition{Kind: EventMatchCondition, Key: "content"}, `{"content":{}}`, false},
|
{Name: "eventMatch", Cond: Condition{Kind: EventMatchCondition, Key: "content", Pattern: pointer("")}, EventJSON: `{"content":{}}`, WantMatch: false, WantErr: false},
|
||||||
{"eventBodyMatch", Condition{Kind: EventMatchCondition, Key: "content.body", Is: "3"}, `{"content":{"body": 3}}`, false},
|
{Name: "eventBodyMatch", Cond: Condition{Kind: EventMatchCondition, Key: "content.body", Is: "3", Pattern: pointer("")}, EventJSON: `{"content":{"body": "3"}}`, WantMatch: false, WantErr: false},
|
||||||
|
{Name: "eventBodyMatch matches", Cond: Condition{Kind: EventMatchCondition, Key: "content.body", Pattern: pointer("world")}, EventJSON: `{"content":{"body": "hello world!"}}`, WantMatch: true, WantErr: false},
|
||||||
|
{Name: "EventMatch missing pattern", Cond: Condition{Kind: EventMatchCondition, Key: "content.body"}, EventJSON: `{"content":{"body": "hello world!"}}`, WantMatch: false, WantErr: true},
|
||||||
|
|
||||||
{"displayNameNoMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"something without displayname"}}`, false},
|
{Name: "displayNameNoMatch", Cond: Condition{Kind: ContainsDisplayNameCondition}, EventJSON: `{"content":{"body":"something without displayname"}}`, WantMatch: false, WantErr: false},
|
||||||
{"displayNameMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"hello Dear User, how are you?"}}`, true},
|
{Name: "displayNameMatch", Cond: Condition{Kind: ContainsDisplayNameCondition}, EventJSON: `{"content":{"body":"hello Dear User, how are you?"}}`, WantMatch: true, WantErr: false},
|
||||||
|
|
||||||
{"roomMemberCountLessNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "<2"}, `{}`, false},
|
{Name: "roomMemberCountLessNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "<2"}, EventJSON: `{}`, WantMatch: false, WantErr: false},
|
||||||
{"roomMemberCountLessMatch", Condition{Kind: RoomMemberCountCondition, Is: "<3"}, `{}`, true},
|
{Name: "roomMemberCountLessMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "<3"}, EventJSON: `{}`, WantMatch: true, WantErr: false},
|
||||||
{"roomMemberCountLessEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "<=1"}, `{}`, false},
|
{Name: "roomMemberCountLessEqualNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "<=1"}, EventJSON: `{}`, WantMatch: false, WantErr: false},
|
||||||
{"roomMemberCountLessEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: "<=2"}, `{}`, true},
|
{Name: "roomMemberCountLessEqualMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "<=2"}, EventJSON: `{}`, WantMatch: true, WantErr: false},
|
||||||
{"roomMemberCountEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "==1"}, `{}`, false},
|
{Name: "roomMemberCountEqualNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "==1"}, EventJSON: `{}`, WantMatch: false, WantErr: false},
|
||||||
{"roomMemberCountEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: "==2"}, `{}`, true},
|
{Name: "roomMemberCountEqualMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "==2"}, EventJSON: `{}`, WantMatch: true, WantErr: false},
|
||||||
{"roomMemberCountGreaterEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: ">=3"}, `{}`, false},
|
{Name: "roomMemberCountGreaterEqualNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: ">=3"}, EventJSON: `{}`, WantMatch: false, WantErr: false},
|
||||||
{"roomMemberCountGreaterEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: ">=2"}, `{}`, true},
|
{Name: "roomMemberCountGreaterEqualMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: ">=2"}, EventJSON: `{}`, WantMatch: true, WantErr: false},
|
||||||
{"roomMemberCountGreaterNoMatch", Condition{Kind: RoomMemberCountCondition, Is: ">2"}, `{}`, false},
|
{Name: "roomMemberCountGreaterNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: ">2"}, EventJSON: `{}`, WantMatch: false, WantErr: false},
|
||||||
{"roomMemberCountGreaterMatch", Condition{Kind: RoomMemberCountCondition, Is: ">1"}, `{}`, true},
|
{Name: "roomMemberCountGreaterMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: ">1"}, EventJSON: `{}`, WantMatch: true, WantErr: false},
|
||||||
|
|
||||||
{"senderNotificationPermissionMatch", Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, `{"sender":"@poweruser:example.com"}`, true},
|
{Name: "senderNotificationPermissionMatch", Cond: Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, EventJSON: `{"sender":"@poweruser:example.com"}`, WantMatch: true, WantErr: false},
|
||||||
{"senderNotificationPermissionNoMatch", Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, `{"sender":"@nobody:example.com"}`, false},
|
{Name: "senderNotificationPermissionNoMatch", Cond: Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, EventJSON: `{"sender":"@nobody:example.com"}`, WantMatch: false, WantErr: false},
|
||||||
}
|
}
|
||||||
for _, tst := range tsts {
|
for _, tst := range tsts {
|
||||||
t.Run(tst.Name, func(t *testing.T) {
|
t.Run(tst.Name, func(t *testing.T) {
|
||||||
got, err := conditionMatches(&tst.Cond, mustEventFromJSON(t, tst.EventJSON), &fakeEvaluationContext{2})
|
got, err := conditionMatches(&tst.Cond, mustEventFromJSON(t, tst.EventJSON), &fakeEvaluationContext{2})
|
||||||
if err != nil {
|
if err != nil && !tst.WantErr {
|
||||||
t.Fatalf("conditionMatches failed: %v", err)
|
t.Fatalf("conditionMatches failed: %v", err)
|
||||||
}
|
}
|
||||||
if got != tst.Want {
|
if got != tst.WantMatch {
|
||||||
t.Errorf("conditionMatches: got %v, want %v on %s", got, tst.Want, tst.Name)
|
t.Errorf("conditionMatches: got %v, want %v on %s", got, tst.WantMatch, tst.Name)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -36,18 +36,18 @@ type Rule struct {
|
||||||
// around. Required.
|
// around. Required.
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
|
|
||||||
|
// Conditions provide the rule's conditions for OverrideKind and
|
||||||
|
// UnderrideKind. Not allowed for other kinds.
|
||||||
|
Conditions []*Condition `json:"conditions,omitempty"`
|
||||||
|
|
||||||
// Actions describe the desired outcome, should the rule
|
// Actions describe the desired outcome, should the rule
|
||||||
// match. Required.
|
// match. Required.
|
||||||
Actions []*Action `json:"actions"`
|
Actions []*Action `json:"actions"`
|
||||||
|
|
||||||
// Conditions provide the rule's conditions for OverrideKind and
|
|
||||||
// UnderrideKind. Not allowed for other kinds.
|
|
||||||
Conditions []*Condition `json:"conditions"`
|
|
||||||
|
|
||||||
// Pattern is the body pattern to match for ContentKind. Required
|
// Pattern is the body pattern to match for ContentKind. Required
|
||||||
// for that kind. The interpretation is the same as that of
|
// for that kind. The interpretation is the same as that of
|
||||||
// Condition.Pattern.
|
// Condition.Pattern.
|
||||||
Pattern string `json:"pattern"`
|
Pattern *string `json:"pattern,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scope only has one valid value. See also AccountRuleSets.
|
// Scope only has one valid value. See also AccountRuleSets.
|
||||||
|
|
|
||||||
|
|
@ -128,3 +128,7 @@ func parseRoomMemberCountCondition(s string) (func(int) bool, error) {
|
||||||
b = int(v)
|
b = int(v)
|
||||||
return cmp, nil
|
return cmp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func pointer[t any](s t) *t {
|
||||||
|
return &s
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,10 @@ func ValidateRule(kind Kind, rule *Rule) []error {
|
||||||
}
|
}
|
||||||
|
|
||||||
case ContentKind:
|
case ContentKind:
|
||||||
if rule.Pattern == "" {
|
if rule.Pattern == nil {
|
||||||
|
errs = append(errs, fmt.Errorf("missing content rule pattern"))
|
||||||
|
}
|
||||||
|
if rule.Pattern != nil && *rule.Pattern == "" {
|
||||||
errs = append(errs, fmt.Errorf("missing content rule pattern"))
|
errs = append(errs, fmt.Errorf("missing content rule pattern"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,15 +12,16 @@ func TestValidateRuleNegatives(t *testing.T) {
|
||||||
Rule Rule
|
Rule Rule
|
||||||
WantErrString string
|
WantErrString string
|
||||||
}{
|
}{
|
||||||
{"emptyRuleID", OverrideKind, Rule{}, "invalid rule ID"},
|
{Name: "emptyRuleID", Kind: OverrideKind, Rule: Rule{}, WantErrString: "invalid rule ID"},
|
||||||
{"invalidKind", Kind("something else"), Rule{}, "invalid rule kind"},
|
{Name: "invalidKind", Kind: Kind("something else"), Rule: Rule{}, WantErrString: "invalid rule kind"},
|
||||||
{"ruleIDBackslash", OverrideKind, Rule{RuleID: "#foo\\:example.com"}, "invalid rule ID"},
|
{Name: "ruleIDBackslash", Kind: OverrideKind, Rule: Rule{RuleID: "#foo\\:example.com"}, WantErrString: "invalid rule ID"},
|
||||||
{"noActions", OverrideKind, Rule{}, "missing actions"},
|
{Name: "noActions", Kind: OverrideKind, Rule: Rule{}, WantErrString: "missing actions"},
|
||||||
{"invalidAction", OverrideKind, Rule{Actions: []*Action{{}}}, "invalid rule action kind"},
|
{Name: "invalidAction", Kind: OverrideKind, Rule: Rule{Actions: []*Action{{}}}, WantErrString: "invalid rule action kind"},
|
||||||
{"invalidCondition", OverrideKind, Rule{Conditions: []*Condition{{}}}, "invalid rule condition kind"},
|
{Name: "invalidCondition", Kind: OverrideKind, Rule: Rule{Conditions: []*Condition{{}}}, WantErrString: "invalid rule condition kind"},
|
||||||
{"overrideNoCondition", OverrideKind, Rule{}, "missing rule conditions"},
|
{Name: "overrideNoCondition", Kind: OverrideKind, Rule: Rule{}, WantErrString: "missing rule conditions"},
|
||||||
{"underrideNoCondition", UnderrideKind, Rule{}, "missing rule conditions"},
|
{Name: "underrideNoCondition", Kind: UnderrideKind, Rule: Rule{}, WantErrString: "missing rule conditions"},
|
||||||
{"contentNoPattern", ContentKind, Rule{}, "missing content rule pattern"},
|
{Name: "contentNoPattern", Kind: ContentKind, Rule: Rule{}, WantErrString: "missing content rule pattern"},
|
||||||
|
{Name: "contentEmptyPattern", Kind: ContentKind, Rule: Rule{Pattern: pointer("")}, WantErrString: "missing content rule pattern"},
|
||||||
}
|
}
|
||||||
for _, tst := range tsts {
|
for _, tst := range tsts {
|
||||||
t.Run(tst.Name, func(t *testing.T) {
|
t.Run(tst.Name, func(t *testing.T) {
|
||||||
|
|
|
||||||
|
|
@ -15,30 +15,96 @@
|
||||||
package internal
|
package internal
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"regexp"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
const minPasswordLength = 8 // http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based
|
const (
|
||||||
|
maxUsernameLength = 254 // http://matrix.org/speculator/spec/HEAD/intro.html#user-identifiers TODO account for domain
|
||||||
|
|
||||||
const maxPasswordLength = 512 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
|
minPasswordLength = 8 // http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based
|
||||||
|
maxPasswordLength = 512 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
|
||||||
|
)
|
||||||
|
|
||||||
// ValidatePassword returns an error response if the password is invalid
|
var (
|
||||||
func ValidatePassword(password string) *util.JSONResponse {
|
ErrPasswordTooLong = fmt.Errorf("password too long: max %d characters", maxPasswordLength)
|
||||||
|
ErrPasswordWeak = fmt.Errorf("password too weak: min %d characters", minPasswordLength)
|
||||||
|
ErrUsernameTooLong = fmt.Errorf("username exceeds the maximum length of %d characters", maxUsernameLength)
|
||||||
|
ErrUsernameInvalid = errors.New("username can only contain characters a-z, 0-9, or '_-./='")
|
||||||
|
ErrUsernameUnderscore = errors.New("username cannot start with a '_'")
|
||||||
|
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`)
|
||||||
|
)
|
||||||
|
|
||||||
|
// ValidatePassword returns an error if the password is invalid
|
||||||
|
func ValidatePassword(password string) error {
|
||||||
// https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
|
// https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
|
||||||
if len(password) > maxPasswordLength {
|
if len(password) > maxPasswordLength {
|
||||||
return &util.JSONResponse{
|
return ErrPasswordTooLong
|
||||||
Code: http.StatusBadRequest,
|
|
||||||
JSON: jsonerror.BadJSON(fmt.Sprintf("password too long: max %d characters", maxPasswordLength)),
|
|
||||||
}
|
|
||||||
} else if len(password) > 0 && len(password) < minPasswordLength {
|
} else if len(password) > 0 && len(password) < minPasswordLength {
|
||||||
|
return ErrPasswordWeak
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PasswordResponse returns a util.JSONResponse for a given error, if any.
|
||||||
|
func PasswordResponse(err error) *util.JSONResponse {
|
||||||
|
switch err {
|
||||||
|
case ErrPasswordWeak:
|
||||||
return &util.JSONResponse{
|
return &util.JSONResponse{
|
||||||
Code: http.StatusBadRequest,
|
Code: http.StatusBadRequest,
|
||||||
JSON: jsonerror.WeakPassword(fmt.Sprintf("password too weak: min %d chars", minPasswordLength)),
|
JSON: jsonerror.WeakPassword(ErrPasswordWeak.Error()),
|
||||||
|
}
|
||||||
|
case ErrPasswordTooLong:
|
||||||
|
return &util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.BadJSON(ErrPasswordTooLong.Error()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ValidateUsername returns an error if the username is invalid
|
||||||
|
func ValidateUsername(localpart string, domain gomatrixserverlib.ServerName) error {
|
||||||
|
// https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
|
||||||
|
if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength {
|
||||||
|
return ErrUsernameTooLong
|
||||||
|
} else if !validUsernameRegex.MatchString(localpart) {
|
||||||
|
return ErrUsernameInvalid
|
||||||
|
} else if localpart[0] == '_' { // Regex checks its not a zero length string
|
||||||
|
return ErrUsernameUnderscore
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UsernameResponse returns a util.JSONResponse for the given error, if any.
|
||||||
|
func UsernameResponse(err error) *util.JSONResponse {
|
||||||
|
switch err {
|
||||||
|
case ErrUsernameTooLong:
|
||||||
|
return &util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.BadJSON(err.Error()),
|
||||||
|
}
|
||||||
|
case ErrUsernameInvalid, ErrUsernameUnderscore:
|
||||||
|
return &util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.InvalidUsername(err.Error()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateApplicationServiceUsername returns an error if the username is invalid for an application service
|
||||||
|
func ValidateApplicationServiceUsername(localpart string, domain gomatrixserverlib.ServerName) error {
|
||||||
|
if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength {
|
||||||
|
return ErrUsernameTooLong
|
||||||
|
} else if !validUsernameRegex.MatchString(localpart) {
|
||||||
|
return ErrUsernameInvalid
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
||||||
170
internal/validate_test.go
Normal file
170
internal/validate_test.go
Normal file
|
|
@ -0,0 +1,170 @@
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_validatePassword(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
password string
|
||||||
|
wantError error
|
||||||
|
wantJSON *util.JSONResponse
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "password too short",
|
||||||
|
password: "shortpw",
|
||||||
|
wantError: ErrPasswordWeak,
|
||||||
|
wantJSON: &util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.WeakPassword(ErrPasswordWeak.Error())},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "password too long",
|
||||||
|
password: strings.Repeat("a", maxPasswordLength+1),
|
||||||
|
wantError: ErrPasswordTooLong,
|
||||||
|
wantJSON: &util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(ErrPasswordTooLong.Error())},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "password OK",
|
||||||
|
password: util.RandomString(10),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
gotErr := ValidatePassword(tt.password)
|
||||||
|
if !reflect.DeepEqual(gotErr, tt.wantError) {
|
||||||
|
t.Errorf("validatePassword() = %v, wantJSON %v", gotErr, tt.wantError)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := PasswordResponse(gotErr); !reflect.DeepEqual(got, tt.wantJSON) {
|
||||||
|
t.Errorf("validatePassword() = %v, wantJSON %v", got, tt.wantJSON)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_validateUsername(t *testing.T) {
|
||||||
|
tooLongUsername := strings.Repeat("a", maxUsernameLength)
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
localpart string
|
||||||
|
domain gomatrixserverlib.ServerName
|
||||||
|
wantErr error
|
||||||
|
wantJSON *util.JSONResponse
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty username",
|
||||||
|
localpart: "",
|
||||||
|
domain: "localhost",
|
||||||
|
wantErr: ErrUsernameInvalid,
|
||||||
|
wantJSON: &util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid username",
|
||||||
|
localpart: "INVALIDUSERNAME",
|
||||||
|
domain: "localhost",
|
||||||
|
wantErr: ErrUsernameInvalid,
|
||||||
|
wantJSON: &util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "username too long",
|
||||||
|
localpart: tooLongUsername,
|
||||||
|
domain: "localhost",
|
||||||
|
wantErr: ErrUsernameTooLong,
|
||||||
|
wantJSON: &util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.BadJSON(ErrUsernameTooLong.Error()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "localpart starting with an underscore",
|
||||||
|
localpart: "_notvalid",
|
||||||
|
domain: "localhost",
|
||||||
|
wantErr: ErrUsernameUnderscore,
|
||||||
|
wantJSON: &util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.InvalidUsername(ErrUsernameUnderscore.Error()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid username",
|
||||||
|
localpart: "valid",
|
||||||
|
domain: "localhost",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "complex username",
|
||||||
|
localpart: "f00_bar-baz.=40/",
|
||||||
|
domain: "localhost",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "rejects emoji username 💥",
|
||||||
|
localpart: "💥",
|
||||||
|
domain: "localhost",
|
||||||
|
wantErr: ErrUsernameInvalid,
|
||||||
|
wantJSON: &util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "special characters are allowed",
|
||||||
|
localpart: "/dev/null",
|
||||||
|
domain: "localhost",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "special characters are allowed 2",
|
||||||
|
localpart: "i_am_allowed=1",
|
||||||
|
domain: "localhost",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not all special characters are allowed",
|
||||||
|
localpart: "notallowed#", // contains #
|
||||||
|
domain: "localhost",
|
||||||
|
wantErr: ErrUsernameInvalid,
|
||||||
|
wantJSON: &util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "username containing numbers",
|
||||||
|
localpart: "hello1337",
|
||||||
|
domain: "localhost",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
gotErr := ValidateUsername(tt.localpart, tt.domain)
|
||||||
|
if !reflect.DeepEqual(gotErr, tt.wantErr) {
|
||||||
|
t.Errorf("ValidateUsername() = %v, wantErr %v", gotErr, tt.wantErr)
|
||||||
|
}
|
||||||
|
if gotJSON := UsernameResponse(gotErr); !reflect.DeepEqual(gotJSON, tt.wantJSON) {
|
||||||
|
t.Errorf("UsernameResponse() = %v, wantJSON %v", gotJSON, tt.wantJSON)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Application services are allowed usernames starting with an underscore
|
||||||
|
if tt.wantErr == ErrUsernameUnderscore {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
gotErr = ValidateApplicationServiceUsername(tt.localpart, tt.domain)
|
||||||
|
if !reflect.DeepEqual(gotErr, tt.wantErr) {
|
||||||
|
t.Errorf("ValidateUsername() = %v, wantErr %v", gotErr, tt.wantErr)
|
||||||
|
}
|
||||||
|
if gotJSON := UsernameResponse(gotErr); !reflect.DeepEqual(gotJSON, tt.wantJSON) {
|
||||||
|
t.Errorf("UsernameResponse() = %v, wantJSON %v", gotJSON, tt.wantJSON)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -78,6 +78,7 @@ const (
|
||||||
type PerformJoinRequest struct {
|
type PerformJoinRequest struct {
|
||||||
RoomIDOrAlias string `json:"room_id_or_alias"`
|
RoomIDOrAlias string `json:"room_id_or_alias"`
|
||||||
UserID string `json:"user_id"`
|
UserID string `json:"user_id"`
|
||||||
|
IsGuest bool `json:"is_guest"`
|
||||||
Content map[string]interface{} `json:"content"`
|
Content map[string]interface{} `json:"content"`
|
||||||
ServerNames []gomatrixserverlib.ServerName `json:"server_names"`
|
ServerNames []gomatrixserverlib.ServerName `json:"server_names"`
|
||||||
Unsigned map[string]interface{} `json:"unsigned"`
|
Unsigned map[string]interface{} `json:"unsigned"`
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,10 @@ import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"github.com/getsentry/sentry-go"
|
"github.com/getsentry/sentry-go"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/nats-io/nats.go"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
asAPI "github.com/matrix-org/dendrite/appservice/api"
|
asAPI "github.com/matrix-org/dendrite/appservice/api"
|
||||||
fsAPI "github.com/matrix-org/dendrite/federationapi/api"
|
fsAPI "github.com/matrix-org/dendrite/federationapi/api"
|
||||||
"github.com/matrix-org/dendrite/internal/caching"
|
"github.com/matrix-org/dendrite/internal/caching"
|
||||||
|
|
@ -19,9 +23,6 @@ import (
|
||||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||||
"github.com/matrix-org/dendrite/setup/process"
|
"github.com/matrix-org/dendrite/setup/process"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
"github.com/nats-io/nats.go"
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// RoomserverInternalAPI is an implementation of api.RoomserverInternalAPI
|
// RoomserverInternalAPI is an implementation of api.RoomserverInternalAPI
|
||||||
|
|
@ -104,6 +105,11 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
|
||||||
r.fsAPI = fsAPI
|
r.fsAPI = fsAPI
|
||||||
r.KeyRing = keyRing
|
r.KeyRing = keyRing
|
||||||
|
|
||||||
|
identity, err := r.Cfg.Matrix.SigningIdentityFor(r.ServerName)
|
||||||
|
if err != nil {
|
||||||
|
logrus.Panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
r.Inputer = &input.Inputer{
|
r.Inputer = &input.Inputer{
|
||||||
Cfg: &r.Base.Cfg.RoomServer,
|
Cfg: &r.Base.Cfg.RoomServer,
|
||||||
Base: r.Base,
|
Base: r.Base,
|
||||||
|
|
@ -114,7 +120,8 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
|
||||||
JetStream: r.JetStream,
|
JetStream: r.JetStream,
|
||||||
NATSClient: r.NATSClient,
|
NATSClient: r.NATSClient,
|
||||||
Durable: nats.Durable(r.Durable),
|
Durable: nats.Durable(r.Durable),
|
||||||
ServerName: r.Cfg.Matrix.ServerName,
|
ServerName: r.ServerName,
|
||||||
|
SigningIdentity: identity,
|
||||||
FSAPI: fsAPI,
|
FSAPI: fsAPI,
|
||||||
KeyRing: keyRing,
|
KeyRing: keyRing,
|
||||||
ACLs: r.ServerACLs,
|
ACLs: r.ServerACLs,
|
||||||
|
|
@ -135,7 +142,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
|
||||||
Queryer: r.Queryer,
|
Queryer: r.Queryer,
|
||||||
}
|
}
|
||||||
r.Peeker = &perform.Peeker{
|
r.Peeker = &perform.Peeker{
|
||||||
ServerName: r.Cfg.Matrix.ServerName,
|
ServerName: r.ServerName,
|
||||||
Cfg: r.Cfg,
|
Cfg: r.Cfg,
|
||||||
DB: r.DB,
|
DB: r.DB,
|
||||||
FSAPI: r.fsAPI,
|
FSAPI: r.fsAPI,
|
||||||
|
|
@ -146,7 +153,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
|
||||||
Inputer: r.Inputer,
|
Inputer: r.Inputer,
|
||||||
}
|
}
|
||||||
r.Unpeeker = &perform.Unpeeker{
|
r.Unpeeker = &perform.Unpeeker{
|
||||||
ServerName: r.Cfg.Matrix.ServerName,
|
ServerName: r.ServerName,
|
||||||
Cfg: r.Cfg,
|
Cfg: r.Cfg,
|
||||||
DB: r.DB,
|
DB: r.DB,
|
||||||
FSAPI: r.fsAPI,
|
FSAPI: r.fsAPI,
|
||||||
|
|
@ -193,6 +200,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
|
||||||
|
|
||||||
func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.RoomserverUserAPI) {
|
func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.RoomserverUserAPI) {
|
||||||
r.Leaver.UserAPI = userAPI
|
r.Leaver.UserAPI = userAPI
|
||||||
|
r.Inputer.UserAPI = userAPI
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) {
|
func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) {
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,8 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
|
||||||
"github.com/Arceliar/phony"
|
"github.com/Arceliar/phony"
|
||||||
"github.com/getsentry/sentry-go"
|
"github.com/getsentry/sentry-go"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
@ -79,6 +81,7 @@ type Inputer struct {
|
||||||
JetStream nats.JetStreamContext
|
JetStream nats.JetStreamContext
|
||||||
Durable nats.SubOpt
|
Durable nats.SubOpt
|
||||||
ServerName gomatrixserverlib.ServerName
|
ServerName gomatrixserverlib.ServerName
|
||||||
|
SigningIdentity *gomatrixserverlib.SigningIdentity
|
||||||
FSAPI fedapi.RoomserverFederationAPI
|
FSAPI fedapi.RoomserverFederationAPI
|
||||||
KeyRing gomatrixserverlib.JSONVerifier
|
KeyRing gomatrixserverlib.JSONVerifier
|
||||||
ACLs *acls.ServerACLs
|
ACLs *acls.ServerACLs
|
||||||
|
|
@ -87,6 +90,7 @@ type Inputer struct {
|
||||||
workers sync.Map // room ID -> *worker
|
workers sync.Map // room ID -> *worker
|
||||||
|
|
||||||
Queryer *query.Queryer
|
Queryer *query.Queryer
|
||||||
|
UserAPI userapi.RoomserverUserAPI
|
||||||
}
|
}
|
||||||
|
|
||||||
// If a room consumer is inactive for a while then we will allow NATS
|
// If a room consumer is inactive for a while then we will allow NATS
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ package input
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
@ -31,6 +32,8 @@ import (
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
userAPI "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
|
||||||
fedapi "github.com/matrix-org/dendrite/federationapi/api"
|
fedapi "github.com/matrix-org/dendrite/federationapi/api"
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||||
|
|
@ -440,6 +443,13 @@ func (r *Inputer) processRoomEvent(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If guest_access changed and is not can_join, kick all guest users.
|
||||||
|
if event.Type() == gomatrixserverlib.MRoomGuestAccess && gjson.GetBytes(event.Content(), "guest_access").Str != "can_join" {
|
||||||
|
if err = r.kickGuests(ctx, event, roomInfo); err != nil {
|
||||||
|
logrus.WithError(err).Error("failed to kick guest users on m.room.guest_access revocation")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Everything was OK — the latest events updater didn't error and
|
// Everything was OK — the latest events updater didn't error and
|
||||||
// we've sent output events. Finally, generate a hook call.
|
// we've sent output events. Finally, generate a hook call.
|
||||||
hooks.Run(hooks.KindNewEventPersisted, headered)
|
hooks.Run(hooks.KindNewEventPersisted, headered)
|
||||||
|
|
@ -729,3 +739,98 @@ func (r *Inputer) calculateAndSetState(
|
||||||
succeeded = true
|
succeeded = true
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// kickGuests kicks guests users from m.room.guest_access rooms, if guest access is now prohibited.
|
||||||
|
func (r *Inputer) kickGuests(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo) error {
|
||||||
|
membershipNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
memberEvents, err := r.DB.Events(ctx, membershipNIDs)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
inputEvents := make([]api.InputRoomEvent, 0, len(memberEvents))
|
||||||
|
latestReq := &api.QueryLatestEventsAndStateRequest{
|
||||||
|
RoomID: event.RoomID(),
|
||||||
|
}
|
||||||
|
latestRes := &api.QueryLatestEventsAndStateResponse{}
|
||||||
|
if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
prevEvents := latestRes.LatestEvents
|
||||||
|
for _, memberEvent := range memberEvents {
|
||||||
|
if memberEvent.StateKey() == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
localpart, senderDomain, err := gomatrixserverlib.SplitID('@', *memberEvent.StateKey())
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
accountRes := &userAPI.QueryAccountByLocalpartResponse{}
|
||||||
|
if err = r.UserAPI.QueryAccountByLocalpart(ctx, &userAPI.QueryAccountByLocalpartRequest{
|
||||||
|
Localpart: localpart,
|
||||||
|
ServerName: senderDomain,
|
||||||
|
}, accountRes); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if accountRes.Account == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if accountRes.Account.AccountType != userAPI.AccountTypeGuest {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var memberContent gomatrixserverlib.MemberContent
|
||||||
|
if err = json.Unmarshal(memberEvent.Content(), &memberContent); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
memberContent.Membership = gomatrixserverlib.Leave
|
||||||
|
|
||||||
|
stateKey := *memberEvent.StateKey()
|
||||||
|
fledglingEvent := &gomatrixserverlib.EventBuilder{
|
||||||
|
RoomID: event.RoomID(),
|
||||||
|
Type: gomatrixserverlib.MRoomMember,
|
||||||
|
StateKey: &stateKey,
|
||||||
|
Sender: stateKey,
|
||||||
|
PrevEvents: prevEvents,
|
||||||
|
}
|
||||||
|
|
||||||
|
if fledglingEvent.Content, err = json.Marshal(memberContent); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(fledglingEvent)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, r.SigningIdentity, time.Now(), &eventsNeeded, latestRes)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
inputEvents = append(inputEvents, api.InputRoomEvent{
|
||||||
|
Kind: api.KindNew,
|
||||||
|
Event: event,
|
||||||
|
Origin: senderDomain,
|
||||||
|
SendAsServer: string(senderDomain),
|
||||||
|
})
|
||||||
|
prevEvents = []gomatrixserverlib.EventReference{
|
||||||
|
event.EventReference(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inputReq := &api.InputRoomEventsRequest{
|
||||||
|
InputRoomEvents: inputEvents,
|
||||||
|
Asynchronous: true, // Needs to be async, as we otherwise create a deadlock
|
||||||
|
}
|
||||||
|
inputRes := &api.InputRoomEventsResponse{}
|
||||||
|
return r.InputRoomEvents(ctx, inputReq, inputRes)
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -122,11 +122,14 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
|
||||||
ctx, req.VirtualHost, requester,
|
ctx, req.VirtualHost, requester,
|
||||||
r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100,
|
r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100,
|
||||||
)
|
)
|
||||||
if err != nil {
|
// Only return an error if we really couldn't get any events.
|
||||||
|
if err != nil && len(events) == 0 {
|
||||||
logrus.WithError(err).Errorf("gomatrixserverlib.RequestBackfill failed")
|
logrus.WithError(err).Errorf("gomatrixserverlib.RequestBackfill failed")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
logrus.WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events))
|
// If we got an error but still got events, that's fine, because a server might have returned a 404 (or something)
|
||||||
|
// but other servers could provide the missing event.
|
||||||
|
logrus.WithError(err).WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events))
|
||||||
|
|
||||||
// persist these new events - auth checks have already been done
|
// persist these new events - auth checks have already been done
|
||||||
roomNID, backfilledEventMap := persistEvents(ctx, r.DB, events)
|
roomNID, backfilledEventMap := persistEvents(ctx, r.DB, events)
|
||||||
|
|
@ -319,6 +322,7 @@ FederationHit:
|
||||||
FedClient: b.fsAPI,
|
FedClient: b.fsAPI,
|
||||||
RememberAuthEvents: false,
|
RememberAuthEvents: false,
|
||||||
Server: srv,
|
Server: srv,
|
||||||
|
Origin: b.virtualHost,
|
||||||
}
|
}
|
||||||
res, err := c.StateIDsBeforeEvent(ctx, targetEvent)
|
res, err := c.StateIDsBeforeEvent(ctx, targetEvent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -394,6 +398,7 @@ func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatr
|
||||||
FedClient: b.fsAPI,
|
FedClient: b.fsAPI,
|
||||||
RememberAuthEvents: false,
|
RememberAuthEvents: false,
|
||||||
Server: srv,
|
Server: srv,
|
||||||
|
Origin: b.virtualHost,
|
||||||
}
|
}
|
||||||
result, err := c.StateBeforeEvent(ctx, roomVer, event, eventIDs)
|
result, err := c.StateBeforeEvent(ctx, roomVer, event, eventIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ package perform
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
@ -270,6 +271,28 @@ func (r *Joiner) performJoinRoomByID(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If a guest is trying to join a room, check that the room has a m.room.guest_access event
|
||||||
|
if req.IsGuest {
|
||||||
|
var guestAccessEvent *gomatrixserverlib.HeaderedEvent
|
||||||
|
guestAccess := "forbidden"
|
||||||
|
guestAccessEvent, err = r.DB.GetStateEvent(ctx, req.RoomIDOrAlias, gomatrixserverlib.MRoomGuestAccess, "")
|
||||||
|
if (err != nil && !errors.Is(err, sql.ErrNoRows)) || guestAccessEvent == nil {
|
||||||
|
logrus.WithError(err).Warn("unable to get m.room.guest_access event, defaulting to 'forbidden'")
|
||||||
|
}
|
||||||
|
if guestAccessEvent != nil {
|
||||||
|
guestAccess = gjson.GetBytes(guestAccessEvent.Content(), "guest_access").String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Servers MUST only allow guest users to join rooms if the m.room.guest_access state event
|
||||||
|
// is present on the room and has the guest_access value can_join.
|
||||||
|
if guestAccess != "can_join" {
|
||||||
|
return "", "", &rsAPI.PerformError{
|
||||||
|
Code: rsAPI.PerformErrorNotAllowed,
|
||||||
|
Msg: "Guest access is forbidden",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// If we should do a forced federated join then do that.
|
// If we should do a forced federated join then do that.
|
||||||
var joinedVia gomatrixserverlib.ServerName
|
var joinedVia gomatrixserverlib.ServerName
|
||||||
if forceFederatedJoin {
|
if forceFederatedJoin {
|
||||||
|
|
|
||||||
|
|
@ -3,18 +3,23 @@ package roomserver_test
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
"github.com/matrix-org/dendrite/internal/httputil"
|
||||||
|
"github.com/matrix-org/dendrite/setup/base"
|
||||||
|
"github.com/matrix-org/dendrite/userapi"
|
||||||
|
|
||||||
|
userAPI "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/httputil"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver"
|
"github.com/matrix-org/dendrite/roomserver"
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/roomserver/inthttp"
|
"github.com/matrix-org/dendrite/roomserver/inthttp"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||||
"github.com/matrix-org/dendrite/setup/base"
|
|
||||||
"github.com/matrix-org/dendrite/test"
|
"github.com/matrix-org/dendrite/test"
|
||||||
"github.com/matrix-org/dendrite/test/testrig"
|
"github.com/matrix-org/dendrite/test/testrig"
|
||||||
)
|
)
|
||||||
|
|
@ -29,7 +34,28 @@ func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, s
|
||||||
return base, db, close
|
return base, db, close
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_SharedUsers(t *testing.T) {
|
func TestUsers(t *testing.T) {
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
base, close := testrig.CreateBaseDendrite(t, dbType)
|
||||||
|
defer close()
|
||||||
|
rsAPI := roomserver.NewInternalAPI(base)
|
||||||
|
// SetFederationAPI starts the room event input consumer
|
||||||
|
rsAPI.SetFederationAPI(nil, nil)
|
||||||
|
|
||||||
|
t.Run("shared users", func(t *testing.T) {
|
||||||
|
testSharedUsers(t, rsAPI)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("kick users", func(t *testing.T) {
|
||||||
|
usrAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, nil, rsAPI, nil)
|
||||||
|
rsAPI.SetUserAPI(usrAPI)
|
||||||
|
testKickUsers(t, rsAPI, usrAPI)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func testSharedUsers(t *testing.T, rsAPI api.RoomserverInternalAPI) {
|
||||||
alice := test.NewUser(t)
|
alice := test.NewUser(t)
|
||||||
bob := test.NewUser(t)
|
bob := test.NewUser(t)
|
||||||
room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat))
|
room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat))
|
||||||
|
|
@ -43,36 +69,93 @@ func Test_SharedUsers(t *testing.T) {
|
||||||
}, test.WithStateKey(bob.ID))
|
}, test.WithStateKey(bob.ID))
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
|
||||||
base, _, close := mustCreateDatabase(t, dbType)
|
|
||||||
defer close()
|
|
||||||
|
|
||||||
rsAPI := roomserver.NewInternalAPI(base)
|
// Create the room
|
||||||
// SetFederationAPI starts the room event input consumer
|
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
|
||||||
rsAPI.SetFederationAPI(nil, nil)
|
t.Errorf("failed to send events: %v", err)
|
||||||
// Create the room
|
}
|
||||||
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
|
|
||||||
t.Fatalf("failed to send events: %v", err)
|
// Query the shared users for Alice, there should only be Bob.
|
||||||
|
// This is used by the SyncAPI keychange consumer.
|
||||||
|
res := &api.QuerySharedUsersResponse{}
|
||||||
|
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID}, res); err != nil {
|
||||||
|
t.Errorf("unable to query known users: %v", err)
|
||||||
|
}
|
||||||
|
if _, ok := res.UserIDsToCount[bob.ID]; !ok {
|
||||||
|
t.Errorf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
|
||||||
|
}
|
||||||
|
// Also verify that we get the expected result when specifying OtherUserIDs.
|
||||||
|
// This is used by the SyncAPI when getting device list changes.
|
||||||
|
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID, OtherUserIDs: []string{bob.ID}}, res); err != nil {
|
||||||
|
t.Errorf("unable to query known users: %v", err)
|
||||||
|
}
|
||||||
|
if _, ok := res.UserIDsToCount[bob.ID]; !ok {
|
||||||
|
t.Errorf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testKickUsers(t *testing.T, rsAPI api.RoomserverInternalAPI, usrAPI userAPI.UserInternalAPI) {
|
||||||
|
// Create users and room; Bob is going to be the guest and kicked on revocation of guest access
|
||||||
|
alice := test.NewUser(t, test.WithAccountType(userAPI.AccountTypeUser))
|
||||||
|
bob := test.NewUser(t, test.WithAccountType(userAPI.AccountTypeGuest))
|
||||||
|
|
||||||
|
room := test.NewRoom(t, alice, test.RoomPreset(test.PresetPublicChat), test.GuestsCanJoin(true))
|
||||||
|
|
||||||
|
// Join with the guest user
|
||||||
|
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||||
|
"membership": "join",
|
||||||
|
}, test.WithStateKey(bob.ID))
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create the users in the userapi, so the RSAPI can query the account type later
|
||||||
|
for _, u := range []*test.User{alice, bob} {
|
||||||
|
localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID)
|
||||||
|
userRes := &userAPI.PerformAccountCreationResponse{}
|
||||||
|
if err := usrAPI.PerformAccountCreation(ctx, &userAPI.PerformAccountCreationRequest{
|
||||||
|
AccountType: u.AccountType,
|
||||||
|
Localpart: localpart,
|
||||||
|
ServerName: serverName,
|
||||||
|
Password: "someRandomPassword",
|
||||||
|
}, userRes); err != nil {
|
||||||
|
t.Errorf("failed to create account: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the room in the database
|
||||||
|
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
|
||||||
|
t.Errorf("failed to send events: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the membership events BEFORE revoking guest access
|
||||||
|
membershipRes := &api.QueryMembershipsForRoomResponse{}
|
||||||
|
if err := rsAPI.QueryMembershipsForRoom(ctx, &api.QueryMembershipsForRoomRequest{LocalOnly: true, JoinedOnly: true, RoomID: room.ID}, membershipRes); err != nil {
|
||||||
|
t.Errorf("failed to query membership for room: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// revoke guest access
|
||||||
|
revokeEvent := room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomGuestAccess, map[string]string{"guest_access": "forbidden"}, test.WithStateKey(""))
|
||||||
|
if err := api.SendEvents(ctx, rsAPI, api.KindNew, []*gomatrixserverlib.HeaderedEvent{revokeEvent}, "test", "test", "test", nil, false); err != nil {
|
||||||
|
t.Errorf("failed to send events: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Even though we are sending the events sync, the "kickUsers" function is sending the events async, so we need
|
||||||
|
// to loop and wait for the events to be processed by the roomserver.
|
||||||
|
for i := 0; i <= 20; i++ {
|
||||||
|
// Get the membership events AFTER revoking guest access
|
||||||
|
membershipRes2 := &api.QueryMembershipsForRoomResponse{}
|
||||||
|
if err := rsAPI.QueryMembershipsForRoom(ctx, &api.QueryMembershipsForRoomRequest{LocalOnly: true, JoinedOnly: true, RoomID: room.ID}, membershipRes2); err != nil {
|
||||||
|
t.Errorf("failed to query membership for room: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query the shared users for Alice, there should only be Bob.
|
// The membership events should NOT match, as Bob (guest user) should now be kicked from the room
|
||||||
// This is used by the SyncAPI keychange consumer.
|
if !reflect.DeepEqual(membershipRes, membershipRes2) {
|
||||||
res := &api.QuerySharedUsersResponse{}
|
return
|
||||||
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID}, res); err != nil {
|
|
||||||
t.Fatalf("unable to query known users: %v", err)
|
|
||||||
}
|
}
|
||||||
if _, ok := res.UserIDsToCount[bob.ID]; !ok {
|
time.Sleep(time.Millisecond * 10)
|
||||||
t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
|
}
|
||||||
}
|
|
||||||
// Also verify that we get the expected result when specifying OtherUserIDs.
|
t.Errorf("memberships didn't change in time")
|
||||||
// This is used by the SyncAPI when getting device list changes.
|
|
||||||
if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID, OtherUserIDs: []string{bob.ID}}, res); err != nil {
|
|
||||||
t.Fatalf("unable to query known users: %v", err)
|
|
||||||
}
|
|
||||||
if _, ok := res.UserIDsToCount[bob.ID]; !ok {
|
|
||||||
t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_QueryLeftUsers(t *testing.T) {
|
func Test_QueryLeftUsers(t *testing.T) {
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ import (
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"golang.org/x/crypto/ed25519"
|
"golang.org/x/crypto/ed25519"
|
||||||
yaml "gopkg.in/yaml.v2"
|
"gopkg.in/yaml.v2"
|
||||||
|
|
||||||
jaegerconfig "github.com/uber/jaeger-client-go/config"
|
jaegerconfig "github.com/uber/jaeger-client-go/config"
|
||||||
jaegermetrics "github.com/uber/jaeger-lib/metrics"
|
jaegermetrics "github.com/uber/jaeger-lib/metrics"
|
||||||
|
|
@ -314,11 +314,13 @@ func (config *Dendrite) Derive() error {
|
||||||
|
|
||||||
if config.ClientAPI.RecaptchaEnabled {
|
if config.ClientAPI.RecaptchaEnabled {
|
||||||
config.Derived.Registration.Params[authtypes.LoginTypeRecaptcha] = map[string]string{"public_key": config.ClientAPI.RecaptchaPublicKey}
|
config.Derived.Registration.Params[authtypes.LoginTypeRecaptcha] = map[string]string{"public_key": config.ClientAPI.RecaptchaPublicKey}
|
||||||
config.Derived.Registration.Flows = append(config.Derived.Registration.Flows,
|
config.Derived.Registration.Flows = []authtypes.Flow{
|
||||||
authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeRecaptcha}})
|
{Stages: []authtypes.LoginType{authtypes.LoginTypeRecaptcha}},
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
config.Derived.Registration.Flows = append(config.Derived.Registration.Flows,
|
config.Derived.Registration.Flows = []authtypes.Flow{
|
||||||
authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeDummy}})
|
{Stages: []authtypes.LoginType{authtypes.LoginTypeDummy}},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load application service configuration files
|
// Load application service configuration files
|
||||||
|
|
|
||||||
|
|
@ -78,9 +78,6 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
|
||||||
c.TURN.Verify(configErrs)
|
c.TURN.Verify(configErrs)
|
||||||
c.RateLimiting.Verify(configErrs)
|
c.RateLimiting.Verify(configErrs)
|
||||||
if c.RecaptchaEnabled {
|
if c.RecaptchaEnabled {
|
||||||
checkNotEmpty(configErrs, "client_api.recaptcha_public_key", c.RecaptchaPublicKey)
|
|
||||||
checkNotEmpty(configErrs, "client_api.recaptcha_private_key", c.RecaptchaPrivateKey)
|
|
||||||
checkNotEmpty(configErrs, "client_api.recaptcha_siteverify_api", c.RecaptchaSiteVerifyAPI)
|
|
||||||
if c.RecaptchaSiteVerifyAPI == "" {
|
if c.RecaptchaSiteVerifyAPI == "" {
|
||||||
c.RecaptchaSiteVerifyAPI = "https://www.google.com/recaptcha/api/siteverify"
|
c.RecaptchaSiteVerifyAPI = "https://www.google.com/recaptcha/api/siteverify"
|
||||||
}
|
}
|
||||||
|
|
@ -93,6 +90,10 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
|
||||||
if c.RecaptchaSitekeyClass == "" {
|
if c.RecaptchaSitekeyClass == "" {
|
||||||
c.RecaptchaSitekeyClass = "g-recaptcha-response"
|
c.RecaptchaSitekeyClass = "g-recaptcha-response"
|
||||||
}
|
}
|
||||||
|
checkNotEmpty(configErrs, "client_api.recaptcha_public_key", c.RecaptchaPublicKey)
|
||||||
|
checkNotEmpty(configErrs, "client_api.recaptcha_private_key", c.RecaptchaPrivateKey)
|
||||||
|
checkNotEmpty(configErrs, "client_api.recaptcha_siteverify_api", c.RecaptchaSiteVerifyAPI)
|
||||||
|
checkNotEmpty(configErrs, "client_api.recaptcha_sitekey_class", c.RecaptchaSitekeyClass)
|
||||||
}
|
}
|
||||||
// Ensure there is any spam counter measure when enabling registration
|
// Ensure there is any spam counter measure when enabling registration
|
||||||
if !c.RegistrationDisabled && !c.OpenRegistrationWithoutVerificationEnabled {
|
if !c.RegistrationDisabled && !c.OpenRegistrationWithoutVerificationEnabled {
|
||||||
|
|
|
||||||
|
|
@ -174,7 +174,7 @@ func (c *Global) SigningIdentityFor(serverName gomatrixserverlib.ServerName) (*g
|
||||||
return id, nil
|
return id, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("no signing identity %q", serverName)
|
return nil, fmt.Errorf("no signing identity for %q", serverName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Global) SigningIdentities() []*gomatrixserverlib.SigningIdentity {
|
func (c *Global) SigningIdentities() []*gomatrixserverlib.SigningIdentity {
|
||||||
|
|
|
||||||
|
|
@ -16,8 +16,10 @@ package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"gopkg.in/yaml.v2"
|
"gopkg.in/yaml.v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -290,3 +292,55 @@ func TestUnmarshalDataUnit(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_SigningIdentityFor(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
virtualHosts []*VirtualHost
|
||||||
|
serverName gomatrixserverlib.ServerName
|
||||||
|
want *gomatrixserverlib.SigningIdentity
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no virtual hosts defined",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no identity found",
|
||||||
|
serverName: gomatrixserverlib.ServerName("doesnotexist"),
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "found identity",
|
||||||
|
serverName: gomatrixserverlib.ServerName("main"),
|
||||||
|
want: &gomatrixserverlib.SigningIdentity{ServerName: "main"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "identity found on virtual hosts",
|
||||||
|
serverName: gomatrixserverlib.ServerName("vh2"),
|
||||||
|
virtualHosts: []*VirtualHost{
|
||||||
|
{SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh1"}},
|
||||||
|
{SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh2"}},
|
||||||
|
},
|
||||||
|
want: &gomatrixserverlib.SigningIdentity{ServerName: "vh2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c := &Global{
|
||||||
|
VirtualHosts: tt.virtualHosts,
|
||||||
|
SigningIdentity: gomatrixserverlib.SigningIdentity{
|
||||||
|
ServerName: "main",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
got, err := c.SigningIdentityFor(tt.serverName)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("SigningIdentityFor() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("SigningIdentityFor() got = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -49,3 +49,6 @@ Leaves are present in non-gapped incremental syncs
|
||||||
|
|
||||||
# Below test was passing for the wrong reason, failing correctly since #2858
|
# Below test was passing for the wrong reason, failing correctly since #2858
|
||||||
New federated private chats get full presence information (SYN-115)
|
New federated private chats get full presence information (SYN-115)
|
||||||
|
|
||||||
|
# We don't have any state to calculate m.room.guest_access when accepting invites
|
||||||
|
Guest users can accept invites to private rooms over federation
|
||||||
|
|
@ -764,3 +764,6 @@ local user has tags copied to the new room
|
||||||
remote user has tags copied to the new room
|
remote user has tags copied to the new room
|
||||||
/upgrade moves remote aliases to the new room
|
/upgrade moves remote aliases to the new room
|
||||||
Local and remote users' homeservers remove a room from their public directory on upgrade
|
Local and remote users' homeservers remove a room from their public directory on upgrade
|
||||||
|
Guest users denied access over federation if guest access prohibited
|
||||||
|
Guest users are kicked from guest_access rooms on revocation of guest_access
|
||||||
|
Guest users are kicked from guest_access rooms on revocation of guest_access over federation
|
||||||
22
test/room.go
22
test/room.go
|
|
@ -38,11 +38,12 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Room struct {
|
type Room struct {
|
||||||
ID string
|
ID string
|
||||||
Version gomatrixserverlib.RoomVersion
|
Version gomatrixserverlib.RoomVersion
|
||||||
preset Preset
|
preset Preset
|
||||||
visibility gomatrixserverlib.HistoryVisibility
|
guestCanJoin bool
|
||||||
creator *User
|
visibility gomatrixserverlib.HistoryVisibility
|
||||||
|
creator *User
|
||||||
|
|
||||||
authEvents gomatrixserverlib.AuthEvents
|
authEvents gomatrixserverlib.AuthEvents
|
||||||
currentState map[string]*gomatrixserverlib.HeaderedEvent
|
currentState map[string]*gomatrixserverlib.HeaderedEvent
|
||||||
|
|
@ -120,6 +121,11 @@ func (r *Room) insertCreateEvents(t *testing.T) {
|
||||||
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomPowerLevels, plContent, WithStateKey(""))
|
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomPowerLevels, plContent, WithStateKey(""))
|
||||||
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomJoinRules, joinRule, WithStateKey(""))
|
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomJoinRules, joinRule, WithStateKey(""))
|
||||||
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomHistoryVisibility, hisVis, WithStateKey(""))
|
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomHistoryVisibility, hisVis, WithStateKey(""))
|
||||||
|
if r.guestCanJoin {
|
||||||
|
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomGuestAccess, map[string]string{
|
||||||
|
"guest_access": "can_join",
|
||||||
|
}, WithStateKey(""))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create an event in this room but do not insert it. Does not modify the room in any way (depth, fwd extremities, etc) so is thread-safe.
|
// Create an event in this room but do not insert it. Does not modify the room in any way (depth, fwd extremities, etc) so is thread-safe.
|
||||||
|
|
@ -268,3 +274,9 @@ func RoomVersion(ver gomatrixserverlib.RoomVersion) roomModifier {
|
||||||
r.Version = ver
|
r.Version = ver
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GuestsCanJoin(canJoin bool) roomModifier {
|
||||||
|
return func(t *testing.T, r *Room) {
|
||||||
|
r.guestCanJoin = canJoin
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,7 @@ type KeyserverUserAPI interface {
|
||||||
|
|
||||||
type RoomserverUserAPI interface {
|
type RoomserverUserAPI interface {
|
||||||
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
||||||
|
QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// api functions required by the media api
|
// api functions required by the media api
|
||||||
|
|
@ -671,3 +672,12 @@ type PerformSaveThreePIDAssociationRequest struct {
|
||||||
ServerName gomatrixserverlib.ServerName
|
ServerName gomatrixserverlib.ServerName
|
||||||
Medium string
|
Medium string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type QueryAccountByLocalpartRequest struct {
|
||||||
|
Localpart string
|
||||||
|
ServerName gomatrixserverlib.ServerName
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueryAccountByLocalpartResponse struct {
|
||||||
|
Account *Account
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -204,6 +204,12 @@ func (t *UserInternalAPITrace) PerformSaveThreePIDAssociation(ctx context.Contex
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *UserInternalAPITrace) QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) error {
|
||||||
|
err := t.Impl.QueryAccountByLocalpart(ctx, req, res)
|
||||||
|
util.GetLogger(ctx).Infof("QueryAccountByLocalpart req=%+v res=%+v", js(req), js(res))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func js(thing interface{}) string {
|
func js(thing interface{}) string {
|
||||||
b, err := json.Marshal(thing)
|
b, err := json.Marshal(thing)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -81,11 +81,6 @@ func Test_evaluatePushRules(t *testing.T) {
|
||||||
wantAction: pushrules.NotifyAction,
|
wantAction: pushrules.NotifyAction,
|
||||||
wantActions: []*pushrules.Action{
|
wantActions: []*pushrules.Action{
|
||||||
{Kind: pushrules.NotifyAction},
|
{Kind: pushrules.NotifyAction},
|
||||||
{
|
|
||||||
Kind: pushrules.SetTweakAction,
|
|
||||||
Tweak: pushrules.HighlightTweak,
|
|
||||||
Value: false,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
@ -103,7 +98,6 @@ func Test_evaluatePushRules(t *testing.T) {
|
||||||
{
|
{
|
||||||
Kind: pushrules.SetTweakAction,
|
Kind: pushrules.SetTweakAction,
|
||||||
Tweak: pushrules.HighlightTweak,
|
Tweak: pushrules.HighlightTweak,
|
||||||
Value: true,
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -548,6 +548,11 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *UserInternalAPI) QueryAccountByLocalpart(ctx context.Context, req *api.QueryAccountByLocalpartRequest, res *api.QueryAccountByLocalpartResponse) (err error) {
|
||||||
|
res.Account, err = a.DB.GetAccountByLocalpart(ctx, req.Localpart, req.ServerName)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Return the appservice 'device' or nil if the token is not an appservice. Returns an error if there was a problem
|
// Return the appservice 'device' or nil if the token is not an appservice. Returns an error if there was a problem
|
||||||
// creating a 'device'.
|
// creating a 'device'.
|
||||||
func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appServiceUserID string) (*api.Device, error) {
|
func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appServiceUserID string) (*api.Device, error) {
|
||||||
|
|
|
||||||
|
|
@ -60,6 +60,7 @@ const (
|
||||||
QueryAccountByPasswordPath = "/userapi/queryAccountByPassword"
|
QueryAccountByPasswordPath = "/userapi/queryAccountByPassword"
|
||||||
QueryLocalpartForThreePIDPath = "/userapi/queryLocalpartForThreePID"
|
QueryLocalpartForThreePIDPath = "/userapi/queryLocalpartForThreePID"
|
||||||
QueryThreePIDsForLocalpartPath = "/userapi/queryThreePIDsForLocalpart"
|
QueryThreePIDsForLocalpartPath = "/userapi/queryThreePIDsForLocalpart"
|
||||||
|
QueryAccountByLocalpartPath = "/userapi/queryAccountType"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
|
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
|
||||||
|
|
@ -440,3 +441,14 @@ func (h *httpUserInternalAPI) PerformSaveThreePIDAssociation(
|
||||||
h.httpClient, ctx, request, response,
|
h.httpClient, ctx, request, response,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *httpUserInternalAPI) QueryAccountByLocalpart(
|
||||||
|
ctx context.Context,
|
||||||
|
req *api.QueryAccountByLocalpartRequest,
|
||||||
|
res *api.QueryAccountByLocalpartResponse,
|
||||||
|
) error {
|
||||||
|
return httputil.CallInternalRPCAPI(
|
||||||
|
"QueryAccountByLocalpart", h.apiURL+QueryAccountByLocalpartPath,
|
||||||
|
h.httpClient, ctx, req, res,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -189,4 +189,9 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI, enableMetrics
|
||||||
PerformSaveThreePIDAssociationPath,
|
PerformSaveThreePIDAssociationPath,
|
||||||
httputil.MakeInternalRPCAPI("UserAPIPerformSaveThreePIDAssociation", enableMetrics, s.PerformSaveThreePIDAssociation),
|
httputil.MakeInternalRPCAPI("UserAPIPerformSaveThreePIDAssociation", enableMetrics, s.PerformSaveThreePIDAssociation),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
internalAPIMux.Handle(
|
||||||
|
QueryAccountByLocalpartPath,
|
||||||
|
httputil.MakeInternalRPCAPI("AccountByLocalpart", enableMetrics, s.QueryAccountByLocalpart),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -307,3 +307,64 @@ func TestLoginToken(t *testing.T) {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQueryAccountByLocalpart(t *testing.T) {
|
||||||
|
alice := test.NewUser(t)
|
||||||
|
|
||||||
|
localpart, userServername, _ := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
intAPI, db, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
|
||||||
|
defer close()
|
||||||
|
|
||||||
|
createdAcc, err := db.CreateAccount(ctx, localpart, userServername, "", "", alice.AccountType)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := func(t *testing.T, internalAPI api.UserInternalAPI) {
|
||||||
|
// Query existing account
|
||||||
|
queryAccResp := &api.QueryAccountByLocalpartResponse{}
|
||||||
|
if err = internalAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{
|
||||||
|
Localpart: localpart,
|
||||||
|
ServerName: userServername,
|
||||||
|
}, queryAccResp); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(createdAcc, queryAccResp.Account) {
|
||||||
|
t.Fatalf("created and queried accounts don't match:\n%+v vs.\n%+v", createdAcc, queryAccResp.Account)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query non-existent account, this should result in an error
|
||||||
|
err = internalAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{
|
||||||
|
Localpart: "doesnotexist",
|
||||||
|
ServerName: userServername,
|
||||||
|
}, queryAccResp)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected an error, but got none: %+v", queryAccResp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("Monolith", func(t *testing.T) {
|
||||||
|
testCases(t, intAPI)
|
||||||
|
// also test tracing
|
||||||
|
testCases(t, &api.UserInternalAPITrace{Impl: intAPI})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("HTTP API", func(t *testing.T) {
|
||||||
|
router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter()
|
||||||
|
userapi.AddInternalRoutes(router, intAPI, false)
|
||||||
|
apiURL, cancel := test.ListenAndServe(t, router, false)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
userHTTPApi, err := inthttp.NewUserAPIClient(apiURL, &http.Client{Timeout: time.Second * 5})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create HTTP client: %s", err)
|
||||||
|
}
|
||||||
|
testCases(t, userHTTPApi)
|
||||||
|
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue