diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 8448d8e23..8d3a8d674 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -32,10 +32,6 @@ jobs: if: github.event_name == 'release' # Only for GitHub releases run: | echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV - echo "BUILD=$(git rev-parse --short HEAD || \"\")" >> $GITHUB_ENV - BRANCH=$(git symbolic-ref --short HEAD | tr -d \/) - [ ${BRANCH} == "main" ] && BRANCH="" - echo "BRANCH=${BRANCH}" >> $GITHUB_ENV - name: Set up QEMU uses: docker/setup-qemu-action@v1 - name: Set up Docker Buildx @@ -60,7 +56,6 @@ jobs: cache-from: type=registry,ref=ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-monolith:buildcache cache-to: type=registry,ref=ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-monolith:buildcache,mode=max context: . - build-args: FLAGS=-X github.com/matrix-org/dendrite/internal.branch=${{ env.BRANCH }} -X github.com/matrix-org/dendrite/internal.build=${{ env.BUILD }} platforms: ${{ env.PLATFORMS }} push: true tags: | @@ -75,7 +70,6 @@ jobs: cache-from: type=gha cache-to: type=gha,mode=max context: . - build-args: FLAGS=-X github.com/matrix-org/dendrite/internal.branch=${{ env.BRANCH }} -X github.com/matrix-org/dendrite/internal.build=${{ env.BUILD }} platforms: ${{ env.PLATFORMS }} push: true tags: | @@ -109,10 +103,6 @@ jobs: if: github.event_name == 'release' # Only for GitHub releases run: | echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV - echo "BUILD=$(git rev-parse --short HEAD || \"\")" >> $GITHUB_ENV - BRANCH=$(git symbolic-ref --short HEAD | tr -d \/) - [ ${BRANCH} == "main" ] && BRANCH="" - echo "BRANCH=${BRANCH}" >> $GITHUB_ENV - name: Set up QEMU uses: docker/setup-qemu-action@v1 - name: Set up Docker Buildx @@ -137,7 +127,6 @@ jobs: cache-from: type=gha cache-to: type=gha,mode=max context: . - build-args: FLAGS=-X github.com/matrix-org/dendrite/internal.branch=${{ env.BRANCH }} -X github.com/matrix-org/dendrite/internal.build=${{ env.BUILD }} file: ./build/docker/Dockerfile.demo-pinecone platforms: ${{ env.PLATFORMS }} push: true @@ -153,7 +142,6 @@ jobs: cache-from: type=gha cache-to: type=gha,mode=max context: . - build-args: FLAGS=-X github.com/matrix-org/dendrite/internal.branch=${{ env.BRANCH }} -X github.com/matrix-org/dendrite/internal.build=${{ env.BUILD }} file: ./build/docker/Dockerfile.demo-pinecone platforms: ${{ env.PLATFORMS }} push: true @@ -176,10 +164,6 @@ jobs: if: github.event_name == 'release' # Only for GitHub releases run: | echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV - echo "BUILD=$(git rev-parse --short HEAD || \"\")" >> $GITHUB_ENV - BRANCH=$(git symbolic-ref --short HEAD | tr -d \/) - [ ${BRANCH} == "main" ] && BRANCH="" - echo "BRANCH=${BRANCH}" >> $GITHUB_ENV - name: Set up QEMU uses: docker/setup-qemu-action@v1 - name: Set up Docker Buildx @@ -204,7 +188,6 @@ jobs: cache-from: type=gha cache-to: type=gha,mode=max context: . - build-args: FLAGS=-X github.com/matrix-org/dendrite/internal.branch=${{ env.BRANCH }} -X github.com/matrix-org/dendrite/internal.build=${{ env.BUILD }} file: ./build/docker/Dockerfile.demo-yggdrasil platforms: ${{ env.PLATFORMS }} push: true @@ -220,7 +203,6 @@ jobs: cache-from: type=gha cache-to: type=gha,mode=max context: . - build-args: FLAGS=-X github.com/matrix-org/dendrite/internal.branch=${{ env.BRANCH }} -X github.com/matrix-org/dendrite/internal.build=${{ env.BUILD }} file: ./build/docker/Dockerfile.demo-yggdrasil platforms: ${{ env.PLATFORMS }} push: true diff --git a/.golangci.yml b/.golangci.yml index bb8d38a8b..5bee0a885 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -180,7 +180,6 @@ linters-settings: linters: enable: - errcheck - - goconst - gocyclo - goimports # Does everything gofmt does - gosimple @@ -211,6 +210,7 @@ linters: - stylecheck - typecheck # Should turn back on soon - unconvert # Should turn back on soon + - goconst # Slightly annoying, as it reports "issues" in SQL statements disable-all: false presets: fast: false diff --git a/CHANGES.md b/CHANGES.md index bdb6a796e..57e3a3d4f 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,23 @@ # Changelog +## Dendrite 0.13.4 (2023-10-25) + +Upgrading to this version is **highly** recommended, as it fixes a long-standing bug in the state resolution +algorithm. + +### Fixes: + +- The "device list updater" now de-duplicates the servers to fetch devices from on startup. (This also + avoids spamming the logs when shutting down.) +- A bug in the state resolution algorithm has been fixed. This bug could result in users "being reset" + out of rooms and other missing state events due to calculating the wrong state. +- A bug when setting notifications from Element Android has been fixed by implementing MSC3987 + +### Features + +- Updated dependencies + - Internal NATS Server has been updated from v2.9.19 to v2.9.23 + ## Dendrite 0.13.3 (2023-09-28) ### Fixes: diff --git a/Dockerfile b/Dockerfile index 4ee20933a..8c8f1588f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,7 +3,7 @@ # # base installs required dependencies and runs go mod download to cache dependencies # -FROM --platform=${BUILDPLATFORM} docker.io/golang:1.20-alpine AS base +FROM --platform=${BUILDPLATFORM} docker.io/golang:1.21-alpine AS base RUN apk --update --no-cache add bash build-base curl git # @@ -13,7 +13,6 @@ FROM --platform=${BUILDPLATFORM} base AS build WORKDIR /src ARG TARGETOS ARG TARGETARCH -ARG FLAGS RUN --mount=target=. \ --mount=type=cache,target=/root/.cache/go-build \ --mount=type=cache,target=/go/pkg/mod \ @@ -21,7 +20,7 @@ RUN --mount=target=. \ GOARCH="$TARGETARCH" \ GOOS="linux" \ CGO_ENABLED=$([ "$TARGETARCH" = "$USERARCH" ] && echo "1" || echo "0") \ - go build -v -ldflags="${FLAGS}" -trimpath -o /out/ ./cmd/... + go build -v -trimpath -o /out/ ./cmd/... # diff --git a/appservice/appservice_test.go b/appservice/appservice_test.go index ddc24477b..eca63371d 100644 --- a/appservice/appservice_test.go +++ b/appservice/appservice_test.go @@ -14,6 +14,7 @@ import ( "testing" "time" + "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/stretchr/testify/assert" "github.com/matrix-org/dendrite/appservice" @@ -32,6 +33,10 @@ import ( "github.com/matrix-org/dendrite/test/testrig" ) +var testIsBlacklistedOrBackingOff = func(s spec.ServerName) (*statistics.ServerStatistics, error) { + return &statistics.ServerStatistics{}, nil +} + func TestAppserviceInternalAPI(t *testing.T) { // Set expected results @@ -144,7 +149,7 @@ func TestAppserviceInternalAPI(t *testing.T) { cm := sqlutil.NewConnectionManager(ctx, cfg.Global.DatabaseOptions) rsAPI := roomserver.NewInternalAPI(ctx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - usrAPI := userapi.NewInternalAPI(ctx, cfg, cm, &natsInstance, rsAPI, nil) + usrAPI := userapi.NewInternalAPI(ctx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) asAPI := appservice.NewInternalAPI(ctx, cfg, &natsInstance, usrAPI, rsAPI) runCases(t, asAPI) @@ -239,7 +244,7 @@ func TestAppserviceInternalAPI_UnixSocket_Simple(t *testing.T) { cm := sqlutil.NewConnectionManager(ctx, cfg.Global.DatabaseOptions) rsAPI := roomserver.NewInternalAPI(ctx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - usrAPI := userapi.NewInternalAPI(ctx, cfg, cm, &natsInstance, rsAPI, nil) + usrAPI := userapi.NewInternalAPI(ctx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) asAPI := appservice.NewInternalAPI(ctx, cfg, &natsInstance, usrAPI, rsAPI) t.Run("UserIDExists", func(t *testing.T) { @@ -378,7 +383,7 @@ func TestRoomserverConsumerOneInvite(t *testing.T) { // Create required internal APIs rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - usrAPI := userapi.NewInternalAPI(processCtx, cfg, cm, natsInstance, rsAPI, nil) + usrAPI := userapi.NewInternalAPI(processCtx, cfg, cm, natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) // start the consumer appservice.NewInternalAPI(processCtx, cfg, natsInstance, usrAPI, rsAPI) diff --git a/build/dendritejs-pinecone/main.go b/build/dendritejs-pinecone/main.go index 61baed902..6acc93c7b 100644 --- a/build/dendritejs-pinecone/main.go +++ b/build/dendritejs-pinecone/main.go @@ -38,6 +38,7 @@ import ( "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/userapi" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib" @@ -190,13 +191,13 @@ func startup() { serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, federation) + fedSenderAPI := federationapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, federation, rsAPI, caches, keyRing, true) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, federation, caching.EnableMetrics, fedSenderAPI.IsBlacklistedOrBackingOff) asQuery := appservice.NewInternalAPI( processCtx, cfg, &natsInstance, userAPI, rsAPI, ) rsAPI.SetAppserviceAPI(asQuery) - fedSenderAPI := federationapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, federation, rsAPI, caches, keyRing, true) rsAPI.SetFederationAPI(fedSenderAPI, keyRing) monolith := setup.Monolith{ diff --git a/build/docker/Dockerfile.demo-yggdrasil b/build/docker/Dockerfile.demo-yggdrasil index efae5496c..b9e387666 100644 --- a/build/docker/Dockerfile.demo-yggdrasil +++ b/build/docker/Dockerfile.demo-yggdrasil @@ -1,4 +1,4 @@ -FROM docker.io/golang:1.19-alpine AS base +FROM docker.io/golang:1.21-alpine AS base # # Needs to be separate from the main Dockerfile for OpenShift, diff --git a/build/gobind-yggdrasil/monolith.go b/build/gobind-yggdrasil/monolith.go index 720ce37eb..2b227d373 100644 --- a/build/gobind-yggdrasil/monolith.go +++ b/build/gobind-yggdrasil/monolith.go @@ -216,7 +216,7 @@ func (m *DendriteMonolith) Start() { processCtx, cfg, cm, &natsInstance, federation, rsAPI, caches, keyRing, true, ) - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, federation) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, federation, caching.EnableMetrics, fsAPI.IsBlacklistedOrBackingOff) asAPI := appservice.NewInternalAPI(processCtx, cfg, &natsInstance, userAPI, rsAPI) rsAPI.SetAppserviceAPI(asAPI) diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go index 66667b03c..f0e5f004d 100644 --- a/clientapi/admin_test.go +++ b/clientapi/admin_test.go @@ -45,7 +45,7 @@ func TestAdminCreateToken(t *testing.T) { caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ aliceAdmin: {}, @@ -196,7 +196,7 @@ func TestAdminListRegistrationTokens(t *testing.T) { caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ aliceAdmin: {}, @@ -314,7 +314,7 @@ func TestAdminGetRegistrationToken(t *testing.T) { caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ aliceAdmin: {}, @@ -415,7 +415,7 @@ func TestAdminDeleteRegistrationToken(t *testing.T) { caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ aliceAdmin: {}, @@ -509,7 +509,7 @@ func TestAdminUpdateRegistrationToken(t *testing.T) { caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) accessTokens := map[*test.User]userDevice{ aliceAdmin: {}, @@ -693,7 +693,7 @@ func TestAdminResetPassword(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) // Needed for changing the password/login - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) // We mostly need the userAPI for this test, so nil for other APIs/caches etc. AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) @@ -791,7 +791,7 @@ func TestPurgeRoom(t *testing.T) { fsAPI := federationapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, rsAPI, caches, nil, true) rsAPI.SetFederationAPI(fsAPI, nil) - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, userAPI, rsAPI, caches, caching.DisableMetrics) // Create the room @@ -863,7 +863,7 @@ func TestAdminEvacuateRoom(t *testing.T) { fsAPI := federationapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, rsAPI, caches, nil, true) rsAPI.SetFederationAPI(fsAPI, nil) - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) // Create the room if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", api.DoNotSendToOtherServers, nil, false); err != nil { @@ -964,7 +964,7 @@ func TestAdminEvacuateUser(t *testing.T) { fsAPI := federationapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, basepkg.CreateFederationClient(cfg, nil), rsAPI, caches, nil, true) rsAPI.SetFederationAPI(fsAPI, nil) - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) // Create the room if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", api.DoNotSendToOtherServers, nil, false); err != nil { @@ -1055,7 +1055,7 @@ func TestAdminMarkAsStale(t *testing.T) { cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) diff --git a/clientapi/clientapi_test.go b/clientapi/clientapi_test.go index 82ec9fea2..2ff4b6503 100644 --- a/clientapi/clientapi_test.go +++ b/clientapi/clientapi_test.go @@ -17,6 +17,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/routing" "github.com/matrix-org/dendrite/clientapi/threepid" + "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/pushrules" @@ -49,6 +50,10 @@ type userDevice struct { password string } +var testIsBlacklistedOrBackingOff = func(s spec.ServerName) (*statistics.ServerStatistics, error) { + return &statistics.ServerStatistics{}, nil +} + func TestGetPutDevices(t *testing.T) { alice := test.NewUser(t) bob := test.NewUser(t) @@ -121,7 +126,7 @@ func TestGetPutDevices(t *testing.T) { cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) @@ -170,7 +175,7 @@ func TestDeleteDevice(t *testing.T) { caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) // We mostly need the rsAPI/ for this test, so nil for other APIs/caches etc. AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) @@ -275,7 +280,7 @@ func TestDeleteDevices(t *testing.T) { caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) // We mostly need the rsAPI/ for this test, so nil for other APIs/caches etc. AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) @@ -442,7 +447,7 @@ func TestSetDisplayname(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, natsInstance, rsAPI, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) asPI := appservice.NewInternalAPI(processCtx, cfg, natsInstance, userAPI, rsAPI) AddPublicRoutes(processCtx, routers, cfg, natsInstance, base.CreateFederationClient(cfg, nil), rsAPI, asPI, nil, nil, userAPI, nil, nil, caching.DisableMetrics) @@ -554,7 +559,7 @@ func TestSetAvatarURL(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, natsInstance, rsAPI, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) asPI := appservice.NewInternalAPI(processCtx, cfg, natsInstance, userAPI, rsAPI) AddPublicRoutes(processCtx, routers, cfg, natsInstance, base.CreateFederationClient(cfg, nil), rsAPI, asPI, nil, nil, userAPI, nil, nil, caching.DisableMetrics) @@ -632,7 +637,7 @@ func TestTyping(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) // Needed to create accounts - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) // We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc. AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) @@ -716,7 +721,7 @@ func TestMembership(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) // Needed to create accounts - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) rsAPI.SetUserAPI(userAPI) // We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc. AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) @@ -955,7 +960,7 @@ func TestCapabilities(t *testing.T) { // Needed to create accounts rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) // We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc. AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) @@ -1002,7 +1007,7 @@ func TestTurnserver(t *testing.T) { // Needed to create accounts rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) //rsAPI.SetUserAPI(userAPI) // We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc. AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) @@ -1100,7 +1105,7 @@ func Test3PID(t *testing.T) { // Needed to create accounts rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) // We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc. AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) @@ -1276,7 +1281,7 @@ func TestPushRules(t *testing.T) { cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) @@ -1418,7 +1423,7 @@ func TestPushRules(t *testing.T) { validateFunc: func(t *testing.T, respBody *bytes.Buffer) { actions := gjson.GetBytes(respBody.Bytes(), "actions").Array() // only a basic check - assert.Equal(t, 1, len(actions)) + assert.Equal(t, 0, len(actions)) }, }, { @@ -1663,7 +1668,7 @@ func TestKeys(t *testing.T) { cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) @@ -2125,7 +2130,7 @@ func TestKeyBackup(t *testing.T) { cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics) diff --git a/clientapi/routing/joinroom_test.go b/clientapi/routing/joinroom_test.go index 933ea8d3a..bd854efa8 100644 --- a/clientapi/routing/joinroom_test.go +++ b/clientapi/routing/joinroom_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/jetstream" @@ -21,6 +22,10 @@ import ( uapi "github.com/matrix-org/dendrite/userapi/api" ) +var testIsBlacklistedOrBackingOff = func(s spec.ServerName) (*statistics.ServerStatistics, error) { + return &statistics.ServerStatistics{}, nil +} + func TestJoinRoomByIDOrAlias(t *testing.T) { alice := test.NewUser(t) bob := test.NewUser(t) @@ -36,7 +41,7 @@ func TestJoinRoomByIDOrAlias(t *testing.T) { natsInstance := jetstream.NATSInstance{} rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) // creates the rs.Inputer etc - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) asAPI := appservice.NewInternalAPI(processCtx, cfg, &natsInstance, userAPI, rsAPI) // Create the users in the userapi diff --git a/clientapi/routing/login_test.go b/clientapi/routing/login_test.go index a88522af2..e38ba5ba0 100644 --- a/clientapi/routing/login_test.go +++ b/clientapi/routing/login_test.go @@ -49,7 +49,7 @@ func TestLogin(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) // Needed for /login - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) // We mostly need the userAPI for this test, so nil for other APIs/caches etc. Setup(routers, cfg, nil, nil, userAPI, nil, nil, nil, nil, nil, nil, nil, caching.DisableMetrics) diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 8b8cc47bc..06683c47d 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -181,18 +181,6 @@ func SendKick( return *errRes } - pl, errRes := getPowerlevels(req, rsAPI, roomID) - if errRes != nil { - return *errRes - } - allowedToKick := pl.UserLevel(*senderID) >= pl.Kick - if !allowedToKick { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: spec.Forbidden("You don't have permission to kick this user, power level too low."), - } - } - bodyUserID, err := spec.NewUserID(body.UserID, true) if err != nil { return util.JSONResponse{ @@ -200,6 +188,19 @@ func SendKick( JSON: spec.BadJSON("body userID is invalid"), } } + + pl, errRes := getPowerlevels(req, rsAPI, roomID) + if errRes != nil { + return *errRes + } + allowedToKick := pl.UserLevel(*senderID) >= pl.Kick || bodyUserID.String() == deviceUserID.String() + if !allowedToKick { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You don't have permission to kick this user, power level too low."), + } + } + var queryRes roomserverAPI.QueryMembershipForUserResponse err = rsAPI.QueryMembershipForUser(req.Context(), &roomserverAPI.QueryMembershipForUserRequest{ RoomID: roomID, diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index 564cd588a..9959144c8 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -300,7 +300,7 @@ func updateProfile( }, e } - if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, device.UserDomain(), domain, domain, nil, true); err != nil { + if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, device.UserDomain(), domain, domain, nil, false); err != nil { util.GetLogger(ctx).WithError(err).Error("SendEvents failed") return util.JSONResponse{ Code: http.StatusInternalServerError, diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 22f0aa0a6..558418a6f 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -236,7 +236,7 @@ type authDict struct { // TODO: Lots of custom keys depending on the type } -// http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#user-interactive-authentication-api +// https://spec.matrix.org/v1.7/client-server-api/#user-interactive-authentication-api type userInteractiveResponse struct { Flows []authtypes.Flow `json:"flows"` Completed []authtypes.LoginType `json:"completed"` @@ -256,7 +256,7 @@ func newUserInteractiveResponse( } } -// http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#post-matrix-client-unstable-register +// https://spec.matrix.org/v1.7/client-server-api/#post_matrixclientv3register type registerResponse struct { UserID string `json:"user_id"` AccessToken string `json:"access_token,omitempty"` @@ -462,7 +462,7 @@ func validateApplicationService( } // Register processes a /register request. -// http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#post-matrix-client-unstable-register +// https://spec.matrix.org/v1.7/client-server-api/#post_matrixclientv3register func Register( req *http.Request, userAPI userapi.ClientUserAPI, diff --git a/clientapi/routing/register_test.go b/clientapi/routing/register_test.go index 0a1986cf7..7fa740e7f 100644 --- a/clientapi/routing/register_test.go +++ b/clientapi/routing/register_test.go @@ -298,25 +298,29 @@ func Test_register(t *testing.T) { guestsDisabled bool enableRecaptcha bool captchaBody string - wantResponse util.JSONResponse + // in case of an error, the expected response + wantErrorResponse util.JSONResponse + // in case of success, the expected username assigned + wantUsername string }{ { name: "disallow guests", kind: "guest", guestsDisabled: true, - wantResponse: util.JSONResponse{ + wantErrorResponse: util.JSONResponse{ Code: http.StatusForbidden, JSON: spec.Forbidden(`Guest registration is disabled on "test"`), }, }, { - name: "allow guests", - kind: "guest", + name: "allow guests", + kind: "guest", + wantUsername: "1", }, { name: "unknown login type", loginType: "im.not.known", - wantResponse: util.JSONResponse{ + wantErrorResponse: util.JSONResponse{ Code: http.StatusNotImplemented, JSON: spec.Unknown("unknown/unimplemented auth type"), }, @@ -324,25 +328,33 @@ func Test_register(t *testing.T) { { name: "disabled registration", registrationDisabled: true, - wantResponse: util.JSONResponse{ + wantErrorResponse: util.JSONResponse{ Code: http.StatusForbidden, JSON: spec.Forbidden(`Registration is disabled on "test"`), }, }, { - name: "successful registration, numeric ID", - username: "", - password: "someRandomPassword", - forceEmpty: true, + name: "successful registration, numeric ID", + username: "", + password: "someRandomPassword", + forceEmpty: true, + wantUsername: "2", }, { name: "successful registration", username: "success", }, + { + name: "successful registration, sequential numeric ID", + username: "", + password: "someRandomPassword", + forceEmpty: true, + wantUsername: "3", + }, { name: "failing registration - user already exists", username: "success", - wantResponse: util.JSONResponse{ + wantErrorResponse: util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.UserInUse("Desired user ID is already taken."), }, @@ -352,14 +364,14 @@ func Test_register(t *testing.T) { username: "LOWERCASED", // this is going to be lower-cased }, { - name: "invalid username", - username: "#totalyNotValid", - wantResponse: *internal.UsernameResponse(internal.ErrUsernameInvalid), + name: "invalid username", + username: "#totalyNotValid", + wantErrorResponse: *internal.UsernameResponse(internal.ErrUsernameInvalid), }, { name: "numeric username is forbidden", username: "1337", - wantResponse: util.JSONResponse{ + wantErrorResponse: util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.InvalidUsername("Numeric user IDs are reserved"), }, @@ -367,7 +379,7 @@ func Test_register(t *testing.T) { { name: "disabled recaptcha login", loginType: authtypes.LoginTypeRecaptcha, - wantResponse: util.JSONResponse{ + wantErrorResponse: util.JSONResponse{ Code: http.StatusForbidden, JSON: spec.Unknown(ErrCaptchaDisabled.Error()), }, @@ -376,7 +388,7 @@ func Test_register(t *testing.T) { name: "enabled recaptcha, no response defined", enableRecaptcha: true, loginType: authtypes.LoginTypeRecaptcha, - wantResponse: util.JSONResponse{ + wantErrorResponse: util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.BadJSON(ErrMissingResponse.Error()), }, @@ -386,7 +398,7 @@ func Test_register(t *testing.T) { enableRecaptcha: true, loginType: authtypes.LoginTypeRecaptcha, captchaBody: `notvalid`, - wantResponse: util.JSONResponse{ + wantErrorResponse: util.JSONResponse{ Code: http.StatusUnauthorized, JSON: spec.BadJSON(ErrInvalidCaptcha.Error()), }, @@ -398,11 +410,11 @@ func Test_register(t *testing.T) { captchaBody: `success`, }, { - name: "captcha invalid from remote", - enableRecaptcha: true, - loginType: authtypes.LoginTypeRecaptcha, - captchaBody: `i should fail for other reasons`, - wantResponse: util.JSONResponse{Code: http.StatusInternalServerError, JSON: spec.InternalServerError{}}, + name: "captcha invalid from remote", + enableRecaptcha: true, + loginType: authtypes.LoginTypeRecaptcha, + captchaBody: `i should fail for other reasons`, + wantErrorResponse: util.JSONResponse{Code: http.StatusInternalServerError, JSON: spec.InternalServerError{}}, }, } @@ -416,7 +428,7 @@ func Test_register(t *testing.T) { cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { @@ -486,8 +498,8 @@ func Test_register(t *testing.T) { t.Fatalf("unexpected registration flows: %+v, want %+v", r.Flows, cfg.Derived.Registration.Flows) } case spec.MatrixError: - if !reflect.DeepEqual(tc.wantResponse, resp) { - t.Fatalf("(%s), unexpected response: %+v, want: %+v", tc.name, resp, tc.wantResponse) + if !reflect.DeepEqual(tc.wantErrorResponse, resp) { + t.Fatalf("(%s), unexpected response: %+v, want: %+v", tc.name, resp, tc.wantErrorResponse) } return case registerResponse: @@ -505,6 +517,13 @@ func Test_register(t *testing.T) { if r.DeviceID == "" { t.Fatalf("missing deviceID in response") } + // if an expected username is provided, assert that it is a match + if tc.wantUsername != "" { + wantUserID := strings.ToLower(fmt.Sprintf("@%s:%s", tc.wantUsername, "test")) + if wantUserID != r.UserID { + t.Fatalf("unexpected userID: %s, want %s", r.UserID, wantUserID) + } + } return default: t.Logf("Got response: %T", resp.JSON) @@ -541,44 +560,29 @@ func Test_register(t *testing.T) { resp = Register(req, userAPI, &cfg.ClientAPI) - switch resp.JSON.(type) { - case spec.InternalServerError: - if !reflect.DeepEqual(tc.wantResponse, resp) { - t.Fatalf("unexpected response: %+v, want: %+v", resp, tc.wantResponse) + switch rr := resp.JSON.(type) { + case spec.InternalServerError, spec.MatrixError, util.JSONResponse: + if !reflect.DeepEqual(tc.wantErrorResponse, resp) { + t.Fatalf("unexpected response: %+v, want: %+v", resp, tc.wantErrorResponse) } return - case spec.MatrixError: - if !reflect.DeepEqual(tc.wantResponse, resp) { - t.Fatalf("unexpected response: %+v, want: %+v", resp, tc.wantResponse) + case registerResponse: + // validate the response + if tc.wantUsername != "" { + // if an expected username is provided, assert that it is a match + wantUserID := strings.ToLower(fmt.Sprintf("@%s:%s", tc.wantUsername, "test")) + if wantUserID != rr.UserID { + t.Fatalf("unexpected userID: %s, want %s", rr.UserID, wantUserID) + } } - return - case util.JSONResponse: - if !reflect.DeepEqual(tc.wantResponse, resp) { - t.Fatalf("unexpected response: %+v, want: %+v", resp, tc.wantResponse) + if rr.DeviceID != *reg.DeviceID { + t.Fatalf("unexpected deviceID: %s, want %s", rr.DeviceID, *reg.DeviceID) } - return - } - - rr, ok := resp.JSON.(registerResponse) - if !ok { - t.Fatalf("expected a registerresponse, got %T", resp.JSON) - } - - // validate the response - if tc.forceEmpty { - // when not supplying a username, one will be generated. Given this _SHOULD_ be - // the second user, set the username accordingly - reg.Username = "2" - } - wantUserID := strings.ToLower(fmt.Sprintf("@%s:%s", reg.Username, "test")) - if wantUserID != rr.UserID { - t.Fatalf("unexpected userID: %s, want %s", rr.UserID, wantUserID) - } - if rr.DeviceID != *reg.DeviceID { - t.Fatalf("unexpected deviceID: %s, want %s", rr.DeviceID, *reg.DeviceID) - } - if rr.AccessToken == "" { - t.Fatalf("missing accessToken in response") + if rr.AccessToken == "" { + t.Fatalf("missing accessToken in response") + } + default: + t.Fatalf("expected one of internalservererror, matrixerror, jsonresponse, registerresponse, got %T", resp.JSON) } }) } @@ -596,7 +600,7 @@ func TestRegisterUserWithDisplayName(t *testing.T) { cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) deviceName, deviceID := "deviceName", "deviceID" expectedDisplayName := "DisplayName" response := completeRegistration( @@ -637,7 +641,7 @@ func TestRegisterAdminUsingSharedSecret(t *testing.T) { caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) expectedDisplayName := "rabbit" jsonStr := []byte(`{"admin":true,"mac":"24dca3bba410e43fe64b9b5c28306693bf3baa9f","nonce":"759f047f312b99ff428b21d581256f8592b8976e58bc1b543972dc6147e529a79657605b52d7becd160ff5137f3de11975684319187e06901955f79e5a6c5a79","password":"wonderland","username":"alice","displayname":"rabbit"}`) diff --git a/cmd/dendrite-demo-pinecone/monolith/monolith.go b/cmd/dendrite-demo-pinecone/monolith/monolith.go index 41af568a6..d9f44b5cc 100644 --- a/cmd/dendrite-demo-pinecone/monolith/monolith.go +++ b/cmd/dendrite-demo-pinecone/monolith/monolith.go @@ -145,7 +145,7 @@ func (p *P2PMonolith) SetupDendrite( ) rsAPI.SetFederationAPI(fsAPI, keyRing) - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, federation) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, federation, enableMetrics, fsAPI.IsBlacklistedOrBackingOff) asAPI := appservice.NewInternalAPI(processCtx, cfg, &natsInstance, userAPI, rsAPI) diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index 3ec550113..1e5186348 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -213,14 +213,15 @@ func main() { natsInstance := jetstream.NATSInstance{} rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.EnableMetrics) - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, federation) - - asAPI := appservice.NewInternalAPI(processCtx, cfg, &natsInstance, userAPI, rsAPI) - rsAPI.SetAppserviceAPI(asAPI) fsAPI := federationapi.NewInternalAPI( processCtx, cfg, cm, &natsInstance, federation, rsAPI, caches, keyRing, true, ) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, federation, caching.EnableMetrics, fsAPI.IsBlacklistedOrBackingOff) + + asAPI := appservice.NewInternalAPI(processCtx, cfg, &natsInstance, userAPI, rsAPI) + rsAPI.SetAppserviceAPI(asAPI) + rsAPI.SetFederationAPI(fsAPI, keyRing) monolith := setup.Monolith{ diff --git a/cmd/dendrite/main.go b/cmd/dendrite/main.go index f3140b4e2..5234b7504 100644 --- a/cmd/dendrite/main.go +++ b/cmd/dendrite/main.go @@ -162,7 +162,7 @@ func main() { // dependency. Other components also need updating after their dependencies are up. rsAPI.SetFederationAPI(fsAPI, keyRing) - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, federationClient) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, federationClient, caching.EnableMetrics, fsAPI.IsBlacklistedOrBackingOff) asAPI := appservice.NewInternalAPI(processCtx, cfg, &natsInstance, userAPI, rsAPI) rsAPI.SetAppserviceAPI(asAPI) diff --git a/cmd/resolve-state/main.go b/cmd/resolve-state/main.go index 5be449097..d6db72436 100644 --- a/cmd/resolve-state/main.go +++ b/cmd/resolve-state/main.go @@ -202,12 +202,25 @@ func main() { authEvents[i] = authEventEntries[i].PDU } + // Get the roomNID + roomInfo, err = roomserverDB.RoomInfo(ctx, authEvents[0].RoomID().String()) + if err != nil { + panic(err) + } + fmt.Println("Resolving state") var resolved Events resolved, err = gomatrixserverlib.ResolveConflicts( gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }, + func(eventID string) bool { + isRejected, rejectedErr := roomserverDB.IsEventRejected(ctx, roomInfo.RoomNID, eventID) + if rejectedErr != nil { + return true + } + return isRejected + }, ) if err != nil { panic(err) diff --git a/dendrite-sample.yaml b/dendrite-sample.yaml index 7affc2599..e143a7398 100644 --- a/dendrite-sample.yaml +++ b/dendrite-sample.yaml @@ -325,6 +325,10 @@ user_api: auto_join_rooms: # - "#main:matrix.org" + # The number of workers to start for the DeviceListUpdater. Defaults to 8. + # This only needs updating if the "InputDeviceListUpdate" stream keeps growing indefinitely. + # worker_count: 8 + # Configuration for Opentracing. # See https://github.com/matrix-org/dendrite/tree/master/docs/tracing for information on # how this works and how to set it up. diff --git a/docs/FAQ.md b/docs/FAQ.md index 570ba677e..82b1581ea 100644 --- a/docs/FAQ.md +++ b/docs/FAQ.md @@ -24,7 +24,7 @@ No, although a good portion of the Matrix specification has been implemented. Mo Dendrite development is currently supported by a small team of developers and due to those limited resources, the majority of the effort is focused on getting Dendrite to be specification complete. If there are major features you're requesting (e.g. new administration endpoints), we'd like to strongly encourage you to join the community in supporting -the development efforts through [contributing](../development/contributing). +the development efforts through [contributing](./development/CONTRIBUTING.md). ## Is there a migration path from Synapse to Dendrite? @@ -105,7 +105,7 @@ This can be done by performing a room upgrade. Use the command `/upgraderoom # If set to "-", storageClassName: "", which disables dynamic provisioning # If undefined (the default) or set to null, no storageClassName spec is # set, choosing the default provisioner. (gp2 on AWS, standard on # GKE, AWS & OpenStack) # storageClass: "" | | persistence.jetstream.existingClaim | string | `""` | Use an existing volume claim for jetstream | | persistence.jetstream.capacity | string | `"1Gi"` | PVC Storage Request for the jetstream volume | -| persistence.jetstream.storageClass | string | `""` | The storage class to use for volume claims. Defaults to persistence.storageClass | | persistence.media.existingClaim | string | `""` | Use an existing volume claim for media files | | persistence.media.capacity | string | `"1Gi"` | PVC Storage Request for the media volume | -| persistence.media.storageClass | string | `""` | The storage class to use for volume claims. Defaults to persistence.storageClass | | persistence.search.existingClaim | string | `""` | Use an existing volume claim for the fulltext search index | | persistence.search.capacity | string | `"1Gi"` | PVC Storage Request for the search volume | -| persistence.search.storageClass | string | `""` | The storage class to use for volume claims. Defaults to persistence.storageClass | | extraVolumes | list | `[]` | Add additional volumes to the Dendrite Pod | | extraVolumeMounts | list | `[]` | Configure additional mount points volumes in the Dendrite Pod | | strategy.type | string | `"RollingUpdate"` | Strategy to use for rolling updates (e.g. Recreate, RollingUpdate) If you are using ReadWriteOnce volumes, you should probably use Recreate | diff --git a/helm/dendrite/grafana_dashboards/dendrite-rev2.json b/helm/dendrite/grafana_dashboards/dendrite-rev2.json index 817f950b3..420d8bf1b 100644 --- a/helm/dendrite/grafana_dashboards/dendrite-rev2.json +++ b/helm/dendrite/grafana_dashboards/dendrite-rev2.json @@ -119,7 +119,7 @@ "refId": "A" } ], - "title": "Registerd Users", + "title": "Registered Users", "type": "stat" }, { diff --git a/helm/dendrite/templates/pvc.yaml b/helm/dendrite/templates/pvc.yaml index 88eff3bed..70b1ce563 100644 --- a/helm/dendrite/templates/pvc.yaml +++ b/helm/dendrite/templates/pvc.yaml @@ -12,7 +12,14 @@ spec: resources: requests: storage: {{ .Values.persistence.media.capacity }} - storageClassName: {{ default .Values.persistence.storageClass .Values.persistence.media.storageClass }} + {{ $storageClass := .Values.persistence.media.storageClass | default .Values.persistence.storageClass }} + {{- if $storageClass }} + {{- if (eq "-" $storageClass) }} + storageClassName: "" + {{- else }} + storageClassName: "{{ $storageClass }}" + {{- end }} + {{- end }} {{ end }} {{ if not .Values.persistence.jetstream.existingClaim }} --- @@ -28,7 +35,14 @@ spec: resources: requests: storage: {{ .Values.persistence.jetstream.capacity }} - storageClassName: {{ default .Values.persistence.storageClass .Values.persistence.jetstream.storageClass }} + {{ $storageClass := .Values.persistence.jetstream.storageClass | default .Values.persistence.storageClass }} + {{- if $storageClass }} + {{- if (eq "-" $storageClass) }} + storageClassName: "" + {{- else }} + storageClassName: "{{ $storageClass }}" + {{- end }} + {{- end }} {{ end }} {{ if not .Values.persistence.search.existingClaim }} --- @@ -44,5 +58,12 @@ spec: resources: requests: storage: {{ .Values.persistence.search.capacity }} - storageClassName: {{ default .Values.persistence.storageClass .Values.persistence.search.storageClass }} + {{ $storageClass := .Values.persistence.search.storageClass | default .Values.persistence.storageClass }} + {{- if $storageClass }} + {{- if (eq "-" $storageClass) }} + storageClassName: "" + {{- else }} + storageClassName: "{{ $storageClass }}" + {{- end }} + {{- end }} {{ end }} diff --git a/helm/dendrite/values.yaml b/helm/dendrite/values.yaml index 8a72f6693..afce1d930 100644 --- a/helm/dendrite/values.yaml +++ b/helm/dendrite/values.yaml @@ -26,7 +26,13 @@ persistence: # -- The storage class to use for volume claims. # Used unless specified at the specific component. # Defaults to the cluster default storage class. - storageClass: "" + ## If defined, storageClassName: + ## If set to "-", storageClassName: "", which disables dynamic provisioning + ## If undefined (the default) or set to null, no storageClassName spec is + ## set, choosing the default provisioner. (gp2 on AWS, standard on + ## GKE, AWS & OpenStack) + ## + # storageClass: "" jetstream: # -- Use an existing volume claim for jetstream existingClaim: "" @@ -34,7 +40,13 @@ persistence: capacity: "1Gi" # -- The storage class to use for volume claims. # Defaults to persistence.storageClass - storageClass: "" + ## If defined, storageClassName: + ## If set to "-", storageClassName: "", which disables dynamic provisioning + ## If undefined (the default) or set to null, no storageClassName spec is + ## set, choosing the default provisioner. (gp2 on AWS, standard on + ## GKE, AWS & OpenStack) + ## + # storageClass: "" media: # -- Use an existing volume claim for media files existingClaim: "" @@ -42,7 +54,13 @@ persistence: capacity: "1Gi" # -- The storage class to use for volume claims. # Defaults to persistence.storageClass - storageClass: "" + ## If defined, storageClassName: + ## If set to "-", storageClassName: "", which disables dynamic provisioning + ## If undefined (the default) or set to null, no storageClassName spec is + ## set, choosing the default provisioner. (gp2 on AWS, standard on + ## GKE, AWS & OpenStack) + ## + # storageClass: "" search: # -- Use an existing volume claim for the fulltext search index existingClaim: "" @@ -50,7 +68,13 @@ persistence: capacity: "1Gi" # -- The storage class to use for volume claims. # Defaults to persistence.storageClass - storageClass: "" + ## If defined, storageClassName: + ## If set to "-", storageClassName: "", which disables dynamic provisioning + ## If undefined (the default) or set to null, no storageClassName spec is + ## set, choosing the default provisioner. (gp2 on AWS, standard on + ## GKE, AWS & OpenStack) + ## + # storageClass: "" # -- Add additional volumes to the Dendrite Pod extraVolumes: [] diff --git a/internal/pushrules/action_test.go b/internal/pushrules/action_test.go index 72db9c998..d1eb16ef1 100644 --- a/internal/pushrules/action_test.go +++ b/internal/pushrules/action_test.go @@ -13,8 +13,6 @@ func TestActionJSON(t *testing.T) { Want Action }{ {Action{Kind: NotifyAction}}, - {Action{Kind: DontNotifyAction}}, - {Action{Kind: CoalesceAction}}, {Action{Kind: SetTweakAction}}, {Action{Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"}}, diff --git a/internal/pushrules/default_override.go b/internal/pushrules/default_override.go index f97427b71..cb825af5f 100644 --- a/internal/pushrules/default_override.go +++ b/internal/pushrules/default_override.go @@ -10,6 +10,7 @@ func defaultOverrideRules(userID string) []*Rule { &mRuleRoomNotifDefinition, &mRuleTombstoneDefinition, &mRuleReactionDefinition, + &mRuleACLsDefinition, } } @@ -30,7 +31,7 @@ var ( RuleID: MRuleMaster, Default: true, Enabled: false, - Actions: []*Action{{Kind: DontNotifyAction}}, + Actions: []*Action{}, } mRuleSuppressNoticesDefinition = Rule{ RuleID: MRuleSuppressNotices, @@ -43,7 +44,7 @@ var ( Pattern: pointer("m.notice"), }, }, - Actions: []*Action{{Kind: DontNotifyAction}}, + Actions: []*Action{}, } mRuleMemberEventDefinition = Rule{ RuleID: MRuleMemberEvent, @@ -56,7 +57,7 @@ var ( Pattern: pointer("m.room.member"), }, }, - Actions: []*Action{{Kind: DontNotifyAction}}, + Actions: []*Action{}, } mRuleContainsDisplayNameDefinition = Rule{ RuleID: MRuleContainsDisplayName, @@ -152,9 +153,7 @@ var ( Pattern: pointer("m.reaction"), }, }, - Actions: []*Action{ - {Kind: DontNotifyAction}, - }, + Actions: []*Action{}, } ) diff --git a/internal/pushrules/default_pushrules_test.go b/internal/pushrules/default_pushrules_test.go index dea829842..506fe3cc8 100644 --- a/internal/pushrules/default_pushrules_test.go +++ b/internal/pushrules/default_pushrules_test.go @@ -21,12 +21,12 @@ func TestDefaultRules(t *testing.T) { // Default override rules { name: ".m.rule.master", - inputBytes: []byte(`{"rule_id":".m.rule.master","default":true,"enabled":false,"actions":["dont_notify"]}`), + inputBytes: []byte(`{"rule_id":".m.rule.master","default":true,"enabled":false,"actions":[]}`), want: mRuleMasterDefinition, }, { name: ".m.rule.suppress_notices", - inputBytes: []byte(`{"rule_id":".m.rule.suppress_notices","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"content.msgtype","pattern":"m.notice"}],"actions":["dont_notify"]}`), + inputBytes: []byte(`{"rule_id":".m.rule.suppress_notices","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"content.msgtype","pattern":"m.notice"}],"actions":[]}`), want: mRuleSuppressNoticesDefinition, }, { @@ -36,7 +36,7 @@ func TestDefaultRules(t *testing.T) { }, { name: ".m.rule.member_event", - inputBytes: []byte(`{"rule_id":".m.rule.member_event","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.member"}],"actions":["dont_notify"]}`), + inputBytes: []byte(`{"rule_id":".m.rule.member_event","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.member"}],"actions":[]}`), want: mRuleMemberEventDefinition, }, { diff --git a/internal/pushrules/util.go b/internal/pushrules/util.go index de8fe5cd0..e2821b57a 100644 --- a/internal/pushrules/util.go +++ b/internal/pushrules/util.go @@ -16,10 +16,7 @@ func ActionsToTweaks(as []*Action) (ActionKind, map[string]interface{}, error) { for _, a := range as { switch a.Kind { - case DontNotifyAction: - // Don't bother processing any further - return DontNotifyAction, nil, nil - + case DontNotifyAction: // Ignored case SetTweakAction: if tweaks == nil { tweaks = map[string]interface{}{} diff --git a/internal/pushrules/util_test.go b/internal/pushrules/util_test.go index 89f8243d9..83eee7ede 100644 --- a/internal/pushrules/util_test.go +++ b/internal/pushrules/util_test.go @@ -17,17 +17,16 @@ func TestActionsToTweaks(t *testing.T) { {"empty", nil, UnknownAction, nil}, {"zero", []*Action{{}}, UnknownAction, nil}, {"onlyPrimary", []*Action{{Kind: NotifyAction}}, NotifyAction, nil}, - {"onlyPrimaryDontNotify", []*Action{{Kind: DontNotifyAction}}, DontNotifyAction, nil}, + {"onlyPrimaryDontNotify", []*Action{}, UnknownAction, nil}, {"onlyTweak", []*Action{{Kind: SetTweakAction, Tweak: HighlightTweak}}, UnknownAction, map[string]interface{}{"highlight": nil}}, {"onlyTweakWithValue", []*Action{{Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"}}, UnknownAction, map[string]interface{}{"sound": "default"}}, { "all", []*Action{ - {Kind: CoalesceAction}, {Kind: SetTweakAction, Tweak: HighlightTweak}, {Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"}, }, - CoalesceAction, + UnknownAction, map[string]interface{}{"highlight": nil, "sound": "default"}, }, } diff --git a/internal/pushrules/validate.go b/internal/pushrules/validate.go index b54ec3fb0..4cc479345 100644 --- a/internal/pushrules/validate.go +++ b/internal/pushrules/validate.go @@ -18,7 +18,7 @@ func ValidateRule(kind Kind, rule *Rule) []error { errs = append(errs, fmt.Errorf("invalid rule ID: %s", rule.RuleID)) } - if len(rule.Actions) == 0 { + if rule.Actions == nil { errs = append(errs, fmt.Errorf("missing actions")) } for _, action := range rule.Actions { diff --git a/internal/validate.go b/internal/validate.go index a90a6636f..da8b35cd3 100644 --- a/internal/validate.go +++ b/internal/validate.go @@ -28,7 +28,7 @@ import ( ) const ( - maxUsernameLength = 254 // http://matrix.org/speculator/spec/HEAD/intro.html#user-identifiers TODO account for domain + maxUsernameLength = 254 // https://spec.matrix.org/v1.7/appendices/#user-identifiers TODO account for domain minPasswordLength = 8 // http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based maxPasswordLength = 512 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 diff --git a/internal/version.go b/internal/version.go index 1f8a62bc2..55726ddcb 100644 --- a/internal/version.go +++ b/internal/version.go @@ -18,7 +18,7 @@ var build string const ( VersionMajor = 0 VersionMinor = 13 - VersionPatch = 3 + VersionPatch = 4 VersionTag = "" // example: "rc1" gitRevLen = 7 // 7 matches the displayed characters on github.com diff --git a/relayapi/storage/postgres/relay_queue_json_table.go b/relayapi/storage/postgres/relay_queue_json_table.go index 74410fc88..94ae41407 100644 --- a/relayapi/storage/postgres/relay_queue_json_table.go +++ b/relayapi/storage/postgres/relay_queue_json_table.go @@ -109,5 +109,5 @@ func (s *relayQueueJSONStatements) SelectQueueJSON( } blobs[nid] = blob } - return blobs, err + return blobs, rows.Err() } diff --git a/relayapi/storage/sqlite3/relay_queue_json_table.go b/relayapi/storage/sqlite3/relay_queue_json_table.go index 502da3b00..a1af82aa0 100644 --- a/relayapi/storage/sqlite3/relay_queue_json_table.go +++ b/relayapi/storage/sqlite3/relay_queue_json_table.go @@ -133,5 +133,5 @@ func (s *relayQueueJSONStatements) SelectQueueJSON( } blobs[nid] = blob } - return blobs, err + return blobs, rows.Err() } diff --git a/roomserver/acls/acls.go b/roomserver/acls/acls.go index 601ce9063..e247c7553 100644 --- a/roomserver/acls/acls.go +++ b/roomserver/acls/acls.go @@ -29,6 +29,8 @@ import ( "github.com/sirupsen/logrus" ) +const MRoomServerACL = "m.room.server_acl" + type ServerACLDatabase interface { // GetKnownRooms returns a list of all rooms we know about. GetKnownRooms(ctx context.Context) ([]string, error) @@ -57,7 +59,7 @@ func NewServerACLs(db ServerACLDatabase) *ServerACLs { // do then we'll process it into memory so that we have the regexes to // hand. for _, room := range rooms { - state, err := db.GetStateEvent(ctx, room, "m.room.server_acl", "") + state, err := db.GetStateEvent(ctx, room, MRoomServerACL, "") if err != nil { logrus.WithError(err).Errorf("Failed to get server ACLs for room %q", room) continue diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 77b50d0e2..520f82a80 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -33,6 +33,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/roomserver/acls" "github.com/matrix-org/dendrite/roomserver/internal/helpers" userAPI "github.com/matrix-org/dendrite/userapi/api" @@ -491,6 +492,27 @@ func (r *Inputer) processRoomEvent( } } + // If this is a membership event, it is possible we newly joined a federated room and eventually + // missed to update our m.room.server_acl - the following ensures we set the ACLs + // TODO: This probably performs badly in benchmarks + if event.Type() == spec.MRoomMember { + membership, _ := event.Membership() + if membership == spec.Join { + _, serverName, _ := gomatrixserverlib.SplitID('@', *event.StateKey()) + // only handle local membership events + if r.Cfg.Matrix.IsLocalServerName(serverName) { + var aclEvent *types.HeaderedEvent + aclEvent, err = r.DB.GetStateEvent(ctx, event.RoomID().String(), acls.MRoomServerACL, "") + if err != nil { + logrus.WithError(err).Error("failed to get server ACLs") + } + if aclEvent != nil { + r.ACLs.OnServerACLUpdate(aclEvent) + } + } + } + } + // Handle remote room upgrades, e.g. remove published room if event.Type() == "m.room.tombstone" && event.StateKeyEquals("") && !r.Cfg.Matrix.IsLocalServerName(senderDomain) { if err = r.handleRemoteRoomUpgrade(ctx, event); err != nil { diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index d9ab291e9..21493287e 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -498,6 +498,13 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID) }, + func(eventID string) bool { + isRejected, err := t.db.IsEventRejected(ctx, t.roomInfo.RoomNID, eventID) + if err != nil { + return true + } + return isRejected + }, ) if err != nil { return nil, err diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index f87a3f7ed..74b010281 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -165,6 +165,13 @@ func (r *Queryer) QueryStateAfterEvents( info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return r.QueryUserIDForSender(ctx, roomID, senderID) }, + func(eventID string) bool { + isRejected, rejectedErr := r.DB.IsEventRejected(ctx, info.RoomNID, eventID) + if rejectedErr != nil { + return true + } + return isRejected + }, ) if err != nil { return fmt.Errorf("state.ResolveConflictsAdhoc: %w", err) @@ -676,6 +683,13 @@ func (r *Queryer) QueryStateAndAuthChain( info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return r.QueryUserIDForSender(ctx, roomID, senderID) }, + func(eventID string) bool { + isRejected, rejectedErr := r.DB.IsEventRejected(ctx, info.RoomNID, eventID) + if rejectedErr != nil { + return true + } + return isRejected + }, ) if err != nil { return err diff --git a/roomserver/internal/query/query_room_hierarchy.go b/roomserver/internal/query/query_room_hierarchy.go index 76eba12be..5f55980f0 100644 --- a/roomserver/internal/query/query_room_hierarchy.go +++ b/roomserver/internal/query/query_room_hierarchy.go @@ -17,6 +17,7 @@ package query import ( "context" "encoding/json" + "errors" "fmt" "sort" @@ -56,6 +57,12 @@ func (querier *Queryer) QueryNextRoomHierarchyPage(ctx context.Context, walker r break } + // If the context is canceled, we might still have discovered rooms + // return them to the client and let the client know there _may_ be more rooms. + if errors.Is(ctx.Err(), context.Canceled) { + break + } + // pop the stack queuedRoom := unvisited[len(unvisited)-1] unvisited = unvisited[:len(unvisited)-1] @@ -112,6 +119,11 @@ func (querier *Queryer) QueryNextRoomHierarchyPage(ctx context.Context, walker r pubRoom := publicRoomsChunk(ctx, querier, queuedRoom.RoomID) + if pubRoom == nil { + util.GetLogger(ctx).WithField("room_id", queuedRoom.RoomID).Debug("unable to get publicRoomsChunk") + continue + } + discoveredRooms = append(discoveredRooms, fclient.RoomHierarchyRoom{ PublicRoom: *pubRoom, RoomType: roomType, diff --git a/roomserver/producers/roomevent.go b/roomserver/producers/roomevent.go index 165304d49..af7e10580 100644 --- a/roomserver/producers/roomevent.go +++ b/roomserver/producers/roomevent.go @@ -73,7 +73,7 @@ func (r *RoomEventProducer) ProduceRoomEvents(roomID string, updates []api.Outpu } } - if eventType == "m.room.server_acl" && update.NewRoomEvent.Event.StateKeyEquals("") { + if eventType == acls.MRoomServerACL && update.NewRoomEvent.Event.StateKeyEquals("") { ev := update.NewRoomEvent.Event.PDU defer r.ACLs.OnServerACLUpdate(ev) } diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 47626b30a..e9cd926d7 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/httputil" @@ -15,6 +16,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/tidwall/gjson" + "github.com/matrix-org/dendrite/roomserver/acls" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/userapi" @@ -34,6 +36,10 @@ import ( "github.com/matrix-org/dendrite/test/testrig" ) +var testIsBlacklistedOrBackingOff = func(s spec.ServerName) (*statistics.ServerStatistics, error) { + return &statistics.ServerStatistics{}, nil +} + type FakeQuerier struct { api.QuerySenderIDAPI } @@ -58,7 +64,7 @@ func TestUsers(t *testing.T) { }) t.Run("kick users", func(t *testing.T) { - usrAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + usrAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) rsAPI.SetUserAPI(usrAPI) testKickUsers(t, rsAPI, usrAPI) }) @@ -258,7 +264,7 @@ func TestPurgeRoom(t *testing.T) { fsAPI := federationapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, rsAPI, caches, nil, true) rsAPI.SetFederationAPI(fsAPI, nil) - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, fsAPI.IsBlacklistedOrBackingOff) syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, &natsInstance, userAPI, rsAPI, caches, caching.DisableMetrics) // Create the room @@ -407,7 +413,9 @@ type fledglingEvent struct { RoomID string Redacts string Depth int64 - PrevEvents []interface{} + PrevEvents []any + AuthEvents []any + Content map[string]any } func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *types.HeaderedEvent) { @@ -424,7 +432,13 @@ func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *types.HeaderedEve Depth: ev.Depth, PrevEvents: ev.PrevEvents, }) - err := eb.SetContent(map[string]interface{}{}) + if ev.Content == nil { + ev.Content = map[string]any{} + } + if ev.AuthEvents != nil { + eb.AuthEvents = ev.AuthEvents + } + err := eb.SetContent(ev.Content) if err != nil { t.Fatalf("mustCreateEvent: failed to marshal event content %v", err) } @@ -1042,7 +1056,7 @@ func TestUpgrade(t *testing.T) { rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) - userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) + userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) rsAPI.SetUserAPI(userAPI) for _, tc := range testCases { @@ -1077,3 +1091,143 @@ func TestUpgrade(t *testing.T) { } }) } + +func TestStateReset(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + charlie := test.NewUser(t) + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + // Prepare APIs + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + defer close() + + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + natsInstance := jetstream.NATSInstance{} + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + + // create a new room + room := test.NewRoom(t, alice, test.RoomPreset(test.PresetPublicChat)) + + // join with Bob and Charlie + bobJoinEv := room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]any{"membership": "join"}, test.WithStateKey(bob.ID)) + charlieJoinEv := room.CreateAndInsert(t, charlie, spec.MRoomMember, map[string]any{"membership": "join"}, test.WithStateKey(charlie.ID)) + + // Send and create the room + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Errorf("failed to send events: %v", err) + } + + // send a message + bobMsg := room.CreateAndInsert(t, bob, "m.room.message", map[string]any{"body": "hello world"}) + charlieMsg := room.CreateAndInsert(t, charlie, "m.room.message", map[string]any{"body": "hello world"}) + + if err := api.SendEvents(ctx, rsAPI, api.KindNew, []*types.HeaderedEvent{bobMsg, charlieMsg}, "test", "test", "test", nil, false); err != nil { + t.Errorf("failed to send events: %v", err) + } + + // Bob changes his name + expectedDisplayname := "Bob!" + bobDisplayname := room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]any{"membership": "join", "displayname": expectedDisplayname}, test.WithStateKey(bob.ID)) + + if err := api.SendEvents(ctx, rsAPI, api.KindNew, []*types.HeaderedEvent{bobDisplayname}, "test", "test", "test", nil, false); err != nil { + t.Errorf("failed to send events: %v", err) + } + + // Change another state event + jrEv := room.CreateAndInsert(t, alice, spec.MRoomJoinRules, gomatrixserverlib.JoinRuleContent{JoinRule: "invite"}, test.WithStateKey("")) + if err := api.SendEvents(ctx, rsAPI, api.KindNew, []*types.HeaderedEvent{jrEv}, "test", "test", "test", nil, false); err != nil { + t.Errorf("failed to send events: %v", err) + } + + // send a message + bobMsg = room.CreateAndInsert(t, bob, "m.room.message", map[string]any{"body": "hello world"}) + charlieMsg = room.CreateAndInsert(t, charlie, "m.room.message", map[string]any{"body": "hello world"}) + + if err := api.SendEvents(ctx, rsAPI, api.KindNew, []*types.HeaderedEvent{bobMsg, charlieMsg}, "test", "test", "test", nil, false); err != nil { + t.Errorf("failed to send events: %v", err) + } + + // Craft the state reset message, which is using Bobs initial join event and the + // last message Charlie sent as the prev_events. This should trigger the recalculation + // of the "current" state, since the message event does not have state and no missing events in the DB. + stateResetMsg := mustCreateEvent(t, fledglingEvent{ + Type: "m.room.message", + SenderID: charlie.ID, + RoomID: room.ID, + Depth: charlieMsg.Depth() + 1, + PrevEvents: []any{ + bobJoinEv.EventID(), + charlieMsg.EventID(), + }, + AuthEvents: []any{ + room.Events()[0].EventID(), // create event + room.Events()[2].EventID(), // PL event + charlieJoinEv.EventID(), // Charlie join event + }, + }) + + // Send the state reset message + if err := api.SendEvents(ctx, rsAPI, api.KindNew, []*types.HeaderedEvent{stateResetMsg}, "test", "test", "test", nil, false); err != nil { + t.Errorf("failed to send events: %v", err) + } + + // Validate that there is a membership event for Bob + bobMembershipEv := api.GetStateEvent(ctx, rsAPI, room.ID, gomatrixserverlib.StateKeyTuple{ + EventType: spec.MRoomMember, + StateKey: bob.ID, + }) + + if bobMembershipEv == nil { + t.Fatalf("Membership event for Bob does not exist. State reset?") + } else { + // Validate it's the correct membership event + if dn := gjson.GetBytes(bobMembershipEv.Content(), "displayname").Str; dn != expectedDisplayname { + t.Fatalf("Expected displayname to be %q, got %q", expectedDisplayname, dn) + } + } + }) +} + +func TestNewServerACLs(t *testing.T) { + alice := test.NewUser(t) + roomWithACL := test.NewRoom(t, alice) + + roomWithACL.CreateAndInsert(t, alice, acls.MRoomServerACL, acls.ServerACL{ + Allowed: []string{"*"}, + Denied: []string{"localhost"}, + AllowIPLiterals: false, + }, test.WithStateKey("")) + + roomWithoutACL := test.NewRoom(t, alice) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, closeDB := testrig.CreateConfig(t, dbType) + defer closeDB() + + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + natsInstance := &jetstream.NATSInstance{} + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + // start JetStream listeners + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + + // let the RS create the events + err := api.SendEvents(context.Background(), rsAPI, api.KindNew, roomWithACL.Events(), "test", "test", "test", nil, false) + assert.NoError(t, err) + err = api.SendEvents(context.Background(), rsAPI, api.KindNew, roomWithoutACL.Events(), "test", "test", "test", nil, false) + assert.NoError(t, err) + + db, err := storage.Open(processCtx.Context(), cm, &cfg.RoomServer.Database, caches) + assert.NoError(t, err) + // create new server ACLs and verify server is banned/not banned + serverACLs := acls.NewServerACLs(db) + banned := serverACLs.IsServerBannedFromRoom("localhost", roomWithACL.ID) + assert.Equal(t, true, banned) + banned = serverACLs.IsServerBannedFromRoom("localhost", roomWithoutACL.ID) + assert.Equal(t, false, banned) + }) +} diff --git a/roomserver/state/state.go b/roomserver/state/state.go index 1e776ff6c..dfd439a22 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -45,6 +45,7 @@ type StateResolutionStorage interface { AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) + IsEventRejected(ctx context.Context, roomNID types.RoomNID, eventID string) (bool, error) } type StateResolution struct { @@ -1066,6 +1067,13 @@ func (v *StateResolution) resolveConflictsV2( func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return v.Querier.QueryUserIDForSender(ctx, roomID, senderID) }, + func(eventID string) bool { + isRejected, err := v.db.IsEventRejected(ctx, v.roomInfo.RoomNID, eventID) + if err != nil { + return true + } + return isRejected + }, ) }() diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index a00b4b1d7..1c9cd1599 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -249,6 +249,7 @@ func (s *eventStatements) BulkSelectSnapshotsFromEventIDs( if err != nil { return nil, err } + defer internal.CloseAndLogIfError(ctx, rows, "BulkSelectSnapshotsFromEventIDs: rows.close() failed") var eventID string var stateNID types.StateSnapshotNID @@ -563,7 +564,7 @@ func (s *eventStatements) SelectRoomNIDsForEventNIDs( } result[eventNID] = roomNID } - return result, nil + return result, rows.Err() } func eventNIDsAsArray(eventNIDs []types.EventNID) pq.Int64Array { diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index 835a43b2d..1a96e3527 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -363,7 +363,7 @@ func (s *membershipStatements) SelectRoomsWithMembership( } roomNIDs = append(roomNIDs, roomNID) } - return roomNIDs, nil + return roomNIDs, rows.Err() } func (s *membershipStatements) SelectJoinedUsersSetForRooms( diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go index c8346733d..bc3820b2c 100644 --- a/roomserver/storage/postgres/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -137,7 +137,7 @@ func (s *roomStatements) SelectRoomIDsWithEvents(ctx context.Context, txn *sql.T } roomIDs = append(roomIDs, roomID) } - return roomIDs, nil + return roomIDs, rows.Err() } func (s *roomStatements) InsertRoomNID( ctx context.Context, txn *sql.Tx, @@ -255,7 +255,7 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs( } result[roomNID] = roomVersion } - return result, nil + return result, rows.Err() } func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error) { @@ -277,7 +277,7 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roo } roomIDs = append(roomIDs, roomID) } - return roomIDs, nil + return roomIDs, rows.Err() } func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error) { @@ -299,7 +299,7 @@ func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, ro } roomNIDs = append(roomNIDs, roomNID) } - return roomNIDs, nil + return roomNIDs, rows.Err() } func roomNIDsAsArray(roomNIDs []types.RoomNID) pq.Int64Array { diff --git a/roomserver/storage/postgres/user_room_keys_table.go b/roomserver/storage/postgres/user_room_keys_table.go index 217ee957f..57e8f213b 100644 --- a/roomserver/storage/postgres/user_room_keys_table.go +++ b/roomserver/storage/postgres/user_room_keys_table.go @@ -162,6 +162,7 @@ func (s *userRoomKeysStatements) SelectAllPublicKeysForUser(ctx context.Context, if errors.Is(err, sql.ErrNoRows) { return nil, nil } + defer internal.CloseAndLogIfError(ctx, rows, "SelectAllPublicKeysForUser: failed to close rows") resultMap := make(map[types.RoomNID]ed25519.PublicKey) @@ -173,5 +174,5 @@ func (s *userRoomKeysStatements) SelectAllPublicKeysForUser(ctx context.Context, } resultMap[roomNID] = pubkey } - return resultMap, err + return resultMap, rows.Err() } diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index 70672a33e..06284d2e3 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -250,3 +250,7 @@ func (u *RoomUpdater) MarkEventAsSent(eventNID types.EventNID) error { func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) { return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal) } + +func (u *RoomUpdater) IsEventRejected(ctx context.Context, roomNID types.RoomNID, eventID string) (bool, error) { + return u.d.IsEventRejected(ctx, roomNID, eventID) +} diff --git a/roomserver/storage/sqlite3/event_json_table.go b/roomserver/storage/sqlite3/event_json_table.go index dc26885bb..325951c7f 100644 --- a/roomserver/storage/sqlite3/event_json_table.go +++ b/roomserver/storage/sqlite3/event_json_table.go @@ -109,5 +109,5 @@ func (s *eventJSONStatements) BulkSelectEventJSON( } result.EventNID = types.EventNID(eventNID) } - return results[:i], nil + return results[:i], rows.Err() } diff --git a/roomserver/storage/sqlite3/event_state_keys_table.go b/roomserver/storage/sqlite3/event_state_keys_table.go index 347524a81..a052d69ed 100644 --- a/roomserver/storage/sqlite3/event_state_keys_table.go +++ b/roomserver/storage/sqlite3/event_state_keys_table.go @@ -136,7 +136,7 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( } result[stateKey] = types.EventStateKeyNID(stateKeyNID) } - return result, nil + return result, rows.Err() } func (s *eventStateKeyStatements) BulkSelectEventStateKey( @@ -167,5 +167,5 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKey( } result[types.EventStateKeyNID(stateKeyNID)] = stateKey } - return result, nil + return result, rows.Err() } diff --git a/roomserver/storage/sqlite3/event_types_table.go b/roomserver/storage/sqlite3/event_types_table.go index 0581ec194..c030fffea 100644 --- a/roomserver/storage/sqlite3/event_types_table.go +++ b/roomserver/storage/sqlite3/event_types_table.go @@ -147,5 +147,5 @@ func (s *eventTypeStatements) BulkSelectEventTypeNID( } result[eventType] = types.EventTypeNID(eventTypeNID) } - return result, nil + return result, rows.Err() } diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index c49c6dc38..2c269bced 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -310,6 +310,9 @@ func (s *eventStatements) BulkSelectStateEventByID( } results = append(results, result) } + if err = rows.Err(); err != nil { + return nil, err + } if !excludeRejected && i != len(eventIDs) { // If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have. // We don't know which ones were missing because we don't return the string IDs in the query. @@ -377,7 +380,7 @@ func (s *eventStatements) BulkSelectStateEventByNID( return nil, err } } - return results[:i], err + return results[:i], rows.Err() } // bulkSelectStateAtEventByID lookups the state at a list of events by event ID. @@ -425,6 +428,9 @@ func (s *eventStatements) BulkSelectStateAtEventByID( ) } } + if err = rows.Err(); err != nil { + return nil, err + } if i != len(eventIDs) { return nil, types.MissingEventError( fmt.Sprintf("storage: event IDs missing from the database (%d != %d)", i, len(eventIDs)), @@ -507,6 +513,9 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference( result.BeforeStateSnapshotNID = types.StateSnapshotNID(stateSnapshotNID) result.EventID = eventID } + if err = rows.Err(); err != nil { + return nil, err + } if i != len(eventNIDs) { return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs)) } @@ -544,6 +553,9 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, ev } results[types.EventNID(eventNID)] = eventID } + if err = rows.Err(); err != nil { + return nil, err + } if i != len(eventNIDs) { return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs)) } @@ -602,7 +614,7 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, e RoomNID: types.RoomNID(roomNID), } } - return results, nil + return results, rows.Err() } func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) { @@ -652,7 +664,7 @@ func (s *eventStatements) SelectRoomNIDsForEventNIDs( } result[eventNID] = roomNID } - return result, nil + return result, rows.Err() } func eventNIDsAsArray(eventNIDs []types.EventNID) string { diff --git a/roomserver/storage/sqlite3/invite_table.go b/roomserver/storage/sqlite3/invite_table.go index ca6e7c511..b678d8add 100644 --- a/roomserver/storage/sqlite3/invite_table.go +++ b/roomserver/storage/sqlite3/invite_table.go @@ -126,6 +126,9 @@ func (s *inviteStatements) UpdateInviteRetired( } eventIDs = append(eventIDs, inviteEventID) } + if err = rows.Err(); err != nil { + return + } // now retire the invites stmt = sqlutil.TxStmt(txn, s.updateInviteRetiredStmt) _, err = stmt.ExecContext(ctx, roomNID, targetUserNID) @@ -157,5 +160,5 @@ func (s *inviteStatements) SelectInviteActiveForUserInRoom( result = append(result, types.EventStateKeyNID(senderUserNID)) eventIDs = append(eventIDs, eventID) } - return result, eventIDs, eventJSON, nil + return result, eventIDs, eventJSON, rows.Err() } diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 977788d50..1012c074a 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -250,6 +250,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom( } eventNIDs = append(eventNIDs, eNID) } + err = rows.Err() return } @@ -277,6 +278,7 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( } eventNIDs = append(eventNIDs, eNID) } + err = rows.Err() return } @@ -313,7 +315,7 @@ func (s *membershipStatements) SelectRoomsWithMembership( } roomNIDs = append(roomNIDs, roomNID) } - return roomNIDs, nil + return roomNIDs, rows.Err() } func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID, localOnly bool) (map[types.EventStateKeyNID]int, error) { diff --git a/roomserver/storage/sqlite3/room_aliases_table.go b/roomserver/storage/sqlite3/room_aliases_table.go index 3bdbbaa35..815b42a27 100644 --- a/roomserver/storage/sqlite3/room_aliases_table.go +++ b/roomserver/storage/sqlite3/room_aliases_table.go @@ -121,7 +121,7 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID( aliases = append(aliases, alias) } - + err = rows.Err() return } diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index 7556b3461..22700a710 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -128,7 +128,7 @@ func (s *roomStatements) SelectRoomIDsWithEvents(ctx context.Context, txn *sql.T } roomIDs = append(roomIDs, roomID) } - return roomIDs, nil + return roomIDs, rows.Err() } func (s *roomStatements) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) { @@ -265,7 +265,7 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs( } result[roomNID] = roomVersion } - return result, nil + return result, rows.Err() } func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error) { @@ -293,7 +293,7 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roo } roomIDs = append(roomIDs, roomID) } - return roomIDs, nil + return roomIDs, rows.Err() } func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error) { @@ -321,5 +321,5 @@ func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, ro } roomNIDs = append(roomNIDs, roomNID) } - return roomNIDs, nil + return roomNIDs, rows.Err() } diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index 2edff0ba8..dcac0b07c 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -133,13 +133,16 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( var stateBlockNIDsJSON string for ; rows.Next(); i++ { result := &results[i] - if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDsJSON); err != nil { + if err = rows.Scan(&result.StateSnapshotNID, &stateBlockNIDsJSON); err != nil { return nil, err } - if err := json.Unmarshal([]byte(stateBlockNIDsJSON), &result.StateBlockNIDs); err != nil { + if err = json.Unmarshal([]byte(stateBlockNIDsJSON), &result.StateBlockNIDs); err != nil { return nil, err } } + if err = rows.Err(); err != nil { + return nil, err + } if i != len(stateNIDs) { return nil, types.MissingStateError(fmt.Sprintf("storage: state NIDs missing from the database (%d != %d)", i, len(stateNIDs))) } diff --git a/roomserver/storage/sqlite3/user_room_keys_table.go b/roomserver/storage/sqlite3/user_room_keys_table.go index 434bad295..13906f771 100644 --- a/roomserver/storage/sqlite3/user_room_keys_table.go +++ b/roomserver/storage/sqlite3/user_room_keys_table.go @@ -177,6 +177,7 @@ func (s *userRoomKeysStatements) SelectAllPublicKeysForUser(ctx context.Context, if errors.Is(err, sql.ErrNoRows) { return nil, nil } + defer internal.CloseAndLogIfError(ctx, rows, "SelectAllPublicKeysForUser: failed to close rows") resultMap := make(map[types.RoomNID]ed25519.PublicKey) @@ -188,5 +189,5 @@ func (s *userRoomKeysStatements) SelectAllPublicKeysForUser(ctx context.Context, } resultMap[roomNID] = pubkey } - return resultMap, err + return resultMap, rows.Err() } diff --git a/roomserver/storage/tables/interface_test.go b/roomserver/storage/tables/interface_test.go new file mode 100644 index 000000000..8727e2436 --- /dev/null +++ b/roomserver/storage/tables/interface_test.go @@ -0,0 +1,76 @@ +package tables + +import ( + "testing" + + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/stretchr/testify/assert" +) + +func TestExtractContentValue(t *testing.T) { + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + + tests := []struct { + name string + event *types.HeaderedEvent + want string + }{ + { + name: "returns creator ID for create events", + event: room.Events()[0], + want: alice.ID, + }, + { + name: "returns the alias for canonical alias events", + event: room.CreateEvent(t, alice, spec.MRoomCanonicalAlias, map[string]string{"alias": "#test:test"}), + want: "#test:test", + }, + { + name: "returns the history_visibility for history visibility events", + event: room.CreateEvent(t, alice, spec.MRoomHistoryVisibility, map[string]string{"history_visibility": "shared"}), + want: "shared", + }, + { + name: "returns the join rules for join_rules events", + event: room.CreateEvent(t, alice, spec.MRoomJoinRules, map[string]string{"join_rule": "public"}), + want: "public", + }, + { + name: "returns the membership for room_member events", + event: room.CreateEvent(t, alice, spec.MRoomMember, map[string]string{"membership": "join"}, test.WithStateKey(alice.ID)), + want: "join", + }, + { + name: "returns the room name for room_name events", + event: room.CreateEvent(t, alice, spec.MRoomName, map[string]string{"name": "testing"}, test.WithStateKey(alice.ID)), + want: "testing", + }, + { + name: "returns the room avatar for avatar events", + event: room.CreateEvent(t, alice, spec.MRoomAvatar, map[string]string{"url": "mxc://testing"}, test.WithStateKey(alice.ID)), + want: "mxc://testing", + }, + { + name: "returns the room topic for topic events", + event: room.CreateEvent(t, alice, spec.MRoomTopic, map[string]string{"topic": "testing"}, test.WithStateKey(alice.ID)), + want: "testing", + }, + { + name: "returns guest_access for guest access events", + event: room.CreateEvent(t, alice, "m.room.guest_access", map[string]string{"guest_access": "forbidden"}, test.WithStateKey(alice.ID)), + want: "forbidden", + }, + { + name: "returns empty string if key can't be found or unknown event", + event: room.CreateEvent(t, alice, "idontexist", nil), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, ExtractContentValue(tt.event), "ExtractContentValue(%v)", tt.event) + }) + } +} diff --git a/setup/config/config_clientapi.go b/setup/config/config_clientapi.go index 44136e2a0..85dfe0beb 100644 --- a/setup/config/config_clientapi.go +++ b/setup/config/config_clientapi.go @@ -93,17 +93,15 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors) { checkNotEmpty(configErrs, "client_api.recaptcha_sitekey_class", c.RecaptchaSitekeyClass) } // Ensure there is any spam counter measure when enabling registration - if !c.RegistrationDisabled && !c.OpenRegistrationWithoutVerificationEnabled { - if !c.RecaptchaEnabled { - configErrs.Add( - "You have tried to enable open registration without any secondary verification methods " + - "(such as reCAPTCHA). By enabling open registration, you are SIGNIFICANTLY " + - "increasing the risk that your server will be used to send spam or abuse, and may result in " + - "your server being banned from some rooms. If you are ABSOLUTELY CERTAIN you want to do this, " + - "start Dendrite with the -really-enable-open-registration command line flag. Otherwise, you " + - "should set the registration_disabled option in your Dendrite config.", - ) - } + if !c.RegistrationDisabled && !c.OpenRegistrationWithoutVerificationEnabled && !c.RecaptchaEnabled { + configErrs.Add( + "You have tried to enable open registration without any secondary verification methods " + + "(such as reCAPTCHA). By enabling open registration, you are SIGNIFICANTLY " + + "increasing the risk that your server will be used to send spam or abuse, and may result in " + + "your server being banned from some rooms. If you are ABSOLUTELY CERTAIN you want to do this, " + + "start Dendrite with the -really-enable-open-registration command line flag. Otherwise, you " + + "should set the registration_disabled option in your Dendrite config.", + ) } } diff --git a/setup/config/config_userapi.go b/setup/config/config_userapi.go index e64a3910c..559de72ac 100644 --- a/setup/config/config_userapi.go +++ b/setup/config/config_userapi.go @@ -21,6 +21,10 @@ type UserAPI struct { // Users who register on this homeserver will automatically // be joined to the rooms listed under this option. AutoJoinRooms []string `yaml:"auto_join_rooms"` + + // The number of workers to start for the DeviceListUpdater. Defaults to 8. + // This only needs updating if the "InputDeviceListUpdate" stream keeps growing indefinitely. + WorkerCount int `yaml:"worker_count"` } const DefaultOpenIDTokenLifetimeMS = 3600000 // 60 minutes @@ -28,6 +32,7 @@ const DefaultOpenIDTokenLifetimeMS = 3600000 // 60 minutes func (c *UserAPI) Defaults(opts DefaultOpts) { c.BCryptCost = bcrypt.DefaultCost c.OpenIDTokenLifetimeMS = DefaultOpenIDTokenLifetimeMS + c.WorkerCount = 8 if opts.Generate { if !opts.SingleDatabase { c.AccountDatabase.ConnectionString = "file:userapi_accounts.db" diff --git a/setup/mscs/msc2836/storage.go b/setup/mscs/msc2836/storage.go index ade2a1616..696d0b0da 100644 --- a/setup/mscs/msc2836/storage.go +++ b/setup/mscs/msc2836/storage.go @@ -301,7 +301,7 @@ func (p *DB) ChildrenForParent(ctx context.Context, eventID, relType string, rec } children = append(children, evInfo) } - return children, nil + return children, rows.Err() } func (p *DB) ParentForChild(ctx context.Context, eventID, relType string) (*eventInfo, error) { diff --git a/syncapi/routing/getevent.go b/syncapi/routing/getevent.go index c089539f0..d0227f4ea 100644 --- a/syncapi/routing/getevent.go +++ b/syncapi/routing/getevent.go @@ -44,7 +44,7 @@ func GetEvent( rsAPI api.SyncRoomserverAPI, ) util.JSONResponse { ctx := req.Context() - db, err := syncDB.NewDatabaseTransaction(ctx) + db, err := syncDB.NewDatabaseSnapshot(ctx) logger := util.GetLogger(ctx).WithFields(logrus.Fields{ "event_id": eventID, "room_id": rawRoomID, @@ -56,6 +56,7 @@ func GetEvent( JSON: spec.InternalServerError{}, } } + defer db.Rollback() // nolint: errcheck roomID, err := spec.NewRoomID(rawRoomID) if err != nil { diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index b0148bef5..ec0b27adc 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -392,7 +392,7 @@ func currentRoomStateRowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, er }) } - return events, nil + return events, rows.Err() } func rowsToEvents(rows *sql.Rows) ([]*rstypes.HeaderedEvent, error) { diff --git a/syncapi/storage/postgres/memberships_table.go b/syncapi/storage/postgres/memberships_table.go index 4fe4260da..e5208b891 100644 --- a/syncapi/storage/postgres/memberships_table.go +++ b/syncapi/storage/postgres/memberships_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "fmt" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/storage/tables" @@ -160,6 +161,7 @@ func (s *membershipsStatements) SelectMemberships( if err != nil { return } + defer internal.CloseAndLogIfError(ctx, rows, "SelectMemberships: failed to close rows") var ( eventID string ) diff --git a/syncapi/storage/postgres/peeks_table.go b/syncapi/storage/postgres/peeks_table.go index 64183073d..1120dce0b 100644 --- a/syncapi/storage/postgres/peeks_table.go +++ b/syncapi/storage/postgres/peeks_table.go @@ -164,7 +164,7 @@ func (s *peekStatements) SelectPeekingDevices( devices = append(devices, types.PeekingDevice{UserID: userID, DeviceID: deviceID}) result[roomID] = devices } - return result, nil + return result, rows.Err() } func (s *peekStatements) SelectMaxPeekID( diff --git a/syncapi/storage/postgres/presence_table.go b/syncapi/storage/postgres/presence_table.go index f37b5331e..53acecce5 100644 --- a/syncapi/storage/postgres/presence_table.go +++ b/syncapi/storage/postgres/presence_table.go @@ -144,7 +144,7 @@ func (p *presenceStatements) GetPresenceForUsers( presence.ClientFields.Presence = presence.Presence.String() result = append(result, presence) } - return result, err + return result, rows.Err() } func (p *presenceStatements) GetMaxPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index 78b2e397c..f430fcc05 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -177,7 +177,7 @@ func (s *currentRoomStateStatements) SelectJoinedUsers( users = append(users, userID) result[roomID] = users } - return result, nil + return result, rows.Err() } // SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room. @@ -236,7 +236,7 @@ func (s *currentRoomStateStatements) SelectRoomIDsWithMembership( } result = append(result, roomID) } - return result, nil + return result, rows.Err() } // SelectRoomIDsWithAnyMembership returns a map of all memberships for the given user. @@ -419,7 +419,7 @@ func currentRoomStateRowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, er }) } - return events, nil + return events, rows.Err() } func rowsToEvents(rows *sql.Rows) ([]*rstypes.HeaderedEvent, error) { diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go index ebb469d24..e50b5cbf8 100644 --- a/syncapi/storage/sqlite3/invites_table.go +++ b/syncapi/storage/sqlite3/invites_table.go @@ -176,7 +176,7 @@ func (s *inviteEventsStatements) SelectInviteEventsInRange( if lastPos == 0 { lastPos = r.To } - return result, retired, lastPos, nil + return result, retired, lastPos, rows.Err() } func (s *inviteEventsStatements) SelectMaxInviteID( diff --git a/syncapi/storage/sqlite3/memberships_table.go b/syncapi/storage/sqlite3/memberships_table.go index a1b16306c..9e50422e5 100644 --- a/syncapi/storage/sqlite3/memberships_table.go +++ b/syncapi/storage/sqlite3/memberships_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "fmt" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/storage/tables" @@ -163,6 +164,7 @@ func (s *membershipsStatements) SelectMemberships( if err != nil { return } + defer internal.CloseAndLogIfError(ctx, rows, "SelectMemberships: failed to close rows") var eventID string for rows.Next() { if err = rows.Scan(&eventID); err != nil { diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index 93caee806..c7b11d3ef 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -274,7 +274,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange( } } - return stateNeeded, eventIDToEvent, nil + return stateNeeded, eventIDToEvent, rows.Err() } // MaxID returns the ID of the last inserted event in this table. 'txn' is optional. If it is not supplied, @@ -520,7 +520,7 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) { ExcludeFromSync: excludeFromSync, }) } - return result, nil + return result, rows.Err() } func (s *outputRoomEventsStatements) SelectContextEvent( ctx context.Context, txn *sql.Tx, roomID, eventID string, diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go index 36967d1e7..c00fb7a79 100644 --- a/syncapi/storage/sqlite3/output_room_events_topology_table.go +++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/storage/tables" @@ -137,6 +138,7 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange( } else if err != nil { return } + defer internal.CloseAndLogIfError(ctx, rows, "SelectEventIDsInRange: failed to close rows") // Return the IDs. var eventID string @@ -155,7 +157,7 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange( start = tokens[0] end = tokens[len(tokens)-1] } - + err = rows.Err() return } diff --git a/syncapi/storage/sqlite3/peeks_table.go b/syncapi/storage/sqlite3/peeks_table.go index 5d5200abc..d8998e2b8 100644 --- a/syncapi/storage/sqlite3/peeks_table.go +++ b/syncapi/storage/sqlite3/peeks_table.go @@ -184,7 +184,7 @@ func (s *peekStatements) SelectPeekingDevices( devices = append(devices, types.PeekingDevice{UserID: userID, DeviceID: deviceID}) result[roomID] = devices } - return result, nil + return result, rows.Err() } func (s *peekStatements) SelectMaxPeekID( diff --git a/syncapi/storage/sqlite3/presence_table.go b/syncapi/storage/sqlite3/presence_table.go index 573fbad6c..40b57e75d 100644 --- a/syncapi/storage/sqlite3/presence_table.go +++ b/syncapi/storage/sqlite3/presence_table.go @@ -169,7 +169,7 @@ func (p *presenceStatements) GetPresenceForUsers( presence.ClientFields.Presence = presence.Presence.String() result = append(result, presence) } - return result, err + return result, rows.Err() } func (p *presenceStatements) GetMaxPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { diff --git a/userapi/consumers/devicelistupdate.go b/userapi/consumers/devicelistupdate.go index 3389bb808..b3ccb573b 100644 --- a/userapi/consumers/devicelistupdate.go +++ b/userapi/consumers/devicelistupdate.go @@ -17,6 +17,7 @@ package consumers import ( "context" "encoding/json" + "time" "github.com/matrix-org/dendrite/userapi/internal" "github.com/matrix-org/gomatrixserverlib" @@ -82,7 +83,10 @@ func (t *DeviceListUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.M return true } - err := t.updater.Update(ctx, m) + timeoutCtx, cancel := context.WithTimeout(ctx, time.Second*30) + defer cancel() + + err := t.updater.Update(timeoutCtx, m) if err != nil { logrus.WithFields(logrus.Fields{ "user_id": m.UserID, diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index 6da41f8a1..fca741298 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -538,8 +538,8 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype if err != nil { return fmt.Errorf("pushrules.ActionsToTweaks: %w", err) } - // TODO: support coalescing. - if a != pushrules.NotifyAction && a != pushrules.CoalesceAction { + + if a != pushrules.NotifyAction { log.WithFields(log.Fields{ "event_id": event.EventID(), "room_id": event.RoomID().String(), diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go index 49dd5b238..7b7c08618 100644 --- a/userapi/consumers/roomserver_test.go +++ b/userapi/consumers/roomserver_test.go @@ -81,12 +81,8 @@ func Test_evaluatePushRules(t *testing.T) { { name: "m.reaction doesn't notify", eventContent: `{"type":"m.reaction","room_id":"!room:example.com"}`, - wantAction: pushrules.DontNotifyAction, - wantActions: []*pushrules.Action{ - { - Kind: pushrules.DontNotifyAction, - }, - }, + wantAction: pushrules.UnknownAction, + wantActions: []*pushrules.Action{}, }, { name: "m.room.message notifies", @@ -136,7 +132,7 @@ func Test_evaluatePushRules(t *testing.T) { t.Fatalf("expected action to be '%s', got '%s'", tc.wantAction, gotAction) } // this is taken from `notifyLocal` - if tc.wantNotify && gotAction != pushrules.NotifyAction && gotAction != pushrules.CoalesceAction { + if tc.wantNotify && gotAction != pushrules.NotifyAction { t.Fatalf("expected to notify but didn't") } }) diff --git a/userapi/internal/device_list_update.go b/userapi/internal/device_list_update.go index 3fccf56bb..b40635160 100644 --- a/userapi/internal/device_list_update.go +++ b/userapi/internal/device_list_update.go @@ -21,9 +21,11 @@ import ( "fmt" "hash/fnv" "net" + "strconv" "sync" "time" + "github.com/matrix-org/dendrite/federationapi/statistics" rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/spec" @@ -107,6 +109,8 @@ type DeviceListUpdater struct { userIDToChan map[string]chan bool userIDToChanMu *sync.Mutex rsAPI rsapi.KeyserverRoomserverAPI + + isBlacklistedOrBackingOffFn func(s spec.ServerName) (*statistics.ServerStatistics, error) } // DeviceListUpdaterDatabase is the subset of functionality from storage.Database required for the updater. @@ -142,26 +146,52 @@ type KeyChangeProducer interface { ProduceKeyChanges(keys []api.DeviceMessage) error } +var deviceListUpdaterBackpressure = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: "dendrite", + Subsystem: "keyserver", + Name: "worker_backpressure", + Help: "How many device list updater requests are queued", + }, + []string{"worker_id"}, +) +var deviceListUpdaterServersRetrying = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: "dendrite", + Subsystem: "keyserver", + Name: "worker_servers_retrying", + Help: "How many servers are queued for retry", + }, + []string{"worker_id"}, +) + // NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale. func NewDeviceListUpdater( process *process.ProcessContext, db DeviceListUpdaterDatabase, api DeviceListUpdaterAPI, producer KeyChangeProducer, fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int, - rsAPI rsapi.KeyserverRoomserverAPI, thisServer spec.ServerName, + rsAPI rsapi.KeyserverRoomserverAPI, + thisServer spec.ServerName, + enableMetrics bool, + isBlacklistedOrBackingOffFn func(s spec.ServerName) (*statistics.ServerStatistics, error), ) *DeviceListUpdater { + if enableMetrics { + prometheus.MustRegister(deviceListUpdaterBackpressure, deviceListUpdaterServersRetrying) + } return &DeviceListUpdater{ - process: process, - userIDToMutex: make(map[string]*sync.Mutex), - mu: &sync.Mutex{}, - db: db, - api: api, - producer: producer, - fedClient: fedClient, - thisServer: thisServer, - workerChans: make([]chan spec.ServerName, numWorkers), - userIDToChan: make(map[string]chan bool), - userIDToChanMu: &sync.Mutex{}, - rsAPI: rsAPI, + process: process, + userIDToMutex: make(map[string]*sync.Mutex), + mu: &sync.Mutex{}, + db: db, + api: api, + producer: producer, + fedClient: fedClient, + thisServer: thisServer, + workerChans: make([]chan spec.ServerName, numWorkers), + userIDToChan: make(map[string]chan bool), + userIDToChanMu: &sync.Mutex{}, + rsAPI: rsAPI, + isBlacklistedOrBackingOffFn: isBlacklistedOrBackingOffFn, } } @@ -173,18 +203,20 @@ func (u *DeviceListUpdater) Start() error { // to stop (in this transaction) until key requests can be made. ch := make(chan spec.ServerName, 10) u.workerChans[i] = ch - go u.worker(ch) + go u.worker(ch, i) } staleLists, err := u.db.StaleDeviceLists(u.process.Context(), []spec.ServerName{}) if err != nil { return err } + + newStaleLists := dedupeStaleLists(staleLists) offset, step := time.Second*10, time.Second - if max := len(staleLists); max > 120 { + if max := len(newStaleLists); max > 120 { step = (time.Second * 120) / time.Duration(max) } - for _, userID := range staleLists { + for _, userID := range newStaleLists { userID := userID // otherwise we are only sending the last entry time.AfterFunc(offset, func() { u.notifyWorkers(userID) @@ -336,11 +368,22 @@ func (u *DeviceListUpdater) notifyWorkers(userID string) { if err != nil { return } + _, err = u.isBlacklistedOrBackingOffFn(remoteServer) + var federationClientError *fedsenderapi.FederationClientError + if errors.As(err, &federationClientError) { + if federationClientError.Blacklisted { + return + } + } + hash := fnv.New32a() _, _ = hash.Write([]byte(remoteServer)) index := int(int64(hash.Sum32()) % int64(len(u.workerChans))) ch := u.assignChannel(userID) + // Since workerChans are buffered, we only increment here and let the worker + // decrement it once it is done processing. + deviceListUpdaterBackpressure.With(prometheus.Labels{"worker_id": strconv.Itoa(index)}).Inc() u.workerChans[index] <- remoteServer select { case <-ch: @@ -370,27 +413,44 @@ func (u *DeviceListUpdater) clearChannel(userID string) { } } -func (u *DeviceListUpdater) worker(ch chan spec.ServerName) { +func (u *DeviceListUpdater) worker(ch chan spec.ServerName, workerID int) { retries := make(map[spec.ServerName]time.Time) retriesMu := &sync.Mutex{} // restarter goroutine which will inject failed servers into ch when it is time go func() { var serversToRetry []spec.ServerName for { - serversToRetry = serversToRetry[:0] // reuse memory - time.Sleep(time.Second) + // nuke serversToRetry by re-slicing it to be "empty". + // The capacity of the slice is unchanged, which ensures we can reuse the memory. + serversToRetry = serversToRetry[:0] + + deviceListUpdaterServersRetrying.With(prometheus.Labels{"worker_id": strconv.Itoa(workerID)}).Set(float64(len(retries))) + time.Sleep(time.Second * 2) + + // -2, so we have space for incoming device list updates over federation + maxServers := (cap(ch) - len(ch)) - 2 + if maxServers <= 0 { + continue + } + retriesMu.Lock() now := time.Now() for srv, retryAt := range retries { if now.After(retryAt) { serversToRetry = append(serversToRetry, srv) + if maxServers == len(serversToRetry) { + break + } } } + for _, srv := range serversToRetry { delete(retries, srv) } retriesMu.Unlock() + for _, srv := range serversToRetry { + deviceListUpdaterBackpressure.With(prometheus.Labels{"worker_id": strconv.Itoa(workerID)}).Inc() ch <- srv } } @@ -399,8 +459,18 @@ func (u *DeviceListUpdater) worker(ch chan spec.ServerName) { retriesMu.Lock() _, exists := retries[serverName] retriesMu.Unlock() - if exists { - // Don't retry a server that we're already waiting for. + + // If the serverName is coming from retries, maybe it was + // blacklisted in the meantime. + _, err := u.isBlacklistedOrBackingOffFn(serverName) + var federationClientError *fedsenderapi.FederationClientError + // unwrap errors and check for FederationClientError, if found, federationClientError will be not nil + errors.As(err, &federationClientError) + isBlacklisted := federationClientError != nil && federationClientError.Blacklisted + + // Don't retry a server that we're already waiting for or is blacklisted by now. + if exists || isBlacklisted { + deviceListUpdaterBackpressure.With(prometheus.Labels{"worker_id": strconv.Itoa(workerID)}).Dec() continue } waitTime, shouldRetry := u.processServer(serverName) @@ -411,11 +481,18 @@ func (u *DeviceListUpdater) worker(ch chan spec.ServerName) { } retriesMu.Unlock() } + deviceListUpdaterBackpressure.With(prometheus.Labels{"worker_id": strconv.Itoa(workerID)}).Dec() } } func (u *DeviceListUpdater) processServer(serverName spec.ServerName) (time.Duration, bool) { ctx := u.process.Context() + // If the process.Context is canceled, there is no need to go further. + // This avoids spamming the logs when shutting down + if errors.Is(ctx.Err(), context.Canceled) { + return defaultWaitTime, false + } + logger := util.GetLogger(ctx).WithField("server_name", serverName) deviceListUpdateCount.WithLabelValues(string(serverName)).Inc() @@ -428,13 +505,6 @@ func (u *DeviceListUpdater) processServer(serverName spec.ServerName) (time.Dura return waitTime, true } - defer func() { - for _, userID := range userIDs { - // always clear the channel to unblock Update calls regardless of success/failure - u.clearChannel(userID) - } - }() - for _, userID := range userIDs { userWait, err := u.processServerUser(ctx, serverName, userID) if err != nil { @@ -461,6 +531,11 @@ func (u *DeviceListUpdater) processServer(serverName spec.ServerName) (time.Dura func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName spec.ServerName, userID string) (time.Duration, error) { ctx, cancel := context.WithTimeout(ctx, requestTimeout) defer cancel() + + // If we are processing more than one user per server, this unblocks further calls to Update + // immediately instead of just after **all** users have been processed. + defer u.clearChannel(userID) + logger := util.GetLogger(ctx).WithFields(logrus.Fields{ "server_name": serverName, "user_id": userID, @@ -579,3 +654,24 @@ func (u *DeviceListUpdater) updateDeviceList(res *fclient.RespUserDevices) error } return nil } + +// dedupeStaleLists de-duplicates the stateList entries using the domain. +// This is used on startup, processServer is getting all users anyway, so +// there is no need to send every user to the workers. +func dedupeStaleLists(staleLists []string) []string { + seenDomains := make(map[spec.ServerName]struct{}) + newStaleLists := make([]string, 0, len(staleLists)) + for _, userID := range staleLists { + _, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + // non-fatal and should not block starting up + continue + } + if _, ok := seenDomains[domain]; ok { + continue + } + newStaleLists = append(newStaleLists, userID) + seenDomains[domain] = struct{}{} + } + return newStaleLists +} diff --git a/userapi/internal/device_list_update_test.go b/userapi/internal/device_list_update_test.go index 10b9c6521..a2f1869d1 100644 --- a/userapi/internal/device_list_update_test.go +++ b/userapi/internal/device_list_update_test.go @@ -27,6 +27,9 @@ import ( "testing" "time" + api2 "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/federationapi/statistics" + "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" @@ -128,6 +131,10 @@ type mockDeviceListUpdaterAPI struct { func (d *mockDeviceListUpdaterAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) { } +var testIsBlacklistedOrBackingOff = func(s spec.ServerName) (*statistics.ServerStatistics, error) { + return &statistics.ServerStatistics{}, nil +} + type roundTripper struct { fn func(*http.Request) (*http.Response, error) } @@ -161,7 +168,7 @@ func TestUpdateHavePrevID(t *testing.T) { } ap := &mockDeviceListUpdaterAPI{} producer := &mockKeyChangeProducer{} - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, nil, "localhost") + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, nil, "localhost", caching.DisableMetrics, testIsBlacklistedOrBackingOff) event := gomatrixserverlib.DeviceListUpdateEvent{ DeviceDisplayName: "Foo Bar", Deleted: false, @@ -233,7 +240,7 @@ func TestUpdateNoPrevID(t *testing.T) { `)), }, nil }) - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, nil, "example.test") + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, nil, "example.test", caching.DisableMetrics, testIsBlacklistedOrBackingOff) if err := updater.Start(); err != nil { t.Fatalf("failed to start updater: %s", err) } @@ -303,7 +310,7 @@ func TestDebounce(t *testing.T) { close(incomingFedReq) return <-fedCh, nil }) - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, nil, "localhost") + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, nil, "localhost", caching.DisableMetrics, testIsBlacklistedOrBackingOff) if err := updater.Start(); err != nil { t.Fatalf("failed to start updater: %s", err) } @@ -406,7 +413,7 @@ func TestDeviceListUpdater_CleanUp(t *testing.T) { updater := NewDeviceListUpdater(processCtx, db, nil, nil, nil, - 0, rsAPI, "test") + 0, rsAPI, "test", caching.DisableMetrics, testIsBlacklistedOrBackingOff) if err := updater.CleanUp(); err != nil { t.Error(err) } @@ -428,3 +435,114 @@ func TestDeviceListUpdater_CleanUp(t *testing.T) { } }) } + +func Test_dedupeStateList(t *testing.T) { + alice := "@alice:localhost" + bob := "@bob:localhost" + charlie := "@charlie:notlocalhost" + invalidUserID := "iaminvalid:localhost" + + tests := []struct { + name string + staleLists []string + want []string + }{ + { + name: "empty stateLists", + staleLists: []string{}, + want: []string{}, + }, + { + name: "single entry", + staleLists: []string{alice}, + want: []string{alice}, + }, + { + name: "multiple entries without dupe servers", + staleLists: []string{alice, charlie}, + want: []string{alice, charlie}, + }, + { + name: "multiple entries with dupe servers", + staleLists: []string{alice, bob, charlie}, + want: []string{alice, charlie}, + }, + { + name: "list with invalid userID", + staleLists: []string{alice, bob, invalidUserID}, + want: []string{alice}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := dedupeStaleLists(tt.staleLists); !reflect.DeepEqual(got, tt.want) { + t.Errorf("dedupeStaleLists() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestDeviceListUpdaterIgnoreBlacklisted(t *testing.T) { + unreachableServer := spec.ServerName("notlocalhost") + + updater := DeviceListUpdater{ + workerChans: make([]chan spec.ServerName, 1), + isBlacklistedOrBackingOffFn: func(s spec.ServerName) (*statistics.ServerStatistics, error) { + switch s { + case unreachableServer: + return nil, &api2.FederationClientError{Blacklisted: true} + } + return nil, nil + }, + mu: &sync.Mutex{}, + userIDToChanMu: &sync.Mutex{}, + userIDToChan: make(map[string]chan bool), + userIDToMutex: make(map[string]*sync.Mutex), + } + workerCh := make(chan spec.ServerName) + defer close(workerCh) + updater.workerChans[0] = workerCh + + // happy case + alice := "@alice:localhost" + aliceCh := updater.assignChannel(alice) + defer updater.clearChannel(alice) + + // failing case + bob := "@bob:" + unreachableServer + bobCh := updater.assignChannel(string(bob)) + defer updater.clearChannel(string(bob)) + + expectedServers := map[spec.ServerName]struct{}{ + "localhost": {}, + } + unexpectedServers := make(map[spec.ServerName]struct{}) + + go func() { + for serverName := range workerCh { + switch serverName { + case "localhost": + delete(expectedServers, serverName) + aliceCh <- true // unblock notifyWorkers + case unreachableServer: // this should not happen as it is "filtered" away by the blacklist + unexpectedServers[serverName] = struct{}{} + bobCh <- true + default: + unexpectedServers[serverName] = struct{}{} + } + } + }() + + // alice is not blacklisted + updater.notifyWorkers(alice) + // bob is blacklisted + updater.notifyWorkers(string(bob)) + + for server := range expectedServers { + t.Errorf("Server still in expectedServers map: %s", server) + } + + for server := range unexpectedServers { + t.Errorf("unexpected server in result: %s", server) + } +} diff --git a/userapi/storage/postgres/cross_signing_keys_table.go b/userapi/storage/postgres/cross_signing_keys_table.go index 138b629d7..deb355718 100644 --- a/userapi/storage/postgres/cross_signing_keys_table.go +++ b/userapi/storage/postgres/cross_signing_keys_table.go @@ -77,7 +77,7 @@ func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( for rows.Next() { var keyTypeInt int16 var keyData spec.Base64Bytes - if err := rows.Scan(&keyTypeInt, &keyData); err != nil { + if err = rows.Scan(&keyTypeInt, &keyData); err != nil { return nil, err } keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt] @@ -86,6 +86,7 @@ func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( } r[keyType] = keyData } + err = rows.Err() return } diff --git a/userapi/storage/postgres/cross_signing_sigs_table.go b/userapi/storage/postgres/cross_signing_sigs_table.go index 61a381184..cba015e13 100644 --- a/userapi/storage/postgres/cross_signing_sigs_table.go +++ b/userapi/storage/postgres/cross_signing_sigs_table.go @@ -98,7 +98,7 @@ func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( var userID string var keyID gomatrixserverlib.KeyID var signature spec.Base64Bytes - if err := rows.Scan(&userID, &keyID, &signature); err != nil { + if err = rows.Scan(&userID, &keyID, &signature); err != nil { return nil, err } if _, ok := r[userID]; !ok { @@ -106,6 +106,7 @@ func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( } r[userID][keyID] = signature } + err = rows.Err() return } diff --git a/userapi/storage/postgres/key_backup_table.go b/userapi/storage/postgres/key_backup_table.go index 91a34c357..59944a125 100644 --- a/userapi/storage/postgres/key_backup_table.go +++ b/userapi/storage/postgres/key_backup_table.go @@ -162,5 +162,5 @@ func unpackKeys(ctx context.Context, rows *sql.Rows) (map[string]map[string]api. roomData[key.SessionID] = key.KeyBackupSession result[key.RoomID] = roomData } - return result, nil + return result, rows.Err() } diff --git a/userapi/storage/postgres/key_changes_table.go b/userapi/storage/postgres/key_changes_table.go index a00494140..de3a9e9c8 100644 --- a/userapi/storage/postgres/key_changes_table.go +++ b/userapi/storage/postgres/key_changes_table.go @@ -115,7 +115,7 @@ func (s *keyChangesStatements) SelectKeyChanges( for rows.Next() { var userID string var offset int64 - if err := rows.Scan(&userID, &offset); err != nil { + if err = rows.Scan(&userID, &offset); err != nil { return nil, 0, err } if offset > latestOffset { @@ -123,5 +123,6 @@ func (s *keyChangesStatements) SelectKeyChanges( } userIDs = append(userIDs, userID) } + err = rows.Err() return } diff --git a/userapi/storage/postgres/one_time_keys_table.go b/userapi/storage/postgres/one_time_keys_table.go index 972a59147..a00f4d6f6 100644 --- a/userapi/storage/postgres/one_time_keys_table.go +++ b/userapi/storage/postgres/one_time_keys_table.go @@ -134,7 +134,7 @@ func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, de } counts.KeyCount[algorithm] = count } - return counts, nil + return counts, rows.Err() } func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) { diff --git a/userapi/storage/postgres/profile_table.go b/userapi/storage/postgres/profile_table.go index e404c32f2..e4f55ed94 100644 --- a/userapi/storage/postgres/profile_table.go +++ b/userapi/storage/postgres/profile_table.go @@ -165,5 +165,5 @@ func (s *profilesStatements) SelectProfilesBySearch( profiles = append(profiles, profile) } } - return profiles, nil + return profiles, rows.Err() } diff --git a/userapi/storage/postgres/threepid_table.go b/userapi/storage/postgres/threepid_table.go index 15b42a0a6..fc46061dc 100644 --- a/userapi/storage/postgres/threepid_table.go +++ b/userapi/storage/postgres/threepid_table.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/gomatrixserverlib/spec" @@ -94,6 +95,7 @@ func (s *threepidStatements) SelectThreePIDsForLocalpart( if err != nil { return } + defer internal.CloseAndLogIfError(ctx, rows, "SelectThreePIDsForLocalpart: failed to close rows") threepids = []authtypes.ThreePID{} for rows.Next() { @@ -107,7 +109,7 @@ func (s *threepidStatements) SelectThreePIDsForLocalpart( Medium: medium, }) } - + err = rows.Err() return } diff --git a/userapi/storage/sqlite3/account_data_table.go b/userapi/storage/sqlite3/account_data_table.go index 3a6367c45..240647b32 100644 --- a/userapi/storage/sqlite3/account_data_table.go +++ b/userapi/storage/sqlite3/account_data_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "encoding/json" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/gomatrixserverlib/spec" @@ -95,6 +96,7 @@ func (s *accountDataStatements) SelectAccountData( if err != nil { return nil, nil, err } + defer internal.CloseAndLogIfError(ctx, rows, "SelectAccountData: failed to close rows") global := map[string]json.RawMessage{} rooms := map[string]map[string]json.RawMessage{} @@ -118,7 +120,7 @@ func (s *accountDataStatements) SelectAccountData( } } - return global, rooms, nil + return global, rooms, rows.Err() } func (s *accountDataStatements) SelectAccountDataByType( diff --git a/userapi/storage/sqlite3/cross_signing_keys_table.go b/userapi/storage/sqlite3/cross_signing_keys_table.go index 5c2ce7039..65b9ff1af 100644 --- a/userapi/storage/sqlite3/cross_signing_keys_table.go +++ b/userapi/storage/sqlite3/cross_signing_keys_table.go @@ -76,7 +76,7 @@ func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( for rows.Next() { var keyTypeInt int16 var keyData spec.Base64Bytes - if err := rows.Scan(&keyTypeInt, &keyData); err != nil { + if err = rows.Scan(&keyTypeInt, &keyData); err != nil { return nil, err } keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt] @@ -85,6 +85,7 @@ func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( } r[keyType] = keyData } + err = rows.Err() return } diff --git a/userapi/storage/sqlite3/cross_signing_sigs_table.go b/userapi/storage/sqlite3/cross_signing_sigs_table.go index 657264115..bf400a00e 100644 --- a/userapi/storage/sqlite3/cross_signing_sigs_table.go +++ b/userapi/storage/sqlite3/cross_signing_sigs_table.go @@ -96,7 +96,7 @@ func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( var userID string var keyID gomatrixserverlib.KeyID var signature spec.Base64Bytes - if err := rows.Scan(&userID, &keyID, &signature); err != nil { + if err = rows.Scan(&userID, &keyID, &signature); err != nil { return nil, err } if _, ok := r[userID]; !ok { @@ -104,6 +104,7 @@ func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( } r[userID][keyID] = signature } + err = rows.Err() return } diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go index 23e823116..5ce285c87 100644 --- a/userapi/storage/sqlite3/devices_table.go +++ b/userapi/storage/sqlite3/devices_table.go @@ -296,6 +296,7 @@ func (s *devicesStatements) SelectDevicesByLocalpart( if err != nil { return devices, err } + defer internal.CloseAndLogIfError(ctx, rows, "SelectDevicesByLocalpart: failed to close rows") var dev api.Device var lastseents sql.NullInt64 @@ -325,7 +326,7 @@ func (s *devicesStatements) SelectDevicesByLocalpart( devices = append(devices, dev) } - return devices, nil + return devices, rows.Err() } func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { diff --git a/userapi/storage/sqlite3/key_backup_table.go b/userapi/storage/sqlite3/key_backup_table.go index ed2746310..1cdaca180 100644 --- a/userapi/storage/sqlite3/key_backup_table.go +++ b/userapi/storage/sqlite3/key_backup_table.go @@ -162,5 +162,5 @@ func unpackKeys(ctx context.Context, rows *sql.Rows) (map[string]map[string]api. roomData[key.SessionID] = key.KeyBackupSession result[key.RoomID] = roomData } - return result, nil + return result, rows.Err() } diff --git a/userapi/storage/sqlite3/key_changes_table.go b/userapi/storage/sqlite3/key_changes_table.go index 923bb57eb..7a4898cfb 100644 --- a/userapi/storage/sqlite3/key_changes_table.go +++ b/userapi/storage/sqlite3/key_changes_table.go @@ -113,7 +113,7 @@ func (s *keyChangesStatements) SelectKeyChanges( for rows.Next() { var userID string var offset int64 - if err := rows.Scan(&userID, &offset); err != nil { + if err = rows.Scan(&userID, &offset); err != nil { return nil, 0, err } if offset > latestOffset { @@ -121,5 +121,6 @@ func (s *keyChangesStatements) SelectKeyChanges( } userIDs = append(userIDs, userID) } + err = rows.Err() return } diff --git a/userapi/storage/sqlite3/one_time_keys_table.go b/userapi/storage/sqlite3/one_time_keys_table.go index a992d399c..2a5b1280b 100644 --- a/userapi/storage/sqlite3/one_time_keys_table.go +++ b/userapi/storage/sqlite3/one_time_keys_table.go @@ -140,7 +140,7 @@ func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, de } counts.KeyCount[algorithm] = count } - return counts, nil + return counts, rows.Err() } func (s *oneTimeKeysStatements) InsertOneTimeKeys( diff --git a/userapi/storage/sqlite3/profile_table.go b/userapi/storage/sqlite3/profile_table.go index a20d7e848..7285110bc 100644 --- a/userapi/storage/sqlite3/profile_table.go +++ b/userapi/storage/sqlite3/profile_table.go @@ -173,5 +173,5 @@ func (s *profilesStatements) SelectProfilesBySearch( profiles = append(profiles, profile) } } - return profiles, nil + return profiles, rows.Err() } diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index a46ee9ebb..5a789dfd7 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -1,6 +1,7 @@ package storage_test import ( + "bytes" "context" "encoding/json" "fmt" @@ -758,3 +759,53 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) { } }) } + +func TestOneTimeKeys(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clean := mustCreateKeyDatabase(t, dbType) + defer clean() + userID := "@alice:localhost" + deviceID := "alice_device" + otk := api.OneTimeKeys{ + UserID: userID, + DeviceID: deviceID, + KeyJSON: map[string]json.RawMessage{"curve25519:KEY1": []byte(`{"key":"v1"}`)}, + } + + // Add a one time key to the DB + _, err := db.StoreOneTimeKeys(ctx, otk) + MustNotError(t, err) + + // Check the count of one time keys is correct + count, err := db.OneTimeKeysCount(ctx, userID, deviceID) + MustNotError(t, err) + if count.KeyCount["curve25519"] != 1 { + t.Fatalf("Expected 1 key, got %d", count.KeyCount["curve25519"]) + } + + // Check the actual key contents are correct + keysJSON, err := db.ExistingOneTimeKeys(ctx, userID, deviceID, []string{"curve25519:KEY1"}) + MustNotError(t, err) + keyJSON, err := keysJSON["curve25519:KEY1"].MarshalJSON() + MustNotError(t, err) + if !bytes.Equal(keyJSON, []byte(`{"key":"v1"}`)) { + t.Fatalf("Existing keys do not match expected. Got %v", keysJSON["curve25519:KEY1"]) + } + + // Claim a one time key from the database. This should remove it from the database. + claimedKeys, err := db.ClaimKeys(ctx, map[string]map[string]string{userID: {deviceID: "curve25519"}}) + MustNotError(t, err) + + // Check the claimed key contents are correct + if !reflect.DeepEqual(claimedKeys[0], otk) { + t.Fatalf("Expected to claim stored key %v. Got %v", otk, claimedKeys[0]) + } + + // Check the count of one time keys is now zero + count, err = db.OneTimeKeysCount(ctx, userID, deviceID) + MustNotError(t, err) + if count.KeyCount["curve25519"] != 0 { + t.Fatalf("Expected 0 keys, got %d", count.KeyCount["curve25519"]) + } + }) +} diff --git a/userapi/userapi.go b/userapi/userapi.go index 6b6dac884..a1c9b94a9 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -18,10 +18,12 @@ import ( "time" fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" rsapi "github.com/matrix-org/dendrite/roomserver/api" @@ -46,6 +48,8 @@ func NewInternalAPI( natsInstance *jetstream.NATSInstance, rsAPI rsapi.UserRoomserverAPI, fedClient fedsenderapi.KeyserverFederationAPI, + enableMetrics bool, + blacklistedOrBackingOffFn func(s spec.ServerName) (*statistics.ServerStatistics, error), ) *internal.UserInternalAPI { js, _ := natsInstance.Prepare(processContext, &dendriteCfg.Global.JetStream) appServices := dendriteCfg.Derived.ApplicationServices @@ -99,7 +103,7 @@ func NewInternalAPI( FedClient: fedClient, } - updater := internal.NewDeviceListUpdater(processContext, keyDB, userAPI, keyChangeProducer, fedClient, 8, rsAPI, dendriteCfg.Global.ServerName) // 8 workers TODO: configurable + updater := internal.NewDeviceListUpdater(processContext, keyDB, userAPI, keyChangeProducer, fedClient, dendriteCfg.UserAPI.WorkerCount, rsAPI, dendriteCfg.Global.ServerName, enableMetrics, blacklistedOrBackingOffFn) userAPI.Updater = updater // Remove users which we don't share a room with anymore if err := updater.CleanUp(); err != nil {