diff --git a/.github/workflows/dendrite.yml b/.github/workflows/dendrite.yml
index 9b452d849..9fa6cf197 100644
--- a/.github/workflows/dendrite.yml
+++ b/.github/workflows/dendrite.yml
@@ -67,7 +67,7 @@ jobs:
steps:
- uses: actions/checkout@v3
- name: golangci-lint
- uses: golangci/golangci-lint-action@v2
+ uses: golangci/golangci-lint-action@v3
# run go test with different go versions
test:
@@ -97,7 +97,7 @@ jobs:
strategy:
fail-fast: false
matrix:
- go: ["1.18"]
+ go: ["1.18", "1.19"]
steps:
- uses: actions/checkout@v3
- name: Setup go
@@ -127,7 +127,7 @@ jobs:
strategy:
fail-fast: false
matrix:
- go: ["1.18"]
+ go: ["1.18", "1.19"]
goos: ["linux"]
goarch: ["amd64", "386"]
steps:
@@ -151,6 +151,7 @@ jobs:
GOOS: ${{ matrix.goos }}
GOARCH: ${{ matrix.goarch }}
CGO_ENABLED: 1
+ CGO_CFLAGS: -fno-stack-protector
run: go build -trimpath -v -o "bin/" ./cmd/...
# build for Windows 64-bit
@@ -160,7 +161,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
- go: ["1.18"]
+ go: ["1.18", "1.19"]
goos: ["windows"]
goarch: ["amd64"]
steps:
@@ -223,6 +224,31 @@ jobs:
- name: Test upgrade
run: ./dendrite-upgrade-tests --head .
+ # run database upgrade tests, skipping over one version
+ upgrade_test_direct:
+ name: Upgrade tests from HEAD-2
+ timeout-minutes: 20
+ needs: initial-tests-done
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+ - name: Setup go
+ uses: actions/setup-go@v2
+ with:
+ go-version: "1.18"
+ - uses: actions/cache@v3
+ with:
+ path: |
+ ~/.cache/go-build
+ ~/go/pkg/mod
+ key: ${{ runner.os }}-go-upgrade-${{ hashFiles('**/go.sum') }}
+ restore-keys: |
+ ${{ runner.os }}-go-upgrade
+ - name: Build upgrade-tests
+ run: go build ./cmd/dendrite-upgrade-tests
+ - name: Test upgrade
+ run: ./dendrite-upgrade-tests -direct -from HEAD-2 --head .
+
# run Sytest in different variations
sytest:
timeout-minutes: 20
@@ -359,7 +385,14 @@ jobs:
integration-tests-done:
name: Integration tests passed
- needs: [initial-tests-done, upgrade_test, sytest, complement]
+ needs:
+ [
+ initial-tests-done,
+ upgrade_test,
+ upgrade_test_direct,
+ sytest,
+ complement,
+ ]
runs-on: ubuntu-latest
if: ${{ !cancelled() }} # Run this even if prior jobs were skipped
steps:
diff --git a/CHANGES.md b/CHANGES.md
index 3df03b2f6..5dd8da362 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,5 +1,49 @@
# Changelog
+## Dendrite 0.9.1 (2022-08-03)
+
+### Fixes
+
+* Upgrades a dependency which caused issues building Dendrite with Go 1.19
+* The roomserver will no longer give up prematurely after failing to call `/state_ids`
+* Removes the faulty room info cache, which caused of a number of race conditions and occasional bugs (including when creating and joining rooms)
+* The media endpoint now sets the `Cache-Control` header correctly to prevent web-based clients from hitting media endpoints excessively
+* The sync API will now advance the PDU stream position correctly in all cases (contributed by [sergekh2](https://github.com/sergekh2))
+* The sync API will now delete the correct range of send-to-device messages when advancing the stream position
+* The device list `changed` key in the `/sync` response should now return the correct users
+* A data race when looking up missing state has been fixed
+* The `/send_join` API is now applying stronger validation to the received membership event
+
+## Dendrite 0.9.0 (2022-08-01)
+
+### Features
+
+* Dendrite now uses Ristretto for managing in-memory caches
+ * Should improve cache utilisation considerably over time by more intelligently selecting and managing cache entries compared to the previous LRU-based cache
+ * Defaults to a 1GB cache size if not configured otherwise
+ * The estimated cache size in memory and maximum age can now be configured with new [configuration options](https://github.com/matrix-org/dendrite/blob/e94ef84aaba30e12baf7f524c4e7a36d2fdeb189/dendrite-sample.monolith.yaml#L44-L61) to prevent unbounded cache growth
+* Added support for serving the `/.well-known/matrix/client` hint directly from Dendrite
+ * Configurable with the new [configuration option](https://github.com/matrix-org/dendrite/blob/e94ef84aaba30e12baf7f524c4e7a36d2fdeb189/dendrite-sample.monolith.yaml#L67-L69)
+* Refactored membership updater, which should eliminate some bugs caused by the membership table getting out of sync with the room state
+* The User API is now responsible for sending account data updates to other components, which may fix some races and duplicate account data events
+* Optimised database query for checking whether a remote server is allowed to request an event over federation without using anywhere near as much CPU time (PostgreSQL only)
+* Database migrations have been refactored to eliminate some problems that were present with `goose` and upgrading from older Dendrite versions
+* Media fetching will now use the `/v3` endpoints for downloading media from remote homeservers
+* HTTP 404 and HTTP 405 errors from the client-facing APIs should now be returned with CORS headers so that web-based clients do not produce incorrect access control warnings for unknown endpoints
+* Some preparation work for full history visibility support
+
+### Fixes
+
+* Fixes a crash that could occur during event redaction
+* The `/members` endpoint will no longer incorrectly return HTTP 500 as a result of some invite events
+* Send-to-device messages should now be ordered more reliably and the last position in the stream updated correctly
+* Parsing of appservice configuration files is now less strict (contributed by [Kab1r](https://github.com/Kab1r))
+* The sync API should now identify shared users correctly when waking up for E2EE key changes
+* The federation `/state` endpoint will now return a HTTP 403 when the state before an event isn't known instead of a HTTP 500
+* Presence timestamps should now be calculated with the correct precision
+* A race condition in the roomserver's room info has been fixed
+* A race condition in the sync API has been fixed
+
## Dendrite 0.8.9 (2022-07-01)
### Features
diff --git a/README.md b/README.md
index 8f54db7b7..10e2b1b86 100644
--- a/README.md
+++ b/README.md
@@ -21,8 +21,7 @@ As of October 2020 (current [progress below](#progress)), Dendrite has now enter
This does not mean:
- Dendrite is bug-free. It has not yet been battle-tested in the real world and so will be error prone initially.
-- All of the CS/Federation APIs are implemented. We are tracking progress via a script called 'Are We Synapse Yet?'. In particular,
- presence and push notifications are entirely missing from Dendrite. See [CHANGES.md](CHANGES.md) for updates.
+- Dendrite is feature-complete. There may be client or federation APIs that are not implemented.
- Dendrite is ready for massive homeserver deployments. You cannot shard each microservice, only run each one on a different machine.
Currently, we expect Dendrite to function well for small (10s/100s of users) homeserver deployments as well as P2P Matrix nodes in-browser or on mobile devices.
@@ -36,6 +35,9 @@ If you have further questions, please take a look at [our FAQ](docs/FAQ.md) or j
## Requirements
+See the [Planning your Installation](https://matrix-org.github.io/dendrite/installation/planning) page for
+more information on requirements.
+
To build Dendrite, you will need Go 1.18 or later.
For a usable federating Dendrite deployment, you will also need:
@@ -83,11 +85,11 @@ $ ./bin/create-account --config dendrite.yaml -username alice
Then point your favourite Matrix client at `http://localhost:8008` or `https://localhost:8448`.
-## Progress
+## Progress
We use a script called Are We Synapse Yet which checks Sytest compliance rates. Sytest is a black-box homeserver
test rig with around 900 tests. The script works out how many of these tests are passing on Dendrite and it
-updates with CI. As of April 2022 we're at around 83% CS API coverage and 95% Federation coverage, though check
+updates with CI. As of August 2022 we're at around 83% CS API coverage and 95% Federation coverage, though check
CI for the latest numbers. In practice, this means you can communicate locally and via federation with Synapse
servers such as matrix.org reasonably well, although there are still some missing features (like Search).
@@ -119,53 +121,8 @@ We would be grateful for any help on issues marked as
all have related Sytests which need to pass in order for the issue to be closed. Once you've written your
code, you can quickly run Sytest to ensure that the test names are now passing.
-For example, if the test `Local device key changes get to remote servers` was marked as failing, find the
-test file (e.g via `grep` or via the
-[CI log output](https://buildkite.com/matrix-dot-org/dendrite/builds/2826#39cff5de-e032-4ad0-ad26-f819e6919c42)
-it's `tests/50federation/40devicelists.pl` ) then to run Sytest:
-
-```
-docker run --rm --name sytest
--v "/Users/kegan/github/sytest:/sytest"
--v "/Users/kegan/github/dendrite:/src"
--v "/Users/kegan/logs:/logs"
--v "/Users/kegan/go/:/gopath"
--e "POSTGRES=1" -e "DENDRITE_TRACE_HTTP=1"
-matrixdotorg/sytest-dendrite:latest tests/50federation/40devicelists.pl
-```
-
-See [sytest.md](docs/sytest.md) for the full description of these flags.
-
-You can try running sytest outside of docker for faster runs, but the dependencies can be temperamental
-and we recommend using docker where possible.
-
-```
-cd sytest
-export PERL5LIB=$HOME/lib/perl5
-export PERL_MB_OPT=--install_base=$HOME
-export PERL_MM_OPT=INSTALL_BASE=$HOME
-./install-deps.pl
-
-./run-tests.pl -I Dendrite::Monolith -d $PATH_TO_DENDRITE_BINARIES
-```
-
-Sometimes Sytest is testing the wrong thing or is flakey, so it will need to be patched.
-Ask on `#dendrite-dev:matrix.org` if you think this is the case for you and we'll be happy to help.
-
-If you're new to the project, see [CONTRIBUTING.md](docs/CONTRIBUTING.md) to get up to speed then
+If you're new to the project, see our
+[Contributing page](https://matrix-org.github.io/dendrite/development/contributing) to get up to speed, then
look for [Good First Issues](https://github.com/matrix-org/dendrite/labels/good%20first%20issue). If you're
familiar with the project, look for [Help Wanted](https://github.com/matrix-org/dendrite/labels/help-wanted)
issues.
-
-## Hardware requirements
-
-Dendrite in Monolith + SQLite works in a range of environments including iOS and in-browser via WASM.
-
-For small homeserver installations joined on ~10s rooms on matrix.org with ~100s of users in those rooms, including some
-encrypted rooms:
-
-- Memory: uses around 100MB of RAM, with peaks at around 200MB.
-- Disk space: After a few months of usage, the database grew to around 2GB (in Monolith mode).
-- CPU: Brief spikes when processing events, typically idles at 1% CPU.
-
-This means Dendrite should comfortably work on things like Raspberry Pis.
diff --git a/appservice/inthttp/client.go b/appservice/inthttp/client.go
index 0a8baea99..3ae2c9278 100644
--- a/appservice/inthttp/client.go
+++ b/appservice/inthttp/client.go
@@ -7,7 +7,6 @@ import (
"github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/internal/httputil"
- "github.com/opentracing/opentracing-go"
)
// HTTP paths for the internal HTTP APIs
@@ -42,11 +41,10 @@ func (h *httpAppServiceQueryAPI) RoomAliasExists(
request *api.RoomAliasExistsRequest,
response *api.RoomAliasExistsResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "appserviceRoomAliasExists")
- defer span.Finish()
-
- apiURL := h.appserviceURL + AppServiceRoomAliasExistsPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "RoomAliasExists", h.appserviceURL+AppServiceRoomAliasExistsPath,
+ h.httpClient, ctx, request, response,
+ )
}
// UserIDExists implements AppServiceQueryAPI
@@ -55,9 +53,8 @@ func (h *httpAppServiceQueryAPI) UserIDExists(
request *api.UserIDExistsRequest,
response *api.UserIDExistsResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "appserviceUserIDExists")
- defer span.Finish()
-
- apiURL := h.appserviceURL + AppServiceUserIDExistsPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "UserIDExists", h.appserviceURL+AppServiceUserIDExistsPath,
+ h.httpClient, ctx, request, response,
+ )
}
diff --git a/appservice/inthttp/server.go b/appservice/inthttp/server.go
index 645b43871..01d9f9895 100644
--- a/appservice/inthttp/server.go
+++ b/appservice/inthttp/server.go
@@ -1,43 +1,20 @@
package inthttp
import (
- "encoding/json"
- "net/http"
-
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/internal/httputil"
- "github.com/matrix-org/util"
)
// AddRoutes adds the AppServiceQueryAPI handlers to the http.ServeMux.
func AddRoutes(a api.AppServiceInternalAPI, internalAPIMux *mux.Router) {
internalAPIMux.Handle(
AppServiceRoomAliasExistsPath,
- httputil.MakeInternalAPI("appserviceRoomAliasExists", func(req *http.Request) util.JSONResponse {
- var request api.RoomAliasExistsRequest
- var response api.RoomAliasExistsResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.ErrorResponse(err)
- }
- if err := a.RoomAliasExists(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+ httputil.MakeInternalRPCAPI("AppserviceRoomAliasExists", a.RoomAliasExists),
)
+
internalAPIMux.Handle(
AppServiceUserIDExistsPath,
- httputil.MakeInternalAPI("appserviceUserIDExists", func(req *http.Request) util.JSONResponse {
- var request api.UserIDExistsRequest
- var response api.UserIDExistsResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.ErrorResponse(err)
- }
- if err := a.UserIDExists(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+ httputil.MakeInternalRPCAPI("AppserviceUserIDExists", a.UserIDExists),
)
}
diff --git a/build/docker/Dockerfile.monolith b/build/docker/Dockerfile.monolith
index 891a3a9e0..bb02934cd 100644
--- a/build/docker/Dockerfile.monolith
+++ b/build/docker/Dockerfile.monolith
@@ -8,7 +8,6 @@ COPY . /build
RUN mkdir -p bin
RUN go build -trimpath -o bin/ ./cmd/dendrite-monolith-server
-RUN go build -trimpath -o bin/ ./cmd/goose
RUN go build -trimpath -o bin/ ./cmd/create-account
RUN go build -trimpath -o bin/ ./cmd/generate-keys
diff --git a/build/docker/Dockerfile.polylith b/build/docker/Dockerfile.polylith
index ffdc35586..166ea99cb 100644
--- a/build/docker/Dockerfile.polylith
+++ b/build/docker/Dockerfile.polylith
@@ -8,7 +8,6 @@ COPY . /build
RUN mkdir -p bin
RUN go build -trimpath -o bin/ ./cmd/dendrite-polylith-multi
-RUN go build -trimpath -o bin/ ./cmd/goose
RUN go build -trimpath -o bin/ ./cmd/create-account
RUN go build -trimpath -o bin/ ./cmd/generate-keys
diff --git a/build/scripts/build-test-lint.sh b/build/scripts/build-test-lint.sh
index 8f0b775b1..32f89c076 100755
--- a/build/scripts/build-test-lint.sh
+++ b/build/scripts/build-test-lint.sh
@@ -13,4 +13,4 @@ go build ./cmd/...
./build/scripts/find-lint.sh
echo "Testing..."
-go test -v ./...
+go test --race -v ./...
diff --git a/clientapi/auth/login.go b/clientapi/auth/login.go
index 5f51c662a..5467e814d 100644
--- a/clientapi/auth/login.go
+++ b/clientapi/auth/login.go
@@ -18,7 +18,6 @@ import (
"context"
"encoding/json"
"io"
- "io/ioutil"
"net/http"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
@@ -34,7 +33,7 @@ import (
// If the final return value is non-nil, an error occurred and the cleanup function
// is nil.
func LoginFromJSONReader(ctx context.Context, r io.Reader, useraccountAPI uapi.UserLoginAPI, userAPI UserInternalAPIForLogin, cfg *config.ClientAPI) (*Login, LoginCleanupFunc, *util.JSONResponse) {
- reqBytes, err := ioutil.ReadAll(r)
+ reqBytes, err := io.ReadAll(r)
if err != nil {
err := &util.JSONResponse{
Code: http.StatusBadRequest,
diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go
index f550c29bb..080d4d9fa 100644
--- a/clientapi/clientapi.go
+++ b/clientapi/clientapi.go
@@ -48,7 +48,6 @@ func AddPublicRoutes(
syncProducer := &producers.SyncAPIProducer{
JetStream: js,
- TopicClientData: cfg.Matrix.JetStream.Prefixed(jetstream.OutputClientData),
TopicReceiptEvent: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReceiptEvent),
TopicSendToDeviceEvent: cfg.Matrix.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent),
TopicTypingEvent: cfg.Matrix.JetStream.Prefixed(jetstream.OutputTypingEvent),
@@ -59,6 +58,7 @@ func AddPublicRoutes(
routing.Setup(
base.PublicClientAPIMux,
+ base.PublicWellKnownAPIMux,
base.SynapseAdminMux,
base.DendriteAdminMux,
cfg, rsAPI, asAPI,
diff --git a/clientapi/httputil/httputil.go b/clientapi/httputil/httputil.go
index b47701368..74f84f1e7 100644
--- a/clientapi/httputil/httputil.go
+++ b/clientapi/httputil/httputil.go
@@ -16,7 +16,7 @@ package httputil
import (
"encoding/json"
- "io/ioutil"
+ "io"
"net/http"
"unicode/utf8"
@@ -29,9 +29,9 @@ import (
func UnmarshalJSONRequest(req *http.Request, iface interface{}) *util.JSONResponse {
// encoding/json allows invalid utf-8, matrix does not
// https://matrix.org/docs/spec/client_server/r0.6.1#api-standards
- body, err := ioutil.ReadAll(req.Body)
+ body, err := io.ReadAll(req.Body)
if err != nil {
- util.GetLogger(req.Context()).WithError(err).Error("ioutil.ReadAll failed")
+ util.GetLogger(req.Context()).WithError(err).Error("io.ReadAll failed")
resp := jsonerror.InternalServerError()
return &resp
}
diff --git a/clientapi/jsonerror/jsonerror.go b/clientapi/jsonerror/jsonerror.go
index 70bac61dc..be7d13a96 100644
--- a/clientapi/jsonerror/jsonerror.go
+++ b/clientapi/jsonerror/jsonerror.go
@@ -15,11 +15,13 @@
package jsonerror
import (
+ "context"
"fmt"
"net/http"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
+ "github.com/sirupsen/logrus"
)
// MatrixError represents the "standard error response" in Matrix.
@@ -213,3 +215,15 @@ func NotTrusted(serverName string) *MatrixError {
Err: fmt.Sprintf("Untrusted server '%s'", serverName),
}
}
+
+// InternalAPIError is returned when Dendrite failed to reach an internal API.
+func InternalAPIError(ctx context.Context, err error) util.JSONResponse {
+ logrus.WithContext(ctx).WithError(err).Error("Error reaching an internal API")
+ return util.JSONResponse{
+ Code: http.StatusInternalServerError,
+ JSON: &MatrixError{
+ ErrCode: "M_INTERNAL_SERVER_ERROR",
+ Err: "Dendrite encountered an error reaching an internal API.",
+ },
+ }
+}
diff --git a/clientapi/producers/syncapi.go b/clientapi/producers/syncapi.go
index 0ac637793..5933ce1a8 100644
--- a/clientapi/producers/syncapi.go
+++ b/clientapi/producers/syncapi.go
@@ -21,7 +21,6 @@ import (
"strconv"
"time"
- "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
@@ -32,7 +31,6 @@ import (
// SyncAPIProducer produces events for the sync API server to consume
type SyncAPIProducer struct {
- TopicClientData string
TopicReceiptEvent string
TopicSendToDeviceEvent string
TopicTypingEvent string
@@ -42,36 +40,6 @@ type SyncAPIProducer struct {
UserAPI userapi.ClientUserAPI
}
-// SendData sends account data to the sync API server
-func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string, readMarker *eventutil.ReadMarkerJSON, ignoredUsers *types.IgnoredUsers) error {
- m := &nats.Msg{
- Subject: p.TopicClientData,
- Header: nats.Header{},
- }
- m.Header.Set(jetstream.UserID, userID)
-
- data := eventutil.AccountData{
- RoomID: roomID,
- Type: dataType,
- ReadMarker: readMarker,
- IgnoredUsers: ignoredUsers,
- }
- var err error
- m.Data, err = json.Marshal(data)
- if err != nil {
- return err
- }
-
- log.WithFields(log.Fields{
- "user_id": userID,
- "room_id": roomID,
- "data_type": dataType,
- }).Tracef("Producing to topic '%s'", p.TopicClientData)
-
- _, err = p.JetStream.PublishMsg(m)
- return err
-}
-
func (p *SyncAPIProducer) SendReceipt(
ctx context.Context,
userID, roomID, eventID, receiptType string, timestamp gomatrixserverlib.Timestamp,
diff --git a/clientapi/routing/account_data.go b/clientapi/routing/account_data.go
index a5a3014ab..b28f0bb1f 100644
--- a/clientapi/routing/account_data.go
+++ b/clientapi/routing/account_data.go
@@ -17,7 +17,7 @@ package routing
import (
"encoding/json"
"fmt"
- "io/ioutil"
+ "io"
"net/http"
"github.com/matrix-org/dendrite/clientapi/httputil"
@@ -25,7 +25,6 @@ import (
"github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/dendrite/internal/eventutil"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
- "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util"
@@ -102,9 +101,9 @@ func SaveAccountData(
}
}
- body, err := ioutil.ReadAll(req.Body)
+ body, err := io.ReadAll(req.Body)
if err != nil {
- util.GetLogger(req.Context()).WithError(err).Error("ioutil.ReadAll failed")
+ util.GetLogger(req.Context()).WithError(err).Error("io.ReadAll failed")
return jsonerror.InternalServerError()
}
@@ -127,18 +126,6 @@ func SaveAccountData(
return util.ErrorResponse(err)
}
- var ignoredUsers *types.IgnoredUsers
- if dataType == "m.ignored_user_list" {
- ignoredUsers = &types.IgnoredUsers{}
- _ = json.Unmarshal(body, ignoredUsers)
- }
-
- // TODO: user API should do this since it's account data
- if err := syncProducer.SendData(userID, roomID, dataType, nil, ignoredUsers); err != nil {
- util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed")
- return jsonerror.InternalServerError()
- }
-
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
@@ -191,11 +178,6 @@ func SaveReadMarker(
return util.ErrorResponse(err)
}
- if err := syncProducer.SendData(device.UserID, roomID, "m.fully_read", &r, nil); err != nil {
- util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed")
- return jsonerror.InternalServerError()
- }
-
// Handle the read receipt that may be included in the read marker
if r.Read != "" {
return SetReceipt(req, syncProducer, device, roomID, "m.read", r.Read)
diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go
index 523b88c99..a8dd0e64f 100644
--- a/clientapi/routing/admin.go
+++ b/clientapi/routing/admin.go
@@ -30,13 +30,15 @@ func AdminEvacuateRoom(req *http.Request, device *userapi.Device, rsAPI roomserv
}
}
res := &roomserverAPI.PerformAdminEvacuateRoomResponse{}
- rsAPI.PerformAdminEvacuateRoom(
+ if err := rsAPI.PerformAdminEvacuateRoom(
req.Context(),
&roomserverAPI.PerformAdminEvacuateRoomRequest{
RoomID: roomID,
},
res,
- )
+ ); err != nil {
+ return util.ErrorResponse(err)
+ }
if err := res.Error; err != nil {
return err.JSONResponse()
}
@@ -67,13 +69,15 @@ func AdminEvacuateUser(req *http.Request, device *userapi.Device, rsAPI roomserv
}
}
res := &roomserverAPI.PerformAdminEvacuateUserResponse{}
- rsAPI.PerformAdminEvacuateUser(
+ if err := rsAPI.PerformAdminEvacuateUser(
req.Context(),
&roomserverAPI.PerformAdminEvacuateUserRequest{
UserID: userID,
},
res,
- )
+ ); err != nil {
+ return jsonerror.InternalAPIError(req.Context(), err)
+ }
if err := res.Error; err != nil {
return err.JSONResponse()
}
diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go
index 3f92b7ba6..874908639 100644
--- a/clientapi/routing/createroom.go
+++ b/clientapi/routing/createroom.go
@@ -556,10 +556,12 @@ func createRoom(
if r.Visibility == "public" {
// expose this room in the published room list
var pubRes roomserverAPI.PerformPublishResponse
- rsAPI.PerformPublish(ctx, &roomserverAPI.PerformPublishRequest{
+ if err := rsAPI.PerformPublish(ctx, &roomserverAPI.PerformPublishRequest{
RoomID: roomID,
Visibility: "public",
- }, &pubRes)
+ }, &pubRes); err != nil {
+ return jsonerror.InternalAPIError(ctx, err)
+ }
if pubRes.Error != nil {
// treat as non-fatal since the room is already made by this point
util.GetLogger(ctx).WithError(pubRes.Error).Error("failed to visibility:public")
diff --git a/clientapi/routing/deactivate.go b/clientapi/routing/deactivate.go
index c8aa6a3bc..f213db7f3 100644
--- a/clientapi/routing/deactivate.go
+++ b/clientapi/routing/deactivate.go
@@ -1,7 +1,7 @@
package routing
import (
- "io/ioutil"
+ "io"
"net/http"
"github.com/matrix-org/dendrite/clientapi/auth"
@@ -20,7 +20,7 @@ func Deactivate(
) util.JSONResponse {
ctx := req.Context()
defer req.Body.Close() // nolint:errcheck
- bodyBytes, err := ioutil.ReadAll(req.Body)
+ bodyBytes, err := io.ReadAll(req.Body)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
diff --git a/clientapi/routing/device.go b/clientapi/routing/device.go
index bb1cf47bd..e3a02661c 100644
--- a/clientapi/routing/device.go
+++ b/clientapi/routing/device.go
@@ -15,7 +15,7 @@
package routing
import (
- "io/ioutil"
+ "io"
"net"
"net/http"
@@ -175,7 +175,7 @@ func DeleteDeviceById(
}()
ctx := req.Context()
defer req.Body.Close() // nolint:errcheck
- bodyBytes, err := ioutil.ReadAll(req.Body)
+ bodyBytes, err := io.ReadAll(req.Body)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go
index 53ba3f190..836d9e152 100644
--- a/clientapi/routing/directory.go
+++ b/clientapi/routing/directory.go
@@ -302,10 +302,12 @@ func SetVisibility(
}
var publishRes roomserverAPI.PerformPublishResponse
- rsAPI.PerformPublish(req.Context(), &roomserverAPI.PerformPublishRequest{
+ if err := rsAPI.PerformPublish(req.Context(), &roomserverAPI.PerformPublishRequest{
RoomID: roomID,
Visibility: v.Visibility,
- }, &publishRes)
+ }, &publishRes); err != nil {
+ return jsonerror.InternalAPIError(req.Context(), err)
+ }
if publishRes.Error != nil {
util.GetLogger(req.Context()).WithError(publishRes.Error).Error("PerformPublish failed")
return publishRes.Error.JSONResponse()
diff --git a/clientapi/routing/directory_public.go b/clientapi/routing/directory_public.go
index c3e6141b2..8ddb3267a 100644
--- a/clientapi/routing/directory_public.go
+++ b/clientapi/routing/directory_public.go
@@ -23,13 +23,14 @@ import (
"strings"
"sync"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/util"
+
"github.com/matrix-org/dendrite/clientapi/api"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
- "github.com/matrix-org/gomatrixserverlib"
- "github.com/matrix-org/util"
)
var (
@@ -196,14 +197,14 @@ func fillPublicRoomsReq(httpReq *http.Request, request *PublicRoomReq) *util.JSO
// sliceInto returns a subslice of `slice` which honours the since/limit values given.
//
-// 0 1 2 3 4 5 6 index
-// [A, B, C, D, E, F, G] slice
+// 0 1 2 3 4 5 6 index
+// [A, B, C, D, E, F, G] slice
//
-// limit=3 => A,B,C (prev='', next='3')
-// limit=3&since=3 => D,E,F (prev='0', next='6')
-// limit=3&since=6 => G (prev='3', next='')
+// limit=3 => A,B,C (prev='', next='3')
+// limit=3&since=3 => D,E,F (prev='0', next='6')
+// limit=3&since=6 => G (prev='3', next='')
//
-// A value of '-1' for prev/next indicates no position.
+// A value of '-1' for prev/next indicates no position.
func sliceInto(slice []gomatrixserverlib.PublicRoom, since int64, limit int16) (subset []gomatrixserverlib.PublicRoom, prev, next int) {
prev = -1
next = -1
diff --git a/clientapi/routing/joinroom.go b/clientapi/routing/joinroom.go
index 4e6acebc3..c50e552bd 100644
--- a/clientapi/routing/joinroom.go
+++ b/clientapi/routing/joinroom.go
@@ -81,8 +81,9 @@ func JoinRoomByIDOrAlias(
done := make(chan util.JSONResponse, 1)
go func() {
defer close(done)
- rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes)
- if joinRes.Error != nil {
+ if err := rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes); err != nil {
+ done <- jsonerror.InternalAPIError(req.Context(), err)
+ } else if joinRes.Error != nil {
done <- joinRes.Error.JSONResponse()
} else {
done <- util.JSONResponse{
diff --git a/clientapi/routing/key_backup.go b/clientapi/routing/key_backup.go
index 28c80415b..b6f8fe1b9 100644
--- a/clientapi/routing/key_backup.go
+++ b/clientapi/routing/key_backup.go
@@ -91,10 +91,12 @@ func CreateKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, de
// Implements GET /_matrix/client/r0/room_keys/version and GET /_matrix/client/r0/room_keys/version/{version}
func KeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version string) util.JSONResponse {
var queryResp userapi.QueryKeyBackupResponse
- userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{
+ if err := userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{
UserID: device.UserID,
Version: version,
- }, &queryResp)
+ }, &queryResp); err != nil {
+ return jsonerror.InternalAPIError(req.Context(), err)
+ }
if queryResp.Error != "" {
return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error))
}
@@ -233,13 +235,15 @@ func GetBackupKeys(
req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version, roomID, sessionID string,
) util.JSONResponse {
var queryResp userapi.QueryKeyBackupResponse
- userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{
+ if err := userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{
UserID: device.UserID,
Version: version,
ReturnKeys: true,
KeysForRoomID: roomID,
KeysForSessionID: sessionID,
- }, &queryResp)
+ }, &queryResp); err != nil {
+ return jsonerror.InternalAPIError(req.Context(), err)
+ }
if queryResp.Error != "" {
return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error))
}
diff --git a/clientapi/routing/key_crosssigning.go b/clientapi/routing/key_crosssigning.go
index 8fbb86f7a..2570db09c 100644
--- a/clientapi/routing/key_crosssigning.go
+++ b/clientapi/routing/key_crosssigning.go
@@ -72,7 +72,9 @@ func UploadCrossSigningDeviceKeys(
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
uploadReq.UserID = device.UserID
- keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes)
+ if err := keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes); err != nil {
+ return jsonerror.InternalAPIError(req.Context(), err)
+ }
if err := uploadRes.Error; err != nil {
switch {
@@ -114,7 +116,9 @@ func UploadCrossSigningDeviceSignatures(req *http.Request, keyserverAPI api.Clie
}
uploadReq.UserID = device.UserID
- keyserverAPI.PerformUploadDeviceSignatures(req.Context(), uploadReq, uploadRes)
+ if err := keyserverAPI.PerformUploadDeviceSignatures(req.Context(), uploadReq, uploadRes); err != nil {
+ return jsonerror.InternalAPIError(req.Context(), err)
+ }
if err := uploadRes.Error; err != nil {
switch {
diff --git a/clientapi/routing/keys.go b/clientapi/routing/keys.go
index fdda34a53..b7a76b47e 100644
--- a/clientapi/routing/keys.go
+++ b/clientapi/routing/keys.go
@@ -62,7 +62,9 @@ func UploadKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *userapi.Devi
}
var uploadRes api.PerformUploadKeysResponse
- keyAPI.PerformUploadKeys(req.Context(), uploadReq, &uploadRes)
+ if err := keyAPI.PerformUploadKeys(req.Context(), uploadReq, &uploadRes); err != nil {
+ return util.ErrorResponse(err)
+ }
if uploadRes.Error != nil {
util.GetLogger(req.Context()).WithError(uploadRes.Error).Error("Failed to PerformUploadKeys")
return jsonerror.InternalServerError()
@@ -107,12 +109,14 @@ func QueryKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *userapi.Devic
return *resErr
}
queryRes := api.QueryKeysResponse{}
- keyAPI.QueryKeys(req.Context(), &api.QueryKeysRequest{
+ if err := keyAPI.QueryKeys(req.Context(), &api.QueryKeysRequest{
UserID: device.UserID,
UserToDevices: r.DeviceKeys,
Timeout: r.GetTimeout(),
// TODO: Token?
- }, &queryRes)
+ }, &queryRes); err != nil {
+ return util.ErrorResponse(err)
+ }
return util.JSONResponse{
Code: 200,
JSON: map[string]interface{}{
@@ -145,10 +149,12 @@ func ClaimKeys(req *http.Request, keyAPI api.ClientKeyAPI) util.JSONResponse {
return *resErr
}
claimRes := api.PerformClaimKeysResponse{}
- keyAPI.PerformClaimKeys(req.Context(), &api.PerformClaimKeysRequest{
+ if err := keyAPI.PerformClaimKeys(req.Context(), &api.PerformClaimKeysRequest{
OneTimeKeys: r.OneTimeKeys,
Timeout: r.GetTimeout(),
- }, &claimRes)
+ }, &claimRes); err != nil {
+ return jsonerror.InternalAPIError(req.Context(), err)
+ }
if claimRes.Error != nil {
util.GetLogger(req.Context()).WithError(claimRes.Error).Error("failed to PerformClaimKeys")
return jsonerror.InternalServerError()
diff --git a/clientapi/routing/peekroom.go b/clientapi/routing/peekroom.go
index d0eeccf17..9b2592eb5 100644
--- a/clientapi/routing/peekroom.go
+++ b/clientapi/routing/peekroom.go
@@ -17,6 +17,7 @@ package routing
import (
"net/http"
+ "github.com/matrix-org/dendrite/clientapi/jsonerror"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
@@ -54,7 +55,9 @@ func PeekRoomByIDOrAlias(
}
// Ask the roomserver to perform the peek.
- rsAPI.PerformPeek(req.Context(), &peekReq, &peekRes)
+ if err := rsAPI.PerformPeek(req.Context(), &peekReq, &peekRes); err != nil {
+ return util.ErrorResponse(err)
+ }
if peekRes.Error != nil {
return peekRes.Error.JSONResponse()
}
@@ -89,7 +92,9 @@ func UnpeekRoomByID(
}
unpeekRes := roomserverAPI.PerformUnpeekResponse{}
- rsAPI.PerformUnpeek(req.Context(), &unpeekReq, &unpeekRes)
+ if err := rsAPI.PerformUnpeek(req.Context(), &unpeekReq, &unpeekRes); err != nil {
+ return jsonerror.InternalAPIError(req.Context(), err)
+ }
if unpeekRes.Error != nil {
return unpeekRes.Error.JSONResponse()
}
diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go
index c4ac0f2e7..af0329a48 100644
--- a/clientapi/routing/register.go
+++ b/clientapi/routing/register.go
@@ -19,7 +19,7 @@ import (
"context"
"encoding/json"
"fmt"
- "io/ioutil"
+ "io"
"net/http"
"net/url"
"regexp"
@@ -371,7 +371,7 @@ func validateRecaptcha(
// Grab the body of the response from the captcha server
var r recaptchaResponse
- body, err := ioutil.ReadAll(resp.Body)
+ body, err := io.ReadAll(resp.Body)
if err != nil {
return &util.JSONResponse{
Code: http.StatusGatewayTimeout,
@@ -539,7 +539,7 @@ func Register(
cfg *config.ClientAPI,
) util.JSONResponse {
defer req.Body.Close() // nolint: errcheck
- reqBody, err := ioutil.ReadAll(req.Body)
+ reqBody, err := io.ReadAll(req.Body)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
diff --git a/clientapi/routing/register_secret_test.go b/clientapi/routing/register_secret_test.go
index e702b2152..a2ed35853 100644
--- a/clientapi/routing/register_secret_test.go
+++ b/clientapi/routing/register_secret_test.go
@@ -2,7 +2,7 @@ package routing
import (
"bytes"
- "io/ioutil"
+ "io"
"testing"
"github.com/patrickmn/go-cache"
@@ -13,7 +13,7 @@ func TestSharedSecretRegister(t *testing.T) {
jsonStr := []byte(`{"admin":false,"mac":"f1ba8d37123866fd659b40de4bad9b0f8965c565","nonce":"759f047f312b99ff428b21d581256f8592b8976e58bc1b543972dc6147e529a79657605b52d7becd160ff5137f3de11975684319187e06901955f79e5a6c5a79","password":"wonderland","username":"alice"}`)
sharedSecret := "dendritetest"
- req, err := NewSharedSecretRegistrationRequest(ioutil.NopCloser(bytes.NewBuffer(jsonStr)))
+ req, err := NewSharedSecretRegistrationRequest(io.NopCloser(bytes.NewBuffer(jsonStr)))
if err != nil {
t.Fatalf("failed to read request: %s", err)
}
diff --git a/clientapi/routing/room_tagging.go b/clientapi/routing/room_tagging.go
index 039289569..92b9e6655 100644
--- a/clientapi/routing/room_tagging.go
+++ b/clientapi/routing/room_tagging.go
@@ -18,8 +18,6 @@ import (
"encoding/json"
"net/http"
- "github.com/sirupsen/logrus"
-
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/producers"
@@ -98,10 +96,6 @@ func PutTag(
return jsonerror.InternalServerError()
}
- if err = syncProducer.SendData(userID, roomID, "m.tag", nil, nil); err != nil {
- logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
- }
-
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
@@ -150,11 +144,6 @@ func DeleteTag(
return jsonerror.InternalServerError()
}
- // TODO: user API should do this since it's account data
- if err := syncProducer.SendData(userID, roomID, "m.tag", nil, nil); err != nil {
- logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
- }
-
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go
index 0460850ef..ced4fdbcf 100644
--- a/clientapi/routing/routing.go
+++ b/clientapi/routing/routing.go
@@ -48,7 +48,7 @@ import (
// applied:
// nolint: gocyclo
func Setup(
- publicAPIMux, synapseAdminRouter, dendriteAdminRouter *mux.Router,
+ publicAPIMux, wkMux, synapseAdminRouter, dendriteAdminRouter *mux.Router,
cfg *config.ClientAPI,
rsAPI roomserverAPI.ClientRoomserverAPI,
asAPI appserviceAPI.AppServiceInternalAPI,
@@ -74,6 +74,26 @@ func Setup(
unstableFeatures["org.matrix."+msc] = true
}
+ if cfg.Matrix.WellKnownClientName != "" {
+ logrus.Infof("Setting m.homeserver base_url as %s at /.well-known/matrix/client", cfg.Matrix.WellKnownClientName)
+ wkMux.Handle("/client", httputil.MakeExternalAPI("wellknown", func(r *http.Request) util.JSONResponse {
+ return util.JSONResponse{
+ Code: http.StatusOK,
+ JSON: struct {
+ HomeserverName struct {
+ BaseUrl string `json:"base_url"`
+ } `json:"m.homeserver"`
+ }{
+ HomeserverName: struct {
+ BaseUrl string `json:"base_url"`
+ }{
+ BaseUrl: cfg.Matrix.WellKnownClientName,
+ },
+ },
+ }
+ })).Methods(http.MethodGet, http.MethodOptions)
+ }
+
publicAPIMux.Handle("/versions",
httputil.MakeExternalAPI("versions", func(req *http.Request) util.JSONResponse {
return util.JSONResponse{
diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go
index 2e864adef..85f1053f3 100644
--- a/clientapi/routing/sendevent.go
+++ b/clientapi/routing/sendevent.go
@@ -63,9 +63,10 @@ var sendEventDuration = prometheus.NewHistogramVec(
)
// SendEvent implements:
-// /rooms/{roomID}/send/{eventType}
-// /rooms/{roomID}/send/{eventType}/{txnID}
-// /rooms/{roomID}/state/{eventType}/{stateKey}
+//
+// /rooms/{roomID}/send/{eventType}
+// /rooms/{roomID}/send/{eventType}/{txnID}
+// /rooms/{roomID}/state/{eventType}/{stateKey}
func SendEvent(
req *http.Request,
device *userapi.Device,
diff --git a/clientapi/routing/threepid.go b/clientapi/routing/threepid.go
index 94b658ee3..4b7989ecb 100644
--- a/clientapi/routing/threepid.go
+++ b/clientapi/routing/threepid.go
@@ -38,8 +38,9 @@ type threePIDsResponse struct {
}
// RequestEmailToken implements:
-// POST /account/3pid/email/requestToken
-// POST /register/email/requestToken
+//
+// POST /account/3pid/email/requestToken
+// POST /register/email/requestToken
func RequestEmailToken(req *http.Request, threePIDAPI api.ClientUserAPI, cfg *config.ClientAPI) util.JSONResponse {
var body threepid.EmailAssociationRequest
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
diff --git a/clientapi/routing/upgrade_room.go b/clientapi/routing/upgrade_room.go
index 744e2d889..34c7eb004 100644
--- a/clientapi/routing/upgrade_room.go
+++ b/clientapi/routing/upgrade_room.go
@@ -64,7 +64,9 @@ func UpgradeRoom(
}
upgradeResp := roomserverAPI.PerformRoomUpgradeResponse{}
- rsAPI.PerformRoomUpgrade(req.Context(), &upgradeReq, &upgradeResp)
+ if err := rsAPI.PerformRoomUpgrade(req.Context(), &upgradeReq, &upgradeResp); err != nil {
+ return jsonerror.InternalAPIError(req.Context(), err)
+ }
if upgradeResp.Error != nil {
if upgradeResp.Error.Code == roomserverAPI.PerformErrorNoRoom {
diff --git a/clientapi/routing/voip.go b/clientapi/routing/voip.go
index c7ddaabcf..f0f69ce3c 100644
--- a/clientapi/routing/voip.go
+++ b/clientapi/routing/voip.go
@@ -22,15 +22,17 @@ import (
"net/http"
"time"
+ "github.com/matrix-org/gomatrix"
+ "github.com/matrix-org/util"
+
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/gomatrix"
- "github.com/matrix-org/util"
)
// RequestTurnServer implements:
-// GET /voip/turnServer
+//
+// GET /voip/turnServer
func RequestTurnServer(req *http.Request, device *api.Device, cfg *config.ClientAPI) util.JSONResponse {
turnConfig := cfg.TURN
diff --git a/cmd/create-account/main.go b/cmd/create-account/main.go
index 7f6d5105e..92179a049 100644
--- a/cmd/create-account/main.go
+++ b/cmd/create-account/main.go
@@ -19,7 +19,6 @@ import (
"flag"
"fmt"
"io"
- "io/ioutil"
"os"
"regexp"
"strings"
@@ -157,7 +156,7 @@ func main() {
func getPassword(password, pwdFile string, pwdStdin bool, r io.Reader) (string, error) {
// read password from file
if pwdFile != "" {
- pw, err := ioutil.ReadFile(pwdFile)
+ pw, err := os.ReadFile(pwdFile)
if err != nil {
return "", fmt.Errorf("Unable to read password from file: %v", err)
}
@@ -166,7 +165,7 @@ func getPassword(password, pwdFile string, pwdStdin bool, r io.Reader) (string,
// read password from stdin
if pwdStdin {
- data, err := ioutil.ReadAll(r)
+ data, err := io.ReadAll(r)
if err != nil {
return "", fmt.Errorf("Unable to read password from stdin: %v", err)
}
diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go
index 8fa935ddf..75f29fe27 100644
--- a/cmd/dendrite-demo-pinecone/main.go
+++ b/cmd/dendrite-demo-pinecone/main.go
@@ -21,7 +21,6 @@ import (
"encoding/hex"
"flag"
"fmt"
- "io/ioutil"
"net"
"net/http"
"os"
@@ -76,11 +75,11 @@ func main() {
if pk, sk, err = ed25519.GenerateKey(nil); err != nil {
panic(err)
}
- if err = ioutil.WriteFile(keyfile, sk, 0644); err != nil {
+ if err = os.WriteFile(keyfile, sk, 0644); err != nil {
panic(err)
}
} else if err == nil {
- if sk, err = ioutil.ReadFile(keyfile); err != nil {
+ if sk, err = os.ReadFile(keyfile); err != nil {
panic(err)
}
if len(sk) != ed25519.PrivateKeySize {
diff --git a/cmd/dendrite-demo-yggdrasil/yggconn/node.go b/cmd/dendrite-demo-yggdrasil/yggconn/node.go
index d93272e2e..ff3c73ec8 100644
--- a/cmd/dendrite-demo-yggdrasil/yggconn/node.go
+++ b/cmd/dendrite-demo-yggdrasil/yggconn/node.go
@@ -20,7 +20,6 @@ import (
"encoding/hex"
"encoding/json"
"fmt"
- "io/ioutil"
"log"
"net"
"os"
@@ -69,7 +68,7 @@ func Setup(instanceName, storageDirectory, peerURI string) (*Node, error) {
yggfile := fmt.Sprintf("%s/%s-yggdrasil.conf", storageDirectory, instanceName)
if _, err := os.Stat(yggfile); !os.IsNotExist(err) {
- yggconf, e := ioutil.ReadFile(yggfile)
+ yggconf, e := os.ReadFile(yggfile)
if e != nil {
panic(err)
}
@@ -88,7 +87,7 @@ func Setup(instanceName, storageDirectory, peerURI string) (*Node, error) {
if err != nil {
panic(err)
}
- if e := ioutil.WriteFile(yggfile, j, 0600); e != nil {
+ if e := os.WriteFile(yggfile, j, 0600); e != nil {
n.log.Printf("Couldn't write private key to file '%s': %s\n", yggfile, e)
}
diff --git a/cmd/dendrite-upgrade-tests/main.go b/cmd/dendrite-upgrade-tests/main.go
index cabd07e70..dce22472d 100644
--- a/cmd/dendrite-upgrade-tests/main.go
+++ b/cmd/dendrite-upgrade-tests/main.go
@@ -6,7 +6,7 @@ import (
"encoding/json"
"flag"
"fmt"
- "io/ioutil"
+ "io"
"log"
"net/http"
"os"
@@ -37,6 +37,7 @@ var (
flagBuildConcurrency = flag.Int("build-concurrency", runtime.NumCPU(), "The amount of build concurrency when building images")
flagHead = flag.String("head", "", "Location to a dendrite repository to treat as HEAD instead of Github")
flagDockerHost = flag.String("docker-host", "localhost", "The hostname of the docker client. 'localhost' if running locally, 'host.docker.internal' if running in Docker.")
+ flagDirect = flag.Bool("direct", false, "If a direct upgrade from the defined FROM version to TO should be done")
alphaNumerics = regexp.MustCompile("[^a-zA-Z0-9]+")
)
@@ -46,7 +47,7 @@ const HEAD = "HEAD"
// We cannot use the dockerfile associated with the repo with each version sadly due to changes in
// Docker versions. Specifically, earlier Dendrite versions are incompatible with newer Docker clients
// due to the error:
-// When using COPY with more than one source file, the destination must be a directory and end with a /
+// When using COPY with more than one source file, the destination must be a directory and end with a /
// We need to run a postgres anyway, so use the dockerfile associated with Complement instead.
const Dockerfile = `FROM golang:1.18-stretch as build
RUN apt-get update && apt-get install -y postgresql
@@ -94,7 +95,9 @@ CMD /build/run_dendrite.sh `
const dendriteUpgradeTestLabel = "dendrite_upgrade_test"
// downloadArchive downloads an arbitrary github archive of the form:
-// https://github.com/matrix-org/dendrite/archive/v0.3.11.tar.gz
+//
+// https://github.com/matrix-org/dendrite/archive/v0.3.11.tar.gz
+//
// and re-tarballs it without the top-level directory which contains branch information. It inserts
// the contents of `dockerfile` as a root file `Dockerfile` in the re-tarballed directory such that
// you can directly feed the retarballed archive to `ImageBuild` to have it run said dockerfile.
@@ -125,7 +128,7 @@ func downloadArchive(cli *http.Client, tmpDir, archiveURL string, dockerfile []b
return nil, err
}
// add top level Dockerfile
- err = ioutil.WriteFile(path.Join(tmpDir, "Dockerfile"), dockerfile, os.ModePerm)
+ err = os.WriteFile(path.Join(tmpDir, "Dockerfile"), dockerfile, os.ModePerm)
if err != nil {
return nil, fmt.Errorf("failed to inject /Dockerfile: %w", err)
}
@@ -147,7 +150,7 @@ func buildDendrite(httpClient *http.Client, dockerClient *client.Client, tmpDir,
if branchOrTagName == HEAD && *flagHead != "" {
log.Printf("%s: Using %s as HEAD", branchOrTagName, *flagHead)
// add top level Dockerfile
- err = ioutil.WriteFile(path.Join(*flagHead, "Dockerfile"), []byte(Dockerfile), os.ModePerm)
+ err = os.WriteFile(path.Join(*flagHead, "Dockerfile"), []byte(Dockerfile), os.ModePerm)
if err != nil {
return "", fmt.Errorf("custom HEAD: failed to inject /Dockerfile: %w", err)
}
@@ -229,7 +232,7 @@ func getAndSortVersionsFromGithub(httpClient *http.Client) (semVers []*semver.Ve
return semVers, nil
}
-func calculateVersions(cli *http.Client, from, to string) []string {
+func calculateVersions(cli *http.Client, from, to string, direct bool) []string {
semvers, err := getAndSortVersionsFromGithub(cli)
if err != nil {
log.Fatalf("failed to collect semvers from github: %s", err)
@@ -284,6 +287,9 @@ func calculateVersions(cli *http.Client, from, to string) []string {
if to == HEAD {
versions = append(versions, HEAD)
}
+ if direct {
+ versions = []string{versions[0], versions[len(versions)-1]}
+ }
return versions
}
@@ -382,7 +388,7 @@ func runImage(dockerClient *client.Client, volumeName, version, imageID string)
})
// ignore errors when cannot get logs, it's just for debugging anyways
if err == nil {
- logbody, err := ioutil.ReadAll(logs)
+ logbody, err := io.ReadAll(logs)
if err == nil {
log.Printf("Container logs:\n\n%s\n\n", string(logbody))
}
@@ -461,7 +467,7 @@ func main() {
os.Exit(1)
}
cleanup(dockerClient)
- versions := calculateVersions(httpClient, *flagFrom, *flagTo)
+ versions := calculateVersions(httpClient, *flagFrom, *flagTo, *flagDirect)
log.Printf("Testing dendrite versions: %v\n", versions)
branchToImageID := buildDendriteImages(httpClient, dockerClient, *flagTempDir, *flagBuildConcurrency, versions)
diff --git a/cmd/dendrite-upgrade-tests/tests.go b/cmd/dendrite-upgrade-tests/tests.go
index e02af92a9..ff1e09dda 100644
--- a/cmd/dendrite-upgrade-tests/tests.go
+++ b/cmd/dendrite-upgrade-tests/tests.go
@@ -18,9 +18,9 @@ type user struct {
}
// runTests performs the following operations:
-// - register alice and bob with branch name muxed into the localpart
-// - create a DM room for the 2 users and exchange messages
-// - create/join a public #global room and exchange messages
+// - register alice and bob with branch name muxed into the localpart
+// - create a DM room for the 2 users and exchange messages
+// - create/join a public #global room and exchange messages
func runTests(baseURL, branchName string) error {
// register 2 users
users := []user{
diff --git a/cmd/furl/main.go b/cmd/furl/main.go
index 75e223388..f59f9c8ce 100644
--- a/cmd/furl/main.go
+++ b/cmd/furl/main.go
@@ -9,7 +9,6 @@ import (
"encoding/pem"
"flag"
"fmt"
- "io/ioutil"
"net/url"
"os"
@@ -30,7 +29,7 @@ func main() {
os.Exit(1)
}
- data, err := ioutil.ReadFile(*requestKey)
+ data, err := os.ReadFile(*requestKey)
if err != nil {
panic(err)
}
diff --git a/cmd/goose/README.md b/cmd/goose/README.md
deleted file mode 100644
index 725c6a586..000000000
--- a/cmd/goose/README.md
+++ /dev/null
@@ -1,109 +0,0 @@
-## Database migrations
-
-We use [goose](https://github.com/pressly/goose) to handle database migrations. This allows us to execute
-both SQL deltas (e.g `ALTER TABLE ...`) as well as manipulate data in the database in Go using Go functions.
-
-To run a migration, the `goose` binary in this directory needs to be built:
-```
-$ go build ./cmd/goose
-```
-
-This binary allows Dendrite databases to be upgraded and downgraded. Sample usage for upgrading the roomserver database:
-
-```
-# for sqlite
-$ ./goose -dir roomserver/storage/sqlite3/deltas sqlite3 ./roomserver.db up
-
-# for postgres
-$ ./goose -dir roomserver/storage/postgres/deltas postgres "user=dendrite dbname=dendrite sslmode=disable" up
-```
-
-For a full list of options, including rollbacks, see https://github.com/pressly/goose or use `goose` with no args.
-
-
-### Rationale
-
-Dendrite creates tables on startup using `CREATE TABLE IF NOT EXISTS`, so you might think that we should also
-apply version upgrades on startup as well. This is convenient and doesn't involve an additional binary to run
-which complicates upgrades. However, combining the upgrade mechanism and the server binary makes it difficult
-to handle rollbacks. Firstly, how do you specify you wish to rollback? We would have to add additional flags
-to the main server binary to say "rollback to version X". Secondly, if you roll back the server binary from
-version 5 to version 4, the version 4 binary doesn't know how to rollback the database from version 5 to
-version 4! For these reasons, we prefer to have a separate "upgrade" binary which is run for database upgrades.
-Rather than roll-our-own migration tool, we decided to use [goose](https://github.com/pressly/goose) as it supports
-complex migrations in Go code in addition to just executing SQL deltas. Other alternatives like
-`github.com/golang-migrate/migrate` [do not support](https://github.com/golang-migrate/migrate/issues/15) these
-kinds of complex migrations.
-
-### Adding new deltas
-
-You can add `.sql` or `.go` files manually or you can use goose to create them for you.
-
-If you only want to add a SQL delta then run:
-
-```
-$ ./goose -dir serverkeyapi/storage/sqlite3/deltas sqlite3 ./foo.db create new_col sql
-2020/09/09 14:37:43 Created new file: serverkeyapi/storage/sqlite3/deltas/20200909143743_new_col.sql
-```
-
-In this case, the version number is `20200909143743`. The important thing is that it is always increasing.
-
-Then add up/downgrade SQL commands to the created file which looks like:
-```sql
--- +goose Up
--- +goose StatementBegin
-SELECT 'up SQL query';
--- +goose StatementEnd
-
--- +goose Down
--- +goose StatementBegin
-SELECT 'down SQL query';
--- +goose StatementEnd
-
-```
-You __must__ keep the `+goose` annotations. You'll need to repeat this process for Postgres.
-
-For complex Go migrations:
-
-```
-$ ./goose -dir serverkeyapi/storage/sqlite3/deltas sqlite3 ./foo.db create complex_update go
-2020/09/09 14:40:38 Created new file: serverkeyapi/storage/sqlite3/deltas/20200909144038_complex_update.go
-```
-
-Then modify the created `.go` file which looks like:
-
-```go
-package migrations
-
-import (
- "database/sql"
- "fmt"
-
- "github.com/pressly/goose"
-)
-
-func init() {
- goose.AddMigration(upComplexUpdate, downComplexUpdate)
-}
-
-func upComplexUpdate(tx *sql.Tx) error {
- // This code is executed when the migration is applied.
- return nil
-}
-
-func downComplexUpdate(tx *sql.Tx) error {
- // This code is executed when the migration is rolled back.
- return nil
-}
-
-```
-
-You __must__ import the package in `/cmd/goose/main.go` so `func init()` gets called.
-
-
-#### Database limitations
-
-- SQLite3 does NOT support `ALTER TABLE table_name DROP COLUMN` - you would have to rename the column or drop the table
- entirely and recreate it. ([example](https://github.com/matrix-org/dendrite/blob/master/userapi/storage/accounts/sqlite3/deltas/20200929203058_is_active.sql))
-
- More information: [sqlite.org](https://www.sqlite.org/lang_altertable.html)
diff --git a/cmd/goose/main.go b/cmd/goose/main.go
deleted file mode 100644
index 31a5b0050..000000000
--- a/cmd/goose/main.go
+++ /dev/null
@@ -1,154 +0,0 @@
-// This is custom goose binary
-
-package main
-
-import (
- "flag"
- "fmt"
- "log"
- "os"
-
- "github.com/pressly/goose"
-
- pgusers "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
- slusers "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
-
- _ "github.com/lib/pq"
- _ "github.com/mattn/go-sqlite3"
-)
-
-const (
- AppService = "appservice"
- FederationSender = "federationapi"
- KeyServer = "keyserver"
- MediaAPI = "mediaapi"
- RoomServer = "roomserver"
- SigningKeyServer = "signingkeyserver"
- SyncAPI = "syncapi"
- UserAPI = "userapi"
-)
-
-var (
- dir = flags.String("dir", "", "directory with migration files")
- flags = flag.NewFlagSet("goose", flag.ExitOnError)
- component = flags.String("component", "", "dendrite component name")
- knownDBs = []string{
- AppService, FederationSender, KeyServer, MediaAPI, RoomServer, SigningKeyServer, SyncAPI, UserAPI,
- }
-)
-
-// nolint: gocyclo
-func main() {
- err := flags.Parse(os.Args[1:])
- if err != nil {
- panic(err.Error())
- }
- args := flags.Args()
-
- if len(args) < 3 {
- fmt.Println(
- `Usage: goose [OPTIONS] DRIVER DBSTRING COMMAND
-
-Drivers:
- postgres
- sqlite3
-
-Examples:
- goose -component roomserver sqlite3 ./roomserver.db status
- goose -component roomserver sqlite3 ./roomserver.db up
-
- goose -component roomserver postgres "user=dendrite dbname=dendrite sslmode=disable" status
-
-Options:
- -component string
- Dendrite component name e.g roomserver, signingkeyserver, clientapi, syncapi
- -table string
- migrations table name (default "goose_db_version")
- -h print help
- -v enable verbose mode
- -dir string
- directory with migration files, only relevant when creating new migrations.
- -version
- print version
-
-Commands:
- up Migrate the DB to the most recent version available
- up-by-one Migrate the DB up by 1
- up-to VERSION Migrate the DB to a specific VERSION
- down Roll back the version by 1
- down-to VERSION Roll back to a specific VERSION
- redo Re-run the latest migration
- reset Roll back all migrations
- status Dump the migration status for the current DB
- version Print the current version of the database
- create NAME [sql|go] Creates new migration file with the current timestamp
- fix Apply sequential ordering to migrations`,
- )
- return
- }
-
- engine := args[0]
- if engine != "sqlite3" && engine != "postgres" {
- fmt.Println("engine must be one of 'sqlite3' or 'postgres'")
- return
- }
-
- knownComponent := false
- for _, c := range knownDBs {
- if c == *component {
- knownComponent = true
- break
- }
- }
- if !knownComponent {
- fmt.Printf("component must be one of %v\n", knownDBs)
- return
- }
-
- if engine == "sqlite3" {
- loadSQLiteDeltas(*component)
- } else {
- loadPostgresDeltas(*component)
- }
-
- dbstring, command := args[1], args[2]
-
- db, err := goose.OpenDBWithDriver(engine, dbstring)
- if err != nil {
- log.Fatalf("goose: failed to open DB: %v\n", err)
- }
-
- defer func() {
- if err := db.Close(); err != nil {
- log.Fatalf("goose: failed to close DB: %v\n", err)
- }
- }()
-
- arguments := []string{}
- if len(args) > 3 {
- arguments = append(arguments, args[3:]...)
- }
-
- // goose demands a directory even though we don't use it for upgrades
- d := *dir
- if d == "" {
- d = os.TempDir()
- }
- if err := goose.Run(command, db, d, arguments...); err != nil {
- log.Fatalf("goose %v: %v", command, err)
- }
-}
-
-func loadSQLiteDeltas(component string) {
- switch component {
- case UserAPI:
- slusers.LoadFromGoose()
- }
-}
-
-func loadPostgresDeltas(component string) {
- switch component {
- case UserAPI:
- pgusers.LoadFromGoose()
- }
-}
diff --git a/dendrite-sample.monolith.yaml b/dendrite-sample.monolith.yaml
index c6050e407..f753c3d9b 100644
--- a/dendrite-sample.monolith.yaml
+++ b/dendrite-sample.monolith.yaml
@@ -64,6 +64,10 @@ global:
# e.g. localhost:443
well_known_server_name: ""
+ # The server name to delegate client-server communications to, with optional port
+ # e.g. localhost:443
+ well_known_client_name: ""
+
# Lists of domains that the server will trust as identity servers to verify third
# party identifiers such as phone numbers and email addresses.
trusted_third_party_id_servers:
@@ -109,6 +113,11 @@ global:
addresses:
# - localhost:4222
+ # Disable the validation of TLS certificates of NATS. This is
+ # not recommended in production since it may allow NATS traffic
+ # to be sent to an insecure endpoint.
+ disable_tls_validation: false
+
# Persistent directory to store JetStream streams in. This directory should be
# preserved across Dendrite restarts.
storage_path: ./
@@ -169,13 +178,16 @@ client_api:
# TURN server information that this homeserver should send to clients.
turn:
- turn_user_lifetime: ""
+ turn_user_lifetime: "5m"
turn_uris:
# - turn:turn.server.org?transport=udp
# - turn:turn.server.org?transport=tcp
turn_shared_secret: ""
- turn_username: ""
- turn_password: ""
+ # If your TURN server requires static credentials, then you will need to enter
+ # them here instead of supplying a shared secret. Note that these credentials
+ # will be visible to clients!
+ # turn_username: ""
+ # turn_password: ""
# Settings for rate-limited endpoints. Rate limiting kicks in after the threshold
# number of "slots" have been taken by requests from a specific host. Each "slot"
@@ -183,7 +195,7 @@ client_api:
# and appservice users are exempt from rate limiting by default.
rate_limiting:
enabled: true
- threshold: 5
+ threshold: 20
cooloff_ms: 500
exempt_user_ids:
# - "@user:domain.com"
diff --git a/dendrite-sample.polylith.yaml b/dendrite-sample.polylith.yaml
index ea3d9d689..856b4ab22 100644
--- a/dendrite-sample.polylith.yaml
+++ b/dendrite-sample.polylith.yaml
@@ -54,6 +54,10 @@ global:
# e.g. localhost:443
well_known_server_name: ""
+ # The server name to delegate client-server communications to, with optional port
+ # e.g. localhost:443
+ well_known_client_name: ""
+
# Lists of domains that the server will trust as identity servers to verify third
# party identifiers such as phone numbers and email addresses.
trusted_third_party_id_servers:
@@ -99,6 +103,11 @@ global:
addresses:
- hostname:4222
+ # Disable the validation of TLS certificates of NATS. This is
+ # not recommended in production since it may allow NATS traffic
+ # to be sent to an insecure endpoint.
+ disable_tls_validation: false
+
# The prefix to use for stream names for this homeserver - really only useful
# if you are running more than one Dendrite server on the same NATS deployment.
topic_prefix: Dendrite
@@ -125,7 +134,7 @@ app_service_api:
# Database configuration for this component.
database:
- connection_string: postgresql://username@password:hostname/dendrite_appservice?sslmode=disable
+ connection_string: postgresql://username:password@hostname/dendrite_appservice?sslmode=disable
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
@@ -172,13 +181,16 @@ client_api:
# TURN server information that this homeserver should send to clients.
turn:
- turn_user_lifetime: ""
+ turn_user_lifetime: "5m"
turn_uris:
# - turn:turn.server.org?transport=udp
# - turn:turn.server.org?transport=tcp
turn_shared_secret: ""
- turn_username: ""
- turn_password: ""
+ # If your TURN server requires static credentials, then you will need to enter
+ # them here instead of supplying a shared secret. Note that these credentials
+ # will be visible to clients!
+ # turn_username: ""
+ # turn_password: ""
# Settings for rate-limited endpoints. Rate limiting kicks in after the threshold
# number of "slots" have been taken by requests from a specific host. Each "slot"
@@ -186,7 +198,7 @@ client_api:
# and appservice users are exempt from rate limiting by default.
rate_limiting:
enabled: true
- threshold: 5
+ threshold: 20
cooloff_ms: 500
exempt_user_ids:
# - "@user:domain.com"
@@ -199,7 +211,7 @@ federation_api:
external_api:
listen: http://[::]:8072
database:
- connection_string: postgresql://username@password:hostname/dendrite_federationapi?sslmode=disable
+ connection_string: postgresql://username:password@hostname/dendrite_federationapi?sslmode=disable
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
@@ -236,7 +248,7 @@ key_server:
listen: http://[::]:7779 # The listen address for incoming API requests
connect: http://key_server:7779 # The connect address for other components to use
database:
- connection_string: postgresql://username@password:hostname/dendrite_keyserver?sslmode=disable
+ connection_string: postgresql://username:password@hostname/dendrite_keyserver?sslmode=disable
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
@@ -249,7 +261,7 @@ media_api:
external_api:
listen: http://[::]:8074
database:
- connection_string: postgresql://username@password:hostname/dendrite_mediaapi?sslmode=disable
+ connection_string: postgresql://username:password@hostname/dendrite_mediaapi?sslmode=disable
max_open_conns: 5
max_idle_conns: 2
conn_max_lifetime: -1
@@ -286,7 +298,7 @@ mscs:
# - msc2836 # (Threading, see https://github.com/matrix-org/matrix-doc/pull/2836)
# - msc2946 # (Spaces Summary, see https://github.com/matrix-org/matrix-doc/pull/2946)
database:
- connection_string: postgresql://username@password:hostname/dendrite_mscs?sslmode=disable
+ connection_string: postgresql://username:password@hostname/dendrite_mscs?sslmode=disable
max_open_conns: 5
max_idle_conns: 2
conn_max_lifetime: -1
@@ -297,7 +309,7 @@ room_server:
listen: http://[::]:7770 # The listen address for incoming API requests
connect: http://room_server:7770 # The connect address for other components to use
database:
- connection_string: postgresql://username@password:hostname/dendrite_roomserver?sslmode=disable
+ connection_string: postgresql://username:password@hostname/dendrite_roomserver?sslmode=disable
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
@@ -310,7 +322,7 @@ sync_api:
external_api:
listen: http://[::]:8073
database:
- connection_string: postgresql://username@password:hostname/dendrite_syncapi?sslmode=disable
+ connection_string: postgresql://username:password@hostname/dendrite_syncapi?sslmode=disable
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
@@ -326,7 +338,7 @@ user_api:
listen: http://[::]:7781 # The listen address for incoming API requests
connect: http://user_api:7781 # The connect address for other components to use
account_database:
- connection_string: postgresql://username@password:hostname/dendrite_userapi?sslmode=disable
+ connection_string: postgresql://username:password@hostname/dendrite_userapi?sslmode=disable
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md
index 5a89e6841..771af9ecf 100644
--- a/docs/CONTRIBUTING.md
+++ b/docs/CONTRIBUTING.md
@@ -24,7 +24,7 @@ Unfortunately we can't accept contributions without it.
## Getting up and running
-See the [Installation](INSTALL.md) section for information on how to build an
+See the [Installation](installation) section for information on how to build an
instance of Dendrite. You will likely need this in order to test your changes.
## Code style
@@ -64,7 +64,7 @@ comment. Please avoid doing this if you can.
We also have unit tests which we run via:
```bash
-go test ./...
+go test --race ./...
```
In general, we like submissions that come with tests. Anything that proves that the
diff --git a/docs/FAQ.md b/docs/FAQ.md
index 47f39b9e6..f8255684e 100644
--- a/docs/FAQ.md
+++ b/docs/FAQ.md
@@ -86,9 +86,12 @@ would be a huge help too, as that will help us to understand where the memory us
You may need to revisit the connection limit of your PostgreSQL server and/or make changes to the `max_connections` lines in your Dendrite configuration. Be aware that each Dendrite component opens its own database connections and has its own connection limit, even in monolith mode!
-## What is being reported when enabling anonymous stats?
+## What is being reported when enabling phone-home statistics?
-If anonymous stats reporting is enabled, the following data is send to the defined endpoint.
+Phone-home statistics contain your server's domain name, some configuration information about
+your deployment and aggregated information about active users on your deployment. They are sent
+to the endpoint URL configured in your Dendrite configuration file only. The following is an
+example of the data that is sent:
```json
{
@@ -106,7 +109,7 @@ If anonymous stats reporting is enabled, the following data is send to the defin
"go_arch": "amd64",
"go_os": "linux",
"go_version": "go1.16.13",
- "homeserver": "localhost:8800",
+ "homeserver": "my.domain.com",
"log_level": "trace",
"memory_rss": 93452,
"monolith": true,
diff --git a/docs/Gemfile.lock b/docs/Gemfile.lock
index e62aa4ce3..88d33ae84 100644
--- a/docs/Gemfile.lock
+++ b/docs/Gemfile.lock
@@ -233,6 +233,8 @@ GEM
multipart-post (2.1.1)
nokogiri (1.13.6-arm64-darwin)
racc (~> 1.4)
+ nokogiri (1.13.6-x86_64-linux)
+ racc (~> 1.4)
octokit (4.22.0)
faraday (>= 0.9)
sawyer (~> 0.8.0, >= 0.5.3)
@@ -263,7 +265,7 @@ GEM
thread_safe (0.3.6)
typhoeus (1.4.0)
ethon (>= 0.9.0)
- tzinfo (1.2.9)
+ tzinfo (1.2.10)
thread_safe (~> 0.1)
unf (0.1.4)
unf_ext
@@ -273,11 +275,11 @@ GEM
PLATFORMS
arm64-darwin-21
+ x86_64-linux
DEPENDENCIES
github-pages (~> 226)
jekyll-feed (~> 0.15.1)
- minima (~> 2.5.1)
BUNDLED WITH
2.3.7
diff --git a/docs/administration/1_createusers.md b/docs/administration/1_createusers.md
index f40b7f576..61ec2299b 100644
--- a/docs/administration/1_createusers.md
+++ b/docs/administration/1_createusers.md
@@ -32,6 +32,15 @@ To create a new **admin account**, add the `-admin` flag:
./bin/create-account -config /path/to/dendrite.yaml -username USERNAME -admin
```
+An example of using `create-account` when running in **Docker**, having found the `CONTAINERNAME` from `docker ps`:
+
+```bash
+docker exec -it CONTAINERNAME /usr/bin/create-account -config /path/to/dendrite.yaml -username USERNAME
+```
+```bash
+docker exec -it CONTAINERNAME /usr/bin/create-account -config /path/to/dendrite.yaml -username USERNAME -admin
+```
+
## Using shared secret registration
Dendrite supports the Synapse-compatible shared secret registration endpoint.
diff --git a/docs/administration/5_troubleshooting.md b/docs/administration/5_troubleshooting.md
new file mode 100644
index 000000000..14df2e3fb
--- /dev/null
+++ b/docs/administration/5_troubleshooting.md
@@ -0,0 +1,81 @@
+---
+title: Troubleshooting
+parent: Administration
+permalink: /administration/troubleshooting
+---
+
+# Troubleshooting
+
+If your Dendrite installation is acting strangely, there are a few things you should
+check before seeking help.
+
+## 1. Logs
+
+Dendrite, by default, will log all warnings and errors to stdout, in addition to any
+other locations configured in the `dendrite.yaml` configuration file. Often there will
+be clues in the logs.
+
+You can increase this log level to the more verbose `debug` level if necessary by adding
+this to the config and restarting Dendrite:
+
+```
+logging:
+- type: std
+ level: debug
+```
+
+Look specifically for lines that contain `level=error` or `level=warning`.
+
+## 2. Federation tester
+
+If you are experiencing problems federating with other homeservers, you should check
+that the [Federation Tester](https://federationtester.matrix.org) is passing for your
+server.
+
+Common reasons that it may not pass include:
+
+1. Incorrect DNS configuration;
+2. Misconfigured DNS SRV entries or well-known files;
+3. Invalid TLS/SSL certificates;
+4. Reverse proxy configuration issues (if applicable).
+
+Correct any errors if shown and re-run the federation tester to check the results.
+
+## 3. System time
+
+Matrix relies heavily on TLS which requires the system time to be correct. If the clock
+drifts then you may find that federation no works reliably (or at all) and clients may
+struggle to connect to your Dendrite server.
+
+Ensure that your system time is correct and consider syncing to a reliable NTP source.
+
+## 4. Database connections
+
+If you are using the PostgreSQL database, you should ensure that Dendrite's configured
+number of database connections does not exceed the maximum allowed by PostgreSQL.
+
+Open your `postgresql.conf` configuration file and check the value of `max_connections`
+(which is typically `100` by default). Then open your `dendrite.yaml` configuration file
+and ensure that:
+
+1. If you are using the `global.database` section, that `max_open_conns` does not exceed
+ that number;
+2. If you are **not** using the `global.database` section, that the sum total of all
+ `max_open_conns` across all `database` blocks does not exceed that number.
+
+## 5. File descriptors
+
+Dendrite requires a sufficient number of file descriptors for every connection it makes
+to a remote server, every connection to the database engine and every file it is reading
+or writing to at a given time (media, logs etc). We recommend ensuring that the limit is
+no lower than 65535 for Dendrite.
+
+Dendrite will check at startup if there are a sufficient number of available descriptors.
+If there aren't, you will see a log lines like this:
+
+```
+level=warning msg="IMPORTANT: Process file descriptor limit is currently 65535, it is recommended to raise the limit for Dendrite to at least 65535 to avoid issues"
+```
+
+Follow the [Optimisation](../installation/10_optimisation.md) instructions to correct the
+available number of file descriptors.
diff --git a/docs/caddy/monolith/CaddyFile b/docs/caddy/monolith/CaddyFile
deleted file mode 100644
index cd93f9e10..000000000
--- a/docs/caddy/monolith/CaddyFile
+++ /dev/null
@@ -1,68 +0,0 @@
-{
- # debug
- admin off
- email example@example.com
- default_sni example.com
- # Debug endpoint
- # acme_ca https://acme-staging-v02.api.letsencrypt.org/directory
-}
-
-#######################################################################
-# Snippets
-#______________________________________________________________________
-
-(handle_errors_maintenance) {
- handle_errors {
- @maintenance expression {http.error.status_code} == 502
- rewrite @maintenance maintenance.html
- root * "/path/to/service/pages"
- file_server
- }
-}
-
-(matrix-well-known-header) {
- # Headers
- header Access-Control-Allow-Origin "*"
- header Access-Control-Allow-Methods "GET, POST, PUT, DELETE, OPTIONS"
- header Access-Control-Allow-Headers "Origin, X-Requested-With, Content-Type, Accept, Authorization"
- header Content-Type "application/json"
-}
-
-#######################################################################
-
-example.com {
-
- # ...
-
- handle /.well-known/matrix/server {
- import matrix-well-known-header
- respond `{ "m.server": "matrix.example.com:443" }` 200
- }
-
- handle /.well-known/matrix/client {
- import matrix-well-known-header
- respond `{ "m.homeserver": { "base_url": "https://matrix.example.com" } }` 200
- }
-
- import handle_errors_maintenance
-}
-
-example.com:8448 {
- # server<->server HTTPS traffic
- reverse_proxy http://dendrite-host:8008
-}
-
-matrix.example.com {
-
- handle /_matrix/* {
- # client<->server HTTPS traffic
- reverse_proxy http://dendrite-host:8008
- }
-
- handle_path /* {
- # Client webapp (Element SPA or ...)
- file_server {
- root /path/to/www/example.com/matrix-web-client/
- }
- }
-}
diff --git a/docs/caddy/monolith/Caddyfile b/docs/caddy/monolith/Caddyfile
new file mode 100644
index 000000000..82567c4a6
--- /dev/null
+++ b/docs/caddy/monolith/Caddyfile
@@ -0,0 +1,57 @@
+# Sample Caddyfile for using Caddy in front of Dendrite.
+#
+# Customize email address and domain names.
+# Optional settings commented out.
+#
+# BE SURE YOUR DOMAINS ARE POINTED AT YOUR SERVER FIRST.
+# Documentation: https://caddyserver.com/docs/
+#
+# Bonus tip: If your IP address changes, use Caddy's
+# dynamic DNS plugin to update your DNS records to
+# point to your new IP automatically:
+# https://github.com/mholt/caddy-dynamicdns
+#
+
+
+# Global options block
+{
+ # In case there is a problem with your certificates.
+ # email example@example.com
+
+ # Turn off the admin endpoint if you don't need graceful config
+ # changes and/or are running untrusted code on your machine.
+ # admin off
+
+ # Enable this if your clients don't send ServerName in TLS handshakes.
+ # default_sni example.com
+
+ # Enable debug mode for verbose logging.
+ # debug
+
+ # Use Let's Encrypt's staging endpoint for testing.
+ # acme_ca https://acme-staging-v02.api.letsencrypt.org/directory
+
+ # If you're port-forwarding HTTP/HTTPS ports from 80/443 to something
+ # else, enable these and put the alternate port numbers here.
+ # http_port 8080
+ # https_port 8443
+}
+
+# The server name of your matrix homeserver. This example shows
+# "well-known delegation" from the registered domain to a subdomain,
+# which is only needed if your server_name doesn't match your Matrix
+# homeserver URL (i.e. you can show users a vanity domain that looks
+# nice and is easy to remember but still have your Matrix server on
+# its own subdomain or hosted service).
+example.com {
+ header /.well-known/matrix/* Content-Type application/json
+ header /.well-known/matrix/* Access-Control-Allow-Origin *
+ respond /.well-known/matrix/server `{"m.server": "matrix.example.com:443"}`
+ respond /.well-known/matrix/client `{"m.homeserver": {"base_url": "https://matrix.example.com"}}`
+}
+
+# The actual domain name whereby your Matrix server is accessed.
+matrix.example.com {
+ # Set localhost:8008 to the address of your Dendrite server, if different
+ reverse_proxy /_matrix/* localhost:8008
+}
diff --git a/docs/caddy/polylith/Caddyfile b/docs/caddy/polylith/Caddyfile
new file mode 100644
index 000000000..244e50e7e
--- /dev/null
+++ b/docs/caddy/polylith/Caddyfile
@@ -0,0 +1,66 @@
+# Sample Caddyfile for using Caddy in front of Dendrite.
+#
+# Customize email address and domain names.
+# Optional settings commented out.
+#
+# BE SURE YOUR DOMAINS ARE POINTED AT YOUR SERVER FIRST.
+# Documentation: https://caddyserver.com/docs/
+#
+# Bonus tip: If your IP address changes, use Caddy's
+# dynamic DNS plugin to update your DNS records to
+# point to your new IP automatically:
+# https://github.com/mholt/caddy-dynamicdns
+#
+
+
+# Global options block
+{
+ # In case there is a problem with your certificates.
+ # email example@example.com
+
+ # Turn off the admin endpoint if you don't need graceful config
+ # changes and/or are running untrusted code on your machine.
+ # admin off
+
+ # Enable this if your clients don't send ServerName in TLS handshakes.
+ # default_sni example.com
+
+ # Enable debug mode for verbose logging.
+ # debug
+
+ # Use Let's Encrypt's staging endpoint for testing.
+ # acme_ca https://acme-staging-v02.api.letsencrypt.org/directory
+
+ # If you're port-forwarding HTTP/HTTPS ports from 80/443 to something
+ # else, enable these and put the alternate port numbers here.
+ # http_port 8080
+ # https_port 8443
+}
+
+# The server name of your matrix homeserver. This example shows
+# "well-known delegation" from the registered domain to a subdomain,
+# which is only needed if your server_name doesn't match your Matrix
+# homeserver URL (i.e. you can show users a vanity domain that looks
+# nice and is easy to remember but still have your Matrix server on
+# its own subdomain or hosted service).
+example.com {
+ header /.well-known/matrix/* Content-Type application/json
+ header /.well-known/matrix/* Access-Control-Allow-Origin *
+ respond /.well-known/matrix/server `{"m.server": "matrix.example.com:443"}`
+ respond /.well-known/matrix/client `{"m.homeserver": {"base_url": "https://matrix.example.com"}}`
+}
+
+# The actual domain name whereby your Matrix server is accessed.
+matrix.example.com {
+ # Change the end of each reverse_proxy line to the correct
+ # address for your various services.
+ @sync_api {
+ path_regexp /_matrix/client/.*?/(sync|user/.*?/filter/?.*|keys/changes|rooms/.*?/messages)$
+ }
+ reverse_proxy @sync_api sync_api:8073
+
+ reverse_proxy /_matrix/client* client_api:8071
+ reverse_proxy /_matrix/federation* federation_api:8071
+ reverse_proxy /_matrix/key* federation_api:8071
+ reverse_proxy /_matrix/media* media_api:8071
+}
diff --git a/docs/installation/9_starting_polylith.md b/docs/installation/10_starting_polylith.md
similarity index 99%
rename from docs/installation/9_starting_polylith.md
rename to docs/installation/10_starting_polylith.md
index 228e52e85..0c2e2af2b 100644
--- a/docs/installation/9_starting_polylith.md
+++ b/docs/installation/10_starting_polylith.md
@@ -2,7 +2,7 @@
title: Starting the polylith
parent: Installation
has_toc: true
-nav_order: 9
+nav_order: 10
permalink: /installation/start/polylith
---
diff --git a/docs/installation/10_optimisation.md b/docs/installation/11_optimisation.md
similarity index 99%
rename from docs/installation/10_optimisation.md
rename to docs/installation/11_optimisation.md
index c19b7a75e..f2f67c947 100644
--- a/docs/installation/10_optimisation.md
+++ b/docs/installation/11_optimisation.md
@@ -2,7 +2,7 @@
title: Optimise your installation
parent: Installation
has_toc: true
-nav_order: 10
+nav_order: 11
permalink: /installation/start/optimisation
---
diff --git a/docs/installation/1_planning.md b/docs/installation/1_planning.md
index d4f3d7052..3aa5b4d85 100644
--- a/docs/installation/1_planning.md
+++ b/docs/installation/1_planning.md
@@ -95,12 +95,13 @@ enabled.
To do so, follow the [NATS Server installation instructions](https://docs.nats.io/running-a-nats-service/introduction/installation) and then [start your NATS deployment](https://docs.nats.io/running-a-nats-service/introduction/running). JetStream must be enabled, either by passing the `-js` flag to `nats-server`,
or by specifying the `store_dir` option in the the `jetstream` configuration.
-### Reverse proxy (polylith deployments)
+### Reverse proxy
-Polylith deployments require a reverse proxy, such as [NGINX](https://www.nginx.com) or
-[HAProxy](http://www.haproxy.org). Configuring those is not covered in this documentation,
-although a [sample configuration for NGINX](https://github.com/matrix-org/dendrite/blob/main/docs/nginx/polylith-sample.conf)
-is provided.
+A reverse proxy such as [Caddy](https://caddyserver.com), [NGINX](https://www.nginx.com) or
+[HAProxy](http://www.haproxy.org) is required for polylith deployments and is useful for monolith
+deployments. Configuring those is not covered in this documentation, although sample configurations
+for [Caddy](https://github.com/matrix-org/dendrite/blob/main/docs/caddy) and
+[NGINX](https://github.com/matrix-org/dendrite/blob/main/docs/nginx) are provided.
### Windows
diff --git a/docs/installation/2_domainname.md b/docs/installation/2_domainname.md
index 0d4300eca..7d7fc86bd 100644
--- a/docs/installation/2_domainname.md
+++ b/docs/installation/2_domainname.md
@@ -14,27 +14,38 @@ that take the format `@user:example.com`.
For federation to work, the server name must be resolvable by other homeservers on the internet
— that is, the domain must be registered and properly configured with the relevant DNS records.
-Matrix servers discover each other when federating using the following methods:
+Matrix servers usually discover each other when federating using the following methods:
-1. If a well-known delegation exists on `example.com`, use the path server from the
+1. If a well-known delegation exists on `example.com`, use the domain and port from the
well-known file to connect to the remote homeserver;
-2. If a DNS SRV delegation exists on `example.com`, use the hostname and port from the DNS SRV
+2. If a DNS SRV delegation exists on `example.com`, use the IP address and port from the DNS SRV
record to connect to the remote homeserver;
3. If neither well-known or DNS SRV delegation are configured, attempt to connect to the remote
homeserver by connecting to `example.com` port TCP/8448 using HTTPS.
+The exact details of how server name resolution works can be found in
+[the spec](https://spec.matrix.org/v1.3/server-server-api/#resolving-server-names).
+
## TLS certificates
Matrix federation requires that valid TLS certificates are present on the domain. You must
-obtain certificates from a publicly accepted Certificate Authority (CA). [LetsEncrypt](https://letsencrypt.org)
-is an example of such a CA that can be used. Self-signed certificates are not suitable for
-federation and will typically not be accepted by other homeservers.
+obtain certificates from a publicly-trusted certificate authority (CA). [Let's Encrypt](https://letsencrypt.org)
+is a popular choice of CA because the certificates are publicly-trusted, free, and automated
+via the ACME protocol. (Self-signed certificates are not suitable for federation and will typically
+not be accepted by other homeservers.)
-A common practice to help ease the management of certificates is to install a reverse proxy in
-front of Dendrite which manages the TLS certificates and HTTPS proxying itself. Software such as
-[NGINX](https://www.nginx.com) and [HAProxy](http://www.haproxy.org) can be used for the task.
-Although the finer details of configuring these are not described here, you must reverse proxy
-all `/_matrix` paths to your Dendrite server.
+Automating the renewal of TLS certificates is best practice. There are many tools for this,
+but the simplest way to achieve TLS automation is to have your reverse proxy do it for you.
+[Caddy](https://caddyserver.com) is recommended as a production-grade reverse proxy with
+automatic TLS which is commonly used in front of Dendrite. It obtains and renews TLS certificates
+automatically and by default as long as your domain name is pointed at your server first.
+Although the finer details of [configuring Caddy](https://caddyserver.com/docs/) is not described
+here, in general, you must reverse proxy all `/_matrix` paths to your Dendrite server. For example,
+with Caddy:
+
+```
+reverse_proxy /_matrix/* localhost:8008
+```
It is possible for the reverse proxy to listen on the standard HTTPS port TCP/443 so long as your
domain delegation is configured to point to port TCP/443.
@@ -51,17 +62,12 @@ you will be able to delegate from `example.com` to `matrix.example.com` so that
Delegation can be performed in one of two ways:
-* **Well-known delegation**: A well-known text file is served over HTTPS on the domain name
- that you want to use, pointing to your server on `matrix.example.com` port 8448;
-* **DNS SRV delegation**: A DNS SRV record is created on the domain name that you want to
- use, pointing to your server on `matrix.example.com` port TCP/8448.
+* **Well-known delegation (preferred)**: A well-known text file is served over HTTPS on the domain
+ name that you want to use, pointing to your server on `matrix.example.com` port 8448;
+* **DNS SRV delegation (not recommended)**: See the SRV delegation section below for details.
-If you are using a reverse proxy to forward `/_matrix` to Dendrite, your well-known or DNS SRV
-delegation must refer to the hostname and port that the reverse proxy is listening on instead.
-
-Well-known delegation is typically easier to set up and usually preferred. However, you can use
-either or both methods to delegate. If you configure both methods of delegation, it is important
-that they both agree and refer to the same hostname and port.
+If you are using a reverse proxy to forward `/_matrix` to Dendrite, your well-known or delegation
+must refer to the hostname and port that the reverse proxy is listening on instead.
## Well-known delegation
@@ -74,20 +80,46 @@ and contain the following JSON document:
```json
{
- "m.server": "https://matrix.example.com:8448"
+ "m.server": "matrix.example.com:8448"
}
```
+For example, this can be done with the following Caddy config:
+
+```
+handle /.well-known/matrix/client {
+ header Content-Type application/json
+ header Access-Control-Allow-Origin *
+ respond `{"m.homeserver": {"base_url": "https://matrix.example.com:8448"}}`
+}
+```
+
+You can also serve `.well-known` with Dendrite itself by setting the `well_known_server_name` config
+option to the value you want for `m.server`. This is primarily useful if Dendrite is exposed on
+`example.com:443` and you don't want to set up a separate webserver just for serving the `.well-known`
+file.
+
+```yaml
+global:
+...
+ well_known_server_name: "example.com:443"
+```
+
## DNS SRV delegation
-Using DNS SRV delegation requires creating DNS SRV records on the `example.com` zone which
-refer to your Dendrite installation.
+This method is not recommended, as the behavior of SRV records in Matrix is rather unintuitive:
+SRV records will only change the IP address and port that other servers connect to, they won't
+affect the domain name. In technical terms, the `Host` header and TLS SNI of federation requests
+will still be `example.com` even if the SRV record points at `matrix.example.com`.
-Assuming that your Dendrite installation is listening for HTTPS connections at `matrix.example.com`
-port 8448, the DNS SRV record must have the following fields:
+In practice, this means that the server must be configured with valid TLS certificates for
+`example.com`, rather than `matrix.example.com` as one might intuitively expect. If there's a
+reverse proxy in between, the proxy configuration must be written as if it's `example.com`, as the
+proxy will never see the name `matrix.example.com` in incoming requests.
-* Name: `@` (or whichever term your DNS provider uses to signal the root)
-* Service: `_matrix`
-* Protocol: `_tcp`
-* Port: `8448`
-* Target: `matrix.example.com`
+This behavior also means that if `example.com` and `matrix.example.com` point at the same IP
+address, there is no reason to have a SRV record pointing at `matrix.example.com`. It can still
+be used to change the port number, but it won't do anything else.
+
+If you understand how SRV records work and still want to use them, the service name is `_matrix` and
+the protocol is `_tcp`.
diff --git a/docs/installation/3_build.md b/docs/installation/3_build.md
new file mode 100644
index 000000000..aed2080db
--- /dev/null
+++ b/docs/installation/3_build.md
@@ -0,0 +1,38 @@
+---
+title: Building Dendrite
+parent: Installation
+has_toc: true
+nav_order: 3
+permalink: /installation/build
+---
+
+# Build all Dendrite commands
+
+Dendrite has numerous utility commands in addition to the actual server binaries.
+Build them all from the root of the source repo with `build.sh` (Linux/Mac):
+
+```sh
+./build.sh
+```
+
+or `build.cmd` (Windows):
+
+```powershell
+build.cmd
+```
+
+The resulting binaries will be placed in the `bin` subfolder.
+
+# Installing as a monolith
+
+You can install the Dendrite monolith binary into `$GOPATH/bin` by using `go install`:
+
+```sh
+go install ./cmd/dendrite-monolith-server
+```
+
+Alternatively, you can specify a custom path for the binary to be written to using `go build`:
+
+```sh
+go build -o /usr/local/bin/ ./cmd/dendrite-monolith-server
+```
diff --git a/docs/installation/3_database.md b/docs/installation/4_database.md
similarity index 96%
rename from docs/installation/3_database.md
rename to docs/installation/4_database.md
index f64fe9150..f6222a8d2 100644
--- a/docs/installation/3_database.md
+++ b/docs/installation/4_database.md
@@ -17,7 +17,9 @@ filenames in the Dendrite configuration file and start Dendrite. The databases w
and populated automatically.
Note that Dendrite **cannot share a single SQLite database across multiple components**. Each
-component must be configured with its own SQLite database filename.
+component must be configured with its own SQLite database filename. You will have to remove
+the `global.database` section from your Dendrite config and add it to each individual section
+instead in order to use SQLite.
### Connection strings
diff --git a/docs/installation/6_install_polylith.md b/docs/installation/6_install_polylith.md
index 375512f8f..ec4a77628 100644
--- a/docs/installation/6_install_polylith.md
+++ b/docs/installation/6_install_polylith.md
@@ -29,5 +29,6 @@ Polylith deployments require a reverse proxy in order to ensure that requests ar
sent to the correct endpoint. You must ensure that a suitable reverse proxy is installed
and configured.
-A [sample configuration file](https://github.com/matrix-org/dendrite/blob/main/docs/nginx/polylith-sample.conf)
-is provided for [NGINX](https://www.nginx.com).
+Sample configurations are provided
+for [Caddy](https://github.com/matrix-org/dendrite/blob/main/docs/caddy/polylith/Caddyfile)
+and [NGINX](https://github.com/matrix-org/dendrite/blob/main/docs/nginx/polylith-sample.conf).
\ No newline at end of file
diff --git a/docs/installation/7_configuration.md b/docs/installation/7_configuration.md
index e676afbe6..b1c747414 100644
--- a/docs/installation/7_configuration.md
+++ b/docs/installation/7_configuration.md
@@ -1,13 +1,13 @@
---
-title: Populate the configuration
+title: Configuring Dendrite
parent: Installation
nav_order: 7
permalink: /installation/configuration
---
-# Populate the configuration
+# Configuring Dendrite
-The configuration file is used to configure Dendrite. Sample configuration files are
+A YAML configuration file is used to configure Dendrite. Sample configuration files are
present in the top level of the Dendrite repository:
* [`dendrite-sample.monolith.yaml`](https://github.com/matrix-org/dendrite/blob/main/dendrite-sample.monolith.yaml)
diff --git a/docs/installation/4_signingkey.md b/docs/installation/8_signingkey.md
similarity index 99%
rename from docs/installation/4_signingkey.md
rename to docs/installation/8_signingkey.md
index 07dc485ff..323759a88 100644
--- a/docs/installation/4_signingkey.md
+++ b/docs/installation/8_signingkey.md
@@ -1,7 +1,7 @@
---
title: Generating signing keys
parent: Installation
-nav_order: 4
+nav_order: 8
permalink: /installation/signingkeys
---
diff --git a/docs/installation/8_starting_monolith.md b/docs/installation/9_starting_monolith.md
similarity index 83%
rename from docs/installation/8_starting_monolith.md
rename to docs/installation/9_starting_monolith.md
index e0e7309d2..124477e73 100644
--- a/docs/installation/8_starting_monolith.md
+++ b/docs/installation/9_starting_monolith.md
@@ -15,8 +15,9 @@ you can start your Dendrite monolith deployment by starting the `dendrite-monoli
./dendrite-monolith-server -config /path/to/dendrite.yaml
```
-If you want to change the addresses or ports that Dendrite listens on, you
-can use the `-http-bind-address` and `-https-bind-address` command line arguments:
+By default, Dendrite will listen HTTP on port 8008. If you want to change the addresses
+or ports that Dendrite listens on, you can use the `-http-bind-address` and
+`-https-bind-address` command line arguments:
```bash
./dendrite-monolith-server -config /path/to/dendrite.yaml \
diff --git a/federationapi/api/api.go b/federationapi/api/api.go
index 53d4701f3..292ed55ad 100644
--- a/federationapi/api/api.go
+++ b/federationapi/api/api.go
@@ -110,7 +110,7 @@ type FederationClientError struct {
Blacklisted bool
}
-func (e *FederationClientError) Error() string {
+func (e FederationClientError) Error() string {
return fmt.Sprintf("%s - (retry_after=%s, blacklisted=%v)", e.Err, e.RetryAfter.String(), e.Blacklisted)
}
diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go
index e50ec66ad..2622ecb3f 100644
--- a/federationapi/consumers/roomserver.go
+++ b/federationapi/consumers/roomserver.go
@@ -208,9 +208,11 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew
// joinedHostsAtEvent works out a list of matrix servers that were joined to
// the room at the event (including peeking ones)
// It is important to use the state at the event for sending messages because:
-// 1) We shouldn't send messages to servers that weren't in the room.
-// 2) If a server is kicked from the rooms it should still be told about the
-// kick event,
+//
+// 1. We shouldn't send messages to servers that weren't in the room.
+// 2. If a server is kicked from the rooms it should still be told about the
+// kick event.
+//
// Usually the list can be calculated locally, but sometimes it will need fetch
// events from the room server.
// Returns an error if there was a problem talking to the room server.
diff --git a/federationapi/consumers/sendtodevice.go b/federationapi/consumers/sendtodevice.go
index 84c9f620d..f99a895e0 100644
--- a/federationapi/consumers/sendtodevice.go
+++ b/federationapi/consumers/sendtodevice.go
@@ -95,6 +95,11 @@ func (t *OutputSendToDeviceConsumer) onMessage(ctx context.Context, msg *nats.Ms
return true
}
+ // The SyncAPI is already handling sendToDevice for the local server
+ if destServerName == t.ServerName {
+ return true
+ }
+
// Pack the EDU and marshal it
edu := &gomatrixserverlib.EDU{
Type: gomatrixserverlib.MDirectToDevice,
diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go
index 97bcc12a5..ff01b1952 100644
--- a/federationapi/federationapi.go
+++ b/federationapi/federationapi.go
@@ -15,6 +15,8 @@
package federationapi
import (
+ "time"
+
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/federationapi/api"
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
@@ -167,5 +169,16 @@ func NewInternalAPI(
if err = presenceConsumer.Start(); err != nil {
logrus.WithError(err).Panic("failed to start presence consumer")
}
+
+ var cleanExpiredEDUs func()
+ cleanExpiredEDUs = func() {
+ logrus.Infof("Cleaning expired EDUs")
+ if err := federationDB.DeleteExpiredEDUs(base.Context()); err != nil {
+ logrus.WithError(err).Error("Failed to clean expired EDUs")
+ }
+ time.AfterFunc(time.Hour, cleanExpiredEDUs)
+ }
+ time.AfterFunc(time.Minute, cleanExpiredEDUs)
+
return internal.NewFederationInternalAPI(federationDB, cfg, rsAPI, federation, stats, caches, queues, keyRing)
}
diff --git a/federationapi/federationapi_keys_test.go b/federationapi/federationapi_keys_test.go
index d1bfe1847..9c3446222 100644
--- a/federationapi/federationapi_keys_test.go
+++ b/federationapi/federationapi_keys_test.go
@@ -6,7 +6,7 @@ import (
"crypto/ed25519"
"encoding/json"
"fmt"
- "io/ioutil"
+ "io"
"net/http"
"os"
"testing"
@@ -66,7 +66,7 @@ func TestMain(m *testing.M) {
s.cache = caching.NewRistrettoCache(8*1024*1024, time.Hour, false)
// Create a temporary directory for JetStream.
- d, err := ioutil.TempDir("./", "jetstream*")
+ d, err := os.MkdirTemp("./", "jetstream*")
if err != nil {
panic(err)
}
@@ -136,7 +136,7 @@ func (m *MockRoundTripper) RoundTrip(req *http.Request) (res *http.Response, err
// And respond.
res = &http.Response{
StatusCode: 200,
- Body: ioutil.NopCloser(bytes.NewReader(body)),
+ Body: io.NopCloser(bytes.NewReader(body)),
}
return
}
diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go
index ae244c566..bdcb9f57c 100644
--- a/federationapi/federationapi_test.go
+++ b/federationapi/federationapi_test.go
@@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"strings"
+ "sync"
"testing"
"time"
@@ -31,11 +32,12 @@ type fedRoomserverAPI struct {
}
// PerformJoin will call this function
-func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) {
+func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) error {
if f.inputRoomEvents == nil {
- return
+ return nil
}
f.inputRoomEvents(ctx, req, res)
+ return nil
}
// keychange consumer calls this
@@ -48,6 +50,7 @@ func (f *fedRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *rsapi.Que
// TODO: This struct isn't generic, only works for TestFederationAPIJoinThenKeyUpdate
type fedClient struct {
+ fedClientMutex sync.Mutex
api.FederationClient
allowJoins []*test.Room
keys map[gomatrixserverlib.ServerName]struct {
@@ -59,6 +62,8 @@ type fedClient struct {
}
func (f *fedClient) GetServerKeys(ctx context.Context, matrixServer gomatrixserverlib.ServerName) (gomatrixserverlib.ServerKeys, error) {
+ f.fedClientMutex.Lock()
+ defer f.fedClientMutex.Unlock()
fmt.Println("GetServerKeys:", matrixServer)
var keys gomatrixserverlib.ServerKeys
var keyID gomatrixserverlib.KeyID
@@ -122,6 +127,8 @@ func (f *fedClient) MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName
return
}
func (f *fedClient) SendJoin(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) {
+ f.fedClientMutex.Lock()
+ defer f.fedClientMutex.Unlock()
for _, r := range f.allowJoins {
if r.ID == event.RoomID() {
r.InsertEvent(f.t, event.Headered(r.Version))
@@ -134,6 +141,8 @@ func (f *fedClient) SendJoin(ctx context.Context, s gomatrixserverlib.ServerName
}
func (f *fedClient) SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) {
+ f.fedClientMutex.Lock()
+ defer f.fedClientMutex.Unlock()
for _, edu := range t.EDUs {
if edu.Type == gomatrixserverlib.MDeviceListUpdate {
f.sentTxn = true
@@ -242,6 +251,8 @@ func testFederationAPIJoinThenKeyUpdate(t *testing.T, dbType test.DBType) {
testrig.MustPublishMsgs(t, jsctx, msg)
time.Sleep(500 * time.Millisecond)
+ fc.fedClientMutex.Lock()
+ defer fc.fedClientMutex.Unlock()
if !fc.sentTxn {
t.Fatalf("did not send device list update")
}
diff --git a/federationapi/inthttp/client.go b/federationapi/inthttp/client.go
index 295ddc495..812d3c6da 100644
--- a/federationapi/inthttp/client.go
+++ b/federationapi/inthttp/client.go
@@ -10,7 +10,6 @@ import (
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
- "github.com/opentracing/opentracing-go"
)
// HTTP paths for the internal HTTP API
@@ -48,7 +47,11 @@ func NewFederationAPIClient(federationSenderURL string, httpClient *http.Client,
if httpClient == nil {
return nil, errors.New("NewFederationInternalAPIHTTP: httpClient is ")
}
- return &httpFederationInternalAPI{federationSenderURL, httpClient, cache}, nil
+ return &httpFederationInternalAPI{
+ federationAPIURL: federationSenderURL,
+ httpClient: httpClient,
+ cache: cache,
+ }, nil
}
type httpFederationInternalAPI struct {
@@ -63,11 +66,10 @@ func (h *httpFederationInternalAPI) PerformLeave(
request *api.PerformLeaveRequest,
response *api.PerformLeaveResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLeaveRequest")
- defer span.Finish()
-
- apiURL := h.federationAPIURL + FederationAPIPerformLeaveRequestPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "PerformLeave", h.federationAPIURL+FederationAPIPerformLeaveRequestPath,
+ h.httpClient, ctx, request, response,
+ )
}
// Handle sending an invite to a remote server.
@@ -76,11 +78,10 @@ func (h *httpFederationInternalAPI) PerformInvite(
request *api.PerformInviteRequest,
response *api.PerformInviteResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformInviteRequest")
- defer span.Finish()
-
- apiURL := h.federationAPIURL + FederationAPIPerformInviteRequestPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "PerformInvite", h.federationAPIURL+FederationAPIPerformInviteRequestPath,
+ h.httpClient, ctx, request, response,
+ )
}
// Handle starting a peek on a remote server.
@@ -89,11 +90,10 @@ func (h *httpFederationInternalAPI) PerformOutboundPeek(
request *api.PerformOutboundPeekRequest,
response *api.PerformOutboundPeekResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformOutboundPeekRequest")
- defer span.Finish()
-
- apiURL := h.federationAPIURL + FederationAPIPerformOutboundPeekRequestPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "PerformOutboundPeek", h.federationAPIURL+FederationAPIPerformOutboundPeekRequestPath,
+ h.httpClient, ctx, request, response,
+ )
}
// QueryJoinedHostServerNamesInRoom implements FederationInternalAPI
@@ -102,11 +102,10 @@ func (h *httpFederationInternalAPI) QueryJoinedHostServerNamesInRoom(
request *api.QueryJoinedHostServerNamesInRoomRequest,
response *api.QueryJoinedHostServerNamesInRoomResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryJoinedHostServerNamesInRoom")
- defer span.Finish()
-
- apiURL := h.federationAPIURL + FederationAPIQueryJoinedHostServerNamesInRoomPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "QueryJoinedHostServerNamesInRoom", h.federationAPIURL+FederationAPIQueryJoinedHostServerNamesInRoomPath,
+ h.httpClient, ctx, request, response,
+ )
}
// Handle an instruction to make_join & send_join with a remote server.
@@ -115,12 +114,10 @@ func (h *httpFederationInternalAPI) PerformJoin(
request *api.PerformJoinRequest,
response *api.PerformJoinResponse,
) {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformJoinRequest")
- defer span.Finish()
-
- apiURL := h.federationAPIURL + FederationAPIPerformJoinRequestPath
- err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
- if err != nil {
+ if err := httputil.CallInternalRPCAPI(
+ "PerformJoinRequest", h.federationAPIURL+FederationAPIPerformJoinRequestPath,
+ h.httpClient, ctx, request, response,
+ ); err != nil {
response.LastError = &gomatrix.HTTPError{
Message: err.Error(),
Code: 0,
@@ -135,11 +132,10 @@ func (h *httpFederationInternalAPI) PerformDirectoryLookup(
request *api.PerformDirectoryLookupRequest,
response *api.PerformDirectoryLookupResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDirectoryLookup")
- defer span.Finish()
-
- apiURL := h.federationAPIURL + FederationAPIPerformDirectoryLookupRequestPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "PerformDirectoryLookup", h.federationAPIURL+FederationAPIPerformDirectoryLookupRequestPath,
+ h.httpClient, ctx, request, response,
+ )
}
// Handle an instruction to broadcast an EDU to all servers in rooms we are joined to.
@@ -148,101 +144,61 @@ func (h *httpFederationInternalAPI) PerformBroadcastEDU(
request *api.PerformBroadcastEDURequest,
response *api.PerformBroadcastEDUResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformBroadcastEDU")
- defer span.Finish()
-
- apiURL := h.federationAPIURL + FederationAPIPerformBroadcastEDUPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "PerformBroadcastEDU", h.federationAPIURL+FederationAPIPerformBroadcastEDUPath,
+ h.httpClient, ctx, request, response,
+ )
}
type getUserDevices struct {
S gomatrixserverlib.ServerName
UserID string
- Res *gomatrixserverlib.RespUserDevices
- Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) GetUserDevices(
ctx context.Context, s gomatrixserverlib.ServerName, userID string,
) (gomatrixserverlib.RespUserDevices, error) {
- span, ctx := opentracing.StartSpanFromContext(ctx, "GetUserDevices")
- defer span.Finish()
-
- var result gomatrixserverlib.RespUserDevices
- request := getUserDevices{
- S: s,
- UserID: userID,
- }
- var response getUserDevices
- apiURL := h.federationAPIURL + FederationAPIGetUserDevicesPath
- err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
- if err != nil {
- return result, err
- }
- if response.Err != nil {
- return result, response.Err
- }
- return *response.Res, nil
+ return httputil.CallInternalProxyAPI[getUserDevices, gomatrixserverlib.RespUserDevices, *api.FederationClientError](
+ "GetUserDevices", h.federationAPIURL+FederationAPIGetUserDevicesPath, h.httpClient,
+ ctx, &getUserDevices{
+ S: s,
+ UserID: userID,
+ },
+ )
}
type claimKeys struct {
S gomatrixserverlib.ServerName
OneTimeKeys map[string]map[string]string
- Res *gomatrixserverlib.RespClaimKeys
- Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) ClaimKeys(
ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string,
) (gomatrixserverlib.RespClaimKeys, error) {
- span, ctx := opentracing.StartSpanFromContext(ctx, "ClaimKeys")
- defer span.Finish()
-
- var result gomatrixserverlib.RespClaimKeys
- request := claimKeys{
- S: s,
- OneTimeKeys: oneTimeKeys,
- }
- var response claimKeys
- apiURL := h.federationAPIURL + FederationAPIClaimKeysPath
- err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
- if err != nil {
- return result, err
- }
- if response.Err != nil {
- return result, response.Err
- }
- return *response.Res, nil
+ return httputil.CallInternalProxyAPI[claimKeys, gomatrixserverlib.RespClaimKeys, *api.FederationClientError](
+ "ClaimKeys", h.federationAPIURL+FederationAPIClaimKeysPath, h.httpClient,
+ ctx, &claimKeys{
+ S: s,
+ OneTimeKeys: oneTimeKeys,
+ },
+ )
}
type queryKeys struct {
S gomatrixserverlib.ServerName
Keys map[string][]string
- Res *gomatrixserverlib.RespQueryKeys
- Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) QueryKeys(
ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string,
) (gomatrixserverlib.RespQueryKeys, error) {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeys")
- defer span.Finish()
-
- var result gomatrixserverlib.RespQueryKeys
- request := queryKeys{
- S: s,
- Keys: keys,
- }
- var response queryKeys
- apiURL := h.federationAPIURL + FederationAPIQueryKeysPath
- err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
- if err != nil {
- return result, err
- }
- if response.Err != nil {
- return result, response.Err
- }
- return *response.Res, nil
+ return httputil.CallInternalProxyAPI[queryKeys, gomatrixserverlib.RespQueryKeys, *api.FederationClientError](
+ "QueryKeys", h.federationAPIURL+FederationAPIQueryKeysPath, h.httpClient,
+ ctx, &queryKeys{
+ S: s,
+ Keys: keys,
+ },
+ )
}
type backfill struct {
@@ -250,32 +206,20 @@ type backfill struct {
RoomID string
Limit int
EventIDs []string
- Res *gomatrixserverlib.Transaction
- Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) Backfill(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string,
) (gomatrixserverlib.Transaction, error) {
- span, ctx := opentracing.StartSpanFromContext(ctx, "Backfill")
- defer span.Finish()
-
- request := backfill{
- S: s,
- RoomID: roomID,
- Limit: limit,
- EventIDs: eventIDs,
- }
- var response backfill
- apiURL := h.federationAPIURL + FederationAPIBackfillPath
- err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
- if err != nil {
- return gomatrixserverlib.Transaction{}, err
- }
- if response.Err != nil {
- return gomatrixserverlib.Transaction{}, response.Err
- }
- return *response.Res, nil
+ return httputil.CallInternalProxyAPI[backfill, gomatrixserverlib.Transaction, *api.FederationClientError](
+ "Backfill", h.federationAPIURL+FederationAPIBackfillPath, h.httpClient,
+ ctx, &backfill{
+ S: s,
+ RoomID: roomID,
+ Limit: limit,
+ EventIDs: eventIDs,
+ },
+ )
}
type lookupState struct {
@@ -283,63 +227,39 @@ type lookupState struct {
RoomID string
EventID string
RoomVersion gomatrixserverlib.RoomVersion
- Res *gomatrixserverlib.RespState
- Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) LookupState(
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion,
) (gomatrixserverlib.RespState, error) {
- span, ctx := opentracing.StartSpanFromContext(ctx, "LookupState")
- defer span.Finish()
-
- request := lookupState{
- S: s,
- RoomID: roomID,
- EventID: eventID,
- RoomVersion: roomVersion,
- }
- var response lookupState
- apiURL := h.federationAPIURL + FederationAPILookupStatePath
- err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
- if err != nil {
- return gomatrixserverlib.RespState{}, err
- }
- if response.Err != nil {
- return gomatrixserverlib.RespState{}, response.Err
- }
- return *response.Res, nil
+ return httputil.CallInternalProxyAPI[lookupState, gomatrixserverlib.RespState, *api.FederationClientError](
+ "LookupState", h.federationAPIURL+FederationAPILookupStatePath, h.httpClient,
+ ctx, &lookupState{
+ S: s,
+ RoomID: roomID,
+ EventID: eventID,
+ RoomVersion: roomVersion,
+ },
+ )
}
type lookupStateIDs struct {
S gomatrixserverlib.ServerName
RoomID string
EventID string
- Res *gomatrixserverlib.RespStateIDs
- Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) LookupStateIDs(
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string,
) (gomatrixserverlib.RespStateIDs, error) {
- span, ctx := opentracing.StartSpanFromContext(ctx, "LookupStateIDs")
- defer span.Finish()
-
- request := lookupStateIDs{
- S: s,
- RoomID: roomID,
- EventID: eventID,
- }
- var response lookupStateIDs
- apiURL := h.federationAPIURL + FederationAPILookupStateIDsPath
- err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
- if err != nil {
- return gomatrixserverlib.RespStateIDs{}, err
- }
- if response.Err != nil {
- return gomatrixserverlib.RespStateIDs{}, response.Err
- }
- return *response.Res, nil
+ return httputil.CallInternalProxyAPI[lookupStateIDs, gomatrixserverlib.RespStateIDs, *api.FederationClientError](
+ "LookupStateIDs", h.federationAPIURL+FederationAPILookupStateIDsPath, h.httpClient,
+ ctx, &lookupStateIDs{
+ S: s,
+ RoomID: roomID,
+ EventID: eventID,
+ },
+ )
}
type lookupMissingEvents struct {
@@ -347,64 +267,38 @@ type lookupMissingEvents struct {
RoomID string
Missing gomatrixserverlib.MissingEvents
RoomVersion gomatrixserverlib.RoomVersion
- Res struct {
- Events []gomatrixserverlib.RawJSON `json:"events"`
- }
- Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) LookupMissingEvents(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string,
missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.RespMissingEvents, err error) {
- span, ctx := opentracing.StartSpanFromContext(ctx, "LookupMissingEvents")
- defer span.Finish()
-
- request := lookupMissingEvents{
- S: s,
- RoomID: roomID,
- Missing: missing,
- RoomVersion: roomVersion,
- }
- apiURL := h.federationAPIURL + FederationAPILookupMissingEventsPath
- err = httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &request)
- if err != nil {
- return res, err
- }
- if request.Err != nil {
- return res, request.Err
- }
- res.Events = request.Res.Events
- return res, nil
+ return httputil.CallInternalProxyAPI[lookupMissingEvents, gomatrixserverlib.RespMissingEvents, *api.FederationClientError](
+ "LookupMissingEvents", h.federationAPIURL+FederationAPILookupMissingEventsPath, h.httpClient,
+ ctx, &lookupMissingEvents{
+ S: s,
+ RoomID: roomID,
+ Missing: missing,
+ RoomVersion: roomVersion,
+ },
+ )
}
type getEvent struct {
S gomatrixserverlib.ServerName
EventID string
- Res *gomatrixserverlib.Transaction
- Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) GetEvent(
ctx context.Context, s gomatrixserverlib.ServerName, eventID string,
) (gomatrixserverlib.Transaction, error) {
- span, ctx := opentracing.StartSpanFromContext(ctx, "GetEvent")
- defer span.Finish()
-
- request := getEvent{
- S: s,
- EventID: eventID,
- }
- var response getEvent
- apiURL := h.federationAPIURL + FederationAPIGetEventPath
- err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
- if err != nil {
- return gomatrixserverlib.Transaction{}, err
- }
- if response.Err != nil {
- return gomatrixserverlib.Transaction{}, response.Err
- }
- return *response.Res, nil
+ return httputil.CallInternalProxyAPI[getEvent, gomatrixserverlib.Transaction, *api.FederationClientError](
+ "GetEvent", h.federationAPIURL+FederationAPIGetEventPath, h.httpClient,
+ ctx, &getEvent{
+ S: s,
+ EventID: eventID,
+ },
+ )
}
type getEventAuth struct {
@@ -412,135 +306,86 @@ type getEventAuth struct {
RoomVersion gomatrixserverlib.RoomVersion
RoomID string
EventID string
- Res *gomatrixserverlib.RespEventAuth
- Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) GetEventAuth(
ctx context.Context, s gomatrixserverlib.ServerName,
roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string,
) (gomatrixserverlib.RespEventAuth, error) {
- span, ctx := opentracing.StartSpanFromContext(ctx, "GetEventAuth")
- defer span.Finish()
-
- request := getEventAuth{
- S: s,
- RoomVersion: roomVersion,
- RoomID: roomID,
- EventID: eventID,
- }
- var response getEventAuth
- apiURL := h.federationAPIURL + FederationAPIGetEventAuthPath
- err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
- if err != nil {
- return gomatrixserverlib.RespEventAuth{}, err
- }
- if response.Err != nil {
- return gomatrixserverlib.RespEventAuth{}, response.Err
- }
- return *response.Res, nil
+ return httputil.CallInternalProxyAPI[getEventAuth, gomatrixserverlib.RespEventAuth, *api.FederationClientError](
+ "GetEventAuth", h.federationAPIURL+FederationAPIGetEventAuthPath, h.httpClient,
+ ctx, &getEventAuth{
+ S: s,
+ RoomVersion: roomVersion,
+ RoomID: roomID,
+ EventID: eventID,
+ },
+ )
}
func (h *httpFederationInternalAPI) QueryServerKeys(
ctx context.Context, req *api.QueryServerKeysRequest, res *api.QueryServerKeysResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerKeys")
- defer span.Finish()
-
- apiURL := h.federationAPIURL + FederationAPIQueryServerKeysPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+ return httputil.CallInternalRPCAPI(
+ "QueryServerKeys", h.federationAPIURL+FederationAPIQueryServerKeysPath,
+ h.httpClient, ctx, req, res,
+ )
}
type lookupServerKeys struct {
S gomatrixserverlib.ServerName
KeyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp
- ServerKeys []gomatrixserverlib.ServerKeys
- Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) LookupServerKeys(
ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
) ([]gomatrixserverlib.ServerKeys, error) {
- span, ctx := opentracing.StartSpanFromContext(ctx, "LookupServerKeys")
- defer span.Finish()
-
- request := lookupServerKeys{
- S: s,
- KeyRequests: keyRequests,
- }
- var response lookupServerKeys
- apiURL := h.federationAPIURL + FederationAPILookupServerKeysPath
- err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
- if err != nil {
- return []gomatrixserverlib.ServerKeys{}, err
- }
- if response.Err != nil {
- return []gomatrixserverlib.ServerKeys{}, response.Err
- }
- return response.ServerKeys, nil
+ return httputil.CallInternalProxyAPI[lookupServerKeys, []gomatrixserverlib.ServerKeys, *api.FederationClientError](
+ "LookupServerKeys", h.federationAPIURL+FederationAPILookupServerKeysPath, h.httpClient,
+ ctx, &lookupServerKeys{
+ S: s,
+ KeyRequests: keyRequests,
+ },
+ )
}
type eventRelationships struct {
S gomatrixserverlib.ServerName
Req gomatrixserverlib.MSC2836EventRelationshipsRequest
RoomVer gomatrixserverlib.RoomVersion
- Res gomatrixserverlib.MSC2836EventRelationshipsResponse
- Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) MSC2836EventRelationships(
ctx context.Context, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest,
roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) {
- span, ctx := opentracing.StartSpanFromContext(ctx, "MSC2836EventRelationships")
- defer span.Finish()
-
- request := eventRelationships{
- S: s,
- Req: r,
- RoomVer: roomVersion,
- }
- var response eventRelationships
- apiURL := h.federationAPIURL + FederationAPIEventRelationshipsPath
- err = httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
- if err != nil {
- return res, err
- }
- if response.Err != nil {
- return res, response.Err
- }
- return response.Res, nil
+ return httputil.CallInternalProxyAPI[eventRelationships, gomatrixserverlib.MSC2836EventRelationshipsResponse, *api.FederationClientError](
+ "MSC2836EventRelationships", h.federationAPIURL+FederationAPIEventRelationshipsPath, h.httpClient,
+ ctx, &eventRelationships{
+ S: s,
+ Req: r,
+ RoomVer: roomVersion,
+ },
+ )
}
type spacesReq struct {
S gomatrixserverlib.ServerName
SuggestedOnly bool
RoomID string
- Res gomatrixserverlib.MSC2946SpacesResponse
- Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) MSC2946Spaces(
ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool,
) (res gomatrixserverlib.MSC2946SpacesResponse, err error) {
- span, ctx := opentracing.StartSpanFromContext(ctx, "MSC2946Spaces")
- defer span.Finish()
-
- request := spacesReq{
- S: dst,
- SuggestedOnly: suggestedOnly,
- RoomID: roomID,
- }
- var response spacesReq
- apiURL := h.federationAPIURL + FederationAPISpacesSummaryPath
- err = httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
- if err != nil {
- return res, err
- }
- if response.Err != nil {
- return res, response.Err
- }
- return response.Res, nil
+ return httputil.CallInternalProxyAPI[spacesReq, gomatrixserverlib.MSC2946SpacesResponse, *api.FederationClientError](
+ "MSC2836EventRelationships", h.federationAPIURL+FederationAPISpacesSummaryPath, h.httpClient,
+ ctx, &spacesReq{
+ S: dst,
+ SuggestedOnly: suggestedOnly,
+ RoomID: roomID,
+ },
+ )
}
func (s *httpFederationInternalAPI) KeyRing() *gomatrixserverlib.KeyRing {
@@ -614,11 +459,10 @@ func (h *httpFederationInternalAPI) InputPublicKeys(
request *api.InputPublicKeysRequest,
response *api.InputPublicKeysResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "InputPublicKey")
- defer span.Finish()
-
- apiURL := h.federationAPIURL + FederationAPIInputPublicKeyPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "InputPublicKey", h.federationAPIURL+FederationAPIInputPublicKeyPath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpFederationInternalAPI) QueryPublicKeys(
@@ -626,9 +470,8 @@ func (h *httpFederationInternalAPI) QueryPublicKeys(
request *api.QueryPublicKeysRequest,
response *api.QueryPublicKeysResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPublicKey")
- defer span.Finish()
-
- apiURL := h.federationAPIURL + FederationAPIQueryPublicKeyPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "QueryPublicKeys", h.federationAPIURL+FederationAPIQueryPublicKeyPath,
+ h.httpClient, ctx, request, response,
+ )
}
diff --git a/federationapi/inthttp/server.go b/federationapi/inthttp/server.go
index 28e52b32d..a8b829a71 100644
--- a/federationapi/inthttp/server.go
+++ b/federationapi/inthttp/server.go
@@ -1,12 +1,14 @@
package inthttp
import (
+ "context"
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/httputil"
+ "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
@@ -15,372 +17,180 @@ import (
func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
internalAPIMux.Handle(
FederationAPIQueryJoinedHostServerNamesInRoomPath,
- httputil.MakeInternalAPI("QueryJoinedHostServerNamesInRoom", func(req *http.Request) util.JSONResponse {
- var request api.QueryJoinedHostServerNamesInRoomRequest
- var response api.QueryJoinedHostServerNamesInRoomResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.ErrorResponse(err)
- }
- if err := intAPI.QueryJoinedHostServerNamesInRoom(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
- )
- internalAPIMux.Handle(
- FederationAPIPerformJoinRequestPath,
- httputil.MakeInternalAPI("PerformJoinRequest", func(req *http.Request) util.JSONResponse {
- var request api.PerformJoinRequest
- var response api.PerformJoinResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- intAPI.PerformJoin(req.Context(), &request, &response)
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
- )
- internalAPIMux.Handle(
- FederationAPIPerformLeaveRequestPath,
- httputil.MakeInternalAPI("PerformLeaveRequest", func(req *http.Request) util.JSONResponse {
- var request api.PerformLeaveRequest
- var response api.PerformLeaveResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := intAPI.PerformLeave(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+ httputil.MakeInternalRPCAPI("FederationAPIQueryJoinedHostServerNamesInRoom", intAPI.QueryJoinedHostServerNamesInRoom),
)
+
internalAPIMux.Handle(
FederationAPIPerformInviteRequestPath,
- httputil.MakeInternalAPI("PerformInviteRequest", func(req *http.Request) util.JSONResponse {
- var request api.PerformInviteRequest
- var response api.PerformInviteResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := intAPI.PerformInvite(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+ httputil.MakeInternalRPCAPI("FederationAPIPerformInvite", intAPI.PerformInvite),
)
+
+ internalAPIMux.Handle(
+ FederationAPIPerformLeaveRequestPath,
+ httputil.MakeInternalRPCAPI("FederationAPIPerformLeave", intAPI.PerformLeave),
+ )
+
internalAPIMux.Handle(
FederationAPIPerformDirectoryLookupRequestPath,
- httputil.MakeInternalAPI("PerformDirectoryLookupRequest", func(req *http.Request) util.JSONResponse {
- var request api.PerformDirectoryLookupRequest
- var response api.PerformDirectoryLookupResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := intAPI.PerformDirectoryLookup(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+ httputil.MakeInternalRPCAPI("FederationAPIPerformDirectoryLookupRequest", intAPI.PerformDirectoryLookup),
)
+
internalAPIMux.Handle(
FederationAPIPerformBroadcastEDUPath,
- httputil.MakeInternalAPI("PerformBroadcastEDU", func(req *http.Request) util.JSONResponse {
- var request api.PerformBroadcastEDURequest
- var response api.PerformBroadcastEDUResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := intAPI.PerformBroadcastEDU(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+ httputil.MakeInternalRPCAPI("FederationAPIPerformBroadcastEDU", intAPI.PerformBroadcastEDU),
)
+
+ internalAPIMux.Handle(
+ FederationAPIPerformJoinRequestPath,
+ httputil.MakeInternalRPCAPI(
+ "FederationAPIPerformJoinRequest",
+ func(ctx context.Context, req *api.PerformJoinRequest, res *api.PerformJoinResponse) error {
+ intAPI.PerformJoin(ctx, req, res)
+ return nil
+ },
+ ),
+ )
+
internalAPIMux.Handle(
FederationAPIGetUserDevicesPath,
- httputil.MakeInternalAPI("GetUserDevices", func(req *http.Request) util.JSONResponse {
- var request getUserDevices
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- res, err := intAPI.GetUserDevices(req.Context(), request.S, request.UserID)
- if err != nil {
- ferr, ok := err.(*api.FederationClientError)
- if ok {
- request.Err = ferr
- } else {
- request.Err = &api.FederationClientError{
- Err: err.Error(),
- }
- }
- }
- request.Res = &res
- return util.JSONResponse{Code: http.StatusOK, JSON: request}
- }),
+ httputil.MakeInternalProxyAPI(
+ "FederationAPIGetUserDevices",
+ func(ctx context.Context, req *getUserDevices) (*gomatrixserverlib.RespUserDevices, error) {
+ res, err := intAPI.GetUserDevices(ctx, req.S, req.UserID)
+ return &res, federationClientError(err)
+ },
+ ),
)
+
internalAPIMux.Handle(
FederationAPIClaimKeysPath,
- httputil.MakeInternalAPI("ClaimKeys", func(req *http.Request) util.JSONResponse {
- var request claimKeys
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- res, err := intAPI.ClaimKeys(req.Context(), request.S, request.OneTimeKeys)
- if err != nil {
- ferr, ok := err.(*api.FederationClientError)
- if ok {
- request.Err = ferr
- } else {
- request.Err = &api.FederationClientError{
- Err: err.Error(),
- }
- }
- }
- request.Res = &res
- return util.JSONResponse{Code: http.StatusOK, JSON: request}
- }),
+ httputil.MakeInternalProxyAPI(
+ "FederationAPIClaimKeys",
+ func(ctx context.Context, req *claimKeys) (*gomatrixserverlib.RespClaimKeys, error) {
+ res, err := intAPI.ClaimKeys(ctx, req.S, req.OneTimeKeys)
+ return &res, federationClientError(err)
+ },
+ ),
)
+
internalAPIMux.Handle(
FederationAPIQueryKeysPath,
- httputil.MakeInternalAPI("QueryKeys", func(req *http.Request) util.JSONResponse {
- var request queryKeys
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- res, err := intAPI.QueryKeys(req.Context(), request.S, request.Keys)
- if err != nil {
- ferr, ok := err.(*api.FederationClientError)
- if ok {
- request.Err = ferr
- } else {
- request.Err = &api.FederationClientError{
- Err: err.Error(),
- }
- }
- }
- request.Res = &res
- return util.JSONResponse{Code: http.StatusOK, JSON: request}
- }),
+ httputil.MakeInternalProxyAPI(
+ "FederationAPIQueryKeys",
+ func(ctx context.Context, req *queryKeys) (*gomatrixserverlib.RespQueryKeys, error) {
+ res, err := intAPI.QueryKeys(ctx, req.S, req.Keys)
+ return &res, federationClientError(err)
+ },
+ ),
)
+
internalAPIMux.Handle(
FederationAPIBackfillPath,
- httputil.MakeInternalAPI("Backfill", func(req *http.Request) util.JSONResponse {
- var request backfill
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- res, err := intAPI.Backfill(req.Context(), request.S, request.RoomID, request.Limit, request.EventIDs)
- if err != nil {
- ferr, ok := err.(*api.FederationClientError)
- if ok {
- request.Err = ferr
- } else {
- request.Err = &api.FederationClientError{
- Err: err.Error(),
- }
- }
- }
- request.Res = &res
- return util.JSONResponse{Code: http.StatusOK, JSON: request}
- }),
+ httputil.MakeInternalProxyAPI(
+ "FederationAPIBackfill",
+ func(ctx context.Context, req *backfill) (*gomatrixserverlib.Transaction, error) {
+ res, err := intAPI.Backfill(ctx, req.S, req.RoomID, req.Limit, req.EventIDs)
+ return &res, federationClientError(err)
+ },
+ ),
)
+
internalAPIMux.Handle(
FederationAPILookupStatePath,
- httputil.MakeInternalAPI("LookupState", func(req *http.Request) util.JSONResponse {
- var request lookupState
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- res, err := intAPI.LookupState(req.Context(), request.S, request.RoomID, request.EventID, request.RoomVersion)
- if err != nil {
- ferr, ok := err.(*api.FederationClientError)
- if ok {
- request.Err = ferr
- } else {
- request.Err = &api.FederationClientError{
- Err: err.Error(),
- }
- }
- }
- request.Res = &res
- return util.JSONResponse{Code: http.StatusOK, JSON: request}
- }),
+ httputil.MakeInternalProxyAPI(
+ "FederationAPILookupState",
+ func(ctx context.Context, req *lookupState) (*gomatrixserverlib.RespState, error) {
+ res, err := intAPI.LookupState(ctx, req.S, req.RoomID, req.EventID, req.RoomVersion)
+ return &res, federationClientError(err)
+ },
+ ),
)
+
internalAPIMux.Handle(
FederationAPILookupStateIDsPath,
- httputil.MakeInternalAPI("LookupStateIDs", func(req *http.Request) util.JSONResponse {
- var request lookupStateIDs
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- res, err := intAPI.LookupStateIDs(req.Context(), request.S, request.RoomID, request.EventID)
- if err != nil {
- ferr, ok := err.(*api.FederationClientError)
- if ok {
- request.Err = ferr
- } else {
- request.Err = &api.FederationClientError{
- Err: err.Error(),
- }
- }
- }
- request.Res = &res
- return util.JSONResponse{Code: http.StatusOK, JSON: request}
- }),
+ httputil.MakeInternalProxyAPI(
+ "FederationAPILookupStateIDs",
+ func(ctx context.Context, req *lookupStateIDs) (*gomatrixserverlib.RespStateIDs, error) {
+ res, err := intAPI.LookupStateIDs(ctx, req.S, req.RoomID, req.EventID)
+ return &res, federationClientError(err)
+ },
+ ),
)
+
internalAPIMux.Handle(
FederationAPILookupMissingEventsPath,
- httputil.MakeInternalAPI("LookupMissingEvents", func(req *http.Request) util.JSONResponse {
- var request lookupMissingEvents
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- res, err := intAPI.LookupMissingEvents(req.Context(), request.S, request.RoomID, request.Missing, request.RoomVersion)
- if err != nil {
- ferr, ok := err.(*api.FederationClientError)
- if ok {
- request.Err = ferr
- } else {
- request.Err = &api.FederationClientError{
- Err: err.Error(),
- }
- }
- }
- for _, event := range res.Events {
- js, err := json.Marshal(event)
- if err != nil {
- return util.MessageResponse(http.StatusInternalServerError, err.Error())
- }
- request.Res.Events = append(request.Res.Events, js)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: request}
- }),
+ httputil.MakeInternalProxyAPI(
+ "FederationAPILookupMissingEvents",
+ func(ctx context.Context, req *lookupMissingEvents) (*gomatrixserverlib.RespMissingEvents, error) {
+ res, err := intAPI.LookupMissingEvents(ctx, req.S, req.RoomID, req.Missing, req.RoomVersion)
+ return &res, federationClientError(err)
+ },
+ ),
)
+
internalAPIMux.Handle(
FederationAPIGetEventPath,
- httputil.MakeInternalAPI("GetEvent", func(req *http.Request) util.JSONResponse {
- var request getEvent
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- res, err := intAPI.GetEvent(req.Context(), request.S, request.EventID)
- if err != nil {
- ferr, ok := err.(*api.FederationClientError)
- if ok {
- request.Err = ferr
- } else {
- request.Err = &api.FederationClientError{
- Err: err.Error(),
- }
- }
- }
- request.Res = &res
- return util.JSONResponse{Code: http.StatusOK, JSON: request}
- }),
+ httputil.MakeInternalProxyAPI(
+ "FederationAPIGetEvent",
+ func(ctx context.Context, req *getEvent) (*gomatrixserverlib.Transaction, error) {
+ res, err := intAPI.GetEvent(ctx, req.S, req.EventID)
+ return &res, federationClientError(err)
+ },
+ ),
)
+
internalAPIMux.Handle(
FederationAPIGetEventAuthPath,
- httputil.MakeInternalAPI("GetEventAuth", func(req *http.Request) util.JSONResponse {
- var request getEventAuth
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- res, err := intAPI.GetEventAuth(req.Context(), request.S, request.RoomVersion, request.RoomID, request.EventID)
- if err != nil {
- ferr, ok := err.(*api.FederationClientError)
- if ok {
- request.Err = ferr
- } else {
- request.Err = &api.FederationClientError{
- Err: err.Error(),
- }
- }
- }
- request.Res = &res
- return util.JSONResponse{Code: http.StatusOK, JSON: request}
- }),
+ httputil.MakeInternalProxyAPI(
+ "FederationAPIGetEventAuth",
+ func(ctx context.Context, req *getEventAuth) (*gomatrixserverlib.RespEventAuth, error) {
+ res, err := intAPI.GetEventAuth(ctx, req.S, req.RoomVersion, req.RoomID, req.EventID)
+ return &res, federationClientError(err)
+ },
+ ),
)
+
internalAPIMux.Handle(
FederationAPIQueryServerKeysPath,
- httputil.MakeInternalAPI("QueryServerKeys", func(req *http.Request) util.JSONResponse {
- var request api.QueryServerKeysRequest
- var response api.QueryServerKeysResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := intAPI.QueryServerKeys(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+ httputil.MakeInternalRPCAPI("FederationAPIQueryServerKeys", intAPI.QueryServerKeys),
)
+
internalAPIMux.Handle(
FederationAPILookupServerKeysPath,
- httputil.MakeInternalAPI("LookupServerKeys", func(req *http.Request) util.JSONResponse {
- var request lookupServerKeys
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- res, err := intAPI.LookupServerKeys(req.Context(), request.S, request.KeyRequests)
- if err != nil {
- ferr, ok := err.(*api.FederationClientError)
- if ok {
- request.Err = ferr
- } else {
- request.Err = &api.FederationClientError{
- Err: err.Error(),
- }
- }
- }
- request.ServerKeys = res
- return util.JSONResponse{Code: http.StatusOK, JSON: request}
- }),
+ httputil.MakeInternalProxyAPI(
+ "FederationAPILookupServerKeys",
+ func(ctx context.Context, req *lookupServerKeys) (*[]gomatrixserverlib.ServerKeys, error) {
+ res, err := intAPI.LookupServerKeys(ctx, req.S, req.KeyRequests)
+ return &res, federationClientError(err)
+ },
+ ),
)
+
internalAPIMux.Handle(
FederationAPIEventRelationshipsPath,
- httputil.MakeInternalAPI("MSC2836EventRelationships", func(req *http.Request) util.JSONResponse {
- var request eventRelationships
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- res, err := intAPI.MSC2836EventRelationships(req.Context(), request.S, request.Req, request.RoomVer)
- if err != nil {
- ferr, ok := err.(*api.FederationClientError)
- if ok {
- request.Err = ferr
- } else {
- request.Err = &api.FederationClientError{
- Err: err.Error(),
- }
- }
- }
- request.Res = res
- return util.JSONResponse{Code: http.StatusOK, JSON: request}
- }),
+ httputil.MakeInternalProxyAPI(
+ "FederationAPIMSC2836EventRelationships",
+ func(ctx context.Context, req *eventRelationships) (*gomatrixserverlib.MSC2836EventRelationshipsResponse, error) {
+ res, err := intAPI.MSC2836EventRelationships(ctx, req.S, req.Req, req.RoomVer)
+ return &res, federationClientError(err)
+ },
+ ),
)
+
internalAPIMux.Handle(
FederationAPISpacesSummaryPath,
- httputil.MakeInternalAPI("MSC2946SpacesSummary", func(req *http.Request) util.JSONResponse {
- var request spacesReq
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- res, err := intAPI.MSC2946Spaces(req.Context(), request.S, request.RoomID, request.SuggestedOnly)
- if err != nil {
- ferr, ok := err.(*api.FederationClientError)
- if ok {
- request.Err = ferr
- } else {
- request.Err = &api.FederationClientError{
- Err: err.Error(),
- }
- }
- }
- request.Res = res
- return util.JSONResponse{Code: http.StatusOK, JSON: request}
- }),
+ httputil.MakeInternalProxyAPI(
+ "FederationAPIMSC2946SpacesSummary",
+ func(ctx context.Context, req *spacesReq) (*gomatrixserverlib.MSC2946SpacesResponse, error) {
+ res, err := intAPI.MSC2946Spaces(ctx, req.S, req.RoomID, req.SuggestedOnly)
+ return &res, federationClientError(err)
+ },
+ ),
)
+
+ // TODO: Look at this shape
internalAPIMux.Handle(FederationAPIQueryPublicKeyPath,
- httputil.MakeInternalAPI("queryPublicKeys", func(req *http.Request) util.JSONResponse {
+ httputil.MakeInternalAPI("FederationAPIQueryPublicKeys", func(req *http.Request) util.JSONResponse {
request := api.QueryPublicKeysRequest{}
response := api.QueryPublicKeysResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
@@ -394,8 +204,10 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
+
+ // TODO: Look at this shape
internalAPIMux.Handle(FederationAPIInputPublicKeyPath,
- httputil.MakeInternalAPI("inputPublicKeys", func(req *http.Request) util.JSONResponse {
+ httputil.MakeInternalAPI("FederationAPIInputPublicKeys", func(req *http.Request) util.JSONResponse {
request := api.InputPublicKeysRequest{}
response := api.InputPublicKeysResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
@@ -408,3 +220,18 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
}),
)
}
+
+func federationClientError(err error) error {
+ switch ferr := err.(type) {
+ case nil:
+ return nil
+ case api.FederationClientError:
+ return &ferr
+ case *api.FederationClientError:
+ return ferr
+ default:
+ return &api.FederationClientError{
+ Err: err.Error(),
+ }
+ }
+}
diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go
index b6edec5da..0d937ffaf 100644
--- a/federationapi/queue/destinationqueue.go
+++ b/federationapi/queue/destinationqueue.go
@@ -127,6 +127,7 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share
oq.destination, // the destination server name
receipt, // NIDs from federationapi_queue_json table
event.Type,
+ nil, // this will use the default expireEDUTypes map
); err != nil {
logrus.WithError(err).Errorf("failed to associate EDU with destination %q", oq.destination)
return
diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go
index 4c25c4ce6..88664fcf9 100644
--- a/federationapi/queue/queue.go
+++ b/federationapi/queue/queue.go
@@ -158,7 +158,7 @@ func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *d
oqs.queuesMutex.Lock()
defer oqs.queuesMutex.Unlock()
oq, ok := oqs.queues[destination]
- if !ok || oq != nil {
+ if !ok || oq == nil {
destinationQueueTotal.Inc()
oq = &destinationQueue{
queues: oqs,
diff --git a/federationapi/routing/devices.go b/federationapi/routing/devices.go
index 2f9da1f25..ce8b06b70 100644
--- a/federationapi/routing/devices.go
+++ b/federationapi/routing/devices.go
@@ -30,9 +30,11 @@ func GetUserDevices(
userID string,
) util.JSONResponse {
var res keyapi.QueryDeviceMessagesResponse
- keyAPI.QueryDeviceMessages(req.Context(), &keyapi.QueryDeviceMessagesRequest{
+ if err := keyAPI.QueryDeviceMessages(req.Context(), &keyapi.QueryDeviceMessagesRequest{
UserID: userID,
- }, &res)
+ }, &res); err != nil {
+ return util.ErrorResponse(err)
+ }
if res.Error != nil {
util.GetLogger(req.Context()).WithError(res.Error).Error("keyAPI.QueryDeviceMessages failed")
return jsonerror.InternalServerError()
@@ -47,7 +49,9 @@ func GetUserDevices(
for _, dev := range res.Devices {
sigReq.TargetIDs[userID] = append(sigReq.TargetIDs[userID], gomatrixserverlib.KeyID(dev.DeviceID))
}
- keyAPI.QuerySignatures(req.Context(), sigReq, sigRes)
+ if err := keyAPI.QuerySignatures(req.Context(), sigReq, sigRes); err != nil {
+ return jsonerror.InternalAPIError(req.Context(), err)
+ }
response := gomatrixserverlib.RespUserDevices{
UserID: userID,
diff --git a/federationapi/routing/invite.go b/federationapi/routing/invite.go
index cde87a0ac..4b795018c 100644
--- a/federationapi/routing/invite.go
+++ b/federationapi/routing/invite.go
@@ -26,7 +26,6 @@ import (
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
- "github.com/sirupsen/logrus"
)
// InviteV2 implements /_matrix/federation/v2/invite/{roomID}/{eventID}
@@ -144,7 +143,6 @@ func processInvite(
// Check that the event is signed by the server sending the request.
redacted, err := gomatrixserverlib.RedactEventJSON(event.JSON(), event.Version())
if err != nil {
- logrus.WithError(err).Errorf("XXX: invite.go")
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("The event JSON could not be redacted"),
diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go
index 41004cf51..b48eaf78e 100644
--- a/federationapi/routing/join.go
+++ b/federationapi/routing/join.go
@@ -21,13 +21,14 @@ import (
"sort"
"time"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/util"
+ "github.com/sirupsen/logrus"
+
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
- "github.com/matrix-org/gomatrixserverlib"
- "github.com/matrix-org/util"
- "github.com/sirupsen/logrus"
)
// MakeJoin implements the /make_join API
@@ -202,6 +203,14 @@ func SendJoin(
}
}
+ // Check that the event is from the server sending the request.
+ if event.Origin() != request.Origin() {
+ return util.JSONResponse{
+ Code: http.StatusForbidden,
+ JSON: jsonerror.Forbidden("The join must be sent by the server it originated on"),
+ }
+ }
+
// Check that a state key is provided.
if event.StateKey() == nil || event.StateKeyEquals("") {
return util.JSONResponse{
@@ -216,6 +225,22 @@ func SendJoin(
}
}
+ // Check that the sender belongs to the server that is sending us
+ // the request. By this point we've already asserted that the sender
+ // and the state key are equal so we don't need to check both.
+ var domain gomatrixserverlib.ServerName
+ if _, domain, err = gomatrixserverlib.SplitID('@', event.Sender()); err != nil {
+ return util.JSONResponse{
+ Code: http.StatusForbidden,
+ JSON: jsonerror.Forbidden("The sender of the join is invalid"),
+ }
+ } else if domain != request.Origin() {
+ return util.JSONResponse{
+ Code: http.StatusForbidden,
+ JSON: jsonerror.Forbidden("The sender of the join must belong to the origin server"),
+ }
+ }
+
// Check that the room ID is correct.
if event.RoomID() != roomID {
return util.JSONResponse{
@@ -242,14 +267,6 @@ func SendJoin(
}
}
- // Check that the event is from the server sending the request.
- if event.Origin() != request.Origin() {
- return util.JSONResponse{
- Code: http.StatusForbidden,
- JSON: jsonerror.Forbidden("The join must be sent by the server it originated on"),
- }
- }
-
// Check that this is in fact a join event
membership, err := event.Membership()
if err != nil {
@@ -375,7 +392,7 @@ func SendJoin(
// the room, so set SendAsServer to cfg.Matrix.ServerName
if !alreadyJoined {
var response api.InputRoomEventsResponse
- rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{
+ if err := rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{
InputRoomEvents: []api.InputRoomEvent{
{
Kind: api.KindNew,
@@ -384,7 +401,9 @@ func SendJoin(
TransactionID: nil,
},
},
- }, &response)
+ }, &response); err != nil {
+ return jsonerror.InternalAPIError(httpReq.Context(), err)
+ }
if response.ErrMsg != "" {
util.GetLogger(httpReq.Context()).WithField(logrus.ErrorKey, response.ErrMsg).Error("SendEvents failed")
if response.NotAllowed {
@@ -419,13 +438,13 @@ func SendJoin(
// a restricted room join. If the room version does not support restricted
// joins then this function returns with no side effects. This returns three
// values:
-// * an optional JSON response body (i.e. M_UNABLE_TO_AUTHORISE_JOIN) which
-// should always be sent back to the client if one is specified
-// * a user ID of an authorising user, typically a user that has power to
-// issue invites in the room, if one has been found
-// * an error if there was a problem finding out if this was allowable,
-// like if the room version isn't known or a problem happened talking to
-// the roomserver
+// - an optional JSON response body (i.e. M_UNABLE_TO_AUTHORISE_JOIN) which
+// should always be sent back to the client if one is specified
+// - a user ID of an authorising user, typically a user that has power to
+// issue invites in the room, if one has been found
+// - an error if there was a problem finding out if this was allowable,
+// like if the room version isn't known or a problem happened talking to
+// the roomserver
func checkRestrictedJoin(
httpReq *http.Request,
rsAPI api.FederationRoomserverAPI,
diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go
index b1a9b6710..b03d4c1d6 100644
--- a/federationapi/routing/keys.go
+++ b/federationapi/routing/keys.go
@@ -19,7 +19,7 @@ import (
"net/http"
"time"
- "github.com/matrix-org/dendrite/clientapi/httputil"
+ clienthttputil "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/keyserver/api"
@@ -61,9 +61,11 @@ func QueryDeviceKeys(
}
var queryRes api.QueryKeysResponse
- keyAPI.QueryKeys(httpReq.Context(), &api.QueryKeysRequest{
+ if err := keyAPI.QueryKeys(httpReq.Context(), &api.QueryKeysRequest{
UserToDevices: qkr.DeviceKeys,
- }, &queryRes)
+ }, &queryRes); err != nil {
+ return jsonerror.InternalAPIError(httpReq.Context(), err)
+ }
if queryRes.Error != nil {
util.GetLogger(httpReq.Context()).WithError(queryRes.Error).Error("Failed to QueryKeys")
return jsonerror.InternalServerError()
@@ -113,9 +115,11 @@ func ClaimOneTimeKeys(
}
var claimRes api.PerformClaimKeysResponse
- keyAPI.PerformClaimKeys(httpReq.Context(), &api.PerformClaimKeysRequest{
+ if err := keyAPI.PerformClaimKeys(httpReq.Context(), &api.PerformClaimKeysRequest{
OneTimeKeys: cor.OneTimeKeys,
- }, &claimRes)
+ }, &claimRes); err != nil {
+ return jsonerror.InternalAPIError(httpReq.Context(), err)
+ }
if claimRes.Error != nil {
util.GetLogger(httpReq.Context()).WithError(claimRes.Error).Error("Failed to PerformClaimKeys")
return jsonerror.InternalServerError()
@@ -184,7 +188,7 @@ func NotaryKeys(
) util.JSONResponse {
if req == nil {
req = &gomatrixserverlib.PublicKeyNotaryLookupRequest{}
- if reqErr := httputil.UnmarshalJSONRequest(httpReq, &req); reqErr != nil {
+ if reqErr := clienthttputil.UnmarshalJSONRequest(httpReq, &req); reqErr != nil {
return *reqErr
}
}
diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go
index dbaf68e5b..8e43ce959 100644
--- a/federationapi/routing/leave.go
+++ b/federationapi/routing/leave.go
@@ -277,7 +277,7 @@ func SendLeave(
// We are responsible for notifying other servers that the user has left
// the room, so set SendAsServer to cfg.Matrix.ServerName
var response api.InputRoomEventsResponse
- rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{
+ if err := rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{
InputRoomEvents: []api.InputRoomEvent{
{
Kind: api.KindNew,
@@ -286,7 +286,9 @@ func SendLeave(
TransactionID: nil,
},
},
- }, &response)
+ }, &response); err != nil {
+ return jsonerror.InternalAPIError(httpReq.Context(), err)
+ }
if response.ErrMsg != "" {
util.GetLogger(httpReq.Context()).WithField(logrus.ErrorKey, response.ErrMsg).WithField("not_allowed", response.NotAllowed).Error("producer.SendEvents failed")
diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go
index 54e14d014..a55e7ce6d 100644
--- a/federationapi/routing/send.go
+++ b/federationapi/routing/send.go
@@ -469,7 +469,9 @@ func (t *txnReq) processSigningKeyUpdate(ctx context.Context, e gomatrixserverli
UserID: updatePayload.UserID,
}
uploadRes := &keyapi.PerformUploadDeviceKeysResponse{}
- t.keyAPI.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes)
+ if err := t.keyAPI.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes); err != nil {
+ return err
+ }
if uploadRes.Error != nil {
return uploadRes.Error
}
diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go
index a111580c7..1c796f542 100644
--- a/federationapi/routing/send_test.go
+++ b/federationapi/routing/send_test.go
@@ -64,11 +64,12 @@ func (t *testRoomserverAPI) InputRoomEvents(
ctx context.Context,
request *api.InputRoomEventsRequest,
response *api.InputRoomEventsResponse,
-) {
+) error {
t.inputRoomEvents = append(t.inputRoomEvents, request.InputRoomEvents...)
for _, ire := range request.InputRoomEvents {
fmt.Println("InputRoomEvents: ", ire.Event.EventID())
}
+ return nil
}
// Query the latest events and state for a room from the room server.
diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go
index 29254948b..b8109b432 100644
--- a/federationapi/storage/interface.go
+++ b/federationapi/storage/interface.go
@@ -16,6 +16,7 @@ package storage
import (
"context"
+ "time"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
"github.com/matrix-org/dendrite/federationapi/types"
@@ -38,7 +39,7 @@ type Database interface {
GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error)
AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error
- AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt, eduType string) error
+ AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error
CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
@@ -70,4 +71,6 @@ type Database interface {
// Query the notary for the server keys for the given server. If `optKeyIDs` is not empty, multiple server keys may be returned (between 1 - len(optKeyIDs))
// such that the combination of all server keys will include all the `optKeyIDs`.
GetNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error)
+ // DeleteExpiredEDUs cleans up expired EDUs
+ DeleteExpiredEDUs(ctx context.Context) error
}
diff --git a/federationapi/storage/postgres/deltas/2021020411080000_rooms.go b/federationapi/storage/postgres/deltas/2021020411080000_rooms.go
index cc4bdadfd..fc894846d 100644
--- a/federationapi/storage/postgres/deltas/2021020411080000_rooms.go
+++ b/federationapi/storage/postgres/deltas/2021020411080000_rooms.go
@@ -15,23 +15,13 @@
package deltas
import (
+ "context"
"database/sql"
"fmt"
-
- "github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/pressly/goose"
)
-func LoadFromGoose() {
- goose.AddMigration(UpRemoveRoomsTable, DownRemoveRoomsTable)
-}
-
-func LoadRemoveRoomsTable(m *sqlutil.Migrations) {
- m.AddMigration(UpRemoveRoomsTable, DownRemoveRoomsTable)
-}
-
-func UpRemoveRoomsTable(tx *sql.Tx) error {
- _, err := tx.Exec(`
+func UpRemoveRoomsTable(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
DROP TABLE IF EXISTS federationsender_rooms;
`)
if err != nil {
diff --git a/federationapi/storage/postgres/deltas/2022042812473400_addexpiresat.go b/federationapi/storage/postgres/deltas/2022042812473400_addexpiresat.go
new file mode 100644
index 000000000..53a7a025e
--- /dev/null
+++ b/federationapi/storage/postgres/deltas/2022042812473400_addexpiresat.go
@@ -0,0 +1,44 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package deltas
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "time"
+
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+func UpAddexpiresat(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, "ALTER TABLE federationsender_queue_edus ADD COLUMN IF NOT EXISTS expires_at BIGINT NOT NULL DEFAULT 0;")
+ if err != nil {
+ return fmt.Errorf("failed to execute upgrade: %w", err)
+ }
+ _, err = tx.ExecContext(ctx, "UPDATE federationsender_queue_edus SET expires_at = $1 WHERE edu_type != 'm.direct_to_device'", gomatrixserverlib.AsTimestamp(time.Now().Add(time.Hour*24)))
+ if err != nil {
+ return fmt.Errorf("failed to update queue_edus: %w", err)
+ }
+ return nil
+}
+
+func DownAddexpiresat(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, "ALTER TABLE federationsender_queue_edus DROP COLUMN expires_at;")
+ if err != nil {
+ return fmt.Errorf("failed to execute downgrade: %w", err)
+ }
+ return nil
+}
diff --git a/federationapi/storage/postgres/queue_edus_table.go b/federationapi/storage/postgres/queue_edus_table.go
index 1fedf0ef1..d6507e13b 100644
--- a/federationapi/storage/postgres/queue_edus_table.go
+++ b/federationapi/storage/postgres/queue_edus_table.go
@@ -19,9 +19,11 @@ import (
"database/sql"
"github.com/lib/pq"
+ "github.com/matrix-org/gomatrixserverlib"
+
+ "github.com/matrix-org/dendrite/federationapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/gomatrixserverlib"
)
const queueEDUsSchema = `
@@ -31,7 +33,9 @@ CREATE TABLE IF NOT EXISTS federationsender_queue_edus (
-- The domain part of the user ID the EDU event is for.
server_name TEXT NOT NULL,
-- The JSON NID from the federationsender_queue_edus_json table.
- json_nid BIGINT NOT NULL
+ json_nid BIGINT NOT NULL,
+ -- The expiry time of this edu, if any.
+ expires_at BIGINT NOT NULL DEFAULT 0
);
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx
@@ -43,8 +47,8 @@ CREATE INDEX IF NOT EXISTS federationsender_queue_edus_server_name_idx
`
const insertQueueEDUSQL = "" +
- "INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid)" +
- " VALUES ($1, $2, $3)"
+ "INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid, expires_at)" +
+ " VALUES ($1, $2, $3, $4)"
const deleteQueueEDUSQL = "" +
"DELETE FROM federationsender_queue_edus WHERE server_name = $1 AND json_nid = ANY($2)"
@@ -65,6 +69,12 @@ const selectQueueEDUCountSQL = "" +
const selectQueueServerNamesSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_queue_edus"
+const selectExpiredEDUsSQL = "" +
+ "SELECT DISTINCT json_nid FROM federationsender_queue_edus WHERE expires_at > 0 AND expires_at <= $1"
+
+const deleteExpiredEDUsSQL = "" +
+ "DELETE FROM federationsender_queue_edus WHERE expires_at > 0 AND expires_at <= $1"
+
type queueEDUsStatements struct {
db *sql.DB
insertQueueEDUStmt *sql.Stmt
@@ -73,6 +83,8 @@ type queueEDUsStatements struct {
selectQueueEDUReferenceJSONCountStmt *sql.Stmt
selectQueueEDUCountStmt *sql.Stmt
selectQueueEDUServerNamesStmt *sql.Stmt
+ selectExpiredEDUsStmt *sql.Stmt
+ deleteExpiredEDUsStmt *sql.Stmt
}
func NewPostgresQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) {
@@ -81,27 +93,34 @@ func NewPostgresQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) {
}
_, err = s.db.Exec(queueEDUsSchema)
if err != nil {
- return
+ return s, err
}
- if s.insertQueueEDUStmt, err = s.db.Prepare(insertQueueEDUSQL); err != nil {
- return
+
+ m := sqlutil.NewMigrator(db)
+ m.AddMigrations(
+ sqlutil.Migration{
+ Version: "federationapi: add expiresat column",
+ Up: deltas.UpAddexpiresat,
+ },
+ )
+ if err := m.Up(context.Background()); err != nil {
+ return s, err
}
- if s.deleteQueueEDUStmt, err = s.db.Prepare(deleteQueueEDUSQL); err != nil {
- return
- }
- if s.selectQueueEDUStmt, err = s.db.Prepare(selectQueueEDUSQL); err != nil {
- return
- }
- if s.selectQueueEDUReferenceJSONCountStmt, err = s.db.Prepare(selectQueueEDUReferenceJSONCountSQL); err != nil {
- return
- }
- if s.selectQueueEDUCountStmt, err = s.db.Prepare(selectQueueEDUCountSQL); err != nil {
- return
- }
- if s.selectQueueEDUServerNamesStmt, err = s.db.Prepare(selectQueueServerNamesSQL); err != nil {
- return
- }
- return
+
+ return s, nil
+}
+
+func (s *queueEDUsStatements) Prepare() error {
+ return sqlutil.StatementList{
+ {&s.insertQueueEDUStmt, insertQueueEDUSQL},
+ {&s.deleteQueueEDUStmt, deleteQueueEDUSQL},
+ {&s.selectQueueEDUStmt, selectQueueEDUSQL},
+ {&s.selectQueueEDUReferenceJSONCountStmt, selectQueueEDUReferenceJSONCountSQL},
+ {&s.selectQueueEDUCountStmt, selectQueueEDUCountSQL},
+ {&s.selectQueueEDUServerNamesStmt, selectQueueServerNamesSQL},
+ {&s.selectExpiredEDUsStmt, selectExpiredEDUsSQL},
+ {&s.deleteExpiredEDUsStmt, deleteExpiredEDUsSQL},
+ }.Prepare(s.db)
}
func (s *queueEDUsStatements) InsertQueueEDU(
@@ -110,6 +129,7 @@ func (s *queueEDUsStatements) InsertQueueEDU(
eduType string,
serverName gomatrixserverlib.ServerName,
nid int64,
+ expiresAt gomatrixserverlib.Timestamp,
) error {
stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt)
_, err := stmt.ExecContext(
@@ -117,6 +137,7 @@ func (s *queueEDUsStatements) InsertQueueEDU(
eduType, // the EDU type
serverName, // destination server name
nid, // JSON blob NID
+ expiresAt, // timestamp of expiry
)
return err
}
@@ -150,7 +171,7 @@ func (s *queueEDUsStatements) SelectQueueEDUs(
}
result = append(result, nid)
}
- return result, nil
+ return result, rows.Err()
}
func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount(
@@ -200,3 +221,33 @@ func (s *queueEDUsStatements) SelectQueueEDUServerNames(
return result, rows.Err()
}
+
+func (s *queueEDUsStatements) SelectExpiredEDUs(
+ ctx context.Context, txn *sql.Tx,
+ expiredBefore gomatrixserverlib.Timestamp,
+) ([]int64, error) {
+ stmt := sqlutil.TxStmt(txn, s.selectExpiredEDUsStmt)
+ rows, err := stmt.QueryContext(ctx, expiredBefore)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "SelectExpiredEDUs: rows.close() failed")
+ var result []int64
+ var nid int64
+ for rows.Next() {
+ if err = rows.Scan(&nid); err != nil {
+ return nil, err
+ }
+ result = append(result, nid)
+ }
+ return result, rows.Err()
+}
+
+func (s *queueEDUsStatements) DeleteExpiredEDUs(
+ ctx context.Context, txn *sql.Tx,
+ expiredBefore gomatrixserverlib.Timestamp,
+) error {
+ stmt := sqlutil.TxStmt(txn, s.deleteExpiredEDUsStmt)
+ _, err := stmt.ExecContext(ctx, expiredBefore)
+ return err
+}
diff --git a/federationapi/storage/postgres/storage.go b/federationapi/storage/postgres/storage.go
index 9863afb2b..6e208d096 100644
--- a/federationapi/storage/postgres/storage.go
+++ b/federationapi/storage/postgres/storage.go
@@ -82,9 +82,16 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil {
return nil, err
}
- m := sqlutil.NewMigrations()
- deltas.LoadRemoveRoomsTable(m)
- if err = m.RunDeltas(d.db, dbProperties); err != nil {
+ m := sqlutil.NewMigrator(d.db)
+ m.AddMigrations(sqlutil.Migration{
+ Version: "federationsender: drop federationsender_rooms",
+ Up: deltas.UpRemoveRoomsTable,
+ })
+ err = m.Up(base.Context())
+ if err != nil {
+ return nil, err
+ }
+ if err = queueEDUs.Prepare(); err != nil {
return nil, err
}
d.Database = shared.Database{
diff --git a/federationapi/storage/shared/storage_edus.go b/federationapi/storage/shared/storage_edus.go
index 02a23338f..b62e5d9c5 100644
--- a/federationapi/storage/shared/storage_edus.go
+++ b/federationapi/storage/shared/storage_edus.go
@@ -20,10 +20,21 @@ import (
"encoding/json"
"errors"
"fmt"
+ "time"
"github.com/matrix-org/gomatrixserverlib"
)
+// defaultExpiry for EDUs if not listed below
+var defaultExpiry = time.Hour * 24
+
+// defaultExpireEDUTypes contains EDUs which can/should be expired after a given time
+// if the target server isn't reachable for some reason.
+var defaultExpireEDUTypes = map[string]time.Duration{
+ gomatrixserverlib.MTyping: time.Minute,
+ gomatrixserverlib.MPresence: time.Minute * 10,
+}
+
// AssociateEDUWithDestination creates an association that the
// destination queues will use to determine which JSON blobs to send
// to which servers.
@@ -32,7 +43,21 @@ func (d *Database) AssociateEDUWithDestination(
serverName gomatrixserverlib.ServerName,
receipt *Receipt,
eduType string,
+ expireEDUTypes map[string]time.Duration,
) error {
+ if expireEDUTypes == nil {
+ expireEDUTypes = defaultExpireEDUTypes
+ }
+ expiresAt := gomatrixserverlib.AsTimestamp(time.Now().Add(defaultExpiry))
+ if duration, ok := expireEDUTypes[eduType]; ok {
+ // Keep EDUs for at least x minutes before deleting them
+ expiresAt = gomatrixserverlib.AsTimestamp(time.Now().Add(duration))
+ }
+ // We forcibly set m.direct_to_device events to 0, as we always want them
+ // to be delivered. (required for E2EE)
+ if eduType == gomatrixserverlib.MDirectToDevice {
+ expiresAt = 0
+ }
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if err := d.FederationQueueEDUs.InsertQueueEDU(
ctx, // context
@@ -40,6 +65,7 @@ func (d *Database) AssociateEDUWithDestination(
eduType, // EDU type for coalescing
serverName, // destination server name
receipt.nid, // NID from the federationapi_queue_json table
+ expiresAt, // The timestamp this EDU will expire
); err != nil {
return fmt.Errorf("InsertQueueEDU: %w", err)
}
@@ -150,3 +176,26 @@ func (d *Database) GetPendingEDUServerNames(
) ([]gomatrixserverlib.ServerName, error) {
return d.FederationQueueEDUs.SelectQueueEDUServerNames(ctx, nil)
}
+
+// DeleteExpiredEDUs deletes expired EDUs
+func (d *Database) DeleteExpiredEDUs(ctx context.Context) error {
+ return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ expiredBefore := gomatrixserverlib.AsTimestamp(time.Now())
+ jsonNIDs, err := d.FederationQueueEDUs.SelectExpiredEDUs(ctx, txn, expiredBefore)
+ if err != nil {
+ return err
+ }
+ if len(jsonNIDs) == 0 {
+ return nil
+ }
+ for i := range jsonNIDs {
+ d.Cache.EvictFederationQueuedEDU(jsonNIDs[i])
+ }
+
+ if err = d.FederationQueueJSON.DeleteQueueJSON(ctx, txn, jsonNIDs); err != nil {
+ return err
+ }
+
+ return d.FederationQueueEDUs.DeleteExpiredEDUs(ctx, txn, expiredBefore)
+ })
+}
diff --git a/federationapi/storage/sqlite3/deltas/2021020411080000_rooms.go b/federationapi/storage/sqlite3/deltas/2021020411080000_rooms.go
index cc4bdadfd..fc894846d 100644
--- a/federationapi/storage/sqlite3/deltas/2021020411080000_rooms.go
+++ b/federationapi/storage/sqlite3/deltas/2021020411080000_rooms.go
@@ -15,23 +15,13 @@
package deltas
import (
+ "context"
"database/sql"
"fmt"
-
- "github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/pressly/goose"
)
-func LoadFromGoose() {
- goose.AddMigration(UpRemoveRoomsTable, DownRemoveRoomsTable)
-}
-
-func LoadRemoveRoomsTable(m *sqlutil.Migrations) {
- m.AddMigration(UpRemoveRoomsTable, DownRemoveRoomsTable)
-}
-
-func UpRemoveRoomsTable(tx *sql.Tx) error {
- _, err := tx.Exec(`
+func UpRemoveRoomsTable(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
DROP TABLE IF EXISTS federationsender_rooms;
`)
if err != nil {
diff --git a/federationapi/storage/sqlite3/deltas/2022042812473400_addexpiresat.go b/federationapi/storage/sqlite3/deltas/2022042812473400_addexpiresat.go
new file mode 100644
index 000000000..c5030163b
--- /dev/null
+++ b/federationapi/storage/sqlite3/deltas/2022042812473400_addexpiresat.go
@@ -0,0 +1,68 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package deltas
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "time"
+
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+func UpAddexpiresat(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, "ALTER TABLE federationsender_queue_edus RENAME TO federationsender_queue_edus_old;")
+ if err != nil {
+ return fmt.Errorf("failed to rename table: %w", err)
+ }
+
+ _, err = tx.ExecContext(ctx, `
+CREATE TABLE IF NOT EXISTS federationsender_queue_edus (
+ edu_type TEXT NOT NULL,
+ server_name TEXT NOT NULL,
+ json_nid BIGINT NOT NULL,
+ expires_at BIGINT NOT NULL DEFAULT 0
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx
+ ON federationsender_queue_edus (json_nid, server_name);
+`)
+ if err != nil {
+ return fmt.Errorf("failed to create new table: %w", err)
+ }
+ _, err = tx.ExecContext(ctx, `
+INSERT
+ INTO federationsender_queue_edus (
+ edu_type, server_name, json_nid, expires_at
+ ) SELECT edu_type, server_name, json_nid, 0 FROM federationsender_queue_edus_old;
+`)
+ if err != nil {
+ return fmt.Errorf("failed to update queue_edus: %w", err)
+ }
+ _, err = tx.ExecContext(ctx, "UPDATE federationsender_queue_edus SET expires_at = $1 WHERE edu_type != 'm.direct_to_device'", gomatrixserverlib.AsTimestamp(time.Now().Add(time.Hour*24)))
+ if err != nil {
+ return fmt.Errorf("failed to update queue_edus: %w", err)
+ }
+ return nil
+}
+
+func DownAddexpiresat(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, "ALTER TABLE federationsender_queue_edus DROP COLUMN expires_at;")
+ if err != nil {
+ return fmt.Errorf("failed to rename table: %w", err)
+ }
+ return nil
+}
diff --git a/federationapi/storage/sqlite3/queue_edus_table.go b/federationapi/storage/sqlite3/queue_edus_table.go
index f4c84f094..8e7e7901f 100644
--- a/federationapi/storage/sqlite3/queue_edus_table.go
+++ b/federationapi/storage/sqlite3/queue_edus_table.go
@@ -20,9 +20,11 @@ import (
"fmt"
"strings"
+ "github.com/matrix-org/gomatrixserverlib"
+
+ "github.com/matrix-org/dendrite/federationapi/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/gomatrixserverlib"
)
const queueEDUsSchema = `
@@ -32,7 +34,9 @@ CREATE TABLE IF NOT EXISTS federationsender_queue_edus (
-- The domain part of the user ID the EDU event is for.
server_name TEXT NOT NULL,
-- The JSON NID from the federationsender_queue_edus_json table.
- json_nid BIGINT NOT NULL
+ json_nid BIGINT NOT NULL,
+ -- The expiry time of this edu, if any.
+ expires_at BIGINT NOT NULL DEFAULT 0
);
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx
@@ -44,8 +48,8 @@ CREATE INDEX IF NOT EXISTS federationsender_queue_edus_server_name_idx
`
const insertQueueEDUSQL = "" +
- "INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid)" +
- " VALUES ($1, $2, $3)"
+ "INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid, expires_at)" +
+ " VALUES ($1, $2, $3, $4)"
const deleteQueueEDUsSQL = "" +
"DELETE FROM federationsender_queue_edus WHERE server_name = $1 AND json_nid IN ($2)"
@@ -66,13 +70,22 @@ const selectQueueEDUCountSQL = "" +
const selectQueueServerNamesSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_queue_edus"
+const selectExpiredEDUsSQL = "" +
+ "SELECT DISTINCT json_nid FROM federationsender_queue_edus WHERE expires_at > 0 AND expires_at <= $1"
+
+const deleteExpiredEDUsSQL = "" +
+ "DELETE FROM federationsender_queue_edus WHERE expires_at > 0 AND expires_at <= $1"
+
type queueEDUsStatements struct {
- db *sql.DB
- insertQueueEDUStmt *sql.Stmt
+ db *sql.DB
+ insertQueueEDUStmt *sql.Stmt
+ // deleteQueueEDUStmt *sql.Stmt - prepared at runtime due to variadic
selectQueueEDUStmt *sql.Stmt
selectQueueEDUReferenceJSONCountStmt *sql.Stmt
selectQueueEDUCountStmt *sql.Stmt
selectQueueEDUServerNamesStmt *sql.Stmt
+ selectExpiredEDUsStmt *sql.Stmt
+ deleteExpiredEDUsStmt *sql.Stmt
}
func NewSQLiteQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) {
@@ -81,24 +94,33 @@ func NewSQLiteQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) {
}
_, err = db.Exec(queueEDUsSchema)
if err != nil {
- return
+ return s, err
}
- if s.insertQueueEDUStmt, err = db.Prepare(insertQueueEDUSQL); err != nil {
- return
+
+ m := sqlutil.NewMigrator(db)
+ m.AddMigrations(
+ sqlutil.Migration{
+ Version: "federationapi: add expiresat column",
+ Up: deltas.UpAddexpiresat,
+ },
+ )
+ if err := m.Up(context.Background()); err != nil {
+ return s, err
}
- if s.selectQueueEDUStmt, err = db.Prepare(selectQueueEDUSQL); err != nil {
- return
- }
- if s.selectQueueEDUReferenceJSONCountStmt, err = db.Prepare(selectQueueEDUReferenceJSONCountSQL); err != nil {
- return
- }
- if s.selectQueueEDUCountStmt, err = db.Prepare(selectQueueEDUCountSQL); err != nil {
- return
- }
- if s.selectQueueEDUServerNamesStmt, err = db.Prepare(selectQueueServerNamesSQL); err != nil {
- return
- }
- return
+
+ return s, nil
+}
+
+func (s *queueEDUsStatements) Prepare() error {
+ return sqlutil.StatementList{
+ {&s.insertQueueEDUStmt, insertQueueEDUSQL},
+ {&s.selectQueueEDUStmt, selectQueueEDUSQL},
+ {&s.selectQueueEDUReferenceJSONCountStmt, selectQueueEDUReferenceJSONCountSQL},
+ {&s.selectQueueEDUCountStmt, selectQueueEDUCountSQL},
+ {&s.selectQueueEDUServerNamesStmt, selectQueueServerNamesSQL},
+ {&s.selectExpiredEDUsStmt, selectExpiredEDUsSQL},
+ {&s.deleteExpiredEDUsStmt, deleteExpiredEDUsSQL},
+ }.Prepare(s.db)
}
func (s *queueEDUsStatements) InsertQueueEDU(
@@ -107,6 +129,7 @@ func (s *queueEDUsStatements) InsertQueueEDU(
eduType string,
serverName gomatrixserverlib.ServerName,
nid int64,
+ expiresAt gomatrixserverlib.Timestamp,
) error {
stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt)
_, err := stmt.ExecContext(
@@ -114,6 +137,7 @@ func (s *queueEDUsStatements) InsertQueueEDU(
eduType, // the EDU type
serverName, // destination server name
nid, // JSON blob NID
+ expiresAt, // timestamp of expiry
)
return err
}
@@ -159,7 +183,7 @@ func (s *queueEDUsStatements) SelectQueueEDUs(
}
result = append(result, nid)
}
- return result, nil
+ return result, rows.Err()
}
func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount(
@@ -209,3 +233,33 @@ func (s *queueEDUsStatements) SelectQueueEDUServerNames(
return result, rows.Err()
}
+
+func (s *queueEDUsStatements) SelectExpiredEDUs(
+ ctx context.Context, txn *sql.Tx,
+ expiredBefore gomatrixserverlib.Timestamp,
+) ([]int64, error) {
+ stmt := sqlutil.TxStmt(txn, s.selectExpiredEDUsStmt)
+ rows, err := stmt.QueryContext(ctx, expiredBefore)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "SelectExpiredEDUs: rows.close() failed")
+ var result []int64
+ var nid int64
+ for rows.Next() {
+ if err = rows.Scan(&nid); err != nil {
+ return nil, err
+ }
+ result = append(result, nid)
+ }
+ return result, rows.Err()
+}
+
+func (s *queueEDUsStatements) DeleteExpiredEDUs(
+ ctx context.Context, txn *sql.Tx,
+ expiredBefore gomatrixserverlib.Timestamp,
+) error {
+ stmt := sqlutil.TxStmt(txn, s.deleteExpiredEDUsStmt)
+ _, err := stmt.ExecContext(ctx, expiredBefore)
+ return err
+}
diff --git a/federationapi/storage/sqlite3/storage.go b/federationapi/storage/sqlite3/storage.go
index 7d0cee90e..c89cb6bea 100644
--- a/federationapi/storage/sqlite3/storage.go
+++ b/federationapi/storage/sqlite3/storage.go
@@ -81,9 +81,16 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil {
return nil, err
}
- m := sqlutil.NewMigrations()
- deltas.LoadRemoveRoomsTable(m)
- if err = m.RunDeltas(d.db, dbProperties); err != nil {
+ m := sqlutil.NewMigrator(d.db)
+ m.AddMigrations(sqlutil.Migration{
+ Version: "federationsender: drop federationsender_rooms",
+ Up: deltas.UpRemoveRoomsTable,
+ })
+ err = m.Up(base.Context())
+ if err != nil {
+ return nil, err
+ }
+ if err = queueEDUs.Prepare(); err != nil {
return nil, err
}
d.Database = shared.Database{
diff --git a/federationapi/storage/storage_test.go b/federationapi/storage/storage_test.go
new file mode 100644
index 000000000..7eba2cbee
--- /dev/null
+++ b/federationapi/storage/storage_test.go
@@ -0,0 +1,81 @@
+package storage_test
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/stretchr/testify/assert"
+
+ "github.com/matrix-org/dendrite/federationapi/storage"
+ "github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/dendrite/test"
+ "github.com/matrix-org/dendrite/test/testrig"
+)
+
+func mustCreateFederationDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
+ b, baseClose := testrig.CreateBaseDendrite(t, dbType)
+ connStr, dbClose := test.PrepareDBConnectionString(t, dbType)
+ db, err := storage.NewDatabase(b, &config.DatabaseOptions{
+ ConnectionString: config.DataSource(connStr),
+ }, b.Caches, b.Cfg.Global.ServerName)
+ if err != nil {
+ t.Fatalf("NewDatabase returned %s", err)
+ }
+ return db, func() {
+ dbClose()
+ baseClose()
+ }
+}
+
+func TestExpireEDUs(t *testing.T) {
+ var expireEDUTypes = map[string]time.Duration{
+ gomatrixserverlib.MReceipt: time.Millisecond,
+ }
+
+ ctx := context.Background()
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, close := mustCreateFederationDatabase(t, dbType)
+ defer close()
+ // insert some data
+ for i := 0; i < 100; i++ {
+ receipt, err := db.StoreJSON(ctx, "{}")
+ assert.NoError(t, err)
+
+ err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, gomatrixserverlib.MReceipt, expireEDUTypes)
+ assert.NoError(t, err)
+ }
+ // add data without expiry
+ receipt, err := db.StoreJSON(ctx, "{}")
+ assert.NoError(t, err)
+
+ // m.read_marker gets the default expiry of 24h, so won't be deleted further down in this test
+ err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, "m.read_marker", expireEDUTypes)
+ assert.NoError(t, err)
+
+ // Delete expired EDUs
+ err = db.DeleteExpiredEDUs(ctx)
+ assert.NoError(t, err)
+
+ // verify the data is gone
+ data, err := db.GetPendingEDUs(ctx, "localhost", 100)
+ assert.NoError(t, err)
+ assert.Equal(t, 1, len(data))
+
+ // check that m.direct_to_device is never expired
+ receipt, err = db.StoreJSON(ctx, "{}")
+ assert.NoError(t, err)
+
+ err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, gomatrixserverlib.MDirectToDevice, expireEDUTypes)
+ assert.NoError(t, err)
+
+ err = db.DeleteExpiredEDUs(ctx)
+ assert.NoError(t, err)
+
+ // We should get two EDUs, the m.read_marker and the m.direct_to_device
+ data, err = db.GetPendingEDUs(ctx, "localhost", 100)
+ assert.NoError(t, err)
+ assert.Equal(t, 2, len(data))
+ })
+}
diff --git a/federationapi/storage/tables/interface.go b/federationapi/storage/tables/interface.go
index 19357393d..3c116a1d0 100644
--- a/federationapi/storage/tables/interface.go
+++ b/federationapi/storage/tables/interface.go
@@ -34,12 +34,15 @@ type FederationQueuePDUs interface {
}
type FederationQueueEDUs interface {
- InsertQueueEDU(ctx context.Context, txn *sql.Tx, eduType string, serverName gomatrixserverlib.ServerName, nid int64) error
+ InsertQueueEDU(ctx context.Context, txn *sql.Tx, eduType string, serverName gomatrixserverlib.ServerName, nid int64, expiresAt gomatrixserverlib.Timestamp) error
DeleteQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error
SelectQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error)
SelectQueueEDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error)
SelectQueueEDUCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error)
SelectQueueEDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error)
+ SelectExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) ([]int64, error)
+ DeleteExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) error
+ Prepare() error
}
type FederationQueueJSON interface {
diff --git a/go.mod b/go.mod
index 2a2a037c1..79ef5c3e8 100644
--- a/go.mod
+++ b/go.mod
@@ -1,9 +1,5 @@
module github.com/matrix-org/dendrite
-replace github.com/nats-io/nats-server/v2 => github.com/neilalexander/nats-server/v2 v2.8.3-0.20220513095553-73a9a246d34f
-
-replace github.com/nats-io/nats.go => github.com/neilalexander/nats.go v1.13.1-0.20220621084451-ac518c356673
-
require (
github.com/Arceliar/ironwood v0.0.0-20220306165321-319147a02d98
github.com/Arceliar/phony v0.0.0-20210209235338-dde1a8dca979
@@ -25,21 +21,20 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
- github.com/matrix-org/gomatrixserverlib v0.0.0-20220711125303-3bb2e997a44c
- github.com/matrix-org/pinecone v0.0.0-20220708135211-1ce778fcde6a
+ github.com/matrix-org/gomatrixserverlib v0.0.0-20220801083850-5ff38e2c2839
+ github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.13
- github.com/nats-io/nats-server/v2 v2.7.4-0.20220309205833-773636c1c5bb
- github.com/nats-io/nats.go v1.14.0
+ github.com/nats-io/nats-server/v2 v2.8.5-0.20220731184415-903a06a5b4ee
+ github.com/nats-io/nats.go v1.16.1-0.20220731182438-87bbea85922b
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
github.com/ngrok/sqlmw v0.0.0-20220520173518-97c9c04efc79
github.com/opentracing/opentracing-go v1.2.0
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pkg/errors v0.9.1
- github.com/pressly/goose v2.7.0+incompatible
github.com/prometheus/client_golang v1.12.2
- github.com/sirupsen/logrus v1.8.1
+ github.com/sirupsen/logrus v1.9.0
github.com/stretchr/testify v1.7.1
github.com/tidwall/gjson v1.14.1
github.com/tidwall/sjson v1.2.4
@@ -47,10 +42,10 @@ require (
github.com/uber/jaeger-lib v2.4.1+incompatible
github.com/yggdrasil-network/yggdrasil-go v0.4.3
go.uber.org/atomic v1.9.0
- golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e
+ golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa
golang.org/x/image v0.0.0-20220413100746-70e8d0d3baa9
golang.org/x/mobile v0.0.0-20220518205345-8578da9835fd
- golang.org/x/net v0.0.0-20220524220425-1d687d428aca
+ golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e
golang.org/x/term v0.0.0-20220526004731-065cf7ba2467
gopkg.in/h2non/bimg.v1 v1.1.9
gopkg.in/yaml.v2 v2.4.0
@@ -77,17 +72,18 @@ require (
github.com/h2non/filetype v1.1.3 // indirect
github.com/juju/errors v0.0.0-20220203013757-bd733f3c86b9 // indirect
github.com/juju/testing v0.0.0-20220203020004-a0ff61f03494 // indirect
- github.com/klauspost/compress v1.14.4 // indirect
- github.com/lucas-clemente/quic-go v0.26.0 // indirect
+ github.com/klauspost/compress v1.15.9 // indirect
+ github.com/lucas-clemente/quic-go v0.28.1 // indirect
github.com/marten-seemann/qtls-go1-16 v0.1.5 // indirect
- github.com/marten-seemann/qtls-go1-17 v0.1.1 // indirect
- github.com/marten-seemann/qtls-go1-18 v0.1.1 // indirect
+ github.com/marten-seemann/qtls-go1-17 v0.1.2 // indirect
+ github.com/marten-seemann/qtls-go1-18 v0.1.2 // indirect
+ github.com/marten-seemann/qtls-go1-19 v0.1.0-beta.1 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 // indirect
github.com/miekg/dns v1.1.49 // indirect
github.com/minio/highwayhash v1.0.2 // indirect
github.com/moby/term v0.0.0-20210610120745-9d4ed1856297 // indirect
github.com/morikuni/aec v1.0.0 // indirect
- github.com/nats-io/jwt/v2 v2.2.1-0.20220330180145-442af02fd36a // indirect
+ github.com/nats-io/jwt/v2 v2.3.0 // indirect
github.com/nats-io/nkeys v0.3.0 // indirect
github.com/nats-io/nuid v1.0.1 // indirect
github.com/nxadm/tail v1.4.8 // indirect
@@ -103,9 +99,9 @@ require (
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect
- golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a // indirect
+ golang.org/x/sys v0.0.0-20220731174439-a90be440212d // indirect
golang.org/x/text v0.3.8-0.20211004125949-5bd84dd9b33b // indirect
- golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 // indirect
+ golang.org/x/time v0.0.0-20220411224347-583f2d630306 // indirect
golang.org/x/tools v0.1.10 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
google.golang.org/protobuf v1.27.1 // indirect
diff --git a/go.sum b/go.sum
index 98549f702..9d5c50d2c 100644
--- a/go.sum
+++ b/go.sum
@@ -302,8 +302,8 @@ github.com/kardianos/minwinsvc v1.0.0/go.mod h1:Bgd0oc+D0Qo3bBytmNtyRKVlp85dAloL
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
-github.com/klauspost/compress v1.14.4 h1:eijASRJcobkVtSt81Olfh7JX43osYLwy5krOJo6YEu4=
-github.com/klauspost/compress v1.14.4/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk=
+github.com/klauspost/compress v1.15.9 h1:wKRjX6JRtDdrE9qwa4b/Cip7ACOshUI4smpCQanqjSY=
+github.com/klauspost/compress v1.15.9/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
@@ -321,8 +321,8 @@ github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y=
github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
github.com/lib/pq v1.10.5 h1:J+gdV2cUmX7ZqL2B0lFcW0m+egaHC2V3lpO8nWxyYiQ=
github.com/lib/pq v1.10.5/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
-github.com/lucas-clemente/quic-go v0.26.0 h1:ALBQXr9UJ8A1LyzvceX4jd9QFsHvlI0RR6BkV16o00A=
-github.com/lucas-clemente/quic-go v0.26.0/go.mod h1:AzgQoPda7N+3IqMMMkywBKggIFo2KT6pfnlrQ2QieeI=
+github.com/lucas-clemente/quic-go v0.28.1 h1:Uo0lvVxWg5la9gflIF9lwa39ONq85Xq2D91YNEIslzU=
+github.com/lucas-clemente/quic-go v0.28.1/go.mod h1:oGz5DKK41cJt5+773+BSO9BXDsREY4HLf7+0odGAPO0=
github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI=
github.com/lxn/walk v0.0.0-20210112085537-c389da54e794/go.mod h1:E23UucZGqpuUANJooIbHWCufXvOcT6E7Stq81gU+CSQ=
github.com/lxn/win v0.0.0-20210218163916-a377121e959e/go.mod h1:KxxjdtRkfNoYDCUP5ryK7XJJNTnpC8atvtmTheChOtk=
@@ -330,10 +330,12 @@ github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN
github.com/marten-seemann/qpack v0.2.1/go.mod h1:F7Gl5L1jIgN1D11ucXefiuJS9UMVP2opoCp2jDKb7wc=
github.com/marten-seemann/qtls-go1-16 v0.1.5 h1:o9JrYPPco/Nukd/HpOHMHZoBDXQqoNtUCmny98/1uqQ=
github.com/marten-seemann/qtls-go1-16 v0.1.5/go.mod h1:gNpI2Ol+lRS3WwSOtIUUtRwZEQMXjYK+dQSBFbethAk=
-github.com/marten-seemann/qtls-go1-17 v0.1.1 h1:DQjHPq+aOzUeh9/lixAGunn6rIOQyWChPSI4+hgW7jc=
-github.com/marten-seemann/qtls-go1-17 v0.1.1/go.mod h1:C2ekUKcDdz9SDWxec1N/MvcXBpaX9l3Nx67XaR84L5s=
-github.com/marten-seemann/qtls-go1-18 v0.1.1 h1:qp7p7XXUFL7fpBvSS1sWD+uSqPvzNQK43DH+/qEkj0Y=
-github.com/marten-seemann/qtls-go1-18 v0.1.1/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4=
+github.com/marten-seemann/qtls-go1-17 v0.1.2 h1:JADBlm0LYiVbuSySCHeY863dNkcpMmDR7s0bLKJeYlQ=
+github.com/marten-seemann/qtls-go1-17 v0.1.2/go.mod h1:C2ekUKcDdz9SDWxec1N/MvcXBpaX9l3Nx67XaR84L5s=
+github.com/marten-seemann/qtls-go1-18 v0.1.2 h1:JH6jmzbduz0ITVQ7ShevK10Av5+jBEKAHMntXmIV7kM=
+github.com/marten-seemann/qtls-go1-18 v0.1.2/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4=
+github.com/marten-seemann/qtls-go1-19 v0.1.0-beta.1 h1:7m/WlWcSROrcK5NxuXaxYD32BZqe/LEEnBrWcH/cOqQ=
+github.com/marten-seemann/qtls-go1-19 v0.1.0-beta.1/go.mod h1:5HTDWtVudo/WFsHKRNuOhWlbdjrfs5JHrYb0wIJqGpI=
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e h1:DP5RC0Z3XdyBEW5dKt8YPeN6vZbm6OzVaGVp7f1BQRM=
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e/go.mod h1:NgPCr+UavRGH6n5jmdX8DuqFZ4JiCWIJoZiuhTRLSUg=
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw2QV3YD/fRrzEDPNGgTlJlvXY0EHHnT87wF3OA=
@@ -341,10 +343,10 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
-github.com/matrix-org/gomatrixserverlib v0.0.0-20220711125303-3bb2e997a44c h1:mt30TDK8kXKV+nCmVfnqoXsh842N+74kvZw7DXuS/JQ=
-github.com/matrix-org/gomatrixserverlib v0.0.0-20220711125303-3bb2e997a44c/go.mod h1:jX38yp3SSLJNftBg3PXU1ayd0PCLIiDHQ4xAc9DIixk=
-github.com/matrix-org/pinecone v0.0.0-20220708135211-1ce778fcde6a h1:DdG8vXMlZ65EAtc4V+3t7zHZ2Gqs24pSnyXS+4BRHUs=
-github.com/matrix-org/pinecone v0.0.0-20220708135211-1ce778fcde6a/go.mod h1:ulJzsVOTssIVp1j/m5eI//4VpAGDkMt5NrRuAVX7wpc=
+github.com/matrix-org/gomatrixserverlib v0.0.0-20220801083850-5ff38e2c2839 h1:QEFxKWH8PlEt3ZQKl31yJNAm8lvpNUwT51IMNTl9v1k=
+github.com/matrix-org/gomatrixserverlib v0.0.0-20220801083850-5ff38e2c2839/go.mod h1:jX38yp3SSLJNftBg3PXU1ayd0PCLIiDHQ4xAc9DIixk=
+github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9 h1:ed8yvWhTLk7+sNeK/eOZRTvESFTOHDRevoRoyeqPtvY=
+github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9/go.mod h1:P4MqPf+u83OPulPJ+XTbSDbbWrdFYNY4LZ/B1PIduFE=
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
@@ -381,8 +383,12 @@ github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7P
github.com/mschoch/smat v0.0.0-20160514031455-90eadee771ae/go.mod h1:qAyveg+e4CE+eKJXWVjKXM4ck2QobLqTDytGJbLLhJg=
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
-github.com/nats-io/jwt/v2 v2.2.1-0.20220330180145-442af02fd36a h1:lem6QCvxR0Y28gth9P+wV2K/zYUUAkJ+55U8cpS0p5I=
-github.com/nats-io/jwt/v2 v2.2.1-0.20220330180145-442af02fd36a/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k=
+github.com/nats-io/jwt/v2 v2.3.0 h1:z2mA1a7tIf5ShggOFlR1oBPgd6hGqcDYsISxZByUzdI=
+github.com/nats-io/jwt/v2 v2.3.0/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k=
+github.com/nats-io/nats-server/v2 v2.8.5-0.20220731184415-903a06a5b4ee h1:vAtoZ+LW6eIUjkCWWwO1DZ6o16UGrVOG+ot/AkwejO8=
+github.com/nats-io/nats-server/v2 v2.8.5-0.20220731184415-903a06a5b4ee/go.mod h1:3Yg3ApyQxPlAs1KKHKV5pobV5VtZk+TtOiUJx/iqkkg=
+github.com/nats-io/nats.go v1.16.1-0.20220731182438-87bbea85922b h1:CE9wSYLvwq8aC/0+6zH8lhhtZYvJ9p8PzwvZeYgdBc0=
+github.com/nats-io/nats.go v1.16.1-0.20220731182438-87bbea85922b/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w=
github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8=
github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4=
github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw=
@@ -390,10 +396,6 @@ github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OS
github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uYLpLIr5fm8diHn0JbqRycJi6w0Ms=
github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo=
github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM=
-github.com/neilalexander/nats-server/v2 v2.8.3-0.20220513095553-73a9a246d34f h1:Fc+TjdV1mOy0oISSzfoxNWdTqjg7tN/Vdgf+B2cwvdo=
-github.com/neilalexander/nats-server/v2 v2.8.3-0.20220513095553-73a9a246d34f/go.mod h1:vIdpKz3OG+DCg4q/xVPdXHoztEyKDWRtykQ4N7hd7C4=
-github.com/neilalexander/nats.go v1.13.1-0.20220621084451-ac518c356673 h1:TcKfa3Tf0qwUotv63PQVu2d1bBoLi2iEA4RHVMGDh5M=
-github.com/neilalexander/nats.go v1.13.1-0.20220621084451-ac518c356673/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w=
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 h1:lrVQzBtkeQEGGYUHwSX1XPe1E5GL6U3KYCNe2G4bncQ=
github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9/go.mod h1:NPHGhPc0/wudcaCqL/H5AOddkRf8GPRhzOujuUKGQu8=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
@@ -432,8 +434,6 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
-github.com/pressly/goose v2.7.0+incompatible h1:PWejVEv07LCerQEzMMeAtjuyCKbyprZ/LBa6K5P0OCQ=
-github.com/pressly/goose v2.7.0+incompatible/go.mod h1:m+QHWCqxR3k8D9l7qfzuC/djtlfzxr34mozWDYEu1z8=
github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
@@ -493,8 +493,8 @@ github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPx
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
-github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE=
-github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
+github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0=
+github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
github.com/smartystreets/goconvey v0.0.0-20181108003508-044398e4856c/go.mod h1:XDJAKZRPZ1CvBcN2aX5YOUTYGHki24fSF0Iv48Ibg0s=
github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE=
@@ -551,7 +551,6 @@ go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
-go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE=
@@ -570,8 +569,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh
golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
-golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e h1:T8NU3HyQ8ClP4SEE+KbFlg6n0NhuTsN4MyznaarGsZM=
-golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
+golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c=
+golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@@ -661,8 +660,8 @@ golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qx
golang.org/x/net v0.0.0-20210927181540-4e4d966f7476/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20211011170408-caeb26a5c8c0/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20211101193420-4a448f8816b3/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
-golang.org/x/net v0.0.0-20220524220425-1d687d428aca h1:xTaFYiPROfpPhqrfTIDXj0ri1SpfueYT951s4bAuDO8=
-golang.org/x/net v0.0.0-20220524220425-1d687d428aca/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
+golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e h1:TsQ7F31D3bUCLeqPT0u+yjp1guoArKaNKmCr22PYgTQ=
+golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
@@ -749,9 +748,12 @@ golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20211102192858-4dd72447c267/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a h1:dGzPydgVsqGcTRVwiLJ1jVbufYwmzD3LfVPLKsKg+0k=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20220731174439-a90be440212d h1:Sv5ogFZatcgIMMtBSTTAgMYsicp25MXBubjXNDKwm80=
+golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
+golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.0.0-20220526004731-065cf7ba2467 h1:CBpWXWQpIRjzmkkA+M7q9Fqnwd2mZr3AFqexg8YTfoM=
golang.org/x/term v0.0.0-20220526004731-065cf7ba2467/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@@ -760,14 +762,15 @@ golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
+golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.3.8-0.20211004125949-5bd84dd9b33b h1:NXqSWXSRUSCaFuvitrWtU169I3876zRTalMRbfd6LL0=
golang.org/x/text v0.3.8-0.20211004125949-5bd84dd9b33b/go.mod h1:EFNZuWvGYxIRUEX+K8UmCFwYmZjqcrnq15ZuVldZkZ0=
golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
-golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 h1:GZokNIeuVkl3aZHJchRrr13WCsols02MLUcz1U9is6M=
-golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
+golang.org/x/time v0.0.0-20220411224347-583f2d630306 h1:+gHMid33q6pen7kv9xvT+JRinntgeXO2AeZVd0AWD3w=
+golang.org/x/time v0.0.0-20220411224347-583f2d630306/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
diff --git a/internal/caching/cache_eventstatekeys.go b/internal/caching/cache_eventstatekeys.go
new file mode 100644
index 000000000..05580ab05
--- /dev/null
+++ b/internal/caching/cache_eventstatekeys.go
@@ -0,0 +1,18 @@
+package caching
+
+import "github.com/matrix-org/dendrite/roomserver/types"
+
+// EventStateKeyCache contains the subset of functions needed for
+// a room event state key cache.
+type EventStateKeyCache interface {
+ GetEventStateKey(eventStateKeyNID types.EventStateKeyNID) (string, bool)
+ StoreEventStateKey(eventStateKeyNID types.EventStateKeyNID, eventStateKey string)
+}
+
+func (c Caches) GetEventStateKey(eventStateKeyNID types.EventStateKeyNID) (string, bool) {
+ return c.RoomServerStateKeys.Get(eventStateKeyNID)
+}
+
+func (c Caches) StoreEventStateKey(eventStateKeyNID types.EventStateKeyNID, eventStateKey string) {
+ c.RoomServerStateKeys.Set(eventStateKeyNID, eventStateKey)
+}
diff --git a/internal/caching/cache_lazy_load_members.go b/internal/caching/cache_lazy_load_members.go
index 0d7009c94..390334da7 100644
--- a/internal/caching/cache_lazy_load_members.go
+++ b/internal/caching/cache_lazy_load_members.go
@@ -14,6 +14,7 @@ type lazyLoadingCacheKey struct {
type LazyLoadCache interface {
StoreLazyLoadedUser(device *userapi.Device, roomID, userID, eventID string)
IsLazyLoadedUserCached(device *userapi.Device, roomID, userID string) (string, bool)
+ InvalidateLazyLoadedUser(device *userapi.Device, roomID, userID string)
}
func (c Caches) StoreLazyLoadedUser(device *userapi.Device, roomID, userID, eventID string) {
@@ -33,3 +34,12 @@ func (c Caches) IsLazyLoadedUserCached(device *userapi.Device, roomID, userID st
TargetUserID: userID,
})
}
+
+func (c Caches) InvalidateLazyLoadedUser(device *userapi.Device, roomID, userID string) {
+ c.LazyLoading.Unset(lazyLoadingCacheKey{
+ UserID: device.UserID,
+ DeviceID: device.ID,
+ RoomID: roomID,
+ TargetUserID: userID,
+ })
+}
diff --git a/internal/caching/cache_roominfo.go b/internal/caching/cache_roominfo.go
deleted file mode 100644
index d03a61077..000000000
--- a/internal/caching/cache_roominfo.go
+++ /dev/null
@@ -1,33 +0,0 @@
-package caching
-
-import (
- "github.com/matrix-org/dendrite/roomserver/types"
-)
-
-// WARNING: This cache is mutable because it's entirely possible that
-// the IsStub or StateSnaphotNID fields can change, even though the
-// room version and room NID fields will not. This is only safe because
-// the RoomInfoCache is used ONLY within the roomserver and because it
-// will be kept up-to-date by the latest events updater. It MUST NOT be
-// used from other components as we currently have no way to invalidate
-// the cache in downstream components.
-
-// RoomInfosCache contains the subset of functions needed for
-// a room Info cache. It must only be used from the roomserver only
-// It is not safe for use from other components.
-type RoomInfoCache interface {
- GetRoomInfo(roomID string) (roomInfo types.RoomInfo, ok bool)
- StoreRoomInfo(roomID string, roomInfo types.RoomInfo)
-}
-
-// GetRoomInfo must only be called from the roomserver only. It is not
-// safe for use from other components.
-func (c Caches) GetRoomInfo(roomID string) (types.RoomInfo, bool) {
- return c.RoomInfos.Get(roomID)
-}
-
-// StoreRoomInfo must only be called from the roomserver only. It is not
-// safe for use from other components.
-func (c Caches) StoreRoomInfo(roomID string, roomInfo types.RoomInfo) {
- c.RoomInfos.Set(roomID, roomInfo)
-}
diff --git a/internal/caching/cache_roomservernids.go b/internal/caching/cache_roomservernids.go
index b409aeef2..88a5b28bc 100644
--- a/internal/caching/cache_roomservernids.go
+++ b/internal/caching/cache_roomservernids.go
@@ -7,8 +7,8 @@ import (
type RoomServerCaches interface {
RoomServerNIDsCache
RoomVersionCache
- RoomInfoCache
RoomServerEventsCache
+ EventStateKeyCache
}
// RoomServerNIDsCache contains the subset of functions needed for
@@ -19,9 +19,9 @@ type RoomServerNIDsCache interface {
}
func (c Caches) GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) {
- return c.RoomServerRoomIDs.Get(int64(roomNID))
+ return c.RoomServerRoomIDs.Get(roomNID)
}
func (c Caches) StoreRoomServerRoomID(roomNID types.RoomNID, roomID string) {
- c.RoomServerRoomIDs.Set(int64(roomNID), roomID)
+ c.RoomServerRoomIDs.Set(roomNID, roomID)
}
diff --git a/internal/caching/caches.go b/internal/caching/caches.go
index 14b232dd0..78c9ab7ee 100644
--- a/internal/caching/caches.go
+++ b/internal/caching/caches.go
@@ -23,16 +23,16 @@ import (
// different implementations as long as they satisfy the Cache
// interface.
type Caches struct {
- RoomVersions Cache[string, gomatrixserverlib.RoomVersion] // room ID -> room version
- ServerKeys Cache[string, gomatrixserverlib.PublicKeyLookupResult] // server name -> server keys
- RoomServerRoomNIDs Cache[string, types.RoomNID] // room ID -> room NID
- RoomServerRoomIDs Cache[int64, string] // room NID -> room ID
- RoomServerEvents Cache[int64, *gomatrixserverlib.Event] // event NID -> event
- RoomInfos Cache[string, types.RoomInfo] // room ID -> room info
- FederationPDUs Cache[int64, *gomatrixserverlib.HeaderedEvent] // queue NID -> PDU
- FederationEDUs Cache[int64, *gomatrixserverlib.EDU] // queue NID -> EDU
- SpaceSummaryRooms Cache[string, gomatrixserverlib.MSC2946SpacesResponse] // room ID -> space response
- LazyLoading Cache[lazyLoadingCacheKey, string] // composite key -> event ID
+ RoomVersions Cache[string, gomatrixserverlib.RoomVersion] // room ID -> room version
+ ServerKeys Cache[string, gomatrixserverlib.PublicKeyLookupResult] // server name -> server keys
+ RoomServerRoomNIDs Cache[string, types.RoomNID] // room ID -> room NID
+ RoomServerRoomIDs Cache[types.RoomNID, string] // room NID -> room ID
+ RoomServerEvents Cache[int64, *gomatrixserverlib.Event] // event NID -> event
+ RoomServerStateKeys Cache[types.EventStateKeyNID, string] // event NID -> event state key
+ FederationPDUs Cache[int64, *gomatrixserverlib.HeaderedEvent] // queue NID -> PDU
+ FederationEDUs Cache[int64, *gomatrixserverlib.EDU] // queue NID -> EDU
+ SpaceSummaryRooms Cache[string, gomatrixserverlib.MSC2946SpacesResponse] // room ID -> space response
+ LazyLoading Cache[lazyLoadingCacheKey, string] // composite key -> event ID
}
// Cache is the interface that an implementation must satisfy.
@@ -44,7 +44,7 @@ type Cache[K keyable, T any] interface {
type keyable interface {
// from https://github.com/dgraph-io/ristretto/blob/8e850b710d6df0383c375ec6a7beae4ce48fc8d5/z/z.go#L34
- uint64 | string | []byte | byte | int | int32 | uint32 | int64 | lazyLoadingCacheKey
+ ~uint64 | ~string | []byte | byte | ~int | ~int32 | ~uint32 | ~int64 | lazyLoadingCacheKey
}
type costable interface {
diff --git a/internal/caching/impl_ristretto.go b/internal/caching/impl_ristretto.go
index 6d625b552..fc0c8cc0f 100644
--- a/internal/caching/impl_ristretto.go
+++ b/internal/caching/impl_ristretto.go
@@ -35,18 +35,18 @@ const (
roomNIDsCache
roomIDsCache
roomEventsCache
- roomInfosCache
federationPDUsCache
federationEDUsCache
spaceSummaryRoomsCache
lazyLoadingCache
+ eventStateKeyCache
)
func NewRistrettoCache(maxCost config.DataUnit, maxAge time.Duration, enablePrometheus bool) *Caches {
cache, err := ristretto.NewCache(&ristretto.Config{
- NumCounters: 1e5, // 10x number of expected cache items, affects bloom filter size, gives us room for 10,000 currently
- BufferItems: 64, // recommended by the ristretto godocs as a sane buffer size value
- MaxCost: int64(maxCost),
+ NumCounters: int64((maxCost / 1024) * 10), // 10 counters per 1KB data, affects bloom filter size
+ BufferItems: 64, // recommended by the ristretto godocs as a sane buffer size value
+ MaxCost: int64(maxCost), // max cost is in bytes, as per the Dendrite config
Metrics: true,
KeyToHash: func(key interface{}) (uint64, uint64) {
return z.KeyToHash(key)
@@ -88,7 +88,7 @@ func NewRistrettoCache(maxCost config.DataUnit, maxAge time.Duration, enableProm
Prefix: roomNIDsCache,
MaxAge: maxAge,
},
- RoomServerRoomIDs: &RistrettoCachePartition[int64, string]{ // room NID -> room ID
+ RoomServerRoomIDs: &RistrettoCachePartition[types.RoomNID, string]{ // room NID -> room ID
cache: cache,
Prefix: roomIDsCache,
MaxAge: maxAge,
@@ -100,11 +100,10 @@ func NewRistrettoCache(maxCost config.DataUnit, maxAge time.Duration, enableProm
MaxAge: maxAge,
},
},
- RoomInfos: &RistrettoCachePartition[string, types.RoomInfo]{ // room ID -> room info
- cache: cache,
- Prefix: roomInfosCache,
- Mutable: true,
- MaxAge: maxAge,
+ RoomServerStateKeys: &RistrettoCachePartition[types.EventStateKeyNID, string]{ // event NID -> event state key
+ cache: cache,
+ Prefix: eventStateKeyCache,
+ MaxAge: maxAge,
},
FederationPDUs: &RistrettoCostedCachePartition[int64, *gomatrixserverlib.HeaderedEvent]{ // queue NID -> PDU
&RistrettoCachePartition[int64, *gomatrixserverlib.HeaderedEvent]{
diff --git a/internal/httputil/http.go b/internal/httputil/http.go
index 4527e2b95..1e07ee33c 100644
--- a/internal/httputil/http.go
+++ b/internal/httputil/http.go
@@ -19,19 +19,21 @@ import (
"context"
"encoding/json"
"fmt"
+ "io"
"net/http"
"net/url"
"strings"
- "github.com/matrix-org/dendrite/userapi/api"
opentracing "github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
)
-// PostJSON performs a POST request with JSON on an internal HTTP API
-func PostJSON(
+// PostJSON performs a POST request with JSON on an internal HTTP API.
+// The error will match the errtype if returned from the remote API, or
+// will be a different type if there was a problem reaching the API.
+func PostJSON[reqtype, restype any, errtype error](
ctx context.Context, span opentracing.Span, httpClient *http.Client,
- apiURL string, request, response interface{},
+ apiURL string, request *reqtype, response *restype,
) error {
jsonBytes, err := json.Marshal(request)
if err != nil {
@@ -69,17 +71,23 @@ func PostJSON(
if err != nil {
return err
}
- if res.StatusCode != http.StatusOK {
- var errorBody struct {
- Message string `json:"message"`
- }
- if _, ok := response.(*api.PerformKeyBackupResponse); ok { // TODO: remove this, once cross-boundary errors are a thing
- return nil
- }
- if msgerr := json.NewDecoder(res.Body).Decode(&errorBody); msgerr == nil {
- return fmt.Errorf("internal API: %d from %s: %s", res.StatusCode, apiURL, errorBody.Message)
- }
- return fmt.Errorf("internal API: %d from %s", res.StatusCode, apiURL)
+ var body []byte
+ body, err = io.ReadAll(res.Body)
+ if err != nil {
+ return err
}
- return json.NewDecoder(res.Body).Decode(response)
+ if res.StatusCode != http.StatusOK {
+ if len(body) == 0 {
+ return fmt.Errorf("HTTP %d from %s (no response body)", res.StatusCode, apiURL)
+ }
+ var reserr errtype
+ if err = json.Unmarshal(body, reserr); err != nil {
+ return fmt.Errorf("HTTP %d from %s", res.StatusCode, apiURL)
+ }
+ return reserr
+ }
+ if err = json.Unmarshal(body, response); err != nil {
+ return fmt.Errorf("json.Unmarshal: %w", err)
+ }
+ return nil
}
diff --git a/internal/httputil/internalapi.go b/internal/httputil/internalapi.go
new file mode 100644
index 000000000..385092d9c
--- /dev/null
+++ b/internal/httputil/internalapi.go
@@ -0,0 +1,93 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package httputil
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "reflect"
+
+ "github.com/matrix-org/util"
+ opentracing "github.com/opentracing/opentracing-go"
+)
+
+type InternalAPIError struct {
+ Type string
+ Message string
+}
+
+func (e InternalAPIError) Error() string {
+ return fmt.Sprintf("internal API returned %q error: %s", e.Type, e.Message)
+}
+
+func MakeInternalRPCAPI[reqtype, restype any](metricsName string, f func(context.Context, *reqtype, *restype) error) http.Handler {
+ return MakeInternalAPI(metricsName, func(req *http.Request) util.JSONResponse {
+ var request reqtype
+ var response restype
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.MessageResponse(http.StatusBadRequest, err.Error())
+ }
+ if err := f(req.Context(), &request, &response); err != nil {
+ return util.JSONResponse{
+ Code: http.StatusInternalServerError,
+ JSON: &InternalAPIError{
+ Type: reflect.TypeOf(err).String(),
+ Message: fmt.Sprintf("%s", err),
+ },
+ }
+ }
+ return util.JSONResponse{
+ Code: http.StatusOK,
+ JSON: &response,
+ }
+ })
+}
+
+func MakeInternalProxyAPI[reqtype, restype any](metricsName string, f func(context.Context, *reqtype) (*restype, error)) http.Handler {
+ return MakeInternalAPI(metricsName, func(req *http.Request) util.JSONResponse {
+ var request reqtype
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.MessageResponse(http.StatusBadRequest, err.Error())
+ }
+ response, err := f(req.Context(), &request)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusInternalServerError,
+ JSON: err,
+ }
+ }
+ return util.JSONResponse{
+ Code: http.StatusOK,
+ JSON: response,
+ }
+ })
+}
+
+func CallInternalRPCAPI[reqtype, restype any](name, url string, client *http.Client, ctx context.Context, request *reqtype, response *restype) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, name)
+ defer span.Finish()
+
+ return PostJSON[reqtype, restype, InternalAPIError](ctx, span, client, url, request, response)
+}
+
+func CallInternalProxyAPI[reqtype, restype any, errtype error](name, url string, client *http.Client, ctx context.Context, request *reqtype) (restype, error) {
+ span, ctx := opentracing.StartSpanFromContext(ctx, name)
+ defer span.Finish()
+
+ var response restype
+ return response, PostJSON[reqtype, restype, errtype](ctx, span, client, url, request, &response)
+}
diff --git a/internal/log.go b/internal/log.go
index bba0ac6e6..a171555ab 100644
--- a/internal/log.go
+++ b/internal/log.go
@@ -27,9 +27,10 @@ import (
"github.com/matrix-org/util"
- "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dugong"
"github.com/sirupsen/logrus"
+
+ "github.com/matrix-org/dendrite/setup/config"
)
type utcFormatter struct {
@@ -145,7 +146,7 @@ func setupFileHook(hook config.LogrusHook, level logrus.Level, componentName str
})
}
-//CloseAndLogIfError Closes io.Closer and logs the error if any
+// CloseAndLogIfError Closes io.Closer and logs the error if any
func CloseAndLogIfError(ctx context.Context, closer io.Closer, message string) {
if closer == nil {
return
diff --git a/internal/log_unix.go b/internal/log_unix.go
index 1e1094f23..75332af73 100644
--- a/internal/log_unix.go
+++ b/internal/log_unix.go
@@ -18,7 +18,7 @@
package internal
import (
- "io/ioutil"
+ "io"
"log/syslog"
"github.com/MFAshby/stdemuxerhook"
@@ -63,7 +63,7 @@ func SetupHookLogging(hooks []config.LogrusHook, componentName string) {
setupStdLogHook(logrus.InfoLevel)
}
// Hooks are now configured for stdout/err, so throw away the default logger output
- logrus.SetOutput(ioutil.Discard)
+ logrus.SetOutput(io.Discard)
}
func checkSyslogHookParams(params map[string]interface{}) {
diff --git a/internal/sqlutil/migrate.go b/internal/sqlutil/migrate.go
index 7518df3c8..18020a902 100644
--- a/internal/sqlutil/migrate.go
+++ b/internal/sqlutil/migrate.go
@@ -1,130 +1,142 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
package sqlutil
import (
+ "context"
"database/sql"
"fmt"
- "runtime"
- "sort"
+ "sync"
+ "time"
- "github.com/matrix-org/dendrite/setup/config"
- "github.com/pressly/goose"
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/sirupsen/logrus"
)
-type Migrations struct {
- registeredGoMigrations map[int64]*goose.Migration
+const createDBMigrationsSQL = "" +
+ "CREATE TABLE IF NOT EXISTS db_migrations (" +
+ " version TEXT PRIMARY KEY NOT NULL," +
+ " time TEXT NOT NULL," +
+ " dendrite_version TEXT NOT NULL" +
+ ");"
+
+const insertVersionSQL = "" +
+ "INSERT INTO db_migrations (version, time, dendrite_version)" +
+ " VALUES ($1, $2, $3)"
+
+const selectDBMigrationsSQL = "SELECT version FROM db_migrations"
+
+// Migration defines a migration to be run.
+type Migration struct {
+ // Version is a simple description/name of this migration.
+ Version string
+ // Up defines the function to execute for an upgrade.
+ Up func(ctx context.Context, txn *sql.Tx) error
+ // Down defines the function to execute for a downgrade (not implemented yet).
+ Down func(ctx context.Context, txn *sql.Tx) error
}
-func NewMigrations() *Migrations {
- return &Migrations{
- registeredGoMigrations: make(map[int64]*goose.Migration),
+// Migrator
+type Migrator struct {
+ db *sql.DB
+ migrations []Migration
+ knownMigrations map[string]struct{}
+ mutex *sync.Mutex
+}
+
+// NewMigrator creates a new DB migrator.
+func NewMigrator(db *sql.DB) *Migrator {
+ return &Migrator{
+ db: db,
+ migrations: []Migration{},
+ knownMigrations: make(map[string]struct{}),
+ mutex: &sync.Mutex{},
}
}
-// Copy-pasted from goose directly to store migrations into a map we control
-
-// AddMigration adds a migration.
-func (m *Migrations) AddMigration(up func(*sql.Tx) error, down func(*sql.Tx) error) {
- _, filename, _, _ := runtime.Caller(1)
- m.AddNamedMigration(filename, up, down)
-}
-
-// AddNamedMigration : Add a named migration.
-func (m *Migrations) AddNamedMigration(filename string, up func(*sql.Tx) error, down func(*sql.Tx) error) {
- v, _ := goose.NumericComponent(filename)
- migration := &goose.Migration{Version: v, Next: -1, Previous: -1, Registered: true, UpFn: up, DownFn: down, Source: filename}
-
- if existing, ok := m.registeredGoMigrations[v]; ok {
- panic(fmt.Sprintf("failed to add migration %q: version conflicts with %q", filename, existing.Source))
+// AddMigrations appends migrations to the list of migrations. Migrations are executed
+// in the order they are added to the list. De-duplicates migrations using their Version field.
+func (m *Migrator) AddMigrations(migrations ...Migration) {
+ m.mutex.Lock()
+ defer m.mutex.Unlock()
+ for _, mig := range migrations {
+ if _, ok := m.knownMigrations[mig.Version]; !ok {
+ m.migrations = append(m.migrations, mig)
+ m.knownMigrations[mig.Version] = struct{}{}
+ }
}
-
- m.registeredGoMigrations[v] = migration
}
-// RunDeltas up to the latest version.
-func (m *Migrations) RunDeltas(db *sql.DB, props *config.DatabaseOptions) error {
- maxVer := goose.MaxVersion
- minVer := int64(0)
- migrations, err := m.collect(minVer, maxVer)
+// Up executes all migrations in order they were added.
+func (m *Migrator) Up(ctx context.Context) error {
+ var (
+ err error
+ dendriteVersion = internal.VersionString()
+ )
+ // ensure there is a table for known migrations
+ executedMigrations, err := m.ExecutedMigrations(ctx)
if err != nil {
- return fmt.Errorf("runDeltas: Failed to collect migrations: %w", err)
+ return fmt.Errorf("unable to create/get migrations: %w", err)
}
- if props.ConnectionString.IsPostgres() {
- if err = goose.SetDialect("postgres"); err != nil {
- return err
- }
- } else if props.ConnectionString.IsSQLite() {
- if err = goose.SetDialect("sqlite3"); err != nil {
- return err
- }
- } else {
- return fmt.Errorf("unknown connection string: %s", props.ConnectionString)
- }
- for {
- current, err := goose.EnsureDBVersion(db)
- if err != nil {
- return fmt.Errorf("runDeltas: Failed to EnsureDBVersion: %w", err)
- }
- next, err := migrations.Next(current)
- if err != nil {
- if err == goose.ErrNoNextVersion {
- return nil
+ return WithTransaction(m.db, func(txn *sql.Tx) error {
+ for i := range m.migrations {
+ now := time.Now().UTC().Format(time.RFC3339)
+ migration := m.migrations[i]
+ logrus.Debugf("Executing database migration '%s'", migration.Version)
+ // Skip migration if it was already executed
+ if _, ok := executedMigrations[migration.Version]; ok {
+ continue
+ }
+ err = migration.Up(ctx, txn)
+ if err != nil {
+ return fmt.Errorf("unable to execute migration '%s': %w", migration.Version, err)
+ }
+ _, err = txn.ExecContext(ctx, insertVersionSQL,
+ migration.Version,
+ now,
+ dendriteVersion,
+ )
+ if err != nil {
+ return fmt.Errorf("unable to insert executed migrations: %w", err)
}
-
- return fmt.Errorf("runDeltas: Failed to load next migration to %+v : %w", next, err)
}
-
- if err = next.Up(db); err != nil {
- return fmt.Errorf("runDeltas: Failed run migration: %w", err)
- }
- }
+ return nil
+ })
}
-func (m *Migrations) collect(current, target int64) (goose.Migrations, error) {
- var migrations goose.Migrations
-
- // Go migrations registered via goose.AddMigration().
- for _, migration := range m.registeredGoMigrations {
- v, err := goose.NumericComponent(migration.Source)
- if err != nil {
- return nil, err
- }
- if versionFilter(v, current, target) {
- migrations = append(migrations, migration)
+// ExecutedMigrations returns a map with already executed migrations in addition to creating the
+// migrations table, if it doesn't exist.
+func (m *Migrator) ExecutedMigrations(ctx context.Context) (map[string]struct{}, error) {
+ result := make(map[string]struct{})
+ _, err := m.db.ExecContext(ctx, createDBMigrationsSQL)
+ if err != nil {
+ return nil, fmt.Errorf("unable to create db_migrations: %w", err)
+ }
+ rows, err := m.db.QueryContext(ctx, selectDBMigrationsSQL)
+ if err != nil {
+ return nil, fmt.Errorf("unable to query db_migrations: %w", err)
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "ExecutedMigrations: rows.close() failed")
+ var version string
+ for rows.Next() {
+ if err = rows.Scan(&version); err != nil {
+ return nil, fmt.Errorf("unable to scan version: %w", err)
}
+ result[version] = struct{}{}
}
- migrations = sortAndConnectMigrations(migrations)
-
- return migrations, nil
-}
-
-func sortAndConnectMigrations(migrations goose.Migrations) goose.Migrations {
- sort.Sort(migrations)
-
- // now that we're sorted in the appropriate direction,
- // populate next and previous for each migration
- for i, m := range migrations {
- prev := int64(-1)
- if i > 0 {
- prev = migrations[i-1].Version
- migrations[i-1].Next = m.Version
- }
- migrations[i].Previous = prev
- }
-
- return migrations
-}
-
-func versionFilter(v, current, target int64) bool {
-
- if target > current {
- return v > current && v <= target
- }
-
- if target < current {
- return v <= current && v > target
- }
-
- return false
+ return result, rows.Err()
}
diff --git a/internal/sqlutil/migrate_test.go b/internal/sqlutil/migrate_test.go
new file mode 100644
index 000000000..d8bcae196
--- /dev/null
+++ b/internal/sqlutil/migrate_test.go
@@ -0,0 +1,112 @@
+package sqlutil_test
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "reflect"
+ "testing"
+
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/test"
+ _ "github.com/mattn/go-sqlite3"
+)
+
+var dummyMigrations = []sqlutil.Migration{
+ {
+ Version: "init",
+ Up: func(ctx context.Context, txn *sql.Tx) error {
+ _, err := txn.ExecContext(ctx, "CREATE TABLE IF NOT EXISTS dummy ( test TEXT );")
+ return err
+ },
+ },
+ {
+ Version: "v2",
+ Up: func(ctx context.Context, txn *sql.Tx) error {
+ _, err := txn.ExecContext(ctx, "ALTER TABLE dummy ADD COLUMN test2 TEXT;")
+ return err
+ },
+ },
+ {
+ Version: "v2", // duplicate, this migration will be skipped
+ Up: func(ctx context.Context, txn *sql.Tx) error {
+ _, err := txn.ExecContext(ctx, "ALTER TABLE dummy ADD COLUMN test2 TEXT;")
+ return err
+ },
+ },
+ {
+ Version: "multiple execs",
+ Up: func(ctx context.Context, txn *sql.Tx) error {
+ _, err := txn.ExecContext(ctx, "ALTER TABLE dummy ADD COLUMN test3 TEXT;")
+ if err != nil {
+ return err
+ }
+ _, err = txn.ExecContext(ctx, "ALTER TABLE dummy ADD COLUMN test4 TEXT;")
+ return err
+ },
+ },
+}
+
+var failMigration = sqlutil.Migration{
+ Version: "iFail",
+ Up: func(ctx context.Context, txn *sql.Tx) error {
+ return fmt.Errorf("iFail")
+ },
+ Down: nil,
+}
+
+func Test_migrations_Up(t *testing.T) {
+ withFail := append(dummyMigrations, failMigration)
+
+ tests := []struct {
+ name string
+ migrations []sqlutil.Migration
+ wantResult map[string]struct{}
+ wantErr bool
+ }{
+ {
+ name: "dummy migration",
+ migrations: dummyMigrations,
+ wantResult: map[string]struct{}{
+ "init": {},
+ "v2": {},
+ "multiple execs": {},
+ },
+ },
+ {
+ name: "with fail",
+ migrations: withFail,
+ wantErr: true,
+ },
+ }
+
+ ctx := context.Background()
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ conStr, close := test.PrepareDBConnectionString(t, dbType)
+ defer close()
+ driverName := "sqlite3"
+ if dbType == test.DBTypePostgres {
+ driverName = "postgres"
+ }
+ db, err := sql.Open(driverName, conStr)
+ if err != nil {
+ t.Errorf("unable to open database: %v", err)
+ }
+ m := sqlutil.NewMigrator(db)
+ m.AddMigrations(tt.migrations...)
+ if err = m.Up(ctx); (err != nil) != tt.wantErr {
+ t.Errorf("Up() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ result, err := m.ExecutedMigrations(ctx)
+ if err != nil {
+ t.Errorf("unable to get executed migrations: %v", err)
+ }
+ if !tt.wantErr && !reflect.DeepEqual(result, tt.wantResult) {
+ t.Errorf("expected: %+v, got %v", tt.wantResult, result)
+ }
+ })
+ })
+ }
+}
diff --git a/internal/version.go b/internal/version.go
index 9568f08cb..38d0864e7 100644
--- a/internal/version.go
+++ b/internal/version.go
@@ -16,8 +16,8 @@ var build string
const (
VersionMajor = 0
- VersionMinor = 8
- VersionPatch = 9
+ VersionMinor = 9
+ VersionPatch = 1
VersionTag = "" // example: "rc1"
)
diff --git a/keyserver/api/api.go b/keyserver/api/api.go
index c0a1eedbb..9ba3988b9 100644
--- a/keyserver/api/api.go
+++ b/keyserver/api/api.go
@@ -38,32 +38,32 @@ type KeyInternalAPI interface {
// API functions required by the clientapi
type ClientKeyAPI interface {
- QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse)
- PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse)
- PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse)
- PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse)
+ QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error
+ PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error
+ PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error
+ PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) error
// PerformClaimKeys claims one-time keys for use in pre-key messages
- PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse)
+ PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error
}
// API functions required by the userapi
type UserKeyAPI interface {
- PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse)
- PerformDeleteKeys(ctx context.Context, req *PerformDeleteKeysRequest, res *PerformDeleteKeysResponse)
+ PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error
+ PerformDeleteKeys(ctx context.Context, req *PerformDeleteKeysRequest, res *PerformDeleteKeysResponse) error
}
// API functions required by the syncapi
type SyncKeyAPI interface {
- QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse)
- QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse)
+ QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error
+ QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error
}
type FederationKeyAPI interface {
- QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse)
- QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse)
- QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse)
- PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse)
- PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse)
+ QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error
+ QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error
+ QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error
+ PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error
+ PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error
}
// KeyError is returned if there was a problem performing/querying the server
diff --git a/keyserver/internal/cross_signing.go b/keyserver/internal/cross_signing.go
index 08bbfedb8..99859dff6 100644
--- a/keyserver/internal/cross_signing.go
+++ b/keyserver/internal/cross_signing.go
@@ -103,7 +103,7 @@ func sanityCheckKey(key gomatrixserverlib.CrossSigningKey, userID string, purpos
}
// nolint:gocyclo
-func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) {
+func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error {
// Find the keys to store.
byPurpose := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{}
toStore := types.CrossSigningKeyMap{}
@@ -115,7 +115,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
Err: "Master key sanity check failed: " + err.Error(),
IsInvalidParam: true,
}
- return
+ return nil
}
byPurpose[gomatrixserverlib.CrossSigningKeyPurposeMaster] = req.MasterKey
@@ -131,7 +131,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
Err: "Self-signing key sanity check failed: " + err.Error(),
IsInvalidParam: true,
}
- return
+ return nil
}
byPurpose[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning] = req.SelfSigningKey
@@ -146,7 +146,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
Err: "User-signing key sanity check failed: " + err.Error(),
IsInvalidParam: true,
}
- return
+ return nil
}
byPurpose[gomatrixserverlib.CrossSigningKeyPurposeUserSigning] = req.UserSigningKey
@@ -161,7 +161,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
Err: "No keys were supplied in the request",
IsMissingParam: true,
}
- return
+ return nil
}
// We can't have a self-signing or user-signing key without a master
@@ -174,7 +174,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
res.Error = &api.KeyError{
Err: "Retrieving cross-signing keys from database failed: " + err.Error(),
}
- return
+ return nil
}
// If we still can't find a master key for the user then stop the upload.
@@ -185,7 +185,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
Err: "No master key was found",
IsMissingParam: true,
}
- return
+ return nil
}
}
@@ -212,7 +212,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
}
}
if !changed {
- return
+ return nil
}
// Store the keys.
@@ -220,7 +220,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.DB.StoreCrossSigningKeysForUser: %s", err),
}
- return
+ return nil
}
// Now upload any signatures that were included with the keys.
@@ -238,7 +238,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.DB.StoreCrossSigningSigsForTarget: %s", err),
}
- return
+ return nil
}
}
}
@@ -255,17 +255,18 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
update.SelfSigningKey = &ssk
}
if update.MasterKey == nil && update.SelfSigningKey == nil {
- return
+ return nil
}
if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err),
}
- return
+ return nil
}
+ return nil
}
-func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req *api.PerformUploadDeviceSignaturesRequest, res *api.PerformUploadDeviceSignaturesResponse) {
+func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req *api.PerformUploadDeviceSignaturesRequest, res *api.PerformUploadDeviceSignaturesResponse) error {
// Before we do anything, we need the master and self-signing keys for this user.
// Then we can verify the signatures make sense.
queryReq := &api.QueryKeysRequest{
@@ -276,7 +277,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req
for userID := range req.Signatures {
queryReq.UserToDevices[userID] = []string{}
}
- a.QueryKeys(ctx, queryReq, queryRes)
+ _ = a.QueryKeys(ctx, queryReq, queryRes)
selfSignatures := map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
otherSignatures := map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
@@ -322,14 +323,14 @@ func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.processSelfSignatures: %s", err),
}
- return
+ return nil
}
if err := a.processOtherSignatures(ctx, req.UserID, queryRes, otherSignatures); err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.processOtherSignatures: %s", err),
}
- return
+ return nil
}
// Finally, generate a notification that we updated the signatures.
@@ -345,9 +346,10 @@ func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err),
}
- return
+ return nil
}
}
+ return nil
}
func (a *KeyInternalAPI) processSelfSignatures(
@@ -520,7 +522,7 @@ func (a *KeyInternalAPI) crossSigningKeysFromDatabase(
}
}
-func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) {
+func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) error {
for targetUserID, forTargetUser := range req.TargetIDs {
keyMap, err := a.DB.CrossSigningKeysForUser(ctx, targetUserID)
if err != nil && err != sql.ErrNoRows {
@@ -559,7 +561,7 @@ func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySign
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.DB.CrossSigningSigsForTarget: %s", err),
}
- return
+ return nil
}
for sourceUserID, forSourceUser := range sigMap {
@@ -581,4 +583,5 @@ func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySign
}
}
}
+ return nil
}
diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go
index acbcd5b8f..80efbec51 100644
--- a/keyserver/internal/device_list_update.go
+++ b/keyserver/internal/device_list_update.go
@@ -22,12 +22,13 @@ import (
"sync"
"time"
- fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
- "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
+
+ fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
+ "github.com/matrix-org/dendrite/keyserver/api"
)
var (
@@ -66,12 +67,14 @@ func init() {
// - We don't have unbounded growth in proportion to the number of servers (this is more important in a P2P world where
// we have many many servers)
// - We can adjust concurrency (at the cost of memory usage) by tuning N, to accommodate mobile devices vs servers.
+//
// The downsides are that:
// - Query requests can get queued behind other servers if they hash to the same worker, even if there are other free
// workers elsewhere. Whilst suboptimal, provided we cap how long a single request can last (e.g using context timeouts)
// we guarantee we will get around to it. Also, more users on a given server does not increase the number of requests
// (as /keys/query allows multiple users to be specified) so being stuck behind matrix.org won't materially be any worse
// than being stuck behind foo.bar
+//
// In the event that the query fails, a lock is acquired and the server name along with the time to wait before retrying is
// set in a map. A restarter goroutine periodically probes this map and injects servers which are ready to be retried.
type DeviceListUpdater struct {
@@ -116,7 +119,7 @@ type DeviceListUpdaterDatabase interface {
}
type DeviceListUpdaterAPI interface {
- PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse)
+ PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error
}
// KeyChangeProducer is the interface for producers.KeyChange useful for testing.
@@ -418,7 +421,7 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam
uploadReq.SelfSigningKey = *res.SelfSigningKey
}
}
- u.api.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes)
+ _ = u.api.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes)
}
err = u.updateDeviceList(&res)
if err != nil {
diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go
index 0033a5086..0520a9e66 100644
--- a/keyserver/internal/device_list_update_test.go
+++ b/keyserver/internal/device_list_update_test.go
@@ -18,7 +18,7 @@ import (
"context"
"crypto/ed25519"
"fmt"
- "io/ioutil"
+ "io"
"net/http"
"net/url"
"reflect"
@@ -27,8 +27,9 @@ import (
"testing"
"time"
- "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/gomatrixserverlib"
+
+ "github.com/matrix-org/dendrite/keyserver/api"
)
var (
@@ -112,8 +113,8 @@ func (d *mockDeviceListUpdaterDatabase) DeviceKeysJSON(ctx context.Context, keys
type mockDeviceListUpdaterAPI struct {
}
-func (d *mockDeviceListUpdaterAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) {
-
+func (d *mockDeviceListUpdaterAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error {
+ return nil
}
type roundTripper struct {
@@ -202,7 +203,7 @@ func TestUpdateNoPrevID(t *testing.T) {
}
return &http.Response{
StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`
+ Body: io.NopCloser(strings.NewReader(`
{
"user_id": "` + remoteUserID + `",
"stream_id": 5,
@@ -317,7 +318,7 @@ func TestDebounce(t *testing.T) {
// now send the response over federation
fedCh <- &http.Response{
StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`
+ Body: io.NopCloser(strings.NewReader(`
{
"user_id": "` + userID + `",
"stream_id": 5,
diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go
index c146b2aa0..41b4d44a4 100644
--- a/keyserver/internal/internal.go
+++ b/keyserver/internal/internal.go
@@ -18,6 +18,7 @@ import (
"bytes"
"context"
"encoding/json"
+ "errors"
"fmt"
"sync"
"time"
@@ -47,18 +48,20 @@ func (a *KeyInternalAPI) SetUserAPI(i userapi.KeyserverUserAPI) {
a.UserAPI = i
}
-func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) {
+func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) error {
userIDs, latest, err := a.DB.KeyChanges(ctx, req.Offset, req.ToOffset)
if err != nil {
res.Error = &api.KeyError{
Err: err.Error(),
}
+ return nil
}
res.Offset = latest
res.UserIDs = userIDs
+ return nil
}
-func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
+func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) error {
res.KeyErrors = make(map[string]map[string]*api.KeyError)
if len(req.DeviceKeys) > 0 {
a.uploadLocalDeviceKeys(ctx, req, res)
@@ -66,9 +69,10 @@ func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perform
if len(req.OneTimeKeys) > 0 {
a.uploadOneTimeKeys(ctx, req, res)
}
+ return nil
}
-func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) {
+func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) error {
res.OneTimeKeys = make(map[string]map[string]map[string]json.RawMessage)
res.Failures = make(map[string]interface{})
// wrap request map in a top-level by-domain map
@@ -112,6 +116,7 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC
if len(domainToDeviceKeys) > 0 {
a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys)
}
+ return nil
}
func (a *KeyInternalAPI) claimRemoteKeys(
@@ -171,32 +176,34 @@ func (a *KeyInternalAPI) claimRemoteKeys(
util.GetLogger(ctx).WithField("num_keys", keysClaimed).Info("Claimed remote keys")
}
-func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) {
+func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error {
if err := a.DB.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("Failed to delete device keys: %s", err),
}
}
+ return nil
}
-func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) {
+func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) error {
count, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("Failed to query OTK counts: %s", err),
}
- return
+ return nil
}
res.Count = *count
+ return nil
}
-func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) {
+func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error {
msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, false)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to query DB for device keys: %s", err),
}
- return
+ return nil
}
maxStreamID := int64(0)
for _, m := range msgs {
@@ -214,10 +221,11 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query
}
res.Devices = result
res.StreamID = maxStreamID
+ return nil
}
// nolint:gocyclo
-func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) {
+func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error {
res.DeviceKeys = make(map[string]map[string]json.RawMessage)
res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
@@ -243,7 +251,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to query local device keys: %s", err),
}
- return
+ return nil
}
// pull out display names after we have the keys so we handle wildcards correctly
@@ -314,6 +322,11 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
for targetKeyID := range masterKey.Keys {
sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID)
if err != nil {
+ // Stop executing the function if the context was canceled/the deadline was exceeded,
+ // as we can't continue without a valid context.
+ if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
+ return nil
+ }
logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed")
continue
}
@@ -335,6 +348,11 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
for targetKeyID, key := range forUserID {
sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, gomatrixserverlib.KeyID(targetKeyID))
if err != nil {
+ // Stop executing the function if the context was canceled/the deadline was exceeded,
+ // as we can't continue without a valid context.
+ if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
+ return nil
+ }
logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed")
continue
}
@@ -361,6 +379,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
}
}
}
+ return nil
}
func (a *KeyInternalAPI) remoteKeysFromDatabase(
diff --git a/keyserver/inthttp/client.go b/keyserver/inthttp/client.go
index dac61d1ea..7a7131145 100644
--- a/keyserver/inthttp/client.go
+++ b/keyserver/inthttp/client.go
@@ -22,7 +22,6 @@ import (
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/keyserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api"
- "github.com/opentracing/opentracing-go"
)
// HTTP paths for the internal HTTP APIs
@@ -68,168 +67,108 @@ func (h *httpKeyInternalAPI) PerformClaimKeys(
ctx context.Context,
request *api.PerformClaimKeysRequest,
response *api.PerformClaimKeysResponse,
-) {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformClaimKeys")
- defer span.Finish()
-
- apiURL := h.apiURL + PerformClaimKeysPath
- err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
- if err != nil {
- response.Error = &api.KeyError{
- Err: err.Error(),
- }
- }
+) error {
+ return httputil.CallInternalRPCAPI(
+ "PerformClaimKeys", h.apiURL+PerformClaimKeysPath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpKeyInternalAPI) PerformDeleteKeys(
ctx context.Context,
request *api.PerformDeleteKeysRequest,
response *api.PerformDeleteKeysResponse,
-) {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformClaimKeys")
- defer span.Finish()
-
- apiURL := h.apiURL + PerformClaimKeysPath
- err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
- if err != nil {
- response.Error = &api.KeyError{
- Err: err.Error(),
- }
- }
+) error {
+ return httputil.CallInternalRPCAPI(
+ "PerformDeleteKeys", h.apiURL+PerformDeleteKeysPath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpKeyInternalAPI) PerformUploadKeys(
ctx context.Context,
request *api.PerformUploadKeysRequest,
response *api.PerformUploadKeysResponse,
-) {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformUploadKeys")
- defer span.Finish()
-
- apiURL := h.apiURL + PerformUploadKeysPath
- err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
- if err != nil {
- response.Error = &api.KeyError{
- Err: err.Error(),
- }
- }
+) error {
+ return httputil.CallInternalRPCAPI(
+ "PerformUploadKeys", h.apiURL+PerformUploadKeysPath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpKeyInternalAPI) QueryKeys(
ctx context.Context,
request *api.QueryKeysRequest,
response *api.QueryKeysResponse,
-) {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeys")
- defer span.Finish()
-
- apiURL := h.apiURL + QueryKeysPath
- err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
- if err != nil {
- response.Error = &api.KeyError{
- Err: err.Error(),
- }
- }
+) error {
+ return httputil.CallInternalRPCAPI(
+ "QueryKeys", h.apiURL+QueryKeysPath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpKeyInternalAPI) QueryOneTimeKeys(
ctx context.Context,
request *api.QueryOneTimeKeysRequest,
response *api.QueryOneTimeKeysResponse,
-) {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryOneTimeKeys")
- defer span.Finish()
-
- apiURL := h.apiURL + QueryOneTimeKeysPath
- err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
- if err != nil {
- response.Error = &api.KeyError{
- Err: err.Error(),
- }
- }
+) error {
+ return httputil.CallInternalRPCAPI(
+ "QueryOneTimeKeys", h.apiURL+QueryOneTimeKeysPath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpKeyInternalAPI) QueryDeviceMessages(
ctx context.Context,
request *api.QueryDeviceMessagesRequest,
response *api.QueryDeviceMessagesResponse,
-) {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryDeviceMessages")
- defer span.Finish()
-
- apiURL := h.apiURL + QueryDeviceMessagesPath
- err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
- if err != nil {
- response.Error = &api.KeyError{
- Err: err.Error(),
- }
- }
+) error {
+ return httputil.CallInternalRPCAPI(
+ "QueryDeviceMessages", h.apiURL+QueryDeviceMessagesPath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpKeyInternalAPI) QueryKeyChanges(
ctx context.Context,
request *api.QueryKeyChangesRequest,
response *api.QueryKeyChangesResponse,
-) {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyChanges")
- defer span.Finish()
-
- apiURL := h.apiURL + QueryKeyChangesPath
- err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
- if err != nil {
- response.Error = &api.KeyError{
- Err: err.Error(),
- }
- }
+) error {
+ return httputil.CallInternalRPCAPI(
+ "QueryKeyChanges", h.apiURL+QueryKeyChangesPath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpKeyInternalAPI) PerformUploadDeviceKeys(
ctx context.Context,
request *api.PerformUploadDeviceKeysRequest,
response *api.PerformUploadDeviceKeysResponse,
-) {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformUploadDeviceKeys")
- defer span.Finish()
-
- apiURL := h.apiURL + PerformUploadDeviceKeysPath
- err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
- if err != nil {
- response.Error = &api.KeyError{
- Err: err.Error(),
- }
- }
+) error {
+ return httputil.CallInternalRPCAPI(
+ "PerformUploadDeviceKeys", h.apiURL+PerformUploadDeviceKeysPath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpKeyInternalAPI) PerformUploadDeviceSignatures(
ctx context.Context,
request *api.PerformUploadDeviceSignaturesRequest,
response *api.PerformUploadDeviceSignaturesResponse,
-) {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformUploadDeviceSignatures")
- defer span.Finish()
-
- apiURL := h.apiURL + PerformUploadDeviceSignaturesPath
- err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
- if err != nil {
- response.Error = &api.KeyError{
- Err: err.Error(),
- }
- }
+) error {
+ return httputil.CallInternalRPCAPI(
+ "PerformUploadDeviceSignatures", h.apiURL+PerformUploadDeviceSignaturesPath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpKeyInternalAPI) QuerySignatures(
ctx context.Context,
request *api.QuerySignaturesRequest,
response *api.QuerySignaturesResponse,
-) {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySignatures")
- defer span.Finish()
-
- apiURL := h.apiURL + QuerySignaturesPath
- err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
- if err != nil {
- response.Error = &api.KeyError{
- Err: err.Error(),
- }
- }
+) error {
+ return httputil.CallInternalRPCAPI(
+ "QuerySignatures", h.apiURL+QuerySignaturesPath,
+ h.httpClient, ctx, request, response,
+ )
}
diff --git a/keyserver/inthttp/server.go b/keyserver/inthttp/server.go
index 5bf5976a8..4e5f9fba4 100644
--- a/keyserver/inthttp/server.go
+++ b/keyserver/inthttp/server.go
@@ -15,124 +15,59 @@
package inthttp
import (
- "encoding/json"
- "net/http"
-
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/keyserver/api"
- "github.com/matrix-org/util"
)
func AddRoutes(internalAPIMux *mux.Router, s api.KeyInternalAPI) {
- internalAPIMux.Handle(PerformClaimKeysPath,
- httputil.MakeInternalAPI("performClaimKeys", func(req *http.Request) util.JSONResponse {
- request := api.PerformClaimKeysRequest{}
- response := api.PerformClaimKeysResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- s.PerformClaimKeys(req.Context(), &request, &response)
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+ internalAPIMux.Handle(
+ PerformClaimKeysPath,
+ httputil.MakeInternalRPCAPI("KeyserverPerformClaimKeys", s.PerformClaimKeys),
)
- internalAPIMux.Handle(PerformDeleteKeysPath,
- httputil.MakeInternalAPI("performDeleteKeys", func(req *http.Request) util.JSONResponse {
- request := api.PerformDeleteKeysRequest{}
- response := api.PerformDeleteKeysResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- s.PerformDeleteKeys(req.Context(), &request, &response)
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+
+ internalAPIMux.Handle(
+ PerformDeleteKeysPath,
+ httputil.MakeInternalRPCAPI("KeyserverPerformDeleteKeys", s.PerformDeleteKeys),
)
- internalAPIMux.Handle(PerformUploadKeysPath,
- httputil.MakeInternalAPI("performUploadKeys", func(req *http.Request) util.JSONResponse {
- request := api.PerformUploadKeysRequest{}
- response := api.PerformUploadKeysResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- s.PerformUploadKeys(req.Context(), &request, &response)
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+
+ internalAPIMux.Handle(
+ PerformUploadKeysPath,
+ httputil.MakeInternalRPCAPI("KeyserverPerformUploadKeys", s.PerformUploadKeys),
)
- internalAPIMux.Handle(PerformUploadDeviceKeysPath,
- httputil.MakeInternalAPI("performUploadDeviceKeys", func(req *http.Request) util.JSONResponse {
- request := api.PerformUploadDeviceKeysRequest{}
- response := api.PerformUploadDeviceKeysResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- s.PerformUploadDeviceKeys(req.Context(), &request, &response)
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+
+ internalAPIMux.Handle(
+ PerformUploadDeviceKeysPath,
+ httputil.MakeInternalRPCAPI("KeyserverPerformUploadDeviceKeys", s.PerformUploadDeviceKeys),
)
- internalAPIMux.Handle(PerformUploadDeviceSignaturesPath,
- httputil.MakeInternalAPI("performUploadDeviceSignatures", func(req *http.Request) util.JSONResponse {
- request := api.PerformUploadDeviceSignaturesRequest{}
- response := api.PerformUploadDeviceSignaturesResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- s.PerformUploadDeviceSignatures(req.Context(), &request, &response)
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+
+ internalAPIMux.Handle(
+ PerformUploadDeviceSignaturesPath,
+ httputil.MakeInternalRPCAPI("KeyserverPerformUploadDeviceSignatures", s.PerformUploadDeviceSignatures),
)
- internalAPIMux.Handle(QueryKeysPath,
- httputil.MakeInternalAPI("queryKeys", func(req *http.Request) util.JSONResponse {
- request := api.QueryKeysRequest{}
- response := api.QueryKeysResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- s.QueryKeys(req.Context(), &request, &response)
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+
+ internalAPIMux.Handle(
+ QueryKeysPath,
+ httputil.MakeInternalRPCAPI("KeyserverQueryKeys", s.QueryKeys),
)
- internalAPIMux.Handle(QueryOneTimeKeysPath,
- httputil.MakeInternalAPI("queryOneTimeKeys", func(req *http.Request) util.JSONResponse {
- request := api.QueryOneTimeKeysRequest{}
- response := api.QueryOneTimeKeysResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- s.QueryOneTimeKeys(req.Context(), &request, &response)
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+
+ internalAPIMux.Handle(
+ QueryOneTimeKeysPath,
+ httputil.MakeInternalRPCAPI("KeyserverQueryOneTimeKeys", s.QueryOneTimeKeys),
)
- internalAPIMux.Handle(QueryDeviceMessagesPath,
- httputil.MakeInternalAPI("queryDeviceMessages", func(req *http.Request) util.JSONResponse {
- request := api.QueryDeviceMessagesRequest{}
- response := api.QueryDeviceMessagesResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- s.QueryDeviceMessages(req.Context(), &request, &response)
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+
+ internalAPIMux.Handle(
+ QueryDeviceMessagesPath,
+ httputil.MakeInternalRPCAPI("KeyserverQueryDeviceMessages", s.QueryDeviceMessages),
)
- internalAPIMux.Handle(QueryKeyChangesPath,
- httputil.MakeInternalAPI("queryKeyChanges", func(req *http.Request) util.JSONResponse {
- request := api.QueryKeyChangesRequest{}
- response := api.QueryKeyChangesResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- s.QueryKeyChanges(req.Context(), &request, &response)
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+
+ internalAPIMux.Handle(
+ QueryKeyChangesPath,
+ httputil.MakeInternalRPCAPI("KeyserverQueryKeyChanges", s.QueryKeyChanges),
)
- internalAPIMux.Handle(QuerySignaturesPath,
- httputil.MakeInternalAPI("querySignatures", func(req *http.Request) util.JSONResponse {
- request := api.QuerySignaturesRequest{}
- response := api.QuerySignaturesResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- s.QuerySignatures(req.Context(), &request, &response)
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+
+ internalAPIMux.Handle(
+ QuerySignaturesPath,
+ httputil.MakeInternalRPCAPI("KeyserverQuerySignatures", s.QuerySignatures),
)
}
diff --git a/keyserver/storage/postgres/cross_signing_sigs_table.go b/keyserver/storage/postgres/cross_signing_sigs_table.go
index b101e7ce5..8b2a865b9 100644
--- a/keyserver/storage/postgres/cross_signing_sigs_table.go
+++ b/keyserver/storage/postgres/cross_signing_sigs_table.go
@@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/keyserver/storage/postgres/deltas"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/dendrite/keyserver/types"
"github.com/matrix-org/gomatrixserverlib"
@@ -66,6 +67,16 @@ func NewPostgresCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, erro
if err != nil {
return nil, err
}
+
+ m := sqlutil.NewMigrator(db)
+ m.AddMigrations(sqlutil.Migration{
+ Version: "keyserver: cross signing signature indexes",
+ Up: deltas.UpFixCrossSigningSignatureIndexes,
+ })
+ if err = m.Up(context.Background()); err != nil {
+ return nil, err
+ }
+
return s, sqlutil.StatementList{
{&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL},
{&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL},
diff --git a/keyserver/storage/postgres/deltas/2022012016470000_key_changes.go b/keyserver/storage/postgres/deltas/2022012016470000_key_changes.go
index e5bcf08d1..0cfe9e791 100644
--- a/keyserver/storage/postgres/deltas/2022012016470000_key_changes.go
+++ b/keyserver/storage/postgres/deltas/2022012016470000_key_changes.go
@@ -15,37 +15,27 @@
package deltas
import (
+ "context"
"database/sql"
"fmt"
-
- "github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/pressly/goose"
)
-func LoadFromGoose() {
- goose.AddMigration(UpRefactorKeyChanges, DownRefactorKeyChanges)
-}
-
-func LoadRefactorKeyChanges(m *sqlutil.Migrations) {
- m.AddMigration(UpRefactorKeyChanges, DownRefactorKeyChanges)
-}
-
-func UpRefactorKeyChanges(tx *sql.Tx) error {
+func UpRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
// start counting from the last max offset, else 0. We need to do a count(*) first to see if there
// even are entries in this table to know if we can query for log_offset. Without the count then
// the query to SELECT the max log offset fails on new Dendrite instances as log_offset doesn't
// exist on that table. Even though we discard the error, the txn is tainted and gets aborted :/
var count int
- _ = tx.QueryRow(`SELECT count(*) FROM keyserver_key_changes`).Scan(&count)
+ _ = tx.QueryRowContext(ctx, `SELECT count(*) FROM keyserver_key_changes`).Scan(&count)
if count > 0 {
var maxOffset int64
- _ = tx.QueryRow(`SELECT coalesce(MAX(log_offset), 0) AS offset FROM keyserver_key_changes`).Scan(&maxOffset)
- if _, err := tx.Exec(fmt.Sprintf(`CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq START %d`, maxOffset)); err != nil {
+ _ = tx.QueryRowContext(ctx, `SELECT coalesce(MAX(log_offset), 0) AS offset FROM keyserver_key_changes`).Scan(&maxOffset)
+ if _, err := tx.ExecContext(ctx, fmt.Sprintf(`CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq START %d`, maxOffset)); err != nil {
return fmt.Errorf("failed to CREATE SEQUENCE for key changes, starting at %d: %s", maxOffset, err)
}
}
- _, err := tx.Exec(`
+ _, err := tx.ExecContext(ctx, `
-- make the new table
DROP TABLE IF EXISTS keyserver_key_changes;
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
@@ -60,8 +50,8 @@ func UpRefactorKeyChanges(tx *sql.Tx) error {
return nil
}
-func DownRefactorKeyChanges(tx *sql.Tx) error {
- _, err := tx.Exec(`
+func DownRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
-- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers
DROP SEQUENCE IF EXISTS keyserver_key_changes_seq;
DROP TABLE IF EXISTS keyserver_key_changes;
diff --git a/keyserver/storage/postgres/deltas/2022042612000000_xsigning_idx.go b/keyserver/storage/postgres/deltas/2022042612000000_xsigning_idx.go
index 12956e3b4..1a3d4fee9 100644
--- a/keyserver/storage/postgres/deltas/2022042612000000_xsigning_idx.go
+++ b/keyserver/storage/postgres/deltas/2022042612000000_xsigning_idx.go
@@ -15,18 +15,13 @@
package deltas
import (
+ "context"
"database/sql"
"fmt"
-
- "github.com/matrix-org/dendrite/internal/sqlutil"
)
-func LoadFixCrossSigningSignatureIndexes(m *sqlutil.Migrations) {
- m.AddMigration(UpFixCrossSigningSignatureIndexes, DownFixCrossSigningSignatureIndexes)
-}
-
-func UpFixCrossSigningSignatureIndexes(tx *sql.Tx) error {
- _, err := tx.Exec(`
+func UpFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey;
ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id);
@@ -38,8 +33,8 @@ func UpFixCrossSigningSignatureIndexes(tx *sql.Tx) error {
return nil
}
-func DownFixCrossSigningSignatureIndexes(tx *sql.Tx) error {
- _, err := tx.Exec(`
+func DownFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey;
ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, target_user_id, target_key_id);
diff --git a/keyserver/storage/postgres/key_changes_table.go b/keyserver/storage/postgres/key_changes_table.go
index f93a94bd3..004f15d82 100644
--- a/keyserver/storage/postgres/key_changes_table.go
+++ b/keyserver/storage/postgres/key_changes_table.go
@@ -18,7 +18,11 @@ import (
"context"
"database/sql"
+ "github.com/lib/pq"
+
"github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/keyserver/storage/postgres/deltas"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
)
@@ -55,7 +59,34 @@ func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
db: db,
}
_, err := db.Exec(keyChangesSchema)
- return s, err
+ if err != nil {
+ return s, err
+ }
+
+ // TODO: Remove when we are sure we are not having goose artefacts in the db
+ // This forces an error, which indicates the migration is already applied, since the
+ // column partition was removed from the table
+ var count int
+ err = db.QueryRow("SELECT partition FROM keyserver_key_changes LIMIT 1;").Scan(&count)
+ if err == nil {
+ m := sqlutil.NewMigrator(db)
+ m.AddMigrations(sqlutil.Migration{
+ Version: "keyserver: refactor key changes",
+ Up: deltas.UpRefactorKeyChanges,
+ })
+ return s, m.Up(context.Background())
+ } else {
+ switch e := err.(type) {
+ case *pq.Error:
+ // ignore undefined_column (42703) errors, as this is expected at this point
+ if e.Code != "42703" {
+ return nil, err
+ }
+ default:
+ return nil, err
+ }
+ }
+ return s, nil
}
func (s *keyChangesStatements) Prepare() (err error) {
diff --git a/keyserver/storage/postgres/storage.go b/keyserver/storage/postgres/storage.go
index b8f70acf8..35e630559 100644
--- a/keyserver/storage/postgres/storage.go
+++ b/keyserver/storage/postgres/storage.go
@@ -16,7 +16,6 @@ package postgres
import (
"github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/keyserver/storage/postgres/deltas"
"github.com/matrix-org/dendrite/keyserver/storage/shared"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
@@ -53,12 +52,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions)
if err != nil {
return nil, err
}
- m := sqlutil.NewMigrations()
- deltas.LoadRefactorKeyChanges(m)
- deltas.LoadFixCrossSigningSignatureIndexes(m)
- if err = m.RunDeltas(db, dbProperties); err != nil {
- return nil, err
- }
if err = kc.Prepare(); err != nil {
return nil, err
}
diff --git a/keyserver/storage/sqlite3/cross_signing_sigs_table.go b/keyserver/storage/sqlite3/cross_signing_sigs_table.go
index 36d562b8a..ea431151e 100644
--- a/keyserver/storage/sqlite3/cross_signing_sigs_table.go
+++ b/keyserver/storage/sqlite3/cross_signing_sigs_table.go
@@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/keyserver/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/dendrite/keyserver/types"
"github.com/matrix-org/gomatrixserverlib"
@@ -65,6 +66,15 @@ func NewSqliteCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error)
if err != nil {
return nil, err
}
+ m := sqlutil.NewMigrator(db)
+ m.AddMigrations(sqlutil.Migration{
+ Version: "keyserver: cross signing signature indexes",
+ Up: deltas.UpFixCrossSigningSignatureIndexes,
+ })
+ if err = m.Up(context.Background()); err != nil {
+ return nil, err
+ }
+
return s, sqlutil.StatementList{
{&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL},
{&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL},
diff --git a/keyserver/storage/sqlite3/deltas/2022012016470000_key_changes.go b/keyserver/storage/sqlite3/deltas/2022012016470000_key_changes.go
index fbc548c38..cd0f19df9 100644
--- a/keyserver/storage/sqlite3/deltas/2022012016470000_key_changes.go
+++ b/keyserver/storage/sqlite3/deltas/2022012016470000_key_changes.go
@@ -15,28 +15,18 @@
package deltas
import (
+ "context"
"database/sql"
"fmt"
-
- "github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/pressly/goose"
)
-func LoadFromGoose() {
- goose.AddMigration(UpRefactorKeyChanges, DownRefactorKeyChanges)
-}
-
-func LoadRefactorKeyChanges(m *sqlutil.Migrations) {
- m.AddMigration(UpRefactorKeyChanges, DownRefactorKeyChanges)
-}
-
-func UpRefactorKeyChanges(tx *sql.Tx) error {
+func UpRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
// start counting from the last max offset, else 0.
var maxOffset int64
var userID string
- _ = tx.QueryRow(`SELECT user_id, MAX(log_offset) FROM keyserver_key_changes GROUP BY user_id`).Scan(&userID, &maxOffset)
+ _ = tx.QueryRowContext(ctx, `SELECT user_id, MAX(log_offset) FROM keyserver_key_changes GROUP BY user_id`).Scan(&userID, &maxOffset)
- _, err := tx.Exec(`
+ _, err := tx.ExecContext(ctx, `
-- make the new table
DROP TABLE IF EXISTS keyserver_key_changes;
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
@@ -51,14 +41,14 @@ func UpRefactorKeyChanges(tx *sql.Tx) error {
}
// to start counting from maxOffset, insert a row with that value
if userID != "" {
- _, err = tx.Exec(`INSERT INTO keyserver_key_changes(change_id, user_id) VALUES($1, $2)`, maxOffset, userID)
+ _, err = tx.ExecContext(ctx, `INSERT INTO keyserver_key_changes(change_id, user_id) VALUES($1, $2)`, maxOffset, userID)
return err
}
return nil
}
-func DownRefactorKeyChanges(tx *sql.Tx) error {
- _, err := tx.Exec(`
+func DownRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
-- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers
DROP TABLE IF EXISTS keyserver_key_changes;
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
diff --git a/keyserver/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go b/keyserver/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go
index 230e39fef..d4e38dea5 100644
--- a/keyserver/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go
+++ b/keyserver/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go
@@ -15,18 +15,13 @@
package deltas
import (
+ "context"
"database/sql"
"fmt"
-
- "github.com/matrix-org/dendrite/internal/sqlutil"
)
-func LoadFixCrossSigningSignatureIndexes(m *sqlutil.Migrations) {
- m.AddMigration(UpFixCrossSigningSignatureIndexes, DownFixCrossSigningSignatureIndexes)
-}
-
-func UpFixCrossSigningSignatureIndexes(tx *sql.Tx) error {
- _, err := tx.Exec(`
+func UpFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp (
origin_user_id TEXT NOT NULL,
origin_key_id TEXT NOT NULL,
@@ -50,8 +45,8 @@ func UpFixCrossSigningSignatureIndexes(tx *sql.Tx) error {
return nil
}
-func DownFixCrossSigningSignatureIndexes(tx *sql.Tx) error {
- _, err := tx.Exec(`
+func DownFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp (
origin_user_id TEXT NOT NULL,
origin_key_id TEXT NOT NULL,
diff --git a/keyserver/storage/sqlite3/key_changes_table.go b/keyserver/storage/sqlite3/key_changes_table.go
index e035e8c9c..217fa7a5d 100644
--- a/keyserver/storage/sqlite3/key_changes_table.go
+++ b/keyserver/storage/sqlite3/key_changes_table.go
@@ -19,6 +19,8 @@ import (
"database/sql"
"github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/keyserver/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
)
@@ -53,7 +55,24 @@ func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
db: db,
}
_, err := db.Exec(keyChangesSchema)
- return s, err
+ if err != nil {
+ return s, err
+ }
+ // TODO: Remove when we are sure we are not having goose artefacts in the db
+ // This forces an error, which indicates the migration is already applied, since the
+ // column partition was removed from the table
+ var count int
+ err = db.QueryRow("SELECT partition FROM keyserver_key_changes LIMIT 1;").Scan(&count)
+ if err == nil {
+ m := sqlutil.NewMigrator(db)
+ m.AddMigrations(sqlutil.Migration{
+ Version: "keyserver: refactor key changes",
+ Up: deltas.UpRefactorKeyChanges,
+ })
+ return s, m.Up(context.Background())
+ }
+
+ return s, nil
}
func (s *keyChangesStatements) Prepare() (err error) {
diff --git a/keyserver/storage/sqlite3/storage.go b/keyserver/storage/sqlite3/storage.go
index aeea9eac6..873fe3e24 100644
--- a/keyserver/storage/sqlite3/storage.go
+++ b/keyserver/storage/sqlite3/storage.go
@@ -17,7 +17,6 @@ package sqlite3
import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/shared"
- "github.com/matrix-org/dendrite/keyserver/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
)
@@ -52,12 +51,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions)
return nil, err
}
- m := sqlutil.NewMigrations()
- deltas.LoadRefactorKeyChanges(m)
- deltas.LoadFixCrossSigningSignatureIndexes(m)
- if err = m.RunDeltas(db, dbProperties); err != nil {
- return nil, err
- }
if err = kc.Prepare(); err != nil {
return nil, err
}
diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go
index 44cfb5f2a..e7a2af7c2 100644
--- a/keyserver/storage/storage_test.go
+++ b/keyserver/storage/storage_test.go
@@ -3,6 +3,7 @@ package storage_test
import (
"context"
"reflect"
+ "sync"
"testing"
"github.com/matrix-org/dendrite/keyserver/api"
@@ -103,6 +104,9 @@ func TestKeyChangesUpperLimit(t *testing.T) {
})
}
+var dbLock sync.Mutex
+var deviceArray = []string{"AAA", "another_device"}
+
// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user,
// and that they are returned correctly when querying for device keys.
func TestDeviceKeysStreamIDGeneration(t *testing.T) {
@@ -169,8 +173,11 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID)
}
+ dbLock.Lock()
+ defer dbLock.Unlock()
// Querying for device keys returns the latest stream IDs
- msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}, false)
+ msgs, err = db.DeviceKeysForUser(ctx, alice, deviceArray, false)
+
if err != nil {
t.Fatalf("DeviceKeysForUser returned error: %s", err)
}
diff --git a/mediaapi/fileutils/fileutils.go b/mediaapi/fileutils/fileutils.go
index 754e4644b..2e719dc82 100644
--- a/mediaapi/fileutils/fileutils.go
+++ b/mediaapi/fileutils/fileutils.go
@@ -21,7 +21,6 @@ import (
"encoding/base64"
"fmt"
"io"
- "io/ioutil"
"os"
"path/filepath"
"strings"
@@ -180,7 +179,7 @@ func createTempDir(baseDirectory config.Path) (types.Path, error) {
if err := os.MkdirAll(baseTmpDir, 0770); err != nil {
return "", fmt.Errorf("failed to create base temp dir: %w", err)
}
- tmpDir, err := ioutil.TempDir(baseTmpDir, "")
+ tmpDir, err := os.MkdirTemp(baseTmpDir, "")
if err != nil {
return "", fmt.Errorf("failed to create temp dir: %w", err)
}
diff --git a/mediaapi/routing/download.go b/mediaapi/routing/download.go
index 10b25a5cd..c9299b1fc 100644
--- a/mediaapi/routing/download.go
+++ b/mediaapi/routing/download.go
@@ -19,7 +19,6 @@ import (
"encoding/json"
"fmt"
"io"
- "io/ioutil"
"mime"
"net/http"
"net/url"
@@ -695,7 +694,7 @@ func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string,
// We successfully parsed the Content-Length, so we'll return a limited
// reader that restricts us to reading only up to this size.
- reader = ioutil.NopCloser(io.LimitReader(*body, parsedLength))
+ reader = io.NopCloser(io.LimitReader(*body, parsedLength))
contentLength = parsedLength
} else {
// Content-Length header is missing. If we have a maximum file size
@@ -704,7 +703,7 @@ func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string,
// ultimately it will get rewritten later when the temp file is written
// to disk.
if maxFileSizeBytes > 0 {
- reader = ioutil.NopCloser(io.LimitReader(*body, int64(maxFileSizeBytes)))
+ reader = io.NopCloser(io.LimitReader(*body, int64(maxFileSizeBytes)))
}
contentLength = 0
}
diff --git a/mediaapi/routing/routing.go b/mediaapi/routing/routing.go
index 196908184..9dcfa955f 100644
--- a/mediaapi/routing/routing.go
+++ b/mediaapi/routing/routing.go
@@ -149,6 +149,9 @@ func makeDownloadAPI(
}
}
+ // Cache media for at least one day.
+ w.Header().Set("Cache-Control", "public,max-age=86400,s-maxage=86400")
+
Download(
w,
req,
diff --git a/roomserver/api/api.go b/roomserver/api/api.go
index 38baa617f..baf63aa31 100644
--- a/roomserver/api/api.go
+++ b/roomserver/api/api.go
@@ -40,7 +40,7 @@ type InputRoomEventsAPI interface {
ctx context.Context,
req *InputRoomEventsRequest,
res *InputRoomEventsResponse,
- )
+ ) error
}
// Query the latest events and state for a room from the room server.
@@ -97,6 +97,14 @@ type SyncRoomserverAPI interface {
req *PerformBackfillRequest,
res *PerformBackfillResponse,
) error
+
+ // QueryMembershipAtEvent queries the memberships at the given events.
+ // Returns a map from eventID to a slice of gomatrixserverlib.HeaderedEvent.
+ QueryMembershipAtEvent(
+ ctx context.Context,
+ request *QueryMembershipAtEventRequest,
+ response *QueryMembershipAtEventResponse,
+ ) error
}
type AppserviceRoomserverAPI interface {
@@ -139,15 +147,15 @@ type ClientRoomserverAPI interface {
GetAliasesForRoomID(ctx context.Context, req *GetAliasesForRoomIDRequest, res *GetAliasesForRoomIDResponse) error
// PerformRoomUpgrade upgrades a room to a newer version
- PerformRoomUpgrade(ctx context.Context, req *PerformRoomUpgradeRequest, resp *PerformRoomUpgradeResponse)
- PerformAdminEvacuateRoom(ctx context.Context, req *PerformAdminEvacuateRoomRequest, res *PerformAdminEvacuateRoomResponse)
- PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse)
- PerformPeek(ctx context.Context, req *PerformPeekRequest, res *PerformPeekResponse)
- PerformUnpeek(ctx context.Context, req *PerformUnpeekRequest, res *PerformUnpeekResponse)
+ PerformRoomUpgrade(ctx context.Context, req *PerformRoomUpgradeRequest, resp *PerformRoomUpgradeResponse) error
+ PerformAdminEvacuateRoom(ctx context.Context, req *PerformAdminEvacuateRoomRequest, res *PerformAdminEvacuateRoomResponse) error
+ PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) error
+ PerformPeek(ctx context.Context, req *PerformPeekRequest, res *PerformPeekResponse) error
+ PerformUnpeek(ctx context.Context, req *PerformUnpeekRequest, res *PerformUnpeekResponse) error
PerformInvite(ctx context.Context, req *PerformInviteRequest, res *PerformInviteResponse) error
- PerformJoin(ctx context.Context, req *PerformJoinRequest, res *PerformJoinResponse)
+ PerformJoin(ctx context.Context, req *PerformJoinRequest, res *PerformJoinResponse) error
PerformLeave(ctx context.Context, req *PerformLeaveRequest, res *PerformLeaveResponse) error
- PerformPublish(ctx context.Context, req *PerformPublishRequest, res *PerformPublishResponse)
+ PerformPublish(ctx context.Context, req *PerformPublishRequest, res *PerformPublishResponse) error
// PerformForget forgets a rooms history for a specific user
PerformForget(ctx context.Context, req *PerformForgetRequest, resp *PerformForgetResponse) error
SetRoomAlias(ctx context.Context, req *SetRoomAliasRequest, res *SetRoomAliasResponse) error
@@ -158,7 +166,7 @@ type UserRoomserverAPI interface {
QueryLatestEventsAndStateAPI
QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
- PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse)
+ PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) error
}
type FederationRoomserverAPI interface {
diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go
index 211f320ff..8bef35379 100644
--- a/roomserver/api/api_trace.go
+++ b/roomserver/api/api_trace.go
@@ -35,9 +35,10 @@ func (t *RoomserverInternalAPITrace) InputRoomEvents(
ctx context.Context,
req *InputRoomEventsRequest,
res *InputRoomEventsResponse,
-) {
- t.Impl.InputRoomEvents(ctx, req, res)
- util.GetLogger(ctx).Infof("InputRoomEvents req=%+v res=%+v", js(req), js(res))
+) error {
+ err := t.Impl.InputRoomEvents(ctx, req, res)
+ util.GetLogger(ctx).WithError(err).Infof("InputRoomEvents req=%+v res=%+v", js(req), js(res))
+ return err
}
func (t *RoomserverInternalAPITrace) PerformInvite(
@@ -45,44 +46,49 @@ func (t *RoomserverInternalAPITrace) PerformInvite(
req *PerformInviteRequest,
res *PerformInviteResponse,
) error {
- util.GetLogger(ctx).Infof("PerformInvite req=%+v res=%+v", js(req), js(res))
- return t.Impl.PerformInvite(ctx, req, res)
+ err := t.Impl.PerformInvite(ctx, req, res)
+ util.GetLogger(ctx).WithError(err).Infof("PerformInvite req=%+v res=%+v", js(req), js(res))
+ return err
}
func (t *RoomserverInternalAPITrace) PerformPeek(
ctx context.Context,
req *PerformPeekRequest,
res *PerformPeekResponse,
-) {
- t.Impl.PerformPeek(ctx, req, res)
- util.GetLogger(ctx).Infof("PerformPeek req=%+v res=%+v", js(req), js(res))
+) error {
+ err := t.Impl.PerformPeek(ctx, req, res)
+ util.GetLogger(ctx).WithError(err).Infof("PerformPeek req=%+v res=%+v", js(req), js(res))
+ return err
}
func (t *RoomserverInternalAPITrace) PerformUnpeek(
ctx context.Context,
req *PerformUnpeekRequest,
res *PerformUnpeekResponse,
-) {
- t.Impl.PerformUnpeek(ctx, req, res)
- util.GetLogger(ctx).Infof("PerformUnpeek req=%+v res=%+v", js(req), js(res))
+) error {
+ err := t.Impl.PerformUnpeek(ctx, req, res)
+ util.GetLogger(ctx).WithError(err).Infof("PerformUnpeek req=%+v res=%+v", js(req), js(res))
+ return err
}
func (t *RoomserverInternalAPITrace) PerformRoomUpgrade(
ctx context.Context,
req *PerformRoomUpgradeRequest,
res *PerformRoomUpgradeResponse,
-) {
- t.Impl.PerformRoomUpgrade(ctx, req, res)
- util.GetLogger(ctx).Infof("PerformRoomUpgrade req=%+v res=%+v", js(req), js(res))
+) error {
+ err := t.Impl.PerformRoomUpgrade(ctx, req, res)
+ util.GetLogger(ctx).WithError(err).Infof("PerformRoomUpgrade req=%+v res=%+v", js(req), js(res))
+ return err
}
func (t *RoomserverInternalAPITrace) PerformJoin(
ctx context.Context,
req *PerformJoinRequest,
res *PerformJoinResponse,
-) {
- t.Impl.PerformJoin(ctx, req, res)
- util.GetLogger(ctx).Infof("PerformJoin req=%+v res=%+v", js(req), js(res))
+) error {
+ err := t.Impl.PerformJoin(ctx, req, res)
+ util.GetLogger(ctx).WithError(err).Infof("PerformJoin req=%+v res=%+v", js(req), js(res))
+ return err
}
func (t *RoomserverInternalAPITrace) PerformLeave(
@@ -99,27 +105,30 @@ func (t *RoomserverInternalAPITrace) PerformPublish(
ctx context.Context,
req *PerformPublishRequest,
res *PerformPublishResponse,
-) {
- t.Impl.PerformPublish(ctx, req, res)
- util.GetLogger(ctx).Infof("PerformPublish req=%+v res=%+v", js(req), js(res))
+) error {
+ err := t.Impl.PerformPublish(ctx, req, res)
+ util.GetLogger(ctx).WithError(err).Infof("PerformPublish req=%+v res=%+v", js(req), js(res))
+ return err
}
func (t *RoomserverInternalAPITrace) PerformAdminEvacuateRoom(
ctx context.Context,
req *PerformAdminEvacuateRoomRequest,
res *PerformAdminEvacuateRoomResponse,
-) {
- t.Impl.PerformAdminEvacuateRoom(ctx, req, res)
- util.GetLogger(ctx).Infof("PerformAdminEvacuateRoom req=%+v res=%+v", js(req), js(res))
+) error {
+ err := t.Impl.PerformAdminEvacuateRoom(ctx, req, res)
+ util.GetLogger(ctx).WithError(err).Infof("PerformAdminEvacuateRoom req=%+v res=%+v", js(req), js(res))
+ return err
}
func (t *RoomserverInternalAPITrace) PerformAdminEvacuateUser(
ctx context.Context,
req *PerformAdminEvacuateUserRequest,
res *PerformAdminEvacuateUserResponse,
-) {
- t.Impl.PerformAdminEvacuateUser(ctx, req, res)
- util.GetLogger(ctx).Infof("PerformAdminEvacuateUser req=%+v res=%+v", js(req), js(res))
+) error {
+ err := t.Impl.PerformAdminEvacuateUser(ctx, req, res)
+ util.GetLogger(ctx).WithError(err).Infof("PerformAdminEvacuateUser req=%+v res=%+v", js(req), js(res))
+ return err
}
func (t *RoomserverInternalAPITrace) PerformInboundPeek(
@@ -128,7 +137,7 @@ func (t *RoomserverInternalAPITrace) PerformInboundPeek(
res *PerformInboundPeekResponse,
) error {
err := t.Impl.PerformInboundPeek(ctx, req, res)
- util.GetLogger(ctx).Infof("PerformInboundPeek req=%+v res=%+v", js(req), js(res))
+ util.GetLogger(ctx).WithError(err).Infof("PerformInboundPeek req=%+v res=%+v", js(req), js(res))
return err
}
@@ -373,6 +382,16 @@ func (t *RoomserverInternalAPITrace) QueryRestrictedJoinAllowed(
return err
}
+func (t *RoomserverInternalAPITrace) QueryMembershipAtEvent(
+ ctx context.Context,
+ request *QueryMembershipAtEventRequest,
+ response *QueryMembershipAtEventResponse,
+) error {
+ err := t.Impl.QueryMembershipAtEvent(ctx, request, response)
+ util.GetLogger(ctx).WithError(err).Infof("QueryMembershipAtEvent req=%+v res=%+v", js(request), js(response))
+ return err
+}
+
func js(thing interface{}) string {
b, err := json.Marshal(thing)
if err != nil {
diff --git a/roomserver/api/query.go b/roomserver/api/query.go
index f157a9025..c8e6f9dc6 100644
--- a/roomserver/api/query.go
+++ b/roomserver/api/query.go
@@ -427,3 +427,17 @@ func (r *QueryCurrentStateResponse) UnmarshalJSON(data []byte) error {
}
return nil
}
+
+// QueryMembershipAtEventRequest requests the membership events for a user
+// for a list of eventIDs.
+type QueryMembershipAtEventRequest struct {
+ RoomID string
+ EventIDs []string
+ UserID string
+}
+
+// QueryMembershipAtEventResponse is the response to QueryMembershipAtEventRequest.
+type QueryMembershipAtEventResponse struct {
+ // Memberships is a map from eventID to a list of events (if any).
+ Memberships map[string][]*gomatrixserverlib.HeaderedEvent `json:"memberships"`
+}
diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go
index 344e9b079..bc2f28176 100644
--- a/roomserver/api/wrapper.go
+++ b/roomserver/api/wrapper.go
@@ -90,7 +90,9 @@ func SendInputRoomEvents(
Asynchronous: async,
}
var response InputRoomEventsResponse
- rsAPI.InputRoomEvents(ctx, &request, &response)
+ if err := rsAPI.InputRoomEvents(ctx, &request, &response); err != nil {
+ return err
+ }
return response.Err()
}
diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go
index 0229f822f..648c50cf6 100644
--- a/roomserver/internal/helpers/auth.go
+++ b/roomserver/internal/helpers/auth.go
@@ -50,14 +50,14 @@ func CheckForSoftFail(
if err != nil {
return false, fmt.Errorf("db.RoomNID: %w", err)
}
- if roomInfo == nil || roomInfo.IsStub {
+ if roomInfo == nil || roomInfo.IsStub() {
return false, nil
}
// Then get the state entries for the current state snapshot.
// We'll use this to check if the event is allowed right now.
roomState := state.NewStateResolution(db, roomInfo)
- authStateEntries, err = roomState.LoadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID)
+ authStateEntries, err = roomState.LoadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID())
if err != nil {
return true, fmt.Errorf("roomState.LoadStateAtSnapshot: %w", err)
}
diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go
index e67bbfcaa..6091f8ec2 100644
--- a/roomserver/internal/helpers/helpers.go
+++ b/roomserver/internal/helpers/helpers.go
@@ -12,6 +12,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
+ "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
@@ -21,14 +22,14 @@ import (
// Move these to a more sensible place.
func UpdateToInviteMembership(
- mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent,
+ mu *shared.MembershipUpdater, add *types.Event, updates []api.OutputEvent,
roomVersion gomatrixserverlib.RoomVersion,
) ([]api.OutputEvent, error) {
// We may have already sent the invite to the user, either because we are
// reprocessing this event, or because the we received this invite from a
// remote server via the federation invite API. In those cases we don't need
// to send the event.
- needsSending, err := mu.SetToInvite(add)
+ needsSending, retired, err := mu.Update(tables.MembershipStateInvite, add)
if err != nil {
return nil, err
}
@@ -38,13 +39,23 @@ func UpdateToInviteMembership(
// room event stream. This ensures that the consumers only have to
// consider a single stream of events when determining whether a user
// is invited, rather than having to combine multiple streams themselves.
- onie := api.OutputNewInviteEvent{
- Event: add.Headered(roomVersion),
- RoomVersion: roomVersion,
- }
updates = append(updates, api.OutputEvent{
- Type: api.OutputTypeNewInviteEvent,
- NewInviteEvent: &onie,
+ Type: api.OutputTypeNewInviteEvent,
+ NewInviteEvent: &api.OutputNewInviteEvent{
+ Event: add.Headered(roomVersion),
+ RoomVersion: roomVersion,
+ },
+ })
+ }
+ for _, eventID := range retired {
+ updates = append(updates, api.OutputEvent{
+ Type: api.OutputTypeRetireInviteEvent,
+ RetireInviteEvent: &api.OutputRetireInviteEvent{
+ EventID: eventID,
+ Membership: gomatrixserverlib.Join,
+ RetiredByEventID: add.EventID(),
+ TargetUserID: *add.StateKey(),
+ },
})
}
return updates, nil
@@ -197,6 +208,12 @@ func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.Room
return roomState.LoadCombinedStateAfterEvents(ctx, prevState)
}
+func MembershipAtEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID) (map[string][]types.StateEntry, error) {
+ roomState := state.NewStateResolution(db, info)
+ // Fetch the state as it was when this event was fired
+ return roomState.LoadMembershipAtEvent(ctx, eventIDs, stateKeyNID)
+}
+
func LoadEvents(
ctx context.Context, db storage.Database, eventNIDs []types.EventNID,
) ([]*gomatrixserverlib.Event, error) {
@@ -225,13 +242,34 @@ func LoadStateEvents(
func CheckServerAllowedToSeeEvent(
ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool,
) (bool, error) {
+ stateAtEvent, err := db.GetHistoryVisibilityState(ctx, info, eventID, string(serverName))
+ switch err {
+ case nil:
+ // No error, so continue normally
+ case tables.OptimisationNotSupportedError:
+ // The database engine didn't support this optimisation, so fall back to using
+ // the old and slow method
+ stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, eventID, serverName)
+ if err != nil {
+ return false, err
+ }
+ default:
+ // Something else went wrong
+ return false, err
+ }
+ return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil
+}
+
+func slowGetHistoryVisibilityState(
+ ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName,
+) ([]*gomatrixserverlib.Event, error) {
roomState := state.NewStateResolution(db, info)
stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
- return false, nil
+ return nil, nil
}
- return false, fmt.Errorf("roomState.LoadStateAtEvent: %w", err)
+ return nil, fmt.Errorf("roomState.LoadStateAtEvent: %w", err)
}
// Extract all of the event state key NIDs from the room state.
@@ -243,7 +281,7 @@ func CheckServerAllowedToSeeEvent(
// Then request those state key NIDs from the database.
stateKeys, err := db.EventStateKeys(ctx, stateKeyNIDs)
if err != nil {
- return false, fmt.Errorf("db.EventStateKeys: %w", err)
+ return nil, fmt.Errorf("db.EventStateKeys: %w", err)
}
// If the event state key doesn't match the given servername
@@ -266,15 +304,10 @@ func CheckServerAllowedToSeeEvent(
}
if len(filteredEntries) == 0 {
- return false, nil
+ return nil, nil
}
- stateAtEvent, err := LoadStateEvents(ctx, db, filteredEntries)
- if err != nil {
- return false, err
- }
-
- return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil
+ return LoadStateEvents(ctx, db, filteredEntries)
}
// TODO: Remove this when we have tests to assert correctness of this function
@@ -382,7 +415,7 @@ func QueryLatestEventsAndState(
if err != nil {
return err
}
- if roomInfo == nil || roomInfo.IsStub {
+ if roomInfo == nil || roomInfo.IsStub() {
response.RoomExists = false
return nil
}
diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go
index ecd4ecbb5..8d24f3c59 100644
--- a/roomserver/internal/input/input.go
+++ b/roomserver/internal/input/input.go
@@ -25,6 +25,11 @@ import (
"github.com/Arceliar/phony"
"github.com/getsentry/sentry-go"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/nats-io/nats.go"
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/sirupsen/logrus"
+
fedapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/roomserver/acls"
"github.com/matrix-org/dendrite/roomserver/api"
@@ -35,10 +40,6 @@ import (
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
- "github.com/matrix-org/gomatrixserverlib"
- "github.com/nats-io/nats.go"
- "github.com/prometheus/client_golang/prometheus"
- "github.com/sirupsen/logrus"
)
// Inputer is responsible for consuming from the roomserver input
@@ -60,9 +61,9 @@ import (
// per-room durable consumers will only progress through the stream
// as events are processed.
//
-// A BC * -> positions of each consumer (* = ephemeral)
-// ⌄ ⌄⌄ ⌄
-// ABAABCAABCAA -> newest (letter = subject for each message)
+// A BC * -> positions of each consumer (* = ephemeral)
+// ⌄ ⌄⌄ ⌄
+// ABAABCAABCAA -> newest (letter = subject for each message)
//
// In this example, A is still processing an event but has two
// pending events to process afterwards. Both B and C are caught
@@ -336,18 +337,18 @@ func (r *Inputer) InputRoomEvents(
ctx context.Context,
request *api.InputRoomEventsRequest,
response *api.InputRoomEventsResponse,
-) {
+) error {
// Queue up the event into the roomserver.
replySub, err := r.queueInputRoomEvents(ctx, request)
if err != nil {
response.ErrMsg = err.Error()
- return
+ return nil
}
// If we aren't waiting for synchronous responses then we can
// give up here, there is nothing further to do.
if replySub == nil {
- return
+ return nil
}
// Otherwise, we'll want to sit and wait for the responses
@@ -359,12 +360,14 @@ func (r *Inputer) InputRoomEvents(
msg, err := replySub.NextMsgWithContext(ctx)
if err != nil {
response.ErrMsg = err.Error()
- return
+ return nil
}
if len(msg.Data) > 0 {
response.ErrMsg = string(msg.Data)
}
}
+
+ return nil
}
var roomserverInputBackpressure = prometheus.NewGaugeVec(
diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go
index 866670d7a..81541260c 100644
--- a/roomserver/internal/input/input_events.go
+++ b/roomserver/internal/input/input_events.go
@@ -299,7 +299,7 @@ func (r *Inputer) processRoomEvent(
// allowed at the time, and also to get the history visibility. We won't
// bother doing this if the event was already rejected as it just ends up
// burning CPU time.
- historyVisibility := gomatrixserverlib.HistoryVisibilityJoined // Default to restrictive.
+ historyVisibility := gomatrixserverlib.HistoryVisibilityShared // Default to shared.
if rejectionErr == nil && !isRejected && !softfail {
var err error
historyVisibility, rejectionErr, err = r.processStateBefore(ctx, input, missingPrev)
@@ -429,7 +429,7 @@ func (r *Inputer) processStateBefore(
input *api.InputRoomEvent,
missingPrev bool,
) (historyVisibility gomatrixserverlib.HistoryVisibility, rejectionErr error, err error) {
- historyVisibility = gomatrixserverlib.HistoryVisibilityJoined // Default to restrictive.
+ historyVisibility = gomatrixserverlib.HistoryVisibilityShared // Default to shared.
event := input.Event.Unwrap()
isCreateEvent := event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("")
var stateBeforeEvent []*gomatrixserverlib.Event
diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go
index f7d15fdb5..d6efad79d 100644
--- a/roomserver/internal/input/input_latest_events.go
+++ b/roomserver/internal/input/input_latest_events.go
@@ -20,32 +20,32 @@ import (
"context"
"fmt"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/util"
+ "github.com/opentracing/opentracing-go"
+ "github.com/sirupsen/logrus"
+
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/types"
- "github.com/matrix-org/gomatrixserverlib"
- "github.com/matrix-org/util"
- "github.com/opentracing/opentracing-go"
- "github.com/sirupsen/logrus"
)
// updateLatestEvents updates the list of latest events for this room in the database and writes the
// event to the output log.
// The latest events are the events that aren't referenced by another event in the database:
//
-// Time goes down the page. 1 is the m.room.create event (root).
-//
-// 1 After storing 1 the latest events are {1}
-// | After storing 2 the latest events are {2}
-// 2 After storing 3 the latest events are {3}
-// / \ After storing 4 the latest events are {3,4}
-// 3 4 After storing 5 the latest events are {5,4}
-// | | After storing 6 the latest events are {5,6}
-// 5 6 <--- latest After storing 7 the latest events are {6,7}
-// |
-// 7 <----- latest
+// Time goes down the page. 1 is the m.room.create event (root).
+// 1 After storing 1 the latest events are {1}
+// | After storing 2 the latest events are {2}
+// 2 After storing 3 the latest events are {3}
+// / \ After storing 4 the latest events are {3,4}
+// 3 4 After storing 5 the latest events are {5,4}
+// | | After storing 6 the latest events are {5,6}
+// 5 6 <--- latest After storing 7 the latest events are {6,7}
+// |
+// 7 <----- latest
//
// Can only be called once at a time
func (r *Inputer) updateLatestEvents(
diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go
index 3ce8791a3..28a54623b 100644
--- a/roomserver/internal/input/input_membership.go
+++ b/roomserver/internal/input/input_membership.go
@@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/internal/helpers"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
+ "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/opentracing/opentracing-go"
@@ -60,20 +61,14 @@ func (r *Inputer) updateMemberships(
var updates []api.OutputEvent
for _, change := range changes {
- var ae *gomatrixserverlib.Event
- var re *gomatrixserverlib.Event
+ var ae *types.Event
+ var re *types.Event
targetUserNID := change.EventStateKeyNID
if change.removedEventNID != 0 {
- ev, _ := helpers.EventMap(events).Lookup(change.removedEventNID)
- if ev != nil {
- re = ev.Event
- }
+ re, _ = helpers.EventMap(events).Lookup(change.removedEventNID)
}
if change.addedEventNID != 0 {
- ev, _ := helpers.EventMap(events).Lookup(change.addedEventNID)
- if ev != nil {
- ae = ev.Event
- }
+ ae, _ = helpers.EventMap(events).Lookup(change.addedEventNID)
}
if updates, err = r.updateMembership(updater, targetUserNID, re, ae, updates); err != nil {
return nil, err
@@ -85,30 +80,27 @@ func (r *Inputer) updateMemberships(
func (r *Inputer) updateMembership(
updater *shared.RoomUpdater,
targetUserNID types.EventStateKeyNID,
- remove, add *gomatrixserverlib.Event,
+ remove, add *types.Event,
updates []api.OutputEvent,
) ([]api.OutputEvent, error) {
var err error
// Default the membership to Leave if no event was added or removed.
- oldMembership := gomatrixserverlib.Leave
newMembership := gomatrixserverlib.Leave
-
- if remove != nil {
- oldMembership, err = remove.Membership()
- if err != nil {
- return nil, err
- }
- }
if add != nil {
newMembership, err = add.Membership()
if err != nil {
return nil, err
}
}
- if oldMembership == newMembership && newMembership != gomatrixserverlib.Join {
- // If the membership is the same then nothing changed and we can return
- // immediately, unless it's a Join update (e.g. profile update).
- return updates, nil
+
+ var targetLocal bool
+ if add != nil {
+ targetLocal = r.isLocalTarget(add)
+ }
+
+ mu, err := updater.MembershipUpdater(targetUserNID, targetLocal)
+ if err != nil {
+ return nil, err
}
// In an ideal world, we shouldn't ever have "add" be nil and "remove" be
@@ -120,17 +112,10 @@ func (r *Inputer) updateMembership(
// after a state reset, often thinking that the user was still joined to
// the room even though the room state said otherwise, and this would prevent
// the user from being able to attempt to rejoin the room without modifying
- // the database. So instead what we'll do is we'll just update the membership
- // table to say that the user is "leave" and we'll use the old event to
- // avoid nil pointer exceptions on the code path that follows.
- if add == nil {
- add = remove
- newMembership = gomatrixserverlib.Leave
- }
-
- mu, err := updater.MembershipUpdater(targetUserNID, r.isLocalTarget(add))
- if err != nil {
- return nil, err
+ // the database. So instead we're going to remove the membership from the
+ // database altogether, so that it doesn't create future problems.
+ if add == nil && remove != nil {
+ return nil, mu.Delete()
}
switch newMembership {
@@ -149,7 +134,7 @@ func (r *Inputer) updateMembership(
}
}
-func (r *Inputer) isLocalTarget(event *gomatrixserverlib.Event) bool {
+func (r *Inputer) isLocalTarget(event *types.Event) bool {
isTargetLocalUser := false
if statekey := event.StateKey(); statekey != nil {
_, domain, _ := gomatrixserverlib.SplitID('@', *statekey)
@@ -159,81 +144,61 @@ func (r *Inputer) isLocalTarget(event *gomatrixserverlib.Event) bool {
}
func updateToJoinMembership(
- mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent,
+ mu *shared.MembershipUpdater, add *types.Event, updates []api.OutputEvent,
) ([]api.OutputEvent, error) {
- // If the user is already marked as being joined, we call SetToJoin to update
- // the event ID then we can return immediately. Retired is ignored as there
- // is no invite event to retire.
- if mu.IsJoin() {
- _, err := mu.SetToJoin(add.Sender(), add.EventID(), true)
- if err != nil {
- return nil, err
- }
- return updates, nil
- }
// When we mark a user as being joined we will invalidate any invites that
// are active for that user. We notify the consumers that the invites have
// been retired using a special event, even though they could infer this
// by studying the state changes in the room event stream.
- retired, err := mu.SetToJoin(add.Sender(), add.EventID(), false)
+ _, retired, err := mu.Update(tables.MembershipStateJoin, add)
if err != nil {
return nil, err
}
for _, eventID := range retired {
- orie := api.OutputRetireInviteEvent{
- EventID: eventID,
- Membership: gomatrixserverlib.Join,
- RetiredByEventID: add.EventID(),
- TargetUserID: *add.StateKey(),
- }
updates = append(updates, api.OutputEvent{
- Type: api.OutputTypeRetireInviteEvent,
- RetireInviteEvent: &orie,
+ Type: api.OutputTypeRetireInviteEvent,
+ RetireInviteEvent: &api.OutputRetireInviteEvent{
+ EventID: eventID,
+ Membership: gomatrixserverlib.Join,
+ RetiredByEventID: add.EventID(),
+ TargetUserID: *add.StateKey(),
+ },
})
}
return updates, nil
}
func updateToLeaveMembership(
- mu *shared.MembershipUpdater, add *gomatrixserverlib.Event,
+ mu *shared.MembershipUpdater, add *types.Event,
newMembership string, updates []api.OutputEvent,
) ([]api.OutputEvent, error) {
- // If the user is already neither joined, nor invited to the room then we
- // can return immediately.
- if mu.IsLeave() {
- return updates, nil
- }
// When we mark a user as having left we will invalidate any invites that
// are active for that user. We notify the consumers that the invites have
// been retired using a special event, even though they could infer this
// by studying the state changes in the room event stream.
- retired, err := mu.SetToLeave(add.Sender(), add.EventID())
+ _, retired, err := mu.Update(tables.MembershipStateLeaveOrBan, add)
if err != nil {
return nil, err
}
for _, eventID := range retired {
- orie := api.OutputRetireInviteEvent{
- EventID: eventID,
- Membership: newMembership,
- RetiredByEventID: add.EventID(),
- TargetUserID: *add.StateKey(),
- }
updates = append(updates, api.OutputEvent{
- Type: api.OutputTypeRetireInviteEvent,
- RetireInviteEvent: &orie,
+ Type: api.OutputTypeRetireInviteEvent,
+ RetireInviteEvent: &api.OutputRetireInviteEvent{
+ EventID: eventID,
+ Membership: newMembership,
+ RetiredByEventID: add.EventID(),
+ TargetUserID: *add.StateKey(),
+ },
})
}
return updates, nil
}
func updateToKnockMembership(
- mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent,
+ mu *shared.MembershipUpdater, add *types.Event, updates []api.OutputEvent,
) ([]api.OutputEvent, error) {
- if mu.IsLeave() {
- _, err := mu.SetToKnock(add)
- if err != nil {
- return nil, err
- }
+ if _, _, err := mu.Update(tables.MembershipStateKnock, add); err != nil {
+ return nil, err
}
return updates, nil
}
diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go
index edc153b7f..0dd2b64c0 100644
--- a/roomserver/internal/input/input_missing.go
+++ b/roomserver/internal/input/input_missing.go
@@ -326,8 +326,10 @@ func (t *missingStateReq) lookupStateAfterEvent(ctx context.Context, roomVersion
return respState, true, nil
}
+ logrus.WithContext(ctx).Warnf("State for event %s not available locally, falling back to federation (via %d servers)", eventID, len(t.servers))
respState, err := t.lookupStateBeforeEvent(ctx, roomVersion, roomID, eventID)
if err != nil {
+ logrus.WithContext(ctx).WithError(err).Errorf("Failed to look up state before event %s", eventID)
return nil, false, fmt.Errorf("t.lookupStateBeforeEvent: %w", err)
}
@@ -339,6 +341,7 @@ func (t *missingStateReq) lookupStateAfterEvent(ctx context.Context, roomVersion
case nil:
// do nothing
default:
+ logrus.WithContext(ctx).WithError(err).Errorf("Failed to look up event %s", eventID)
return nil, false, fmt.Errorf("t.lookupEvent: %w", err)
}
h = t.cacheAndReturn(h)
@@ -375,11 +378,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, room
defer span.Finish()
var res parsedRespState
- roomInfo, err := t.db.RoomInfo(ctx, roomID)
- if err != nil {
- return nil
- }
- roomState := state.NewStateResolution(t.db, roomInfo)
+ roomState := state.NewStateResolution(t.db, t.roomInfo)
stateAtEvents, err := t.db.StateAtEventIDs(ctx, []string{eventID})
if err != nil {
util.GetLogger(ctx).WithField("room_id", roomID).WithError(err).Warnf("failed to get state after %s locally", eventID)
@@ -666,9 +665,22 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo
util.GetLogger(ctx).WithField("room_id", roomID).Infof("lookupMissingStateViaStateIDs %s", eventID)
// fetch the state event IDs at the time of the event
- stateIDs, err := t.federation.LookupStateIDs(ctx, t.origin, roomID, eventID)
+ var stateIDs gomatrixserverlib.RespStateIDs
+ var err error
+ count := 0
+ totalctx, totalcancel := context.WithTimeout(ctx, time.Minute*5)
+ for _, serverName := range t.servers {
+ reqctx, reqcancel := context.WithTimeout(totalctx, time.Second*20)
+ stateIDs, err = t.federation.LookupStateIDs(reqctx, serverName, roomID, eventID)
+ reqcancel()
+ if err == nil {
+ break
+ }
+ count++
+ }
+ totalcancel()
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("t.federation.LookupStateIDs tried %d server(s), last error: %w", count, err)
}
// work out which auth/state IDs are missing
wantIDs := append(stateIDs.StateEventIDs, stateIDs.AuthEventIDs...)
@@ -754,9 +766,8 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo
// Define what we'll do in order to fetch the missing event ID.
fetch := func(missingEventID string) {
- var h *gomatrixserverlib.Event
- h, err = t.lookupEvent(ctx, roomVersion, roomID, missingEventID, false)
- switch err.(type) {
+ h, herr := t.lookupEvent(ctx, roomVersion, roomID, missingEventID, false)
+ switch herr.(type) {
case verifySigError:
return
case nil:
@@ -765,7 +776,7 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo
util.GetLogger(ctx).WithFields(logrus.Fields{
"event_id": missingEventID,
"room_id": roomID,
- }).Warn("Failed to fetch missing event")
+ }).WithError(herr).Warn("Failed to fetch missing event")
return
}
haveEventsMutex.Lock()
diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go
index 1cb52966a..cb6b22d32 100644
--- a/roomserver/internal/perform/perform_admin.go
+++ b/roomserver/internal/perform/perform_admin.go
@@ -43,21 +43,21 @@ func (r *Admin) PerformAdminEvacuateRoom(
ctx context.Context,
req *api.PerformAdminEvacuateRoomRequest,
res *api.PerformAdminEvacuateRoomResponse,
-) {
+) error {
roomInfo, err := r.DB.RoomInfo(ctx, req.RoomID)
if err != nil {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.RoomInfo: %s", err),
}
- return
+ return nil
}
- if roomInfo == nil || roomInfo.IsStub {
+ if roomInfo == nil || roomInfo.IsStub() {
res.Error = &api.PerformError{
Code: api.PerformErrorNoRoom,
Msg: fmt.Sprintf("Room %s not found", req.RoomID),
}
- return
+ return nil
}
memberNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true)
@@ -66,7 +66,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.GetMembershipEventNIDsForRoom: %s", err),
}
- return
+ return nil
}
memberEvents, err := r.DB.Events(ctx, memberNIDs)
@@ -75,7 +75,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.Events: %s", err),
}
- return
+ return nil
}
inputEvents := make([]api.InputRoomEvent, 0, len(memberEvents))
@@ -89,7 +89,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.Queryer.QueryLatestEventsAndState: %s", err),
}
- return
+ return nil
}
prevEvents := latestRes.LatestEvents
@@ -104,7 +104,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("json.Unmarshal: %s", err),
}
- return
+ return nil
}
memberContent.Membership = gomatrixserverlib.Leave
@@ -122,7 +122,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("json.Marshal: %s", err),
}
- return
+ return nil
}
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(fledglingEvent)
@@ -131,7 +131,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("gomatrixserverlib.StateNeededForEventBuilder: %s", err),
}
- return
+ return nil
}
event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, time.Now(), &eventsNeeded, latestRes)
@@ -140,7 +140,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("eventutil.BuildEvent: %s", err),
}
- return
+ return nil
}
inputEvents = append(inputEvents, api.InputRoomEvent{
@@ -160,28 +160,28 @@ func (r *Admin) PerformAdminEvacuateRoom(
Asynchronous: true,
}
inputRes := &api.InputRoomEventsResponse{}
- r.Inputer.InputRoomEvents(ctx, inputReq, inputRes)
+ return r.Inputer.InputRoomEvents(ctx, inputReq, inputRes)
}
func (r *Admin) PerformAdminEvacuateUser(
ctx context.Context,
req *api.PerformAdminEvacuateUserRequest,
res *api.PerformAdminEvacuateUserResponse,
-) {
+) error {
_, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Malformed user ID: %s", err),
}
- return
+ return nil
}
if domain != r.Cfg.Matrix.ServerName {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: "Can only evacuate local users using this endpoint",
}
- return
+ return nil
}
roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, gomatrixserverlib.Join)
@@ -190,7 +190,7 @@ func (r *Admin) PerformAdminEvacuateUser(
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.GetRoomsByMembership: %s", err),
}
- return
+ return nil
}
inviteRoomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, gomatrixserverlib.Invite)
@@ -199,7 +199,7 @@ func (r *Admin) PerformAdminEvacuateUser(
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.GetRoomsByMembership: %s", err),
}
- return
+ return nil
}
for _, roomID := range append(roomIDs, inviteRoomIDs...) {
@@ -214,7 +214,7 @@ func (r *Admin) PerformAdminEvacuateUser(
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.Leaver.PerformLeave: %s", err),
}
- return
+ return nil
}
if len(outputEvents) == 0 {
continue
@@ -224,9 +224,10 @@ func (r *Admin) PerformAdminEvacuateUser(
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.Inputer.WriteOutputEvents: %s", err),
}
- return
+ return nil
}
res.Affected = append(res.Affected, roomID)
}
+ return nil
}
diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go
index 3f98fbc24..298ba04f6 100644
--- a/roomserver/internal/perform/perform_backfill.go
+++ b/roomserver/internal/perform/perform_backfill.go
@@ -19,6 +19,10 @@ import (
"fmt"
"github.com/getsentry/sentry-go"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/util"
+ "github.com/sirupsen/logrus"
+
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api"
@@ -26,9 +30,6 @@ import (
"github.com/matrix-org/dendrite/roomserver/internal/helpers"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types"
- "github.com/matrix-org/gomatrixserverlib"
- "github.com/matrix-org/util"
- "github.com/sirupsen/logrus"
)
// the max number of servers to backfill from per request. If this is too low we may fail to backfill when
@@ -73,7 +74,7 @@ func (r *Backfiller) PerformBackfill(
if err != nil {
return err
}
- if info == nil || info.IsStub {
+ if info == nil || info.IsStub() {
return fmt.Errorf("PerformBackfill: missing room info for room %s", request.RoomID)
}
@@ -106,7 +107,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
if err != nil {
return err
}
- if info == nil || info.IsStub {
+ if info == nil || info.IsStub() {
return fmt.Errorf("backfillViaFederation: missing room info for room %s", req.RoomID)
}
requester := newBackfillRequester(r.DB, r.FSAPI, r.ServerName, req.BackwardsExtremities, r.PreferServers)
@@ -434,7 +435,7 @@ FindSuccessor:
logrus.WithError(err).WithField("room_id", roomID).Error("ServersAtEvent: failed to get RoomInfo for room")
return nil
}
- if info == nil || info.IsStub {
+ if info == nil || info.IsStub() {
logrus.WithField("room_id", roomID).Error("ServersAtEvent: failed to get RoomInfo for room, room is missing")
return nil
}
@@ -522,8 +523,9 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion,
}
// joinEventsFromHistoryVisibility returns all CURRENTLY joined members if our server can read the room history
+//
// TODO: Long term we probably want a history_visibility table which stores eventNID | visibility_enum so we can just
-// pull all events and then filter by that table.
+// pull all events and then filter by that table.
func joinEventsFromHistoryVisibility(
ctx context.Context, db storage.Database, roomID string, stateEntries []types.StateEntry,
thisServer gomatrixserverlib.ServerName) ([]types.Event, error) {
diff --git a/roomserver/internal/perform/perform_inbound_peek.go b/roomserver/internal/perform/perform_inbound_peek.go
index 32c81e849..29decd363 100644
--- a/roomserver/internal/perform/perform_inbound_peek.go
+++ b/roomserver/internal/perform/perform_inbound_peek.go
@@ -50,7 +50,7 @@ func (r *InboundPeeker) PerformInboundPeek(
if err != nil {
return err
}
- if info == nil || info.IsStub {
+ if info == nil || info.IsStub() {
return nil
}
response.RoomExists = true
diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go
index 644c954b6..483e78c3f 100644
--- a/roomserver/internal/perform/perform_invite.go
+++ b/roomserver/internal/perform/perform_invite.go
@@ -39,11 +39,13 @@ type Inviter struct {
Inputer *input.Inputer
}
+// nolint:gocyclo
func (r *Inviter) PerformInvite(
ctx context.Context,
req *api.PerformInviteRequest,
res *api.PerformInviteResponse,
) ([]api.OutputEvent, error) {
+ var outputUpdates []api.OutputEvent
event := req.Event
if event.StateKey() == nil {
return nil, fmt.Errorf("invite must be a state event")
@@ -66,6 +68,13 @@ func (r *Inviter) PerformInvite(
}
isTargetLocal := domain == r.Cfg.Matrix.ServerName
isOriginLocal := event.Origin() == r.Cfg.Matrix.ServerName
+ if !isOriginLocal && !isTargetLocal {
+ res.Error = &api.PerformError{
+ Code: api.PerformErrorBadRequest,
+ Msg: "The invite must be either from or to a local user",
+ }
+ return nil, nil
+ }
logger := util.GetLogger(ctx).WithFields(map[string]interface{}{
"inviter": event.Sender(),
@@ -97,6 +106,34 @@ func (r *Inviter) PerformInvite(
}
}
+ updateMembershipTableManually := func() ([]api.OutputEvent, error) {
+ var updater *shared.MembershipUpdater
+ if updater, err = r.DB.MembershipUpdater(ctx, roomID, targetUserID, isTargetLocal, req.RoomVersion); err != nil {
+ return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err)
+ }
+ outputUpdates, err = helpers.UpdateToInviteMembership(updater, &types.Event{
+ EventNID: 0,
+ Event: event.Unwrap(),
+ }, outputUpdates, req.Event.RoomVersion)
+ if err != nil {
+ return nil, fmt.Errorf("updateToInviteMembership: %w", err)
+ }
+ if err = updater.Commit(); err != nil {
+ return nil, fmt.Errorf("updater.Commit: %w", err)
+ }
+ logger.Debugf("updated membership to invite and sending invite OutputEvent")
+ return outputUpdates, nil
+ }
+
+ if (info == nil || info.IsStub()) && !isOriginLocal && isTargetLocal {
+ // The invite came in over federation for a room that we don't know about
+ // yet. We need to handle this a bit differently to most invites because
+ // we don't know the room state, therefore the roomserver can't process
+ // an input event. Instead we will update the membership table with the
+ // new invite and generate an output event.
+ return updateMembershipTableManually()
+ }
+
var isAlreadyJoined bool
if info != nil {
_, isAlreadyJoined, _, err = r.DB.GetMembership(ctx, info.RoomNID, *event.StateKey())
@@ -140,31 +177,13 @@ func (r *Inviter) PerformInvite(
return nil, nil
}
+ // If the invite originated remotely then we can't send an
+ // InputRoomEvent for the invite as it will never pass auth checks
+ // due to lacking room state, but we still need to tell the client
+ // about the invite so we can accept it, hence we return an output
+ // event to send to the Sync API.
if !isOriginLocal {
- // The invite originated over federation. Process the membership
- // update, which will notify the sync API etc about the incoming
- // invite. We do NOT send an InputRoomEvent for the invite as it
- // will never pass auth checks due to lacking room state, but we
- // still need to tell the client about the invite so we can accept
- // it, hence we return an output event to send to the sync api.
- var updater *shared.MembershipUpdater
- updater, err = r.DB.MembershipUpdater(ctx, roomID, targetUserID, isTargetLocal, req.RoomVersion)
- if err != nil {
- return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err)
- }
-
- unwrapped := event.Unwrap()
- var outputUpdates []api.OutputEvent
- outputUpdates, err = helpers.UpdateToInviteMembership(updater, unwrapped, nil, req.Event.RoomVersion)
- if err != nil {
- return nil, fmt.Errorf("updateToInviteMembership: %w", err)
- }
-
- if err = updater.Commit(); err != nil {
- return nil, fmt.Errorf("updater.Commit: %w", err)
- }
- logger.Debugf("updated membership to invite and sending invite OutputEvent")
- return outputUpdates, nil
+ return updateMembershipTableManually()
}
// The invite originated locally. Therefore we have a responsibility to
@@ -222,19 +241,20 @@ func (r *Inviter) PerformInvite(
},
}
inputRes := &api.InputRoomEventsResponse{}
- r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes)
+ if err = r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes); err != nil {
+ return nil, fmt.Errorf("r.Inputer.InputRoomEvents: %w", err)
+ }
if err = inputRes.Err(); err != nil {
res.Error = &api.PerformError{
Msg: fmt.Sprintf("r.InputRoomEvents: %s", err.Error()),
Code: api.PerformErrorNotAllowed,
}
logger.WithError(err).WithField("event_id", event.EventID()).Error("r.InputRoomEvents failed")
- return nil, nil
}
// Don't notify the sync api of this event in the same way as a federated invite so the invitee
// gets the invite, as the roomserver will do this when it processes the m.room.member invite.
- return nil, nil
+ return outputUpdates, nil
}
func buildInviteStrippedState(
@@ -258,7 +278,7 @@ func buildInviteStrippedState(
}
roomState := state.NewStateResolution(db, info)
stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples(
- ctx, info.StateSnapshotNID, stateWanted,
+ ctx, info.StateSnapshotNID(), stateWanted,
)
if err != nil {
return nil, err
diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go
index c9e839198..43be54beb 100644
--- a/roomserver/internal/perform/perform_join.go
+++ b/roomserver/internal/perform/perform_join.go
@@ -52,7 +52,7 @@ func (r *Joiner) PerformJoin(
ctx context.Context,
req *rsAPI.PerformJoinRequest,
res *rsAPI.PerformJoinResponse,
-) {
+) error {
logger := logrus.WithContext(ctx).WithFields(logrus.Fields{
"room_id": req.RoomIDOrAlias,
"user_id": req.UserID,
@@ -71,11 +71,12 @@ func (r *Joiner) PerformJoin(
Msg: err.Error(),
}
}
- return
+ return nil
}
logger.Info("User joined room successfully")
res.RoomID = roomID
res.JoinedVia = joinedVia
+ return nil
}
func (r *Joiner) performJoin(
@@ -268,21 +269,19 @@ func (r *Joiner) performJoinRoomByID(
case nil:
// The room join is local. Send the new join event into the
// roomserver. First of all check that the user isn't already
- // a member of the room.
- alreadyJoined := false
- for _, se := range buildRes.StateEvents {
- if !se.StateKeyEquals(userID) {
- continue
- }
- if membership, merr := se.Membership(); merr == nil {
- alreadyJoined = (membership == gomatrixserverlib.Join)
- break
- }
+ // a member of the room. This is best-effort (as in we won't
+ // fail if we can't find the existing membership) because there
+ // is really no harm in just sending another membership event.
+ membershipReq := &api.QueryMembershipForUserRequest{
+ RoomID: req.RoomIDOrAlias,
+ UserID: userID,
}
+ membershipRes := &api.QueryMembershipForUserResponse{}
+ _ = r.Queryer.QueryMembershipForUser(ctx, membershipReq, membershipRes)
// If we haven't already joined the room then send an event
// into the room changing our membership status.
- if !alreadyJoined {
+ if !membershipRes.RoomExists || !membershipRes.IsInRoom {
inputReq := rsAPI.InputRoomEventsRequest{
InputRoomEvents: []rsAPI.InputRoomEvent{
{
@@ -293,7 +292,12 @@ func (r *Joiner) performJoinRoomByID(
},
}
inputRes := rsAPI.InputRoomEventsResponse{}
- r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes)
+ if err = r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil {
+ return "", "", &rsAPI.PerformError{
+ Code: rsAPI.PerformErrorNoOperation,
+ Msg: fmt.Sprintf("InputRoomEvents failed: %s", err),
+ }
+ }
if err = inputRes.Err(); err != nil {
return "", "", &rsAPI.PerformError{
Code: rsAPI.PerformErrorNotAllowed,
diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go
index c5b62ac00..036404cd2 100644
--- a/roomserver/internal/perform/perform_leave.go
+++ b/roomserver/internal/perform/perform_leave.go
@@ -186,7 +186,9 @@ func (r *Leaver) performLeaveRoomByID(
},
}
inputRes := api.InputRoomEventsResponse{}
- r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes)
+ if err = r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil {
+ return nil, fmt.Errorf("r.Inputer.InputRoomEvents: %w", err)
+ }
if err = inputRes.Err(); err != nil {
return nil, fmt.Errorf("r.InputRoomEvents: %w", err)
}
@@ -228,14 +230,14 @@ func (r *Leaver) performFederatedRejectInvite(
util.GetLogger(ctx).WithError(err).Errorf("failed to get MembershipUpdater, still retiring invite event")
}
if updater != nil {
- if _, err = updater.SetToLeave(req.UserID, eventID); err != nil {
- util.GetLogger(ctx).WithError(err).Errorf("failed to set membership to leave, still retiring invite event")
+ if err = updater.Delete(); err != nil {
+ util.GetLogger(ctx).WithError(err).Errorf("failed to delete membership, still retiring invite event")
if err = updater.Rollback(); err != nil {
- util.GetLogger(ctx).WithError(err).Errorf("failed to rollback membership leave, still retiring invite event")
+ util.GetLogger(ctx).WithError(err).Errorf("failed to rollback deleting membership, still retiring invite event")
}
} else {
if err = updater.Commit(); err != nil {
- util.GetLogger(ctx).WithError(err).Errorf("failed to commit membership update, still retiring invite event")
+ util.GetLogger(ctx).WithError(err).Errorf("failed to commit deleting membership, still retiring invite event")
}
}
}
diff --git a/roomserver/internal/perform/perform_peek.go b/roomserver/internal/perform/perform_peek.go
index 5560916b2..74d87a5b4 100644
--- a/roomserver/internal/perform/perform_peek.go
+++ b/roomserver/internal/perform/perform_peek.go
@@ -44,7 +44,7 @@ func (r *Peeker) PerformPeek(
ctx context.Context,
req *api.PerformPeekRequest,
res *api.PerformPeekResponse,
-) {
+) error {
roomID, err := r.performPeek(ctx, req)
if err != nil {
perr, ok := err.(*api.PerformError)
@@ -57,6 +57,7 @@ func (r *Peeker) PerformPeek(
}
}
res.RoomID = roomID
+ return nil
}
func (r *Peeker) performPeek(
diff --git a/roomserver/internal/perform/perform_publish.go b/roomserver/internal/perform/perform_publish.go
index 6ff42ac1a..1631fc657 100644
--- a/roomserver/internal/perform/perform_publish.go
+++ b/roomserver/internal/perform/perform_publish.go
@@ -29,11 +29,12 @@ func (r *Publisher) PerformPublish(
ctx context.Context,
req *api.PerformPublishRequest,
res *api.PerformPublishResponse,
-) {
+) error {
err := r.DB.PublishRoom(ctx, req.RoomID, req.Visibility == "public")
if err != nil {
res.Error = &api.PerformError{
Msg: err.Error(),
}
}
+ return nil
}
diff --git a/roomserver/internal/perform/perform_unpeek.go b/roomserver/internal/perform/perform_unpeek.go
index 1fe8d5a0f..49e9067c9 100644
--- a/roomserver/internal/perform/perform_unpeek.go
+++ b/roomserver/internal/perform/perform_unpeek.go
@@ -41,7 +41,7 @@ func (r *Unpeeker) PerformUnpeek(
ctx context.Context,
req *api.PerformUnpeekRequest,
res *api.PerformUnpeekResponse,
-) {
+) error {
if err := r.performUnpeek(ctx, req); err != nil {
perr, ok := err.(*api.PerformError)
if ok {
@@ -52,6 +52,7 @@ func (r *Unpeeker) PerformUnpeek(
}
}
}
+ return nil
}
func (r *Unpeeker) performUnpeek(
diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go
index 393d7dd14..d6dc9708c 100644
--- a/roomserver/internal/perform/perform_upgrade.go
+++ b/roomserver/internal/perform/perform_upgrade.go
@@ -45,12 +45,13 @@ func (r *Upgrader) PerformRoomUpgrade(
ctx context.Context,
req *api.PerformRoomUpgradeRequest,
res *api.PerformRoomUpgradeResponse,
-) {
+) error {
res.NewRoomID, res.Error = r.performRoomUpgrade(ctx, req)
if res.Error != nil {
res.NewRoomID = ""
logrus.WithContext(ctx).WithError(res.Error).Error("Room upgrade failed")
}
+ return nil
}
func (r *Upgrader) performRoomUpgrade(
@@ -286,22 +287,24 @@ func publishNewRoomAndUnpublishOldRoom(
) {
// expose this room in the published room list
var pubNewRoomRes api.PerformPublishResponse
- URSAPI.PerformPublish(ctx, &api.PerformPublishRequest{
+ if err := URSAPI.PerformPublish(ctx, &api.PerformPublishRequest{
RoomID: newRoomID,
Visibility: "public",
- }, &pubNewRoomRes)
- if pubNewRoomRes.Error != nil {
+ }, &pubNewRoomRes); err != nil {
+ util.GetLogger(ctx).WithError(err).Error("failed to reach internal API")
+ } else if pubNewRoomRes.Error != nil {
// treat as non-fatal since the room is already made by this point
util.GetLogger(ctx).WithError(pubNewRoomRes.Error).Error("failed to visibility:public")
}
var unpubOldRoomRes api.PerformPublishResponse
// remove the old room from the published room list
- URSAPI.PerformPublish(ctx, &api.PerformPublishRequest{
+ if err := URSAPI.PerformPublish(ctx, &api.PerformPublishRequest{
RoomID: oldRoomID,
Visibility: "private",
- }, &unpubOldRoomRes)
- if unpubOldRoomRes.Error != nil {
+ }, &unpubOldRoomRes); err != nil {
+ util.GetLogger(ctx).WithError(err).Error("failed to reach internal API")
+ } else if unpubOldRoomRes.Error != nil {
// treat as non-fatal since the room is already made by this point
util.GetLogger(ctx).WithError(unpubOldRoomRes.Error).Error("failed to visibility:private")
}
diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go
index da1b32530..c41e1ea67 100644
--- a/roomserver/internal/query/query.go
+++ b/roomserver/internal/query/query.go
@@ -16,6 +16,7 @@ package query
import (
"context"
+ "database/sql"
"encoding/json"
"errors"
"fmt"
@@ -60,7 +61,7 @@ func (r *Queryer) QueryStateAfterEvents(
if err != nil {
return err
}
- if info == nil || info.IsStub {
+ if info == nil || info.IsStub() {
return nil
}
@@ -203,6 +204,54 @@ func (r *Queryer) QueryMembershipForUser(
return err
}
+func (r *Queryer) QueryMembershipAtEvent(
+ ctx context.Context,
+ request *api.QueryMembershipAtEventRequest,
+ response *api.QueryMembershipAtEventResponse,
+) error {
+ response.Memberships = make(map[string][]*gomatrixserverlib.HeaderedEvent)
+ info, err := r.DB.RoomInfo(ctx, request.RoomID)
+ if err != nil {
+ return fmt.Errorf("unable to get roomInfo: %w", err)
+ }
+ if info == nil {
+ return fmt.Errorf("no roomInfo found")
+ }
+
+ // get the users stateKeyNID
+ stateKeyNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{request.UserID})
+ if err != nil {
+ return fmt.Errorf("unable to get stateKeyNIDs for %s: %w", request.UserID, err)
+ }
+ if _, ok := stateKeyNIDs[request.UserID]; !ok {
+ return fmt.Errorf("requested stateKeyNID for %s was not found", request.UserID)
+ }
+
+ stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, info, request.EventIDs, stateKeyNIDs[request.UserID])
+ if err != nil {
+ return fmt.Errorf("unable to get state before event: %w", err)
+ }
+
+ for _, eventID := range request.EventIDs {
+ stateEntry := stateEntries[eventID]
+ memberships, err := helpers.GetMembershipsAtState(ctx, r.DB, stateEntry, false)
+ if err != nil {
+ return fmt.Errorf("unable to get memberships at state: %w", err)
+ }
+ res := make([]*gomatrixserverlib.HeaderedEvent, 0, len(memberships))
+
+ for i := range memberships {
+ ev := memberships[i]
+ if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(request.UserID) {
+ res = append(res, ev.Headered(info.RoomVersion))
+ }
+ }
+ response.Memberships[eventID] = res
+ }
+
+ return nil
+}
+
// QueryMembershipsForRoom implements api.RoomserverInternalAPI
func (r *Queryer) QueryMembershipsForRoom(
ctx context.Context,
@@ -225,6 +274,9 @@ func (r *Queryer) QueryMembershipsForRoom(
var eventNIDs []types.EventNID
eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, request.JoinedOnly, request.LocalOnly)
if err != nil {
+ if err == sql.ErrNoRows {
+ return nil
+ }
return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err)
}
events, err = r.DB.Events(ctx, eventNIDs)
@@ -260,6 +312,9 @@ func (r *Queryer) QueryMembershipsForRoom(
var eventNIDs []types.EventNID
eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, request.JoinedOnly, false)
if err != nil {
+ if err == sql.ErrNoRows {
+ return nil
+ }
return err
}
@@ -295,7 +350,7 @@ func (r *Queryer) QueryServerJoinedToRoom(
if err != nil {
return fmt.Errorf("r.DB.RoomInfo: %w", err)
}
- if info == nil || info.IsStub {
+ if info == nil || info.IsStub() {
return nil
}
response.RoomExists = true
@@ -344,8 +399,8 @@ func (r *Queryer) QueryServerAllowedToSeeEvent(
if err != nil {
return err
}
- if info == nil {
- return fmt.Errorf("QueryServerAllowedToSeeEvent: no room info for room %s", roomID)
+ if info == nil || info.IsStub() {
+ return nil
}
response.AllowedToSeeEvent, err = helpers.CheckServerAllowedToSeeEvent(
ctx, r.DB, info, request.EventID, request.ServerName, inRoomRes.IsInRoom,
@@ -383,7 +438,7 @@ func (r *Queryer) QueryMissingEvents(
if err != nil {
return err
}
- if info == nil || info.IsStub {
+ if info == nil || info.IsStub() {
return fmt.Errorf("missing RoomInfo for room %s", events[0].RoomID())
}
@@ -422,7 +477,7 @@ func (r *Queryer) QueryStateAndAuthChain(
if err != nil {
return err
}
- if info == nil || info.IsStub {
+ if info == nil || info.IsStub() {
return nil
}
response.RoomExists = true
@@ -767,7 +822,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query
if err != nil {
return fmt.Errorf("r.DB.RoomInfo: %w", err)
}
- if roomInfo == nil || roomInfo.IsStub {
+ if roomInfo == nil || roomInfo.IsStub() {
return nil // fmt.Errorf("room %q doesn't exist or is stub room", req.RoomID)
}
// If the room version doesn't allow restricted joins then don't
@@ -830,7 +885,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query
// See if the room exists. If it doesn't exist or if it's a stub
// room entry then we can't check memberships.
targetRoomInfo, err := r.DB.RoomInfo(ctx, rule.RoomID)
- if err != nil || targetRoomInfo == nil || targetRoomInfo.IsStub {
+ if err != nil || targetRoomInfo == nil || targetRoomInfo.IsStub() {
res.Resident = false
continue
}
diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go
index 2fa8afc49..a1dfc6aac 100644
--- a/roomserver/inthttp/client.go
+++ b/roomserver/inthttp/client.go
@@ -3,18 +3,16 @@ package inthttp
import (
"context"
"errors"
- "fmt"
"net/http"
+ "github.com/matrix-org/gomatrixserverlib"
+
asAPI "github.com/matrix-org/dendrite/appservice/api"
fsInputAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/roomserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api"
-
- "github.com/matrix-org/gomatrixserverlib"
- "github.com/opentracing/opentracing-go"
)
const (
@@ -63,6 +61,7 @@ const (
RoomserverQueryServerBannedFromRoomPath = "/roomserver/queryServerBannedFromRoom"
RoomserverQueryAuthChainPath = "/roomserver/queryAuthChain"
RoomserverQueryRestrictedJoinAllowed = "/roomserver/queryRestrictedJoinAllowed"
+ RoomserverQueryMembershipAtEventPath = "/roomserver/queryMembershipAtEvent"
)
type httpRoomserverInternalAPI struct {
@@ -106,11 +105,10 @@ func (h *httpRoomserverInternalAPI) SetRoomAlias(
request *api.SetRoomAliasRequest,
response *api.SetRoomAliasResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "SetRoomAlias")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverSetRoomAliasPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "SetRoomAlias", h.roomserverURL+RoomserverSetRoomAliasPath,
+ h.httpClient, ctx, request, response,
+ )
}
// GetRoomIDForAlias implements RoomserverAliasAPI
@@ -119,11 +117,10 @@ func (h *httpRoomserverInternalAPI) GetRoomIDForAlias(
request *api.GetRoomIDForAliasRequest,
response *api.GetRoomIDForAliasResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "GetRoomIDForAlias")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverGetRoomIDForAliasPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "GetRoomIDForAlias", h.roomserverURL+RoomserverGetRoomIDForAliasPath,
+ h.httpClient, ctx, request, response,
+ )
}
// GetAliasesForRoomID implements RoomserverAliasAPI
@@ -132,11 +129,10 @@ func (h *httpRoomserverInternalAPI) GetAliasesForRoomID(
request *api.GetAliasesForRoomIDRequest,
response *api.GetAliasesForRoomIDResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "GetAliasesForRoomID")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverGetAliasesForRoomIDPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "GetAliasesForRoomID", h.roomserverURL+RoomserverGetAliasesForRoomIDPath,
+ h.httpClient, ctx, request, response,
+ )
}
// RemoveRoomAlias implements RoomserverAliasAPI
@@ -145,11 +141,10 @@ func (h *httpRoomserverInternalAPI) RemoveRoomAlias(
request *api.RemoveRoomAliasRequest,
response *api.RemoveRoomAliasResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "RemoveRoomAlias")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverRemoveRoomAliasPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "RemoveRoomAlias", h.roomserverURL+RoomserverRemoveRoomAliasPath,
+ h.httpClient, ctx, request, response,
+ )
}
// InputRoomEvents implements RoomserverInputAPI
@@ -157,15 +152,14 @@ func (h *httpRoomserverInternalAPI) InputRoomEvents(
ctx context.Context,
request *api.InputRoomEventsRequest,
response *api.InputRoomEventsResponse,
-) {
- span, ctx := opentracing.StartSpanFromContext(ctx, "InputRoomEvents")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverInputRoomEventsPath
- err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
- if err != nil {
+) error {
+ if err := httputil.CallInternalRPCAPI(
+ "InputRoomEvents", h.roomserverURL+RoomserverInputRoomEventsPath,
+ h.httpClient, ctx, request, response,
+ ); err != nil {
response.ErrMsg = err.Error()
}
+ return nil
}
func (h *httpRoomserverInternalAPI) PerformInvite(
@@ -173,45 +167,32 @@ func (h *httpRoomserverInternalAPI) PerformInvite(
request *api.PerformInviteRequest,
response *api.PerformInviteResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformInvite")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverPerformInvitePath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "PerformInvite", h.roomserverURL+RoomserverPerformInvitePath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpRoomserverInternalAPI) PerformJoin(
ctx context.Context,
request *api.PerformJoinRequest,
response *api.PerformJoinResponse,
-) {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformJoin")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverPerformJoinPath
- err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
- if err != nil {
- response.Error = &api.PerformError{
- Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err),
- }
- }
+) error {
+ return httputil.CallInternalRPCAPI(
+ "PerformJoin", h.roomserverURL+RoomserverPerformJoinPath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpRoomserverInternalAPI) PerformPeek(
ctx context.Context,
request *api.PerformPeekRequest,
response *api.PerformPeekResponse,
-) {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPeek")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverPerformPeekPath
- err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
- if err != nil {
- response.Error = &api.PerformError{
- Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err),
- }
- }
+) error {
+ return httputil.CallInternalRPCAPI(
+ "PerformPeek", h.roomserverURL+RoomserverPerformPeekPath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpRoomserverInternalAPI) PerformInboundPeek(
@@ -219,45 +200,32 @@ func (h *httpRoomserverInternalAPI) PerformInboundPeek(
request *api.PerformInboundPeekRequest,
response *api.PerformInboundPeekResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformInboundPeek")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverPerformInboundPeekPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "PerformInboundPeek", h.roomserverURL+RoomserverPerformInboundPeekPath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpRoomserverInternalAPI) PerformUnpeek(
ctx context.Context,
request *api.PerformUnpeekRequest,
response *api.PerformUnpeekResponse,
-) {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformUnpeek")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverPerformUnpeekPath
- err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
- if err != nil {
- response.Error = &api.PerformError{
- Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err),
- }
- }
+) error {
+ return httputil.CallInternalRPCAPI(
+ "PerformUnpeek", h.roomserverURL+RoomserverPerformUnpeekPath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpRoomserverInternalAPI) PerformRoomUpgrade(
ctx context.Context,
request *api.PerformRoomUpgradeRequest,
response *api.PerformRoomUpgradeResponse,
-) {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformRoomUpgrade")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverPerformRoomUpgradePath
- err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
- if err != nil {
- response.Error = &api.PerformError{
- Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err),
- }
- }
+) error {
+ return httputil.CallInternalRPCAPI(
+ "PerformRoomUpgrade", h.roomserverURL+RoomserverPerformRoomUpgradePath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpRoomserverInternalAPI) PerformLeave(
@@ -265,62 +233,43 @@ func (h *httpRoomserverInternalAPI) PerformLeave(
request *api.PerformLeaveRequest,
response *api.PerformLeaveResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLeave")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverPerformLeavePath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "PerformLeave", h.roomserverURL+RoomserverPerformLeavePath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpRoomserverInternalAPI) PerformPublish(
ctx context.Context,
- req *api.PerformPublishRequest,
- res *api.PerformPublishResponse,
-) {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPublish")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverPerformPublishPath
- err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
- if err != nil {
- res.Error = &api.PerformError{
- Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err),
- }
- }
+ request *api.PerformPublishRequest,
+ response *api.PerformPublishResponse,
+) error {
+ return httputil.CallInternalRPCAPI(
+ "PerformPublish", h.roomserverURL+RoomserverPerformPublishPath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpRoomserverInternalAPI) PerformAdminEvacuateRoom(
ctx context.Context,
- req *api.PerformAdminEvacuateRoomRequest,
- res *api.PerformAdminEvacuateRoomResponse,
-) {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformAdminEvacuateRoom")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverPerformAdminEvacuateRoomPath
- err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
- if err != nil {
- res.Error = &api.PerformError{
- Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err),
- }
- }
+ request *api.PerformAdminEvacuateRoomRequest,
+ response *api.PerformAdminEvacuateRoomResponse,
+) error {
+ return httputil.CallInternalRPCAPI(
+ "PerformAdminEvacuateRoom", h.roomserverURL+RoomserverPerformAdminEvacuateRoomPath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpRoomserverInternalAPI) PerformAdminEvacuateUser(
ctx context.Context,
- req *api.PerformAdminEvacuateUserRequest,
- res *api.PerformAdminEvacuateUserResponse,
-) {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformAdminEvacuateUser")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverPerformAdminEvacuateUserPath
- err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
- if err != nil {
- res.Error = &api.PerformError{
- Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err),
- }
- }
+ request *api.PerformAdminEvacuateUserRequest,
+ response *api.PerformAdminEvacuateUserResponse,
+) error {
+ return httputil.CallInternalRPCAPI(
+ "PerformAdminEvacuateUser", h.roomserverURL+RoomserverPerformAdminEvacuateUserPath,
+ h.httpClient, ctx, request, response,
+ )
}
// QueryLatestEventsAndState implements RoomserverQueryAPI
@@ -329,11 +278,10 @@ func (h *httpRoomserverInternalAPI) QueryLatestEventsAndState(
request *api.QueryLatestEventsAndStateRequest,
response *api.QueryLatestEventsAndStateResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryLatestEventsAndState")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverQueryLatestEventsAndStatePath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "QueryLatestEventsAndState", h.roomserverURL+RoomserverQueryLatestEventsAndStatePath,
+ h.httpClient, ctx, request, response,
+ )
}
// QueryStateAfterEvents implements RoomserverQueryAPI
@@ -342,11 +290,10 @@ func (h *httpRoomserverInternalAPI) QueryStateAfterEvents(
request *api.QueryStateAfterEventsRequest,
response *api.QueryStateAfterEventsResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryStateAfterEvents")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverQueryStateAfterEventsPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "QueryStateAfterEvents", h.roomserverURL+RoomserverQueryStateAfterEventsPath,
+ h.httpClient, ctx, request, response,
+ )
}
// QueryEventsByID implements RoomserverQueryAPI
@@ -355,11 +302,10 @@ func (h *httpRoomserverInternalAPI) QueryEventsByID(
request *api.QueryEventsByIDRequest,
response *api.QueryEventsByIDResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryEventsByID")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverQueryEventsByIDPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "QueryEventsByID", h.roomserverURL+RoomserverQueryEventsByIDPath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpRoomserverInternalAPI) QueryPublishedRooms(
@@ -367,11 +313,10 @@ func (h *httpRoomserverInternalAPI) QueryPublishedRooms(
request *api.QueryPublishedRoomsRequest,
response *api.QueryPublishedRoomsResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPublishedRooms")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverQueryPublishedRoomsPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "QueryPublishedRooms", h.roomserverURL+RoomserverQueryPublishedRoomsPath,
+ h.httpClient, ctx, request, response,
+ )
}
// QueryMembershipForUser implements RoomserverQueryAPI
@@ -380,11 +325,10 @@ func (h *httpRoomserverInternalAPI) QueryMembershipForUser(
request *api.QueryMembershipForUserRequest,
response *api.QueryMembershipForUserResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMembershipForUser")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverQueryMembershipForUserPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "QueryMembershipForUser", h.roomserverURL+RoomserverQueryMembershipForUserPath,
+ h.httpClient, ctx, request, response,
+ )
}
// QueryMembershipsForRoom implements RoomserverQueryAPI
@@ -393,11 +337,10 @@ func (h *httpRoomserverInternalAPI) QueryMembershipsForRoom(
request *api.QueryMembershipsForRoomRequest,
response *api.QueryMembershipsForRoomResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMembershipsForRoom")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverQueryMembershipsForRoomPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "QueryMembershipsForRoom", h.roomserverURL+RoomserverQueryMembershipsForRoomPath,
+ h.httpClient, ctx, request, response,
+ )
}
// QueryMembershipsForRoom implements RoomserverQueryAPI
@@ -406,11 +349,10 @@ func (h *httpRoomserverInternalAPI) QueryServerJoinedToRoom(
request *api.QueryServerJoinedToRoomRequest,
response *api.QueryServerJoinedToRoomResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerJoinedToRoom")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverQueryServerJoinedToRoomPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "QueryServerJoinedToRoom", h.roomserverURL+RoomserverQueryServerJoinedToRoomPath,
+ h.httpClient, ctx, request, response,
+ )
}
// QueryServerAllowedToSeeEvent implements RoomserverQueryAPI
@@ -419,11 +361,10 @@ func (h *httpRoomserverInternalAPI) QueryServerAllowedToSeeEvent(
request *api.QueryServerAllowedToSeeEventRequest,
response *api.QueryServerAllowedToSeeEventResponse,
) (err error) {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerAllowedToSeeEvent")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverQueryServerAllowedToSeeEventPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "QueryServerAllowedToSeeEvent", h.roomserverURL+RoomserverQueryServerAllowedToSeeEventPath,
+ h.httpClient, ctx, request, response,
+ )
}
// QueryMissingEvents implements RoomServerQueryAPI
@@ -432,11 +373,10 @@ func (h *httpRoomserverInternalAPI) QueryMissingEvents(
request *api.QueryMissingEventsRequest,
response *api.QueryMissingEventsResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMissingEvents")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverQueryMissingEventsPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "QueryMissingEvents", h.roomserverURL+RoomserverQueryMissingEventsPath,
+ h.httpClient, ctx, request, response,
+ )
}
// QueryStateAndAuthChain implements RoomserverQueryAPI
@@ -445,11 +385,10 @@ func (h *httpRoomserverInternalAPI) QueryStateAndAuthChain(
request *api.QueryStateAndAuthChainRequest,
response *api.QueryStateAndAuthChainResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryStateAndAuthChain")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverQueryStateAndAuthChainPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "QueryStateAndAuthChain", h.roomserverURL+RoomserverQueryStateAndAuthChainPath,
+ h.httpClient, ctx, request, response,
+ )
}
// PerformBackfill implements RoomServerQueryAPI
@@ -458,11 +397,10 @@ func (h *httpRoomserverInternalAPI) PerformBackfill(
request *api.PerformBackfillRequest,
response *api.PerformBackfillResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformBackfill")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverPerformBackfillPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "PerformBackfill", h.roomserverURL+RoomserverPerformBackfillPath,
+ h.httpClient, ctx, request, response,
+ )
}
// QueryRoomVersionCapabilities implements RoomServerQueryAPI
@@ -471,11 +409,10 @@ func (h *httpRoomserverInternalAPI) QueryRoomVersionCapabilities(
request *api.QueryRoomVersionCapabilitiesRequest,
response *api.QueryRoomVersionCapabilitiesResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRoomVersionCapabilities")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverQueryRoomVersionCapabilitiesPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "QueryRoomVersionCapabilities", h.roomserverURL+RoomserverQueryRoomVersionCapabilitiesPath,
+ h.httpClient, ctx, request, response,
+ )
}
// QueryRoomVersionForRoom implements RoomServerQueryAPI
@@ -488,12 +425,10 @@ func (h *httpRoomserverInternalAPI) QueryRoomVersionForRoom(
response.RoomVersion = roomVersion
return nil
}
-
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRoomVersionForRoom")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverQueryRoomVersionForRoomPath
- err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ err := httputil.CallInternalRPCAPI(
+ "QueryRoomVersionForRoom", h.roomserverURL+RoomserverQueryRoomVersionForRoomPath,
+ h.httpClient, ctx, request, response,
+ )
if err == nil {
h.cache.StoreRoomVersion(request.RoomID, response.RoomVersion)
}
@@ -505,11 +440,10 @@ func (h *httpRoomserverInternalAPI) QueryCurrentState(
request *api.QueryCurrentStateRequest,
response *api.QueryCurrentStateResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryCurrentState")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverQueryCurrentStatePath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "QueryCurrentState", h.roomserverURL+RoomserverQueryCurrentStatePath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpRoomserverInternalAPI) QueryRoomsForUser(
@@ -517,11 +451,10 @@ func (h *httpRoomserverInternalAPI) QueryRoomsForUser(
request *api.QueryRoomsForUserRequest,
response *api.QueryRoomsForUserResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRoomsForUser")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverQueryRoomsForUserPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "QueryRoomsForUser", h.roomserverURL+RoomserverQueryRoomsForUserPath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpRoomserverInternalAPI) QueryBulkStateContent(
@@ -529,68 +462,82 @@ func (h *httpRoomserverInternalAPI) QueryBulkStateContent(
request *api.QueryBulkStateContentRequest,
response *api.QueryBulkStateContentResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryBulkStateContent")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverQueryBulkStateContentPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "QueryBulkStateContent", h.roomserverURL+RoomserverQueryBulkStateContentPath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpRoomserverInternalAPI) QuerySharedUsers(
- ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse,
+ ctx context.Context,
+ request *api.QuerySharedUsersRequest,
+ response *api.QuerySharedUsersResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySharedUsers")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverQuerySharedUsersPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+ return httputil.CallInternalRPCAPI(
+ "QuerySharedUsers", h.roomserverURL+RoomserverQuerySharedUsersPath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpRoomserverInternalAPI) QueryKnownUsers(
- ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse,
+ ctx context.Context,
+ request *api.QueryKnownUsersRequest,
+ response *api.QueryKnownUsersResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKnownUsers")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverQueryKnownUsersPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+ return httputil.CallInternalRPCAPI(
+ "QueryKnownUsers", h.roomserverURL+RoomserverQueryKnownUsersPath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpRoomserverInternalAPI) QueryAuthChain(
- ctx context.Context, req *api.QueryAuthChainRequest, res *api.QueryAuthChainResponse,
+ ctx context.Context,
+ request *api.QueryAuthChainRequest,
+ response *api.QueryAuthChainResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryAuthChain")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverQueryAuthChainPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+ return httputil.CallInternalRPCAPI(
+ "QueryAuthChain", h.roomserverURL+RoomserverQueryAuthChainPath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpRoomserverInternalAPI) QueryServerBannedFromRoom(
- ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse,
+ ctx context.Context,
+ request *api.QueryServerBannedFromRoomRequest,
+ response *api.QueryServerBannedFromRoomResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerBannedFromRoom")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverQueryServerBannedFromRoomPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+ return httputil.CallInternalRPCAPI(
+ "QueryServerBannedFromRoom", h.roomserverURL+RoomserverQueryServerBannedFromRoomPath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpRoomserverInternalAPI) QueryRestrictedJoinAllowed(
- ctx context.Context, req *api.QueryRestrictedJoinAllowedRequest, res *api.QueryRestrictedJoinAllowedResponse,
+ ctx context.Context,
+ request *api.QueryRestrictedJoinAllowedRequest,
+ response *api.QueryRestrictedJoinAllowedResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRestrictedJoinAllowed")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverQueryRestrictedJoinAllowed
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+ return httputil.CallInternalRPCAPI(
+ "QueryRestrictedJoinAllowed", h.roomserverURL+RoomserverQueryRestrictedJoinAllowed,
+ h.httpClient, ctx, request, response,
+ )
}
-func (h *httpRoomserverInternalAPI) PerformForget(ctx context.Context, req *api.PerformForgetRequest, res *api.PerformForgetResponse) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformForget")
- defer span.Finish()
-
- apiURL := h.roomserverURL + RoomserverPerformForgetPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+func (h *httpRoomserverInternalAPI) PerformForget(
+ ctx context.Context,
+ request *api.PerformForgetRequest,
+ response *api.PerformForgetResponse,
+) error {
+ return httputil.CallInternalRPCAPI(
+ "PerformForget", h.roomserverURL+RoomserverPerformForgetPath,
+ h.httpClient, ctx, request, response,
+ )
}
+
+func (h *httpRoomserverInternalAPI) QueryMembershipAtEvent(ctx context.Context, request *api.QueryMembershipAtEventRequest, response *api.QueryMembershipAtEventResponse) error {
+ return httputil.CallInternalRPCAPI(
+ "QueryMembershiptAtEvent", h.roomserverURL+RoomserverQueryMembershipAtEventPath,
+ h.httpClient, ctx, request, response,
+ )
+}
diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go
index 993381585..3b688174a 100644
--- a/roomserver/inthttp/server.go
+++ b/roomserver/inthttp/server.go
@@ -1,499 +1,201 @@
package inthttp
import (
- "encoding/json"
- "net/http"
-
"github.com/gorilla/mux"
+
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/roomserver/api"
- "github.com/matrix-org/util"
)
// AddRoutes adds the RoomserverInternalAPI handlers to the http.ServeMux.
// nolint: gocyclo
func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) {
- internalAPIMux.Handle(RoomserverInputRoomEventsPath,
- httputil.MakeInternalAPI("inputRoomEvents", func(req *http.Request) util.JSONResponse {
- var request api.InputRoomEventsRequest
- var response api.InputRoomEventsResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- r.InputRoomEvents(req.Context(), &request, &response)
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+ internalAPIMux.Handle(
+ RoomserverInputRoomEventsPath,
+ httputil.MakeInternalRPCAPI("RoomserverInputRoomEvents", r.InputRoomEvents),
)
- internalAPIMux.Handle(RoomserverPerformInvitePath,
- httputil.MakeInternalAPI("performInvite", func(req *http.Request) util.JSONResponse {
- var request api.PerformInviteRequest
- var response api.PerformInviteResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := r.PerformInvite(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+
+ internalAPIMux.Handle(
+ RoomserverPerformInvitePath,
+ httputil.MakeInternalRPCAPI("RoomserverPerformInvite", r.PerformInvite),
)
- internalAPIMux.Handle(RoomserverPerformJoinPath,
- httputil.MakeInternalAPI("performJoin", func(req *http.Request) util.JSONResponse {
- var request api.PerformJoinRequest
- var response api.PerformJoinResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- r.PerformJoin(req.Context(), &request, &response)
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+
+ internalAPIMux.Handle(
+ RoomserverPerformJoinPath,
+ httputil.MakeInternalRPCAPI("RoomserverPerformJoin", r.PerformJoin),
)
- internalAPIMux.Handle(RoomserverPerformLeavePath,
- httputil.MakeInternalAPI("performLeave", func(req *http.Request) util.JSONResponse {
- var request api.PerformLeaveRequest
- var response api.PerformLeaveResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := r.PerformLeave(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+
+ internalAPIMux.Handle(
+ RoomserverPerformLeavePath,
+ httputil.MakeInternalRPCAPI("RoomserverPerformLeave", r.PerformLeave),
)
- internalAPIMux.Handle(RoomserverPerformPeekPath,
- httputil.MakeInternalAPI("performPeek", func(req *http.Request) util.JSONResponse {
- var request api.PerformPeekRequest
- var response api.PerformPeekResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- r.PerformPeek(req.Context(), &request, &response)
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+
+ internalAPIMux.Handle(
+ RoomserverPerformPeekPath,
+ httputil.MakeInternalRPCAPI("RoomserverPerformPeek", r.PerformPeek),
)
- internalAPIMux.Handle(RoomserverPerformInboundPeekPath,
- httputil.MakeInternalAPI("performInboundPeek", func(req *http.Request) util.JSONResponse {
- var request api.PerformInboundPeekRequest
- var response api.PerformInboundPeekResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := r.PerformInboundPeek(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+
+ internalAPIMux.Handle(
+ RoomserverPerformInboundPeekPath,
+ httputil.MakeInternalRPCAPI("RoomserverPerformInboundPeek", r.PerformInboundPeek),
)
- internalAPIMux.Handle(RoomserverPerformPeekPath,
- httputil.MakeInternalAPI("performUnpeek", func(req *http.Request) util.JSONResponse {
- var request api.PerformUnpeekRequest
- var response api.PerformUnpeekResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- r.PerformUnpeek(req.Context(), &request, &response)
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+
+ internalAPIMux.Handle(
+ RoomserverPerformUnpeekPath,
+ httputil.MakeInternalRPCAPI("RoomserverPerformUnpeek", r.PerformUnpeek),
)
- internalAPIMux.Handle(RoomserverPerformRoomUpgradePath,
- httputil.MakeInternalAPI("performRoomUpgrade", func(req *http.Request) util.JSONResponse {
- var request api.PerformRoomUpgradeRequest
- var response api.PerformRoomUpgradeResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- r.PerformRoomUpgrade(req.Context(), &request, &response)
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+
+ internalAPIMux.Handle(
+ RoomserverPerformRoomUpgradePath,
+ httputil.MakeInternalRPCAPI("RoomserverPerformRoomUpgrade", r.PerformRoomUpgrade),
)
- internalAPIMux.Handle(RoomserverPerformPublishPath,
- httputil.MakeInternalAPI("performPublish", func(req *http.Request) util.JSONResponse {
- var request api.PerformPublishRequest
- var response api.PerformPublishResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- r.PerformPublish(req.Context(), &request, &response)
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+
+ internalAPIMux.Handle(
+ RoomserverPerformPublishPath,
+ httputil.MakeInternalRPCAPI("RoomserverPerformPublish", r.PerformPublish),
)
- internalAPIMux.Handle(RoomserverPerformAdminEvacuateRoomPath,
- httputil.MakeInternalAPI("performAdminEvacuateRoom", func(req *http.Request) util.JSONResponse {
- var request api.PerformAdminEvacuateRoomRequest
- var response api.PerformAdminEvacuateRoomResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- r.PerformAdminEvacuateRoom(req.Context(), &request, &response)
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+
+ internalAPIMux.Handle(
+ RoomserverPerformAdminEvacuateRoomPath,
+ httputil.MakeInternalRPCAPI("RoomserverPerformAdminEvacuateRoom", r.PerformAdminEvacuateRoom),
)
- internalAPIMux.Handle(RoomserverPerformAdminEvacuateUserPath,
- httputil.MakeInternalAPI("performAdminEvacuateUser", func(req *http.Request) util.JSONResponse {
- var request api.PerformAdminEvacuateUserRequest
- var response api.PerformAdminEvacuateUserResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- r.PerformAdminEvacuateUser(req.Context(), &request, &response)
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+
+ internalAPIMux.Handle(
+ RoomserverPerformAdminEvacuateUserPath,
+ httputil.MakeInternalRPCAPI("RoomserverPerformAdminEvacuateUser", r.PerformAdminEvacuateUser),
)
+
internalAPIMux.Handle(
RoomserverQueryPublishedRoomsPath,
- httputil.MakeInternalAPI("queryPublishedRooms", func(req *http.Request) util.JSONResponse {
- var request api.QueryPublishedRoomsRequest
- var response api.QueryPublishedRoomsResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.ErrorResponse(err)
- }
- if err := r.QueryPublishedRooms(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+ httputil.MakeInternalRPCAPI("RoomserverQueryPublishedRooms", r.QueryPublishedRooms),
)
+
internalAPIMux.Handle(
RoomserverQueryLatestEventsAndStatePath,
- httputil.MakeInternalAPI("queryLatestEventsAndState", func(req *http.Request) util.JSONResponse {
- var request api.QueryLatestEventsAndStateRequest
- var response api.QueryLatestEventsAndStateResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.ErrorResponse(err)
- }
- if err := r.QueryLatestEventsAndState(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+ httputil.MakeInternalRPCAPI("RoomserverQueryLatestEventsAndState", r.QueryLatestEventsAndState),
)
+
internalAPIMux.Handle(
RoomserverQueryStateAfterEventsPath,
- httputil.MakeInternalAPI("queryStateAfterEvents", func(req *http.Request) util.JSONResponse {
- var request api.QueryStateAfterEventsRequest
- var response api.QueryStateAfterEventsResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.ErrorResponse(err)
- }
- if err := r.QueryStateAfterEvents(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+ httputil.MakeInternalRPCAPI("RoomserverQueryStateAfterEvents", r.QueryStateAfterEvents),
)
+
internalAPIMux.Handle(
RoomserverQueryEventsByIDPath,
- httputil.MakeInternalAPI("queryEventsByID", func(req *http.Request) util.JSONResponse {
- var request api.QueryEventsByIDRequest
- var response api.QueryEventsByIDResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.ErrorResponse(err)
- }
- if err := r.QueryEventsByID(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+ httputil.MakeInternalRPCAPI("RoomserverQueryEventsByID", r.QueryEventsByID),
)
+
internalAPIMux.Handle(
RoomserverQueryMembershipForUserPath,
- httputil.MakeInternalAPI("QueryMembershipForUser", func(req *http.Request) util.JSONResponse {
- var request api.QueryMembershipForUserRequest
- var response api.QueryMembershipForUserResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.ErrorResponse(err)
- }
- if err := r.QueryMembershipForUser(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+ httputil.MakeInternalRPCAPI("RoomserverQueryMembershipForUser", r.QueryMembershipForUser),
)
+
internalAPIMux.Handle(
RoomserverQueryMembershipsForRoomPath,
- httputil.MakeInternalAPI("queryMembershipsForRoom", func(req *http.Request) util.JSONResponse {
- var request api.QueryMembershipsForRoomRequest
- var response api.QueryMembershipsForRoomResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.ErrorResponse(err)
- }
- if err := r.QueryMembershipsForRoom(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+ httputil.MakeInternalRPCAPI("RoomserverQueryMembershipsForRoom", r.QueryMembershipsForRoom),
)
+
internalAPIMux.Handle(
RoomserverQueryServerJoinedToRoomPath,
- httputil.MakeInternalAPI("queryServerJoinedToRoom", func(req *http.Request) util.JSONResponse {
- var request api.QueryServerJoinedToRoomRequest
- var response api.QueryServerJoinedToRoomResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.ErrorResponse(err)
- }
- if err := r.QueryServerJoinedToRoom(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+ httputil.MakeInternalRPCAPI("RoomserverQueryServerJoinedToRoom", r.QueryServerJoinedToRoom),
)
+
internalAPIMux.Handle(
RoomserverQueryServerAllowedToSeeEventPath,
- httputil.MakeInternalAPI("queryServerAllowedToSeeEvent", func(req *http.Request) util.JSONResponse {
- var request api.QueryServerAllowedToSeeEventRequest
- var response api.QueryServerAllowedToSeeEventResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.ErrorResponse(err)
- }
- if err := r.QueryServerAllowedToSeeEvent(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+ httputil.MakeInternalRPCAPI("RoomserverQueryServerAllowedToSeeEvent", r.QueryServerAllowedToSeeEvent),
)
+
internalAPIMux.Handle(
RoomserverQueryMissingEventsPath,
- httputil.MakeInternalAPI("queryMissingEvents", func(req *http.Request) util.JSONResponse {
- var request api.QueryMissingEventsRequest
- var response api.QueryMissingEventsResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.ErrorResponse(err)
- }
- if err := r.QueryMissingEvents(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+ httputil.MakeInternalRPCAPI("RoomserverQueryMissingEvents", r.QueryMissingEvents),
)
+
internalAPIMux.Handle(
RoomserverQueryStateAndAuthChainPath,
- httputil.MakeInternalAPI("queryStateAndAuthChain", func(req *http.Request) util.JSONResponse {
- var request api.QueryStateAndAuthChainRequest
- var response api.QueryStateAndAuthChainResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.ErrorResponse(err)
- }
- if err := r.QueryStateAndAuthChain(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+ httputil.MakeInternalRPCAPI("RoomserverQueryStateAndAuthChain", r.QueryStateAndAuthChain),
)
+
internalAPIMux.Handle(
RoomserverPerformBackfillPath,
- httputil.MakeInternalAPI("PerformBackfill", func(req *http.Request) util.JSONResponse {
- var request api.PerformBackfillRequest
- var response api.PerformBackfillResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.ErrorResponse(err)
- }
- if err := r.PerformBackfill(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+ httputil.MakeInternalRPCAPI("RoomserverPerformBackfill", r.PerformBackfill),
)
+
internalAPIMux.Handle(
RoomserverPerformForgetPath,
- httputil.MakeInternalAPI("PerformForget", func(req *http.Request) util.JSONResponse {
- var request api.PerformForgetRequest
- var response api.PerformForgetResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.ErrorResponse(err)
- }
- if err := r.PerformForget(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+ httputil.MakeInternalRPCAPI("RoomserverPerformForget", r.PerformForget),
)
+
internalAPIMux.Handle(
RoomserverQueryRoomVersionCapabilitiesPath,
- httputil.MakeInternalAPI("QueryRoomVersionCapabilities", func(req *http.Request) util.JSONResponse {
- var request api.QueryRoomVersionCapabilitiesRequest
- var response api.QueryRoomVersionCapabilitiesResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.ErrorResponse(err)
- }
- if err := r.QueryRoomVersionCapabilities(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+ httputil.MakeInternalRPCAPI("RoomserverQueryRoomVersionCapabilities", r.QueryRoomVersionCapabilities),
)
+
internalAPIMux.Handle(
RoomserverQueryRoomVersionForRoomPath,
- httputil.MakeInternalAPI("QueryRoomVersionForRoom", func(req *http.Request) util.JSONResponse {
- var request api.QueryRoomVersionForRoomRequest
- var response api.QueryRoomVersionForRoomResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.ErrorResponse(err)
- }
- if err := r.QueryRoomVersionForRoom(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+ httputil.MakeInternalRPCAPI("RoomserverQueryRoomVersionForRoom", r.QueryRoomVersionForRoom),
)
+
internalAPIMux.Handle(
RoomserverSetRoomAliasPath,
- httputil.MakeInternalAPI("setRoomAlias", func(req *http.Request) util.JSONResponse {
- var request api.SetRoomAliasRequest
- var response api.SetRoomAliasResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.ErrorResponse(err)
- }
- if err := r.SetRoomAlias(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+ httputil.MakeInternalRPCAPI("RoomserverSetRoomAlias", r.SetRoomAlias),
)
+
internalAPIMux.Handle(
RoomserverGetRoomIDForAliasPath,
- httputil.MakeInternalAPI("GetRoomIDForAlias", func(req *http.Request) util.JSONResponse {
- var request api.GetRoomIDForAliasRequest
- var response api.GetRoomIDForAliasResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.ErrorResponse(err)
- }
- if err := r.GetRoomIDForAlias(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+ httputil.MakeInternalRPCAPI("RoomserverGetRoomIDForAlias", r.GetRoomIDForAlias),
)
+
internalAPIMux.Handle(
RoomserverGetAliasesForRoomIDPath,
- httputil.MakeInternalAPI("getAliasesForRoomID", func(req *http.Request) util.JSONResponse {
- var request api.GetAliasesForRoomIDRequest
- var response api.GetAliasesForRoomIDResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.ErrorResponse(err)
- }
- if err := r.GetAliasesForRoomID(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+ httputil.MakeInternalRPCAPI("RoomserverGetAliasesForRoomID", r.GetAliasesForRoomID),
)
+
internalAPIMux.Handle(
RoomserverRemoveRoomAliasPath,
- httputil.MakeInternalAPI("removeRoomAlias", func(req *http.Request) util.JSONResponse {
- var request api.RemoveRoomAliasRequest
- var response api.RemoveRoomAliasResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.ErrorResponse(err)
- }
- if err := r.RemoveRoomAlias(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+ httputil.MakeInternalRPCAPI("RoomserverRemoveRoomAlias", r.RemoveRoomAlias),
)
- internalAPIMux.Handle(RoomserverQueryCurrentStatePath,
- httputil.MakeInternalAPI("queryCurrentState", func(req *http.Request) util.JSONResponse {
- request := api.QueryCurrentStateRequest{}
- response := api.QueryCurrentStateResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := r.QueryCurrentState(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+
+ internalAPIMux.Handle(
+ RoomserverQueryCurrentStatePath,
+ httputil.MakeInternalRPCAPI("RoomserverQueryCurrentState", r.QueryCurrentState),
)
- internalAPIMux.Handle(RoomserverQueryRoomsForUserPath,
- httputil.MakeInternalAPI("queryRoomsForUser", func(req *http.Request) util.JSONResponse {
- request := api.QueryRoomsForUserRequest{}
- response := api.QueryRoomsForUserResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := r.QueryRoomsForUser(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+
+ internalAPIMux.Handle(
+ RoomserverQueryRoomsForUserPath,
+ httputil.MakeInternalRPCAPI("RoomserverQueryRoomsForUser", r.QueryRoomsForUser),
)
- internalAPIMux.Handle(RoomserverQueryBulkStateContentPath,
- httputil.MakeInternalAPI("queryBulkStateContent", func(req *http.Request) util.JSONResponse {
- request := api.QueryBulkStateContentRequest{}
- response := api.QueryBulkStateContentResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := r.QueryBulkStateContent(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+
+ internalAPIMux.Handle(
+ RoomserverQueryBulkStateContentPath,
+ httputil.MakeInternalRPCAPI("RoomserverQueryBulkStateContent", r.QueryBulkStateContent),
)
- internalAPIMux.Handle(RoomserverQuerySharedUsersPath,
- httputil.MakeInternalAPI("querySharedUsers", func(req *http.Request) util.JSONResponse {
- request := api.QuerySharedUsersRequest{}
- response := api.QuerySharedUsersResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := r.QuerySharedUsers(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+
+ internalAPIMux.Handle(
+ RoomserverQuerySharedUsersPath,
+ httputil.MakeInternalRPCAPI("RoomserverQuerySharedUsers", r.QuerySharedUsers),
)
- internalAPIMux.Handle(RoomserverQueryKnownUsersPath,
- httputil.MakeInternalAPI("queryKnownUsers", func(req *http.Request) util.JSONResponse {
- request := api.QueryKnownUsersRequest{}
- response := api.QueryKnownUsersResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := r.QueryKnownUsers(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+
+ internalAPIMux.Handle(
+ RoomserverQueryKnownUsersPath,
+ httputil.MakeInternalRPCAPI("RoomserverQueryKnownUsers", r.QueryKnownUsers),
)
- internalAPIMux.Handle(RoomserverQueryServerBannedFromRoomPath,
- httputil.MakeInternalAPI("queryServerBannedFromRoom", func(req *http.Request) util.JSONResponse {
- request := api.QueryServerBannedFromRoomRequest{}
- response := api.QueryServerBannedFromRoomResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := r.QueryServerBannedFromRoom(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+
+ internalAPIMux.Handle(
+ RoomserverQueryServerBannedFromRoomPath,
+ httputil.MakeInternalRPCAPI("RoomserverQueryServerBannedFromRoom", r.QueryServerBannedFromRoom),
)
- internalAPIMux.Handle(RoomserverQueryAuthChainPath,
- httputil.MakeInternalAPI("queryAuthChain", func(req *http.Request) util.JSONResponse {
- request := api.QueryAuthChainRequest{}
- response := api.QueryAuthChainResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := r.QueryAuthChain(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+
+ internalAPIMux.Handle(
+ RoomserverQueryAuthChainPath,
+ httputil.MakeInternalRPCAPI("RoomserverQueryAuthChain", r.QueryAuthChain),
)
- internalAPIMux.Handle(RoomserverQueryRestrictedJoinAllowed,
- httputil.MakeInternalAPI("queryRestrictedJoinAllowed", func(req *http.Request) util.JSONResponse {
- request := api.QueryRestrictedJoinAllowedRequest{}
- response := api.QueryRestrictedJoinAllowedResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := r.QueryRestrictedJoinAllowed(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+
+ internalAPIMux.Handle(
+ RoomserverQueryRestrictedJoinAllowed,
+ httputil.MakeInternalRPCAPI("RoomserverQueryRestrictedJoinAllowed", r.QueryRestrictedJoinAllowed),
+ )
+ internalAPIMux.Handle(
+ RoomserverQueryMembershipAtEventPath,
+ httputil.MakeInternalRPCAPI("RoomserverQueryMembershipAtEventPath", r.QueryMembershipAtEvent),
)
}
diff --git a/roomserver/state/state.go b/roomserver/state/state.go
index d1d24b099..a40a2e9ba 100644
--- a/roomserver/state/state.go
+++ b/roomserver/state/state.go
@@ -23,12 +23,11 @@ import (
"sync"
"time"
+ "github.com/matrix-org/dendrite/roomserver/types"
+ "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/opentracing/opentracing-go"
"github.com/prometheus/client_golang/prometheus"
-
- "github.com/matrix-org/dendrite/roomserver/types"
- "github.com/matrix-org/gomatrixserverlib"
)
type StateResolutionStorage interface {
@@ -124,6 +123,84 @@ func (v *StateResolution) LoadStateAtEvent(
return stateEntries, nil
}
+func (v *StateResolution) LoadMembershipAtEvent(
+ ctx context.Context, eventIDs []string, stateKeyNID types.EventStateKeyNID,
+) (map[string][]types.StateEntry, error) {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadMembershipAtEvent")
+ defer span.Finish()
+
+ // De-dupe snapshotNIDs
+ snapshotNIDMap := make(map[types.StateSnapshotNID][]string) // map from snapshot NID to eventIDs
+ for i := range eventIDs {
+ eventID := eventIDs[i]
+ snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID)
+ if err != nil {
+ return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %w", eventID, err)
+ }
+ if snapshotNID == 0 {
+ return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID(%s) returned 0 NID, was this event stored?", eventID)
+ }
+ snapshotNIDMap[snapshotNID] = append(snapshotNIDMap[snapshotNID], eventID)
+ }
+
+ snapshotNIDs := make([]types.StateSnapshotNID, 0, len(snapshotNIDMap))
+ for nid := range snapshotNIDMap {
+ snapshotNIDs = append(snapshotNIDs, nid)
+ }
+
+ stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, snapshotNIDs)
+ if err != nil {
+ return nil, err
+ }
+
+ result := make(map[string][]types.StateEntry)
+ for _, stateBlockNIDList := range stateBlockNIDLists {
+ // Query the membership event for the user at the given stateblocks
+ stateEntryLists, err := v.db.StateEntriesForTuples(ctx, stateBlockNIDList.StateBlockNIDs, []types.StateKeyTuple{
+ {
+ EventTypeNID: types.MRoomMemberNID,
+ EventStateKeyNID: stateKeyNID,
+ },
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ evIDs := snapshotNIDMap[stateBlockNIDList.StateSnapshotNID]
+
+ for _, evID := range evIDs {
+ for _, x := range stateEntryLists {
+ result[evID] = append(result[evID], x.StateEntries...)
+ }
+ }
+ }
+
+ return result, nil
+}
+
+// LoadStateAtEvent loads the full state of a room before a particular event.
+func (v *StateResolution) LoadStateAtEventForHistoryVisibility(
+ ctx context.Context, eventID string,
+) ([]types.StateEntry, error) {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadStateAtEvent")
+ defer span.Finish()
+
+ snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID)
+ if err != nil {
+ return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %w", eventID, err)
+ }
+ if snapshotNID == 0 {
+ return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID(%s) returned 0 NID, was this event stored?", eventID)
+ }
+
+ stateEntries, err := v.LoadStateAtSnapshot(ctx, snapshotNID)
+ if err != nil {
+ return nil, err
+ }
+
+ return stateEntries, nil
+}
+
// LoadCombinedStateAfterEvents loads a snapshot of the state after each of the events
// and combines those snapshots together into a single list. At this point it is
// possible to run into duplicate (type, state key) tuples.
diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go
index a98fda073..b12025c41 100644
--- a/roomserver/storage/interface.go
+++ b/roomserver/storage/interface.go
@@ -166,4 +166,6 @@ type Database interface {
GetKnownRooms(ctx context.Context) ([]string, error)
// ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room
ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error
+
+ GetHistoryVisibilityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]*gomatrixserverlib.Event, error)
}
diff --git a/roomserver/storage/postgres/deltas/20201028212440_add_forgotten_column.go b/roomserver/storage/postgres/deltas/20201028212440_add_forgotten_column.go
index f3bd8632f..61d4dba87 100644
--- a/roomserver/storage/postgres/deltas/20201028212440_add_forgotten_column.go
+++ b/roomserver/storage/postgres/deltas/20201028212440_add_forgotten_column.go
@@ -15,32 +15,21 @@
package deltas
import (
+ "context"
"database/sql"
"fmt"
-
- "github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/pressly/goose"
)
-func LoadFromGoose() {
- goose.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn)
- goose.AddMigration(UpStateBlocksRefactor, DownStateBlocksRefactor)
-}
-
-func LoadAddForgottenColumn(m *sqlutil.Migrations) {
- m.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn)
-}
-
-func UpAddForgottenColumn(tx *sql.Tx) error {
- _, err := tx.Exec(`ALTER TABLE roomserver_membership ADD COLUMN IF NOT EXISTS forgotten BOOLEAN NOT NULL DEFAULT false;`)
+func UpAddForgottenColumn(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `ALTER TABLE roomserver_membership ADD COLUMN IF NOT EXISTS forgotten BOOLEAN NOT NULL DEFAULT false;`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
-func DownAddForgottenColumn(tx *sql.Tx) error {
- _, err := tx.Exec(`ALTER TABLE roomserver_membership DROP COLUMN IF EXISTS forgotten;`)
+func DownAddForgottenColumn(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `ALTER TABLE roomserver_membership DROP COLUMN IF EXISTS forgotten;`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
diff --git a/roomserver/storage/postgres/deltas/2021041615092700_state_blocks_refactor.go b/roomserver/storage/postgres/deltas/2021041615092700_state_blocks_refactor.go
index 06442a4c3..355c49b14 100644
--- a/roomserver/storage/postgres/deltas/2021041615092700_state_blocks_refactor.go
+++ b/roomserver/storage/postgres/deltas/2021041615092700_state_blocks_refactor.go
@@ -15,11 +15,11 @@
package deltas
import (
+ "context"
"database/sql"
"fmt"
"github.com/lib/pq"
- "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
@@ -36,48 +36,44 @@ type stateBlockData struct {
EventNIDs types.EventNIDs
}
-func LoadStateBlocksRefactor(m *sqlutil.Migrations) {
- m.AddMigration(UpStateBlocksRefactor, DownStateBlocksRefactor)
-}
-
// nolint:gocyclo
-func UpStateBlocksRefactor(tx *sql.Tx) error {
+func UpStateBlocksRefactor(ctx context.Context, tx *sql.Tx) error {
logrus.Warn("Performing state storage upgrade. Please wait, this may take some time!")
defer logrus.Warn("State storage upgrade complete")
var snapshotcount int
var maxsnapshotid int
var maxblockid int
- if err := tx.QueryRow(`SELECT COUNT(DISTINCT state_snapshot_nid) FROM roomserver_state_snapshots;`).Scan(&snapshotcount); err != nil {
- return fmt.Errorf("tx.QueryRow.Scan (count snapshots): %w", err)
+ if err := tx.QueryRowContext(ctx, `SELECT COUNT(DISTINCT state_snapshot_nid) FROM roomserver_state_snapshots;`).Scan(&snapshotcount); err != nil {
+ return fmt.Errorf("tx.QueryRowContext.Scan (count snapshots): %w", err)
}
- if err := tx.QueryRow(`SELECT COALESCE(MAX(state_snapshot_nid),0) FROM roomserver_state_snapshots;`).Scan(&maxsnapshotid); err != nil {
- return fmt.Errorf("tx.QueryRow.Scan (count snapshots): %w", err)
+ if err := tx.QueryRowContext(ctx, `SELECT COALESCE(MAX(state_snapshot_nid),0) FROM roomserver_state_snapshots;`).Scan(&maxsnapshotid); err != nil {
+ return fmt.Errorf("tx.QueryRowContext.Scan (count snapshots): %w", err)
}
- if err := tx.QueryRow(`SELECT COALESCE(MAX(state_block_nid),0) FROM roomserver_state_block;`).Scan(&maxblockid); err != nil {
- return fmt.Errorf("tx.QueryRow.Scan (count snapshots): %w", err)
+ if err := tx.QueryRowContext(ctx, `SELECT COALESCE(MAX(state_block_nid),0) FROM roomserver_state_block;`).Scan(&maxblockid); err != nil {
+ return fmt.Errorf("tx.QueryRowContext.Scan (count snapshots): %w", err)
}
maxsnapshotid++
maxblockid++
- if _, err := tx.Exec(`ALTER TABLE roomserver_state_block RENAME TO _roomserver_state_block;`); err != nil {
- return fmt.Errorf("tx.Exec: %w", err)
+ if _, err := tx.ExecContext(ctx, `ALTER TABLE roomserver_state_block RENAME TO _roomserver_state_block;`); err != nil {
+ return fmt.Errorf("tx.ExecContext: %w", err)
}
- if _, err := tx.Exec(`ALTER TABLE roomserver_state_snapshots RENAME TO _roomserver_state_snapshots;`); err != nil {
- return fmt.Errorf("tx.Exec: %w", err)
+ if _, err := tx.ExecContext(ctx, `ALTER TABLE roomserver_state_snapshots RENAME TO _roomserver_state_snapshots;`); err != nil {
+ return fmt.Errorf("tx.ExecContext: %w", err)
}
// We create new sequences starting with the maximum state snapshot and block NIDs.
// This means that all newly created snapshots and blocks by the migration will have
// NIDs higher than these values, so that when we come to update the references to
// these NIDs using UPDATE statements, we can guarantee we are only ever updating old
// values and not accidentally overwriting new ones.
- if _, err := tx.Exec(fmt.Sprintf(`CREATE SEQUENCE roomserver_state_block_nid_sequence START WITH %d;`, maxblockid)); err != nil {
- return fmt.Errorf("tx.Exec: %w", err)
+ if _, err := tx.ExecContext(ctx, fmt.Sprintf(`CREATE SEQUENCE roomserver_state_block_nid_sequence START WITH %d;`, maxblockid)); err != nil {
+ return fmt.Errorf("tx.ExecContext: %w", err)
}
- if _, err := tx.Exec(fmt.Sprintf(`CREATE SEQUENCE roomserver_state_snapshot_nid_sequence START WITH %d;`, maxsnapshotid)); err != nil {
- return fmt.Errorf("tx.Exec: %w", err)
+ if _, err := tx.ExecContext(ctx, fmt.Sprintf(`CREATE SEQUENCE roomserver_state_snapshot_nid_sequence START WITH %d;`, maxsnapshotid)); err != nil {
+ return fmt.Errorf("tx.ExecContext: %w", err)
}
- _, err := tx.Exec(`
+ _, err := tx.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS roomserver_state_block (
state_block_nid bigint PRIMARY KEY DEFAULT nextval('roomserver_state_block_nid_sequence'),
state_block_hash BYTEA UNIQUE,
@@ -87,7 +83,7 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
if err != nil {
return fmt.Errorf("tx.Exec (create blocks table): %w", err)
}
- _, err = tx.Exec(`
+ _, err = tx.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS roomserver_state_snapshots (
state_snapshot_nid bigint PRIMARY KEY DEFAULT nextval('roomserver_state_snapshot_nid_sequence'),
state_snapshot_hash BYTEA UNIQUE,
@@ -104,7 +100,7 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
// in question a state snapshot NID of 0 to indicate 'no snapshot'.
// If we don't do this, we'll fail the assertions later on which try to ensure we didn't forget
// any snapshots.
- _, err = tx.Exec(
+ _, err = tx.ExecContext(ctx,
`UPDATE roomserver_events SET state_snapshot_nid = 0 WHERE event_type_nid = $1 AND event_state_key_nid = $2`,
types.MRoomCreateNID, types.EmptyStateKeyNID,
)
@@ -115,7 +111,7 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
batchsize := 100
for batchoffset := 0; batchoffset < snapshotcount; batchoffset += batchsize {
var snapshotrows *sql.Rows
- snapshotrows, err = tx.Query(`
+ snapshotrows, err = tx.QueryContext(ctx, `
SELECT
state_snapshot_nid,
room_nid,
@@ -146,7 +142,7 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
state_block_nid;
`, batchsize, batchoffset)
if err != nil {
- return fmt.Errorf("tx.Query: %w", err)
+ return fmt.Errorf("tx.QueryContext: %w", err)
}
logrus.Warnf("Rewriting snapshots %d-%d of %d...", batchoffset, batchoffset+batchsize, snapshotcount)
@@ -183,7 +179,7 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
// fill in bad create snapshots
for _, s := range badCreateSnapshots {
var createEventNID types.EventNID
- err = tx.QueryRow(
+ err = tx.QueryRowContext(ctx,
`SELECT event_nid FROM roomserver_events WHERE state_snapshot_nid = $1 AND event_type_nid = 1`, s.StateSnapshotNID,
).Scan(&createEventNID)
if err == sql.ErrNoRows {
@@ -208,7 +204,7 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
}
var blocknid types.StateBlockNID
- err = tx.QueryRow(`
+ err = tx.QueryRowContext(ctx, `
INSERT INTO roomserver_state_block (state_block_hash, event_nids)
VALUES ($1, $2)
ON CONFLICT (state_block_hash) DO UPDATE SET event_nids=$2
@@ -227,7 +223,7 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
}
var newNID types.StateSnapshotNID
- err = tx.QueryRow(`
+ err = tx.QueryRowContext(ctx, `
INSERT INTO roomserver_state_snapshots (state_snapshot_hash, room_nid, state_block_nids)
VALUES ($1, $2, $3)
ON CONFLICT (state_snapshot_hash) DO UPDATE SET room_nid=$2
@@ -237,12 +233,12 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
return fmt.Errorf("tx.QueryRow.Scan (insert new snapshot): %w", err)
}
- if _, err = tx.Exec(`UPDATE roomserver_events SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$3`, newNID, snapshotdata.StateSnapshotNID, maxsnapshotid); err != nil {
- return fmt.Errorf("tx.Exec (update events): %w", err)
+ if _, err = tx.ExecContext(ctx, `UPDATE roomserver_events SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$3`, newNID, snapshotdata.StateSnapshotNID, maxsnapshotid); err != nil {
+ return fmt.Errorf("tx.ExecContext (update events): %w", err)
}
- if _, err = tx.Exec(`UPDATE roomserver_rooms SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$3`, newNID, snapshotdata.StateSnapshotNID, maxsnapshotid); err != nil {
- return fmt.Errorf("tx.Exec (update rooms): %w", err)
+ if _, err = tx.ExecContext(ctx, `UPDATE roomserver_rooms SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$3`, newNID, snapshotdata.StateSnapshotNID, maxsnapshotid); err != nil {
+ return fmt.Errorf("tx.ExecContext (update rooms): %w", err)
}
}
}
@@ -252,13 +248,13 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
// in roomserver_state_snapshots
var count int64
- if err = tx.QueryRow(`SELECT COUNT(*) FROM roomserver_events WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, maxsnapshotid).Scan(&count); err != nil {
+ if err = tx.QueryRowContext(ctx, `SELECT COUNT(*) FROM roomserver_events WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, maxsnapshotid).Scan(&count); err != nil {
return fmt.Errorf("assertion query failed: %s", err)
}
if count > 0 {
var res sql.Result
var c int64
- res, err = tx.Exec(`UPDATE roomserver_events SET state_snapshot_nid = 0 WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, maxsnapshotid)
+ res, err = tx.ExecContext(ctx, `UPDATE roomserver_events SET state_snapshot_nid = 0 WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, maxsnapshotid)
if err != nil && err != sql.ErrNoRows {
return fmt.Errorf("failed to reset invalid state snapshots: %w", err)
}
@@ -268,13 +264,13 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
return fmt.Errorf("expected to reset %d event(s) but only updated %d event(s)", count, c)
}
}
- if err = tx.QueryRow(`SELECT COUNT(*) FROM roomserver_rooms WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, maxsnapshotid).Scan(&count); err != nil {
+ if err = tx.QueryRowContext(ctx, `SELECT COUNT(*) FROM roomserver_rooms WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, maxsnapshotid).Scan(&count); err != nil {
return fmt.Errorf("assertion query failed: %s", err)
}
if count > 0 {
var debugRoomID string
var debugSnapNID, debugLastEventNID int64
- err = tx.QueryRow(
+ err = tx.QueryRowContext(ctx,
`SELECT room_id, state_snapshot_nid, last_event_sent_nid FROM roomserver_rooms WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, maxsnapshotid,
).Scan(&debugRoomID, &debugSnapNID, &debugLastEventNID)
if err != nil {
@@ -291,13 +287,13 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
return fmt.Errorf("%d rooms exist in roomserver_rooms which have not been converted to a new state_snapshot_nid; this is a bug, please report", count)
}
- if _, err = tx.Exec(`
+ if _, err = tx.ExecContext(ctx, `
DROP TABLE _roomserver_state_snapshots;
DROP SEQUENCE roomserver_state_snapshot_nid_seq;
`); err != nil {
return fmt.Errorf("tx.Exec (delete old snapshot table): %w", err)
}
- if _, err = tx.Exec(`
+ if _, err = tx.ExecContext(ctx, `
DROP TABLE _roomserver_state_block;
DROP SEQUENCE roomserver_state_block_nid_seq;
`); err != nil {
@@ -307,6 +303,6 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
return nil
}
-func DownStateBlocksRefactor(tx *sql.Tx) error {
+func DownStateBlocksRefactor(ctx context.Context, tx *sql.Tx) error {
panic("Downgrading state storage is not supported")
}
diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go
index ce626ad1d..bd3fd5592 100644
--- a/roomserver/storage/postgres/membership_table.go
+++ b/roomserver/storage/postgres/membership_table.go
@@ -23,6 +23,7 @@ import (
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/roomserver/storage/postgres/deltas"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
@@ -86,24 +87,24 @@ const insertMembershipSQL = "" +
const selectMembershipFromRoomAndTargetSQL = "" +
"SELECT membership_nid, event_nid, forgotten FROM roomserver_membership" +
- " WHERE room_nid = $1 AND target_nid = $2"
+ " WHERE room_nid = $1 AND event_nid != 0 AND target_nid = $2"
const selectMembershipsFromRoomAndMembershipSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
- " WHERE room_nid = $1 AND membership_nid = $2 and forgotten = false"
+ " WHERE room_nid = $1 AND event_nid != 0 AND membership_nid = $2 and forgotten = false"
const selectLocalMembershipsFromRoomAndMembershipSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
- " WHERE room_nid = $1 AND membership_nid = $2" +
+ " WHERE room_nid = $1 AND event_nid != 0 AND membership_nid = $2" +
" AND target_local = true and forgotten = false"
const selectMembershipsFromRoomSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
- " WHERE room_nid = $1 and forgotten = false"
+ " WHERE room_nid = $1 AND event_nid != 0 and forgotten = false"
const selectLocalMembershipsFromRoomSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
- " WHERE room_nid = $1" +
+ " WHERE room_nid = $1 AND event_nid != 0" +
" AND target_local = true and forgotten = false"
const selectMembershipForUpdateSQL = "" +
@@ -118,6 +119,9 @@ const updateMembershipForgetRoom = "" +
"UPDATE roomserver_membership SET forgotten = $3" +
" WHERE room_nid = $1 AND target_nid = $2"
+const deleteMembershipSQL = "" +
+ "DELETE FROM roomserver_membership WHERE room_nid = $1 AND target_nid = $2"
+
const selectRoomsWithMembershipSQL = "" +
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false"
@@ -165,11 +169,20 @@ type membershipStatements struct {
updateMembershipForgetRoomStmt *sql.Stmt
selectLocalServerInRoomStmt *sql.Stmt
selectServerInRoomStmt *sql.Stmt
+ deleteMembershipStmt *sql.Stmt
}
func CreateMembershipTable(db *sql.DB) error {
_, err := db.Exec(membershipSchema)
- return err
+ if err != nil {
+ return err
+ }
+ m := sqlutil.NewMigrator(db)
+ m.AddMigrations(sqlutil.Migration{
+ Version: "roomserver: add forgotten column",
+ Up: deltas.UpAddForgottenColumn,
+ })
+ return m.Up(context.Background())
}
func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) {
@@ -191,6 +204,7 @@ func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) {
{&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom},
{&s.selectLocalServerInRoomStmt, selectLocalServerInRoomSQL},
{&s.selectServerInRoomStmt, selectServerInRoomSQL},
+ {&s.deleteMembershipStmt, deleteMembershipSQL},
}.Prepare(db)
}
@@ -412,3 +426,13 @@ func (s *membershipStatements) SelectServerInRoom(
}
return roomNID == nid, nil
}
+
+func (s *membershipStatements) DeleteMembership(
+ ctx context.Context, txn *sql.Tx,
+ roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
+) error {
+ _, err := sqlutil.TxStmt(txn, s.deleteMembershipStmt).ExecContext(
+ ctx, roomNID, targetUserNID,
+ )
+ return err
+}
diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go
index 24362af74..994399532 100644
--- a/roomserver/storage/postgres/rooms_table.go
+++ b/roomserver/storage/postgres/rooms_table.go
@@ -147,14 +147,16 @@ func (s *roomStatements) InsertRoomNID(
func (s *roomStatements) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) {
var info types.RoomInfo
var latestNIDs pq.Int64Array
+ var stateSnapshotNID types.StateSnapshotNID
stmt := sqlutil.TxStmt(txn, s.selectRoomInfoStmt)
err := stmt.QueryRowContext(ctx, roomID).Scan(
- &info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDs,
+ &info.RoomVersion, &info.RoomNID, &stateSnapshotNID, &latestNIDs,
)
if err == sql.ErrNoRows {
return nil, nil
}
- info.IsStub = len(latestNIDs) == 0
+ info.SetStateSnapshotNID(stateSnapshotNID)
+ info.SetIsStub(len(latestNIDs) == 0)
return &info, err
}
diff --git a/roomserver/storage/postgres/state_snapshot_table.go b/roomserver/storage/postgres/state_snapshot_table.go
index a24b7f3f0..99c76befe 100644
--- a/roomserver/storage/postgres/state_snapshot_table.go
+++ b/roomserver/storage/postgres/state_snapshot_table.go
@@ -72,9 +72,35 @@ const bulkSelectStateBlockNIDsSQL = "" +
"SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" +
" WHERE state_snapshot_nid = ANY($1) ORDER BY state_snapshot_nid ASC"
+// Looks up both the history visibility event and relevant membership events from
+// a given domain name from a given state snapshot. This is used to optimise the
+// helpers.CheckServerAllowedToSeeEvent function.
+// TODO: There's a sequence scan here because of the hash join strategy, which is
+// probably O(n) on state key entries, so there must be a way to avoid that somehow.
+// Event type NIDs are:
+// - 5: m.room.member as per https://github.com/matrix-org/dendrite/blob/c7f7aec4d07d59120d37d5b16a900f6d608a75c4/roomserver/storage/postgres/event_types_table.go#L40
+// - 7: m.room.history_visibility as per https://github.com/matrix-org/dendrite/blob/c7f7aec4d07d59120d37d5b16a900f6d608a75c4/roomserver/storage/postgres/event_types_table.go#L42
+const bulkSelectStateForHistoryVisibilitySQL = `
+ SELECT event_nid FROM (
+ SELECT event_nid, event_type_nid, event_state_key_nid FROM roomserver_events
+ WHERE (event_type_nid = 5 OR event_type_nid = 7)
+ AND event_nid = ANY(
+ SELECT UNNEST(event_nids) FROM roomserver_state_block
+ WHERE state_block_nid = ANY(
+ SELECT UNNEST(state_block_nids) FROM roomserver_state_snapshots
+ WHERE state_snapshot_nid = $1
+ )
+ )
+ ) AS roomserver_events
+ INNER JOIN roomserver_event_state_keys
+ ON roomserver_events.event_state_key_nid = roomserver_event_state_keys.event_state_key_nid
+ AND (event_type_nid = 7 OR event_state_key LIKE '%:' || $2);
+`
+
type stateSnapshotStatements struct {
- insertStateStmt *sql.Stmt
- bulkSelectStateBlockNIDsStmt *sql.Stmt
+ insertStateStmt *sql.Stmt
+ bulkSelectStateBlockNIDsStmt *sql.Stmt
+ bulkSelectStateForHistoryVisibilityStmt *sql.Stmt
}
func CreateStateSnapshotTable(db *sql.DB) error {
@@ -88,6 +114,7 @@ func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
return s, sqlutil.StatementList{
{&s.insertStateStmt, insertStateSQL},
{&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL},
+ {&s.bulkSelectStateForHistoryVisibilityStmt, bulkSelectStateForHistoryVisibilitySQL},
}.Prepare(db)
}
@@ -136,3 +163,23 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
}
return results, nil
}
+
+func (s *stateSnapshotStatements) BulkSelectStateForHistoryVisibility(
+ ctx context.Context, txn *sql.Tx, stateSnapshotNID types.StateSnapshotNID, domain string,
+) ([]types.EventNID, error) {
+ stmt := sqlutil.TxStmt(txn, s.bulkSelectStateForHistoryVisibilityStmt)
+ rows, err := stmt.QueryContext(ctx, stateSnapshotNID, domain)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+ results := make([]types.EventNID, 0, 16)
+ for rows.Next() {
+ var eventNID types.EventNID
+ if err = rows.Scan(&eventNID); err != nil {
+ return nil, err
+ }
+ results = append(results, eventNID)
+ }
+ return results, rows.Err()
+}
diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go
index 70ea4d8ba..f47a64c80 100644
--- a/roomserver/storage/postgres/storage.go
+++ b/roomserver/storage/postgres/storage.go
@@ -19,6 +19,7 @@ import (
"database/sql"
"fmt"
+ "github.com/lib/pq"
// Import the postgres database driver.
_ "github.com/lib/pq"
@@ -45,22 +46,41 @@ func Open(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache c
}
// Create the tables.
- if err := d.create(db); err != nil {
+ if err = d.create(db); err != nil {
return nil, err
}
- // Then execute the migrations. By this point the tables are created with the latest
- // schemas.
- m := sqlutil.NewMigrations()
- deltas.LoadAddForgottenColumn(m)
- deltas.LoadStateBlocksRefactor(m)
- if err := m.RunDeltas(db, dbProperties); err != nil {
- return nil, err
+ // Special case, since this migration uses several tables, so it needs to
+ // be sure that all tables are created first.
+ // TODO: Remove when we are sure we are not having goose artefacts in the db
+ // This forces an error, which indicates the migration is already applied, since the
+ // column event_nid was removed from the table
+ var eventNID int
+ err = db.QueryRow("SELECT event_nid FROM roomserver_state_block LIMIT 1;").Scan(&eventNID)
+ if err == nil {
+ m := sqlutil.NewMigrator(db)
+ m.AddMigrations(sqlutil.Migration{
+ Version: "roomserver: state blocks refactor",
+ Up: deltas.UpStateBlocksRefactor,
+ })
+ if err = m.Up(base.Context()); err != nil {
+ return nil, err
+ }
+ } else {
+ switch e := err.(type) {
+ case *pq.Error:
+ // ignore undefined_column (42703) errors, as this is expected at this point
+ if e.Code != "42703" {
+ return nil, err
+ }
+ default:
+ return nil, err
+ }
}
// Then prepare the statements. Now that the migrations have run, any columns referred
// to in the database code should now exist.
- if err := d.prepare(db, writer, cache); err != nil {
+ if err = d.prepare(db, writer, cache); err != nil {
return nil, err
}
diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go
index ebfcef569..07fb697f9 100644
--- a/roomserver/storage/shared/membership_updater.go
+++ b/roomserver/storage/shared/membership_updater.go
@@ -15,7 +15,7 @@ type MembershipUpdater struct {
d *Database
roomNID types.RoomNID
targetUserNID types.EventStateKeyNID
- membership tables.MembershipState
+ oldMembership tables.MembershipState
}
func NewMembershipUpdater(
@@ -30,7 +30,6 @@ func NewMembershipUpdater(
if err != nil {
return err
}
-
targetUserNID, err = d.assignStateKeyNID(ctx, targetUserID)
if err != nil {
return err
@@ -73,139 +72,62 @@ func (d *Database) membershipUpdaterTxn(
// IsInvite implements types.MembershipUpdater
func (u *MembershipUpdater) IsInvite() bool {
- return u.membership == tables.MembershipStateInvite
+ return u.oldMembership == tables.MembershipStateInvite
}
// IsJoin implements types.MembershipUpdater
func (u *MembershipUpdater) IsJoin() bool {
- return u.membership == tables.MembershipStateJoin
+ return u.oldMembership == tables.MembershipStateJoin
}
// IsLeave implements types.MembershipUpdater
func (u *MembershipUpdater) IsLeave() bool {
- return u.membership == tables.MembershipStateLeaveOrBan
+ return u.oldMembership == tables.MembershipStateLeaveOrBan
}
// IsKnock implements types.MembershipUpdater
func (u *MembershipUpdater) IsKnock() bool {
- return u.membership == tables.MembershipStateKnock
+ return u.oldMembership == tables.MembershipStateKnock
}
-// SetToInvite implements types.MembershipUpdater
-func (u *MembershipUpdater) SetToInvite(event *gomatrixserverlib.Event) (bool, error) {
- var inserted bool
- err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
+func (u *MembershipUpdater) Delete() error {
+ if _, err := u.d.InvitesTable.UpdateInviteRetired(u.ctx, u.txn, u.roomNID, u.targetUserNID); err != nil {
+ return err
+ }
+ return u.d.MembershipTable.DeleteMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID)
+}
+
+func (u *MembershipUpdater) Update(newMembership tables.MembershipState, event *types.Event) (bool, []string, error) {
+ var inserted bool // Did the query result in a membership change?
+ var retired []string // Did we retire any updates in the process?
+ return inserted, retired, u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, event.Sender())
if err != nil {
return fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
}
- inserted, err = u.d.InvitesTable.InsertInviteEvent(
- u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(),
- )
+ inserted, err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, newMembership, event.EventNID, false)
if err != nil {
- return fmt.Errorf("u.d.InvitesTable.InsertInviteEvent: %w", err)
+ return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
}
- if u.membership != tables.MembershipStateInvite {
- if inserted, err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0, false); err != nil {
- return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
+ if !inserted {
+ return nil
+ }
+ switch {
+ case u.oldMembership != tables.MembershipStateInvite && newMembership == tables.MembershipStateInvite:
+ inserted, err = u.d.InvitesTable.InsertInviteEvent(
+ u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(),
+ )
+ if err != nil {
+ return fmt.Errorf("u.d.InvitesTable.InsertInviteEvent: %w", err)
}
- }
- return nil
- })
- return inserted, err
-}
-
-// SetToJoin implements types.MembershipUpdater
-func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) {
- var inviteEventIDs []string
-
- err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
- senderUserNID, err := u.d.assignStateKeyNID(u.ctx, senderUserID)
- if err != nil {
- return fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
- }
-
- // If this is a join event update, there is no invite to update
- if !isUpdate {
- inviteEventIDs, err = u.d.InvitesTable.UpdateInviteRetired(
+ case u.oldMembership == tables.MembershipStateInvite && newMembership != tables.MembershipStateInvite:
+ retired, err = u.d.InvitesTable.UpdateInviteRetired(
u.ctx, u.txn, u.roomNID, u.targetUserNID,
)
if err != nil {
return fmt.Errorf("u.d.InvitesTables.UpdateInviteRetired: %w", err)
}
}
-
- // Look up the NID of the new join event
- nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}, false)
- if err != nil {
- return fmt.Errorf("u.d.EventNIDs: %w", err)
- }
-
- if u.membership != tables.MembershipStateJoin || isUpdate {
- if _, err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateJoin, nIDs[eventID], false); err != nil {
- return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
- }
- }
-
return nil
})
-
- return inviteEventIDs, err
-}
-
-// SetToLeave implements types.MembershipUpdater
-func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) {
- var inviteEventIDs []string
-
- err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
- senderUserNID, err := u.d.assignStateKeyNID(u.ctx, senderUserID)
- if err != nil {
- return fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
- }
- inviteEventIDs, err = u.d.InvitesTable.UpdateInviteRetired(
- u.ctx, u.txn, u.roomNID, u.targetUserNID,
- )
- if err != nil {
- return fmt.Errorf("u.d.InvitesTable.updateInviteRetired: %w", err)
- }
-
- // Look up the NID of the new leave event
- nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}, false)
- if err != nil {
- return fmt.Errorf("u.d.EventNIDs: %w", err)
- }
-
- if u.membership != tables.MembershipStateLeaveOrBan {
- if _, err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateLeaveOrBan, nIDs[eventID], false); err != nil {
- return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
- }
- }
-
- return nil
- })
- return inviteEventIDs, err
-}
-
-// SetToKnock implements types.MembershipUpdater
-func (u *MembershipUpdater) SetToKnock(event *gomatrixserverlib.Event) (bool, error) {
- var inserted bool
- err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
- senderUserNID, err := u.d.assignStateKeyNID(u.ctx, event.Sender())
- if err != nil {
- return fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
- }
- if u.membership != tables.MembershipStateKnock {
- // Look up the NID of the new knock event
- nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{event.EventID()}, false)
- if err != nil {
- return fmt.Errorf("u.d.EventNIDs: %w", err)
- }
-
- if inserted, err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateKnock, nIDs[event.EventID()], false); err != nil {
- return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
- }
- }
- return nil
- })
- return inserted, err
}
diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go
index 8f4e011bf..42c0c8f2d 100644
--- a/roomserver/storage/shared/room_updater.go
+++ b/roomserver/storage/shared/room_updater.go
@@ -217,6 +217,14 @@ func (u *RoomUpdater) SetLatestEvents(
roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID,
currentStateSnapshotNID types.StateSnapshotNID,
) error {
+ switch {
+ case len(latest) == 0:
+ return fmt.Errorf("cannot set latest events with no latest event references")
+ case currentStateSnapshotNID == 0:
+ return fmt.Errorf("cannot set latest events with invalid state snapshot NID")
+ case lastEventNIDSent == 0:
+ return fmt.Errorf("cannot set latest events with invalid latest event NID")
+ }
eventNIDs := make([]types.EventNID, len(latest))
for i := range latest {
eventNIDs[i] = latest[i].EventNID
@@ -225,12 +233,13 @@ func (u *RoomUpdater) SetLatestEvents(
if err := u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID); err != nil {
return fmt.Errorf("u.d.RoomsTable.updateLatestEventNIDs: %w", err)
}
- if roomID, ok := u.d.Cache.GetRoomServerRoomID(roomNID); ok {
- if roomInfo, ok := u.d.Cache.GetRoomInfo(roomID); ok {
- roomInfo.StateSnapshotNID = currentStateSnapshotNID
- roomInfo.IsStub = false
- u.d.Cache.StoreRoomInfo(roomID, roomInfo)
- }
+
+ // Since it's entirely possible that this types.RoomInfo came from the
+ // cache, we should make sure to update that entry so that the next run
+ // works from live data.
+ if u.roomInfo != nil {
+ u.roomInfo.SetStateSnapshotNID(currentStateSnapshotNID)
+ u.roomInfo.SetIsStub(false)
}
return nil
})
diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go
index 692af1f6c..cbf9c8b20 100644
--- a/roomserver/storage/shared/storage.go
+++ b/roomserver/storage/shared/storage.go
@@ -72,7 +72,24 @@ func (d *Database) eventTypeNIDs(
func (d *Database) EventStateKeys(
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]string, error) {
- return d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, eventStateKeyNIDs)
+ result := make(map[types.EventStateKeyNID]string, len(eventStateKeyNIDs))
+ fetch := make([]types.EventStateKeyNID, 0, len(eventStateKeyNIDs))
+ for _, nid := range eventStateKeyNIDs {
+ if key, ok := d.Cache.GetEventStateKey(nid); ok {
+ result[nid] = key
+ } else {
+ fetch = append(fetch, nid)
+ }
+ }
+ fromDB, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, fetch)
+ if err != nil {
+ return nil, err
+ }
+ for nid, key := range fromDB {
+ result[nid] = key
+ d.Cache.StoreEventStateKey(nid, key)
+ }
+ return result, nil
}
func (d *Database) EventStateKeyNIDs(
@@ -139,13 +156,13 @@ func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo
}
func (d *Database) roomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) {
- if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok {
- return &roomInfo, nil
- }
roomInfo, err := d.RoomsTable.SelectRoomInfo(ctx, txn, roomID)
- if err == nil && roomInfo != nil {
+ if err != nil {
+ return nil, err
+ }
+ if roomInfo != nil {
d.Cache.StoreRoomServerRoomID(roomInfo.RoomNID, roomID)
- d.Cache.StoreRoomInfo(roomID, *roomInfo)
+ d.Cache.StoreRoomVersion(roomID, roomInfo.RoomVersion)
}
return roomInfo, err
}
@@ -472,8 +489,8 @@ func (d *Database) events(
fetchNIDList := make([]types.RoomNID, 0, len(uniqueRoomNIDs))
for n := range uniqueRoomNIDs {
if roomID, ok := d.Cache.GetRoomServerRoomID(n); ok {
- if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok {
- roomVersions[n] = roomInfo.RoomVersion
+ if roomVersion, ok := d.Cache.GetRoomVersion(roomID); ok {
+ roomVersions[n] = roomVersion
continue
}
}
@@ -659,7 +676,7 @@ func (d *Database) storeEvent(
succeeded := false
if updater == nil {
var roomInfo *types.RoomInfo
- roomInfo, err = d.RoomInfo(ctx, event.RoomID())
+ roomInfo, err = d.roomInfo(ctx, txn, event.RoomID())
if err != nil {
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err)
}
@@ -730,9 +747,6 @@ func (d *Database) MissingAuthPrevEvents(
func (d *Database) assignRoomNID(
ctx context.Context, roomID string, roomVersion gomatrixserverlib.RoomVersion,
) (types.RoomNID, error) {
- if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok {
- return roomInfo.RoomNID, nil
- }
// Check if we already have a numeric ID in the database.
roomNID, err := d.RoomsTable.SelectRoomNID(ctx, nil, roomID)
if err == sql.ErrNoRows {
@@ -805,8 +819,9 @@ func extractRoomVersionFromCreateEvent(event *gomatrixserverlib.Event) (
// "servers should not apply or send redactions to clients until both the redaction event and original event have been seen, and are valid."
// https://matrix.org/docs/spec/rooms/v3#authorization-rules-for-events
// These cases are:
-// - This is a redaction event, redact the event it references if we know about it.
-// - This is a normal event which may have been previously redacted.
+// - This is a redaction event, redact the event it references if we know about it.
+// - This is a normal event which may have been previously redacted.
+//
// In the first case, check if we have the referenced event then apply the redaction, else store it
// in the redactions table with validated=FALSE. In the second case, check if there is a redaction for it:
// if there is then apply the redactions and set validated=TRUE.
@@ -971,6 +986,38 @@ func (d *Database) loadEvent(ctx context.Context, eventID string) *types.Event {
return &evs[0]
}
+func (d *Database) GetHistoryVisibilityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]*gomatrixserverlib.Event, error) {
+ eventStates, err := d.EventsTable.BulkSelectStateAtEventByID(ctx, nil, []string{eventID})
+ if err != nil {
+ return nil, err
+ }
+ stateSnapshotNID := eventStates[0].BeforeStateSnapshotNID
+ if stateSnapshotNID == 0 {
+ return nil, nil
+ }
+ eventNIDs, err := d.StateSnapshotTable.BulkSelectStateForHistoryVisibility(ctx, nil, stateSnapshotNID, domain)
+ if err != nil {
+ return nil, err
+ }
+ eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
+ if err != nil {
+ eventIDs = map[types.EventNID]string{}
+ }
+ events := make([]*gomatrixserverlib.Event, 0, len(eventNIDs))
+ for _, eventNID := range eventNIDs {
+ data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, []types.EventNID{eventNID})
+ if err != nil {
+ return nil, err
+ }
+ ev, err := gomatrixserverlib.NewEventFromTrustedJSONWithEventID(eventIDs[eventNID], data[0].EventJSON, false, roomInfo.RoomVersion)
+ if err != nil {
+ return nil, err
+ }
+ events = append(events, ev)
+ }
+ return events, nil
+}
+
// GetStateEvent returns the current state event of a given type for a given room with a given state key
// If no event could be found, returns nil
// If there was an issue during the retrieval, returns an error
@@ -983,7 +1030,7 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s
return nil, fmt.Errorf("room %s doesn't exist", roomID)
}
// e.g invited rooms
- if roomInfo.IsStub {
+ if roomInfo.IsStub() {
return nil, nil
}
eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType)
@@ -1002,7 +1049,7 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s
if err != nil {
return nil, err
}
- entries, err := d.loadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID)
+ entries, err := d.loadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID())
if err != nil {
return nil, err
}
@@ -1048,7 +1095,7 @@ func (d *Database) GetStateEventsWithEventType(ctx context.Context, roomID, evTy
return nil, fmt.Errorf("room %s doesn't exist", roomID)
}
// e.g invited rooms
- if roomInfo.IsStub {
+ if roomInfo.IsStub() {
return nil, nil
}
eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType)
@@ -1059,7 +1106,7 @@ func (d *Database) GetStateEventsWithEventType(ctx context.Context, roomID, evTy
if err != nil {
return nil, err
}
- entries, err := d.loadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID)
+ entries, err := d.loadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID())
if err != nil {
return nil, err
}
@@ -1176,10 +1223,10 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
return nil, fmt.Errorf("GetBulkStateContent: failed to load room info for room %s : %w", roomID, err2)
}
// for unknown rooms or rooms which we don't have the current state, skip them.
- if roomInfo == nil || roomInfo.IsStub {
+ if roomInfo == nil || roomInfo.IsStub() {
continue
}
- entries, err2 := d.loadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID)
+ entries, err2 := d.loadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID())
if err2 != nil {
return nil, fmt.Errorf("GetBulkStateContent: failed to load state for room %s : %w", roomID, err2)
}
diff --git a/roomserver/storage/sqlite3/deltas/20201028212440_add_forgotten_column.go b/roomserver/storage/sqlite3/deltas/20201028212440_add_forgotten_column.go
index d08ab02d5..4c002e33d 100644
--- a/roomserver/storage/sqlite3/deltas/20201028212440_add_forgotten_column.go
+++ b/roomserver/storage/sqlite3/deltas/20201028212440_add_forgotten_column.go
@@ -15,24 +15,13 @@
package deltas
import (
+ "context"
"database/sql"
"fmt"
-
- "github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/pressly/goose"
)
-func LoadFromGoose() {
- goose.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn)
- goose.AddMigration(UpStateBlocksRefactor, DownStateBlocksRefactor)
-}
-
-func LoadAddForgottenColumn(m *sqlutil.Migrations) {
- m.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn)
-}
-
-func UpAddForgottenColumn(tx *sql.Tx) error {
- _, err := tx.Exec(` ALTER TABLE roomserver_membership RENAME TO roomserver_membership_tmp;
+func UpAddForgottenColumn(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, ` ALTER TABLE roomserver_membership RENAME TO roomserver_membership_tmp;
CREATE TABLE IF NOT EXISTS roomserver_membership (
room_nid INTEGER NOT NULL,
target_nid INTEGER NOT NULL,
@@ -57,8 +46,8 @@ DROP TABLE roomserver_membership_tmp;`)
return nil
}
-func DownAddForgottenColumn(tx *sql.Tx) error {
- _, err := tx.Exec(` ALTER TABLE roomserver_membership RENAME TO roomserver_membership_tmp;
+func DownAddForgottenColumn(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, ` ALTER TABLE roomserver_membership RENAME TO roomserver_membership_tmp;
CREATE TABLE IF NOT EXISTS roomserver_membership (
room_nid INTEGER NOT NULL,
target_nid INTEGER NOT NULL,
diff --git a/roomserver/storage/sqlite3/deltas/2021041615092700_state_blocks_refactor.go b/roomserver/storage/sqlite3/deltas/2021041615092700_state_blocks_refactor.go
index 8f5ab8fc5..00978121f 100644
--- a/roomserver/storage/sqlite3/deltas/2021041615092700_state_blocks_refactor.go
+++ b/roomserver/storage/sqlite3/deltas/2021041615092700_state_blocks_refactor.go
@@ -21,40 +21,35 @@ import (
"fmt"
"github.com/matrix-org/dendrite/internal"
- "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
)
-func LoadStateBlocksRefactor(m *sqlutil.Migrations) {
- m.AddMigration(UpStateBlocksRefactor, DownStateBlocksRefactor)
-}
-
// nolint:gocyclo
-func UpStateBlocksRefactor(tx *sql.Tx) error {
+func UpStateBlocksRefactor(ctx context.Context, tx *sql.Tx) error {
logrus.Warn("Performing state storage upgrade. Please wait, this may take some time!")
defer logrus.Warn("State storage upgrade complete")
var maxsnapshotid int
var maxblockid int
- if err := tx.QueryRow(`SELECT IFNULL(MAX(state_snapshot_nid),0) FROM roomserver_state_snapshots;`).Scan(&maxsnapshotid); err != nil {
- return fmt.Errorf("tx.QueryRow.Scan (count snapshots): %w", err)
+ if err := tx.QueryRowContext(ctx, `SELECT IFNULL(MAX(state_snapshot_nid),0) FROM roomserver_state_snapshots;`).Scan(&maxsnapshotid); err != nil {
+ return fmt.Errorf("tx.QueryRowContext.Scan (count snapshots): %w", err)
}
- if err := tx.QueryRow(`SELECT IFNULL(MAX(state_block_nid),0) FROM roomserver_state_block;`).Scan(&maxblockid); err != nil {
- return fmt.Errorf("tx.QueryRow.Scan (count snapshots): %w", err)
+ if err := tx.QueryRowContext(ctx, `SELECT IFNULL(MAX(state_block_nid),0) FROM roomserver_state_block;`).Scan(&maxblockid); err != nil {
+ return fmt.Errorf("tx.QueryRowContext.Scan (count snapshots): %w", err)
}
maxsnapshotid++
maxblockid++
oldMaxSnapshotID := maxsnapshotid
- if _, err := tx.Exec(`ALTER TABLE roomserver_state_block RENAME TO _roomserver_state_block;`); err != nil {
- return fmt.Errorf("tx.Exec: %w", err)
+ if _, err := tx.ExecContext(ctx, `ALTER TABLE roomserver_state_block RENAME TO _roomserver_state_block;`); err != nil {
+ return fmt.Errorf("tx.ExecContext: %w", err)
}
- if _, err := tx.Exec(`ALTER TABLE roomserver_state_snapshots RENAME TO _roomserver_state_snapshots;`); err != nil {
- return fmt.Errorf("tx.Exec: %w", err)
+ if _, err := tx.ExecContext(ctx, `ALTER TABLE roomserver_state_snapshots RENAME TO _roomserver_state_snapshots;`); err != nil {
+ return fmt.Errorf("tx.ExecContext: %w", err)
}
- _, err := tx.Exec(`
+ _, err := tx.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS roomserver_state_block (
state_block_nid INTEGER PRIMARY KEY AUTOINCREMENT,
state_block_hash BLOB UNIQUE,
@@ -62,9 +57,9 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
);
`)
if err != nil {
- return fmt.Errorf("tx.Exec: %w", err)
+ return fmt.Errorf("tx.ExecContext: %w", err)
}
- _, err = tx.Exec(`
+ _, err = tx.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS roomserver_state_snapshots (
state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT,
state_snapshot_hash BLOB UNIQUE,
@@ -73,11 +68,11 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
);
`)
if err != nil {
- return fmt.Errorf("tx.Exec: %w", err)
+ return fmt.Errorf("tx.ExecContext: %w", err)
}
- snapshotrows, err := tx.Query(`SELECT state_snapshot_nid, room_nid, state_block_nids FROM _roomserver_state_snapshots;`)
+ snapshotrows, err := tx.QueryContext(ctx, `SELECT state_snapshot_nid, room_nid, state_block_nids FROM _roomserver_state_snapshots;`)
if err != nil {
- return fmt.Errorf("tx.Query: %w", err)
+ return fmt.Errorf("tx.QueryContext: %w", err)
}
defer internal.CloseAndLogIfError(context.TODO(), snapshotrows, "rows.close() failed")
for snapshotrows.Next() {
@@ -99,7 +94,7 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
// in question a state snapshot NID of 0 to indicate 'no snapshot'.
// If we don't do this, we'll fail the assertions later on which try to ensure we didn't forget
// any snapshots.
- _, err = tx.Exec(
+ _, err = tx.ExecContext(ctx,
`UPDATE roomserver_events SET state_snapshot_nid = 0 WHERE event_type_nid = $1 AND event_state_key_nid = $2 AND state_snapshot_nid = $3`,
types.MRoomCreateNID, types.EmptyStateKeyNID, snapshot,
)
@@ -109,9 +104,9 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
}
for _, block := range blocks {
if err = func() error {
- blockrows, berr := tx.Query(`SELECT event_nid FROM _roomserver_state_block WHERE state_block_nid = $1`, block)
+ blockrows, berr := tx.QueryContext(ctx, `SELECT event_nid FROM _roomserver_state_block WHERE state_block_nid = $1`, block)
if berr != nil {
- return fmt.Errorf("tx.Query (event nids from old block): %w", berr)
+ return fmt.Errorf("tx.QueryContext (event nids from old block): %w", berr)
}
defer internal.CloseAndLogIfError(context.TODO(), blockrows, "rows.close() failed")
events := types.EventNIDs{}
@@ -129,14 +124,14 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
}
var blocknid types.StateBlockNID
- err = tx.QueryRow(`
+ err = tx.QueryRowContext(ctx, `
INSERT INTO roomserver_state_block (state_block_nid, state_block_hash, event_nids)
VALUES ($1, $2, $3)
ON CONFLICT (state_block_hash) DO UPDATE SET event_nids=$3
RETURNING state_block_nid
`, maxblockid, events.Hash(), eventjson).Scan(&blocknid)
if err != nil {
- return fmt.Errorf("tx.QueryRow.Scan (insert new block): %w", err)
+ return fmt.Errorf("tx.QueryRowContext.Scan (insert new block): %w", err)
}
maxblockid++
newblocks = append(newblocks, blocknid)
@@ -151,22 +146,22 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
}
var newsnapshot types.StateSnapshotNID
- err = tx.QueryRow(`
+ err = tx.QueryRowContext(ctx, `
INSERT INTO roomserver_state_snapshots (state_snapshot_nid, state_snapshot_hash, room_nid, state_block_nids)
VALUES ($1, $2, $3, $4)
ON CONFLICT (state_snapshot_hash) DO UPDATE SET room_nid=$3
RETURNING state_snapshot_nid
`, maxsnapshotid, newblocks.Hash(), room, newblocksjson).Scan(&newsnapshot)
if err != nil {
- return fmt.Errorf("tx.QueryRow.Scan (insert new snapshot): %w", err)
+ return fmt.Errorf("tx.QueryRowContext.Scan (insert new snapshot): %w", err)
}
maxsnapshotid++
- _, err = tx.Exec(`UPDATE roomserver_events SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$3`, newsnapshot, snapshot, maxsnapshotid)
+ _, err = tx.ExecContext(ctx, `UPDATE roomserver_events SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$3`, newsnapshot, snapshot, maxsnapshotid)
if err != nil {
- return fmt.Errorf("tx.Exec (update events): %w", err)
+ return fmt.Errorf("tx.ExecContext (update events): %w", err)
}
- if _, err = tx.Exec(`UPDATE roomserver_rooms SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$3`, newsnapshot, snapshot, maxsnapshotid); err != nil {
- return fmt.Errorf("tx.Exec (update rooms): %w", err)
+ if _, err = tx.ExecContext(ctx, `UPDATE roomserver_rooms SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$3`, newsnapshot, snapshot, maxsnapshotid); err != nil {
+ return fmt.Errorf("tx.ExecContext (update rooms): %w", err)
}
}
}
@@ -175,13 +170,13 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
// If we do, this is a problem if Dendrite tries to load the snapshot as it will not exist
// in roomserver_state_snapshots
var count int64
- if err = tx.QueryRow(`SELECT COUNT(*) FROM roomserver_events WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, oldMaxSnapshotID).Scan(&count); err != nil {
+ if err = tx.QueryRowContext(ctx, `SELECT COUNT(*) FROM roomserver_events WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, oldMaxSnapshotID).Scan(&count); err != nil {
return fmt.Errorf("assertion query failed: %s", err)
}
if count > 0 {
var res sql.Result
var c int64
- res, err = tx.Exec(`UPDATE roomserver_events SET state_snapshot_nid = 0 WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, oldMaxSnapshotID)
+ res, err = tx.ExecContext(ctx, `UPDATE roomserver_events SET state_snapshot_nid = 0 WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, oldMaxSnapshotID)
if err != nil && err != sql.ErrNoRows {
return fmt.Errorf("failed to reset invalid state snapshots: %w", err)
}
@@ -191,23 +186,23 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
return fmt.Errorf("expected to reset %d event(s) but only updated %d event(s)", count, c)
}
}
- if err = tx.QueryRow(`SELECT COUNT(*) FROM roomserver_rooms WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, oldMaxSnapshotID).Scan(&count); err != nil {
+ if err = tx.QueryRowContext(ctx, `SELECT COUNT(*) FROM roomserver_rooms WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, oldMaxSnapshotID).Scan(&count); err != nil {
return fmt.Errorf("assertion query failed: %s", err)
}
if count > 0 {
return fmt.Errorf("%d rooms exist in roomserver_rooms which have not been converted to a new state_snapshot_nid; this is a bug, please report", count)
}
- if _, err = tx.Exec(`DROP TABLE _roomserver_state_snapshots;`); err != nil {
+ if _, err = tx.ExecContext(ctx, `DROP TABLE _roomserver_state_snapshots;`); err != nil {
return fmt.Errorf("tx.Exec (delete old snapshot table): %w", err)
}
- if _, err = tx.Exec(`DROP TABLE _roomserver_state_block;`); err != nil {
+ if _, err = tx.ExecContext(ctx, `DROP TABLE _roomserver_state_block;`); err != nil {
return fmt.Errorf("tx.Exec (delete old block table): %w", err)
}
return nil
}
-func DownStateBlocksRefactor(tx *sql.Tx) error {
+func DownStateBlocksRefactor(ctx context.Context, tx *sql.Tx) error {
panic("Downgrading state storage is not supported")
}
diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go
index 570d3919c..f3303eb0e 100644
--- a/roomserver/storage/sqlite3/membership_table.go
+++ b/roomserver/storage/sqlite3/membership_table.go
@@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/roomserver/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
@@ -62,24 +63,24 @@ const insertMembershipSQL = "" +
const selectMembershipFromRoomAndTargetSQL = "" +
"SELECT membership_nid, event_nid, forgotten FROM roomserver_membership" +
- " WHERE room_nid = $1 AND target_nid = $2"
+ " WHERE room_nid = $1 AND event_nid != 0 AND target_nid = $2"
const selectMembershipsFromRoomAndMembershipSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
- " WHERE room_nid = $1 AND membership_nid = $2 and forgotten = false"
+ " WHERE room_nid = $1 AND event_nid != 0 AND membership_nid = $2 and forgotten = false"
const selectLocalMembershipsFromRoomAndMembershipSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
- " WHERE room_nid = $1 AND membership_nid = $2" +
+ " WHERE room_nid = $1 AND event_nid != 0 AND membership_nid = $2" +
" AND target_local = true and forgotten = false"
const selectMembershipsFromRoomSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
- " WHERE room_nid = $1 and forgotten = false"
+ " WHERE room_nid = $1 AND event_nid != 0 and forgotten = false"
const selectLocalMembershipsFromRoomSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
- " WHERE room_nid = $1" +
+ " WHERE room_nid = $1 AND event_nid != 0" +
" AND target_local = true and forgotten = false"
const selectMembershipForUpdateSQL = "" +
@@ -125,6 +126,9 @@ const selectServerInRoomSQL = "" +
" JOIN roomserver_event_state_keys ON roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" +
" WHERE membership_nid = $1 AND room_nid = $2 AND event_state_key LIKE '%:' || $3 LIMIT 1"
+const deleteMembershipSQL = "" +
+ "DELETE FROM roomserver_membership WHERE room_nid = $1 AND target_nid = $2"
+
type membershipStatements struct {
db *sql.DB
insertMembershipStmt *sql.Stmt
@@ -140,11 +144,20 @@ type membershipStatements struct {
updateMembershipForgetRoomStmt *sql.Stmt
selectLocalServerInRoomStmt *sql.Stmt
selectServerInRoomStmt *sql.Stmt
+ deleteMembershipStmt *sql.Stmt
}
func CreateMembershipTable(db *sql.DB) error {
_, err := db.Exec(membershipSchema)
- return err
+ if err != nil {
+ return err
+ }
+ m := sqlutil.NewMigrator(db)
+ m.AddMigrations(sqlutil.Migration{
+ Version: "roomserver: add forgotten column",
+ Up: deltas.UpAddForgottenColumn,
+ })
+ return m.Up(context.Background())
}
func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) {
@@ -166,6 +179,7 @@ func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) {
{&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom},
{&s.selectLocalServerInRoomStmt, selectLocalServerInRoomSQL},
{&s.selectServerInRoomStmt, selectServerInRoomSQL},
+ {&s.deleteMembershipStmt, deleteMembershipSQL},
}.Prepare(db)
}
@@ -383,3 +397,13 @@ func (s *membershipStatements) SelectServerInRoom(ctx context.Context, txn *sql.
}
return roomNID == nid, nil
}
+
+func (s *membershipStatements) DeleteMembership(
+ ctx context.Context, txn *sql.Tx,
+ roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
+) error {
+ _, err := sqlutil.TxStmt(txn, s.deleteMembershipStmt).ExecContext(
+ ctx, roomNID, targetUserNID,
+ )
+ return err
+}
diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go
index 03ad4b3d0..25b611b3e 100644
--- a/roomserver/storage/sqlite3/rooms_table.go
+++ b/roomserver/storage/sqlite3/rooms_table.go
@@ -129,9 +129,10 @@ func (s *roomStatements) SelectRoomIDsWithEvents(ctx context.Context, txn *sql.T
func (s *roomStatements) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) {
var info types.RoomInfo
var latestNIDsJSON string
+ var stateSnapshotNID types.StateSnapshotNID
stmt := sqlutil.TxStmt(txn, s.selectRoomInfoStmt)
err := stmt.QueryRowContext(ctx, roomID).Scan(
- &info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDsJSON,
+ &info.RoomVersion, &info.RoomNID, &stateSnapshotNID, &latestNIDsJSON,
)
if err != nil {
if err == sql.ErrNoRows {
@@ -143,7 +144,8 @@ func (s *roomStatements) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID
if err = json.Unmarshal([]byte(latestNIDsJSON), &latestNIDs); err != nil {
return nil, err
}
- info.IsStub = len(latestNIDs) == 0
+ info.SetStateSnapshotNID(stateSnapshotNID)
+ info.SetIsStub(len(latestNIDs) == 0)
return &info, err
}
diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go
index b8136b758..73827522c 100644
--- a/roomserver/storage/sqlite3/state_snapshot_table.go
+++ b/roomserver/storage/sqlite3/state_snapshot_table.go
@@ -140,3 +140,9 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
}
return results, nil
}
+
+func (s *stateSnapshotStatements) BulkSelectStateForHistoryVisibility(
+ ctx context.Context, txn *sql.Tx, stateSnapshotNID types.StateSnapshotNID, domain string,
+) ([]types.EventNID, error) {
+ return nil, tables.OptimisationNotSupportedError
+}
diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go
index 8325fdad5..9f8a1b118 100644
--- a/roomserver/storage/sqlite3/storage.go
+++ b/roomserver/storage/sqlite3/storage.go
@@ -20,6 +20,8 @@ import (
"database/sql"
"fmt"
+ "github.com/matrix-org/gomatrixserverlib"
+
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
@@ -27,7 +29,6 @@ import (
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
- "github.com/matrix-org/gomatrixserverlib"
)
// A Database is used to store room events and stream offsets.
@@ -54,22 +55,31 @@ func Open(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache c
// db.SetMaxOpenConns(20)
// Create the tables.
- if err := d.create(db); err != nil {
+ if err = d.create(db); err != nil {
return nil, err
}
- // Then execute the migrations. By this point the tables are created with the latest
- // schemas.
- m := sqlutil.NewMigrations()
- deltas.LoadAddForgottenColumn(m)
- deltas.LoadStateBlocksRefactor(m)
- if err := m.RunDeltas(db, dbProperties); err != nil {
- return nil, err
+ // Special case, since this migration uses several tables, so it needs to
+ // be sure that all tables are created first.
+ // TODO: Remove when we are sure we are not having goose artefacts in the db
+ // This forces an error, which indicates the migration is already applied, since the
+ // column event_nid was removed from the table
+ var eventNID int
+ err = db.QueryRow("SELECT event_nid FROM roomserver_state_block LIMIT 1;").Scan(&eventNID)
+ if err == nil {
+ m := sqlutil.NewMigrator(db)
+ m.AddMigrations(sqlutil.Migration{
+ Version: "roomserver: state blocks refactor",
+ Up: deltas.UpStateBlocksRefactor,
+ })
+ if err = m.Up(base.Context()); err != nil {
+ return nil, err
+ }
}
// Then prepare the statements. Now that the migrations have run, any columns referred
// to in the database code should now exist.
- if err := d.prepare(db, writer, cache); err != nil {
+ if err = d.prepare(db, writer, cache); err != nil {
return nil, err
}
diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go
index 116e11c4e..0bc389b80 100644
--- a/roomserver/storage/tables/interface.go
+++ b/roomserver/storage/tables/interface.go
@@ -3,12 +3,16 @@ package tables
import (
"context"
"database/sql"
+ "errors"
- "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/tidwall/gjson"
+
+ "github.com/matrix-org/dendrite/roomserver/types"
)
+var OptimisationNotSupportedError = errors.New("optimisation not supported")
+
type EventJSONPair struct {
EventNID types.EventNID
EventJSON []byte
@@ -80,6 +84,10 @@ type Rooms interface {
type StateSnapshot interface {
InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs types.StateBlockNIDs) (stateNID types.StateSnapshotNID, err error)
BulkSelectStateBlockNIDs(ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
+ // BulkSelectStateForHistoryVisibility is a PostgreSQL-only optimisation for finding
+ // which users are in a room faster than having to load the entire room state. In the
+ // case of SQLite, this will return tables.OptimisationNotSupportedError.
+ BulkSelectStateForHistoryVisibility(ctx context.Context, txn *sql.Tx, stateSnapshotNID types.StateSnapshotNID, domain string) ([]types.EventNID, error)
}
type StateBlock interface {
@@ -133,6 +141,7 @@ type Membership interface {
UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error
SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error)
SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error)
+ DeleteMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) error
}
type Published interface {
@@ -170,7 +179,7 @@ type StrippedEvent struct {
}
// ExtractContentValue from the given state event. For example, given an m.room.name event with:
-// content: { name: "Foo" }
+// content: { name: "Foo" }
// this returns "Foo".
func ExtractContentValue(ev *gomatrixserverlib.HeaderedEvent) string {
content := ev.Content()
diff --git a/roomserver/storage/tables/membership_table_test.go b/roomserver/storage/tables/membership_table_test.go
index 14e8ce50a..f789ef4ac 100644
--- a/roomserver/storage/tables/membership_table_test.go
+++ b/roomserver/storage/tables/membership_table_test.go
@@ -60,6 +60,9 @@ func TestMembershipTable(t *testing.T) {
// This inserts a left user to the room
err = tab.InsertMembership(ctx, nil, 1, stateKeyNID, true)
assert.NoError(t, err)
+ // We must update the membership with a non-zero event NID or it will get filtered out in later queries
+ _, err = tab.UpdateMembership(ctx, nil, 1, stateKeyNID, userNIDs[0], tables.MembershipStateLeaveOrBan, 1, false)
+ assert.NoError(t, err)
}
// ... so this should be false
diff --git a/roomserver/storage/tables/rooms_table_test.go b/roomserver/storage/tables/rooms_table_test.go
index 0a02369a1..eddd012c8 100644
--- a/roomserver/storage/tables/rooms_table_test.go
+++ b/roomserver/storage/tables/rooms_table_test.go
@@ -63,12 +63,12 @@ func TestRoomsTable(t *testing.T) {
roomInfo, err := tab.SelectRoomInfo(ctx, nil, room.ID)
assert.NoError(t, err)
- assert.Equal(t, &types.RoomInfo{
- RoomNID: wantRoomNID,
- RoomVersion: room.Version,
- StateSnapshotNID: 0,
- IsStub: true, // there are no latestEventNIDs
- }, roomInfo)
+ expected := &types.RoomInfo{
+ RoomNID: wantRoomNID,
+ RoomVersion: room.Version,
+ }
+ expected.SetIsStub(true) // there are no latestEventNIDs
+ assert.Equal(t, expected, roomInfo)
roomInfo, err = tab.SelectRoomInfo(ctx, nil, "!doesnotexist:localhost")
assert.NoError(t, err)
@@ -103,12 +103,12 @@ func TestRoomsTable(t *testing.T) {
roomInfo, err = tab.SelectRoomInfo(ctx, nil, room.ID)
assert.NoError(t, err)
- assert.Equal(t, &types.RoomInfo{
- RoomNID: wantRoomNID,
- RoomVersion: room.Version,
- StateSnapshotNID: 1,
- IsStub: false,
- }, roomInfo)
+ expected = &types.RoomInfo{
+ RoomNID: wantRoomNID,
+ RoomVersion: room.Version,
+ }
+ expected.SetStateSnapshotNID(1)
+ assert.Equal(t, expected, roomInfo)
eventNIDs, snapshotNID, err := tab.SelectLatestEventNIDs(ctx, nil, wantRoomNID)
assert.NoError(t, err)
diff --git a/roomserver/storage/tables/state_snapshot_table_test.go b/roomserver/storage/tables/state_snapshot_table_test.go
index dcdb5d8f1..b2e59377d 100644
--- a/roomserver/storage/tables/state_snapshot_table_test.go
+++ b/roomserver/storage/tables/state_snapshot_table_test.go
@@ -23,6 +23,15 @@ func mustCreateStateSnapshotTable(t *testing.T, dbType test.DBType) (tab tables.
assert.NoError(t, err)
switch dbType {
case test.DBTypePostgres:
+ // for the PostgreSQL history visibility optimisation to work,
+ // we also need some other tables to exist
+ err = postgres.CreateEventStateKeysTable(db)
+ assert.NoError(t, err)
+ err = postgres.CreateEventsTable(db)
+ assert.NoError(t, err)
+ err = postgres.CreateStateBlockTable(db)
+ assert.NoError(t, err)
+ // ... and then the snapshot table itself
err = postgres.CreateStateSnapshotTable(db)
assert.NoError(t, err)
tab, err = postgres.PrepareStateSnapshotTable(db)
diff --git a/roomserver/types/types.go b/roomserver/types/types.go
index bc01ca33c..f40980994 100644
--- a/roomserver/types/types.go
+++ b/roomserver/types/types.go
@@ -19,6 +19,7 @@ import (
"encoding/json"
"sort"
"strings"
+ "sync"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
@@ -279,8 +280,46 @@ func (e RejectedError) Error() string { return string(e) }
// RoomInfo contains metadata about a room
type RoomInfo struct {
+ mu sync.RWMutex
RoomNID RoomNID
RoomVersion gomatrixserverlib.RoomVersion
- StateSnapshotNID StateSnapshotNID
- IsStub bool
+ stateSnapshotNID StateSnapshotNID
+ isStub bool
+}
+
+func (r *RoomInfo) StateSnapshotNID() StateSnapshotNID {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ return r.stateSnapshotNID
+}
+
+func (r *RoomInfo) IsStub() bool {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ return r.isStub
+}
+
+func (r *RoomInfo) SetStateSnapshotNID(nid StateSnapshotNID) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.stateSnapshotNID = nid
+}
+
+func (r *RoomInfo) SetIsStub(isStub bool) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.isStub = isStub
+}
+
+func (r *RoomInfo) CopyFrom(r2 *RoomInfo) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ r2.mu.RLock()
+ defer r2.mu.RUnlock()
+
+ r.RoomNID = r2.RoomNID
+ r.RoomVersion = r2.RoomVersion
+ r.stateSnapshotNID = r2.stateSnapshotNID
+ r.isStub = r2.isStub
}
diff --git a/roomserver/version/version.go b/roomserver/version/version.go
index 1f66995d8..729d00a80 100644
--- a/roomserver/version/version.go
+++ b/roomserver/version/version.go
@@ -23,7 +23,7 @@ import (
// DefaultRoomVersion contains the room version that will, by
// default, be used to create new rooms on this server.
func DefaultRoomVersion() gomatrixserverlib.RoomVersion {
- return gomatrixserverlib.RoomVersionV6
+ return gomatrixserverlib.RoomVersionV9
}
// RoomVersions returns a map of all known room versions to this
diff --git a/run-sytest.sh b/run-sytest.sh
index 47635fd12..e23982397 100755
--- a/run-sytest.sh
+++ b/run-sytest.sh
@@ -17,7 +17,7 @@ main() {
if [ -d ../sytest ]; then
local tmpdir
- tmpdir="$(mktemp -d --tmpdir run-systest.XXXXXXXXXX)"
+ tmpdir="$(mktemp -d -t run-systest.XXXXXXXXXX)"
trap "rm -r '$tmpdir'" EXIT
if [ -z "$DISABLE_BUILDING_SYTEST" ]; then
diff --git a/setup/base/base.go b/setup/base/base.go
index 93ab87de1..b21eeba47 100644
--- a/setup/base/base.go
+++ b/setup/base/base.go
@@ -369,6 +369,25 @@ func (b *BaseDendrite) CreateFederationClient() *gomatrixserverlib.FederationCli
return client
}
+func (b *BaseDendrite) configureHTTPErrors() {
+ notAllowedHandler := func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusMethodNotAllowed)
+ _, _ = w.Write([]byte(fmt.Sprintf("405 %s not allowed on this endpoint", r.Method)))
+ }
+
+ notFoundCORSHandler := httputil.WrapHandlerInCORS(http.NotFoundHandler())
+ notAllowedCORSHandler := httputil.WrapHandlerInCORS(http.HandlerFunc(notAllowedHandler))
+
+ for _, router := range []*mux.Router{
+ b.PublicClientAPIMux, b.PublicMediaAPIMux,
+ b.DendriteAdminMux, b.SynapseAdminMux,
+ b.PublicWellKnownAPIMux,
+ } {
+ router.NotFoundHandler = notFoundCORSHandler
+ router.MethodNotAllowedHandler = notAllowedCORSHandler
+ }
+}
+
// SetupAndServeHTTP sets up the HTTP server to serve endpoints registered on
// ApiMux under /api/ and adds a prometheus handler under /metrics.
func (b *BaseDendrite) SetupAndServeHTTP(
@@ -409,6 +428,8 @@ func (b *BaseDendrite) SetupAndServeHTTP(
}
}
+ b.configureHTTPErrors()
+
internalRouter.PathPrefix(httputil.InternalPathPrefix).Handler(b.InternalAPIMux)
if b.Cfg.Global.Metrics.Enabled {
internalRouter.Handle("/metrics", httputil.WrapHandlerInBasicAuth(promhttp.Handler(), b.Cfg.Global.Metrics.BasicAuth))
diff --git a/setup/config/config.go b/setup/config/config.go
index 9b9000a62..924b51f22 100644
--- a/setup/config/config.go
+++ b/setup/config/config.go
@@ -19,8 +19,8 @@ import (
"encoding/pem"
"fmt"
"io"
- "io/ioutil"
"net/url"
+ "os"
"path/filepath"
"regexp"
"strings"
@@ -191,7 +191,7 @@ type ConfigErrors []string
// Load a yaml config file for a server run as multiple processes or as a monolith.
// Checks the config to ensure that it is valid.
func Load(configPath string, monolith bool) (*Dendrite, error) {
- configData, err := ioutil.ReadFile(configPath)
+ configData, err := os.ReadFile(configPath)
if err != nil {
return nil, err
}
@@ -199,9 +199,9 @@ func Load(configPath string, monolith bool) (*Dendrite, error) {
if err != nil {
return nil, err
}
- // Pass the current working directory and ioutil.ReadFile so that they can
+ // Pass the current working directory and os.ReadFile so that they can
// be mocked in the tests
- return loadConfig(basePath, configData, ioutil.ReadFile, monolith)
+ return loadConfig(basePath, configData, os.ReadFile, monolith)
}
func loadConfig(
@@ -530,7 +530,7 @@ func (config *Dendrite) KeyServerURL() string {
// SetupTracing configures the opentracing using the supplied configuration.
func (config *Dendrite) SetupTracing(serviceName string) (closer io.Closer, err error) {
if !config.Tracing.Enabled {
- return ioutil.NopCloser(bytes.NewReader([]byte{})), nil
+ return io.NopCloser(bytes.NewReader([]byte{})), nil
}
return config.Tracing.Jaeger.InitGlobalTracer(
serviceName,
diff --git a/setup/config/config_appservice.go b/setup/config/config_appservice.go
index 9b89fc9af..b8f99a612 100644
--- a/setup/config/config_appservice.go
+++ b/setup/config/config_appservice.go
@@ -16,7 +16,7 @@ package config
import (
"fmt"
- "io/ioutil"
+ "os"
"path/filepath"
"regexp"
"strings"
@@ -181,7 +181,7 @@ func loadAppServices(config *AppServiceAPI, derived *Derived) error {
}
// Read the application service's config file
- configData, err := ioutil.ReadFile(absPath)
+ configData, err := os.ReadFile(absPath)
if err != nil {
return err
}
diff --git a/setup/config/config_global.go b/setup/config/config_global.go
index ac1380a4e..d4e54e203 100644
--- a/setup/config/config_global.go
+++ b/setup/config/config_global.go
@@ -46,6 +46,9 @@ type Global struct {
// The server name to delegate server-server communications to, with optional port
WellKnownServerName string `yaml:"well_known_server_name"`
+ // The server name to delegate client-server communications to, with optional port
+ WellKnownClientName string `yaml:"well_known_client_name"`
+
// Disables federation. Dendrite will not be able to make any outbound HTTP requests
// to other servers and the federation API will not be exposed.
DisableFederation bool `yaml:"disable_federation"`
@@ -73,7 +76,7 @@ type Global struct {
// ServerNotices configuration used for sending server notices
ServerNotices ServerNotices `yaml:"server_notices"`
- // ReportStats configures opt-in anonymous stats reporting.
+ // ReportStats configures opt-in phone-home statistics reporting.
ReportStats ReportStats `yaml:"report_stats"`
// Configuration for the caches.
@@ -189,9 +192,9 @@ func (c *Cache) Verify(errors *ConfigErrors, isMonolith bool) {
checkPositive(errors, "max_size_estimated", int64(c.EstimatedMaxSize))
}
-// ReportStats configures opt-in anonymous stats reporting.
+// ReportStats configures opt-in phone-home statistics reporting.
type ReportStats struct {
- // Enabled configures anonymous usage stats of the server
+ // Enabled configures phone-home statistics of the server
Enabled bool `yaml:"enabled"`
// Endpoint the endpoint to report stats to
diff --git a/setup/config/config_jetstream.go b/setup/config/config_jetstream.go
index e4cfd4d3b..a7827597e 100644
--- a/setup/config/config_jetstream.go
+++ b/setup/config/config_jetstream.go
@@ -17,6 +17,10 @@ type JetStream struct {
TopicPrefix string `yaml:"topic_prefix"`
// Keep all storage in memory. This is mostly useful for unit tests.
InMemory bool `yaml:"in_memory"`
+ // Disable logging. This is mostly useful for unit tests.
+ NoLog bool `yaml:"-"`
+ // Disables TLS validation. This should NOT be used in production
+ DisableTLSValidation bool `yaml:"disable_tls_validation"`
}
func (c *JetStream) Prefixed(name string) string {
@@ -32,6 +36,8 @@ func (c *JetStream) Defaults(generate bool) {
c.TopicPrefix = "Dendrite"
if generate {
c.StoragePath = Path("./")
+ c.NoLog = true
+ c.DisableTLSValidation = true
}
}
diff --git a/setup/config/config_test.go b/setup/config/config_test.go
index b9b1e7bb5..ee7e7389c 100644
--- a/setup/config/config_test.go
+++ b/setup/config/config_test.go
@@ -42,6 +42,7 @@ global:
key_id: ed25519:auto
key_validity_period: 168h0m0s
well_known_server_name: "localhost:443"
+ well_known_client_name: "localhost:443"
trusted_third_party_id_servers:
- matrix.org
- vector.im
diff --git a/setup/jetstream/nats.go b/setup/jetstream/nats.go
index 248b0e656..051d55a35 100644
--- a/setup/jetstream/nats.go
+++ b/setup/jetstream/nats.go
@@ -1,6 +1,7 @@
package jetstream
import (
+ "crypto/tls"
"fmt"
"reflect"
"strings"
@@ -13,16 +14,16 @@ import (
"github.com/sirupsen/logrus"
natsserver "github.com/nats-io/nats-server/v2/server"
- "github.com/nats-io/nats.go"
natsclient "github.com/nats-io/nats.go"
)
type NATSInstance struct {
*natsserver.Server
- sync.Mutex
}
-func DeleteAllStreams(js nats.JetStreamContext, cfg *config.JetStream) {
+var natsLock sync.Mutex
+
+func DeleteAllStreams(js natsclient.JetStreamContext, cfg *config.JetStream) {
for _, stream := range streams { // streams are defined in streams.go
name := cfg.Prefixed(stream.Name)
_ = js.DeleteStream(name)
@@ -30,11 +31,12 @@ func DeleteAllStreams(js nats.JetStreamContext, cfg *config.JetStream) {
}
func (s *NATSInstance) Prepare(process *process.ProcessContext, cfg *config.JetStream) (natsclient.JetStreamContext, *natsclient.Conn) {
+ natsLock.Lock()
+ defer natsLock.Unlock()
// check if we need an in-process NATS Server
if len(cfg.Addresses) != 0 {
return setupNATS(process, cfg, nil)
}
- s.Lock()
if s.Server == nil {
var err error
s.Server, err = natsserver.NewServer(&natsserver.Options{
@@ -45,6 +47,7 @@ func (s *NATSInstance) Prepare(process *process.ProcessContext, cfg *config.JetS
NoSystemAccount: true,
MaxPayload: 16 * 1024 * 1024,
NoSigs: true,
+ NoLog: cfg.NoLog,
})
if err != nil {
panic(err)
@@ -61,7 +64,6 @@ func (s *NATSInstance) Prepare(process *process.ProcessContext, cfg *config.JetS
process.ComponentFinished()
}()
}
- s.Unlock()
if !s.ReadyForConnections(time.Second * 10) {
logrus.Fatalln("NATS did not start in time")
}
@@ -75,7 +77,13 @@ func (s *NATSInstance) Prepare(process *process.ProcessContext, cfg *config.JetS
func setupNATS(process *process.ProcessContext, cfg *config.JetStream, nc *natsclient.Conn) (natsclient.JetStreamContext, *natsclient.Conn) {
if nc == nil {
var err error
- nc, err = natsclient.Connect(strings.Join(cfg.Addresses, ","))
+ opts := []natsclient.Option{}
+ if cfg.DisableTLSValidation {
+ opts = append(opts, natsclient.Secure(&tls.Config{
+ InsecureSkipVerify: true,
+ }))
+ }
+ nc, err = natsclient.Connect(strings.Join(cfg.Addresses, ","), opts...)
if err != nil {
logrus.WithError(err).Panic("Unable to connect to NATS")
return nil, nil
diff --git a/setup/mscs/msc2836/msc2836_test.go b/setup/mscs/msc2836/msc2836_test.go
index 9044823af..3e9d90a1f 100644
--- a/setup/mscs/msc2836/msc2836_test.go
+++ b/setup/mscs/msc2836/msc2836_test.go
@@ -7,7 +7,7 @@ import (
"crypto/sha256"
"encoding/base64"
"encoding/json"
- "io/ioutil"
+ "io"
"net/http"
"sort"
"strings"
@@ -15,6 +15,8 @@ import (
"time"
"github.com/gorilla/mux"
+ "github.com/matrix-org/gomatrixserverlib"
+
"github.com/matrix-org/dendrite/internal/hooks"
"github.com/matrix-org/dendrite/internal/httputil"
roomserver "github.com/matrix-org/dendrite/roomserver/api"
@@ -22,7 +24,6 @@ import (
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/mscs/msc2836"
userapi "github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/gomatrixserverlib"
)
var (
@@ -32,15 +33,17 @@ var (
)
// Basic sanity check of MSC2836 logic. Injects a thread that looks like:
-// A
-// |
-// B
-// / \
-// C D
-// /|\
-// E F G
-// |
-// H
+//
+// A
+// |
+// B
+// / \
+// C D
+// /|\
+// E F G
+// |
+// H
+//
// And makes sure POST /event_relationships works with various parameters
func TestMSC2836(t *testing.T) {
alice := "@alice:localhost"
@@ -161,9 +164,9 @@ func TestMSC2836(t *testing.T) {
// make everyone joined to each other's rooms
nopRsAPI := &testRoomserverAPI{
userToJoinedRooms: map[string][]string{
- alice: []string{roomID},
- bob: []string{roomID},
- charlie: []string{roomID},
+ alice: {roomID},
+ bob: {roomID},
+ charlie: {roomID},
},
events: map[string]*gomatrixserverlib.HeaderedEvent{
eventA.EventID(): eventA,
@@ -425,12 +428,12 @@ func postRelationships(t *testing.T, expectCode int, accessToken string, req *ms
t.Fatalf("failed to do request: %s", err)
}
if res.StatusCode != expectCode {
- body, _ := ioutil.ReadAll(res.Body)
+ body, _ := io.ReadAll(res.Body)
t.Fatalf("wrong response code, got %d want %d - body: %s", res.StatusCode, expectCode, string(body))
}
if res.StatusCode == 200 {
var result msc2836.EventRelationshipResponse
- body, err := ioutil.ReadAll(res.Body)
+ body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("response 200 OK but failed to read response body: %s", err)
}
diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go
index 61520d50e..4cffa82ad 100644
--- a/setup/mscs/msc2946/msc2946.go
+++ b/setup/mscs/msc2946/msc2946.go
@@ -708,7 +708,6 @@ func stripped(ev *gomatrixserverlib.Event) *gomatrixserverlib.MSC2946StrippedEve
StateKey: *ev.StateKey(),
Content: ev.Content(),
Sender: ev.Sender(),
- RoomID: ev.RoomID(),
OriginServerTS: ev.OriginServerTS(),
}
}
diff --git a/syncapi/consumers/clientapi.go b/syncapi/consumers/clientapi.go
index eec369c1a..02633b567 100644
--- a/syncapi/consumers/clientapi.go
+++ b/syncapi/consumers/clientapi.go
@@ -21,6 +21,11 @@ import (
"fmt"
"github.com/getsentry/sentry-go"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/nats-io/nats.go"
+ "github.com/sirupsen/logrus"
+ log "github.com/sirupsen/logrus"
+
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
@@ -29,10 +34,6 @@ import (
"github.com/matrix-org/dendrite/syncapi/producers"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
- "github.com/matrix-org/gomatrixserverlib"
- "github.com/nats-io/nats.go"
- "github.com/sirupsen/logrus"
- log "github.com/sirupsen/logrus"
)
// OutputClientDataConsumer consumes events that originated in the client API server.
@@ -107,7 +108,8 @@ func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msg *nats.Msg)
"type": output.Type,
"room_id": output.RoomID,
log.ErrorKey: err,
- }).Panicf("could not save account data")
+ }).Errorf("could not save account data")
+ return false
}
if err = s.sendReadUpdate(ctx, userID, output); err != nil {
diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go
index f0ca2106f..f77b1673b 100644
--- a/syncapi/consumers/roomserver.go
+++ b/syncapi/consumers/roomserver.go
@@ -240,6 +240,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
msg.RemovesStateEventIDs,
msg.TransactionID,
false,
+ msg.HistoryVisibility,
)
if err != nil {
// panic rather than continue with an inconsistent database
@@ -289,7 +290,8 @@ func (s *OutputRoomEventConsumer) onOldRoomEvent(
[]string{}, // adds no state
[]string{}, // removes no state
nil, // no transaction
- ev.StateKey() != nil, // exclude from sync?
+ ev.StateKey() != nil, // exclude from sync?,
+ msg.HistoryVisibility,
)
if err != nil {
// panic rather than continue with an inconsistent database
@@ -363,7 +365,7 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent(
"event": string(msg.Event.JSON()),
"pdupos": pduPos,
log.ErrorKey: err,
- }).Panicf("roomserver output log: write invite failure")
+ }).Errorf("roomserver output log: write invite failure")
return
}
@@ -383,7 +385,7 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent(
log.WithFields(log.Fields{
"event_id": msg.EventID,
log.ErrorKey: err,
- }).Panicf("roomserver output log: remove invite failure")
+ }).Errorf("roomserver output log: remove invite failure")
return
}
@@ -401,7 +403,7 @@ func (s *OutputRoomEventConsumer) onNewPeek(
// panic rather than continue with an inconsistent database
log.WithFields(log.Fields{
log.ErrorKey: err,
- }).Panicf("roomserver output log: write peek failure")
+ }).Errorf("roomserver output log: write peek failure")
return
}
@@ -420,7 +422,7 @@ func (s *OutputRoomEventConsumer) onRetirePeek(
// panic rather than continue with an inconsistent database
log.WithFields(log.Fields{
log.ErrorKey: err,
- }).Panicf("roomserver output log: write peek failure")
+ }).Errorf("roomserver output log: write peek failure")
return
}
diff --git a/syncapi/internal/history_visibility.go b/syncapi/internal/history_visibility.go
new file mode 100644
index 000000000..e73c004e5
--- /dev/null
+++ b/syncapi/internal/history_visibility.go
@@ -0,0 +1,217 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package internal
+
+import (
+ "context"
+ "math"
+ "time"
+
+ "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/syncapi/storage"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/tidwall/gjson"
+)
+
+func init() {
+ prometheus.MustRegister(calculateHistoryVisibilityDuration)
+}
+
+// calculateHistoryVisibilityDuration stores the time it takes to
+// calculate the history visibility. In polylith mode the roundtrip
+// to the roomserver is included in this time.
+var calculateHistoryVisibilityDuration = prometheus.NewHistogramVec(
+ prometheus.HistogramOpts{
+ Namespace: "dendrite",
+ Subsystem: "syncapi",
+ Name: "calculateHistoryVisibility_duration_millis",
+ Help: "How long it takes to calculate the history visibility",
+ Buckets: []float64{ // milliseconds
+ 5, 10, 25, 50, 75, 100, 250, 500,
+ 1000, 2000, 3000, 4000, 5000, 6000,
+ 7000, 8000, 9000, 10000, 15000, 20000,
+ },
+ },
+ []string{"api"},
+)
+
+var historyVisibilityPriority = map[gomatrixserverlib.HistoryVisibility]uint8{
+ gomatrixserverlib.WorldReadable: 0,
+ gomatrixserverlib.HistoryVisibilityShared: 1,
+ gomatrixserverlib.HistoryVisibilityInvited: 2,
+ gomatrixserverlib.HistoryVisibilityJoined: 3,
+}
+
+// eventVisibility contains the history visibility and membership state at a given event
+type eventVisibility struct {
+ visibility gomatrixserverlib.HistoryVisibility
+ membershipAtEvent string
+ membershipCurrent string
+}
+
+// allowed checks the eventVisibility if the user is allowed to see the event.
+// Rules as defined by https://spec.matrix.org/v1.3/client-server-api/#server-behaviour-5
+func (ev eventVisibility) allowed() (allowed bool) {
+ switch ev.visibility {
+ case gomatrixserverlib.HistoryVisibilityWorldReadable:
+ // If the history_visibility was set to world_readable, allow.
+ return true
+ case gomatrixserverlib.HistoryVisibilityJoined:
+ // If the user’s membership was join, allow.
+ if ev.membershipAtEvent == gomatrixserverlib.Join {
+ return true
+ }
+ return false
+ case gomatrixserverlib.HistoryVisibilityShared:
+ // If the user’s membership was join, allow.
+ // If history_visibility was set to shared, and the user joined the room at any point after the event was sent, allow.
+ if ev.membershipAtEvent == gomatrixserverlib.Join || ev.membershipCurrent == gomatrixserverlib.Join {
+ return true
+ }
+ return false
+ case gomatrixserverlib.HistoryVisibilityInvited:
+ // If the user’s membership was join, allow.
+ if ev.membershipAtEvent == gomatrixserverlib.Join {
+ return true
+ }
+ if ev.membershipAtEvent == gomatrixserverlib.Invite {
+ return true
+ }
+ return false
+ default:
+ return false
+ }
+}
+
+// ApplyHistoryVisibilityFilter applies the room history visibility filter on gomatrixserverlib.HeaderedEvents.
+// Returns the filtered events and an error, if any.
+func ApplyHistoryVisibilityFilter(
+ ctx context.Context,
+ syncDB storage.Database,
+ rsAPI api.SyncRoomserverAPI,
+ events []*gomatrixserverlib.HeaderedEvent,
+ alwaysIncludeEventIDs map[string]struct{},
+ userID, endpoint string,
+) ([]*gomatrixserverlib.HeaderedEvent, error) {
+ if len(events) == 0 {
+ return events, nil
+ }
+ start := time.Now()
+
+ // try to get the current membership of the user
+ membershipCurrent, _, err := syncDB.SelectMembershipForUser(ctx, events[0].RoomID(), userID, math.MaxInt64)
+ if err != nil {
+ return nil, err
+ }
+
+ // Get the mapping from eventID -> eventVisibility
+ eventsFiltered := make([]*gomatrixserverlib.HeaderedEvent, 0, len(events))
+ visibilities, err := visibilityForEvents(ctx, rsAPI, events, userID, events[0].RoomID())
+ if err != nil {
+ return eventsFiltered, err
+ }
+ for _, ev := range events {
+ evVis := visibilities[ev.EventID()]
+ evVis.membershipCurrent = membershipCurrent
+ // Always include specific state events for /sync responses
+ if alwaysIncludeEventIDs != nil {
+ if _, ok := alwaysIncludeEventIDs[ev.EventID()]; ok {
+ eventsFiltered = append(eventsFiltered, ev)
+ continue
+ }
+ }
+ // NOTSPEC: Always allow user to see their own membership events (spec contains more "rules")
+ if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(userID) {
+ eventsFiltered = append(eventsFiltered, ev)
+ continue
+ }
+ // Always allow history evVis events on boundaries. This is done
+ // by setting the effective evVis to the least restrictive
+ // of the old vs new.
+ // https://spec.matrix.org/v1.3/client-server-api/#server-behaviour-5
+ if hisVis, err := ev.HistoryVisibility(); err == nil {
+ prevHisVis := gjson.GetBytes(ev.Unsigned(), "prev_content.history_visibility").String()
+ oldPrio, ok := historyVisibilityPriority[gomatrixserverlib.HistoryVisibility(prevHisVis)]
+ // if we can't get the previous history visibility, default to shared.
+ if !ok {
+ oldPrio = historyVisibilityPriority[gomatrixserverlib.HistoryVisibilityShared]
+ }
+ // no OK check, since this should have been validated when setting the value
+ newPrio := historyVisibilityPriority[hisVis]
+ if oldPrio < newPrio {
+ evVis.visibility = gomatrixserverlib.HistoryVisibility(prevHisVis)
+ }
+ }
+ // do the actual check
+ allowed := evVis.allowed()
+ if allowed {
+ eventsFiltered = append(eventsFiltered, ev)
+ }
+ }
+ calculateHistoryVisibilityDuration.With(prometheus.Labels{"api": endpoint}).Observe(float64(time.Since(start).Milliseconds()))
+ return eventsFiltered, nil
+}
+
+// visibilityForEvents returns a map from eventID to eventVisibility containing the visibility and the membership
+// of `userID` at the given event.
+// Returns an error if the roomserver can't calculate the memberships.
+func visibilityForEvents(
+ ctx context.Context,
+ rsAPI api.SyncRoomserverAPI,
+ events []*gomatrixserverlib.HeaderedEvent,
+ userID, roomID string,
+) (map[string]eventVisibility, error) {
+ eventIDs := make([]string, len(events))
+ for i := range events {
+ eventIDs[i] = events[i].EventID()
+ }
+
+ result := make(map[string]eventVisibility, len(eventIDs))
+
+ // get the membership events for all eventIDs
+ membershipResp := &api.QueryMembershipAtEventResponse{}
+ err := rsAPI.QueryMembershipAtEvent(ctx, &api.QueryMembershipAtEventRequest{
+ RoomID: roomID,
+ EventIDs: eventIDs,
+ UserID: userID,
+ }, membershipResp)
+ if err != nil {
+ return result, err
+ }
+
+ // Create a map from eventID -> eventVisibility
+ for _, event := range events {
+ eventID := event.EventID()
+ vis := eventVisibility{
+ membershipAtEvent: gomatrixserverlib.Leave, // default to leave, to not expose events by accident
+ visibility: event.Visibility,
+ }
+ membershipEvs, ok := membershipResp.Memberships[eventID]
+ if !ok {
+ result[eventID] = vis
+ continue
+ }
+ for _, ev := range membershipEvs {
+ membership, err := ev.Membership()
+ if err != nil {
+ return result, err
+ }
+ vis.membershipAtEvent = membership
+ }
+ result[eventID] = vis
+ }
+ return result, nil
+}
diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go
index d96718d20..23824e366 100644
--- a/syncapi/internal/keychange.go
+++ b/syncapi/internal/keychange.go
@@ -21,17 +21,17 @@ import (
keyapi "github.com/matrix-org/dendrite/keyserver/api"
keytypes "github.com/matrix-org/dendrite/keyserver/types"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
+ "github.com/sirupsen/logrus"
)
-const DeviceListLogName = "dl"
-
// DeviceOTKCounts adds one-time key counts to the /sync response
func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.SyncKeyAPI, userID, deviceID string, res *types.Response) error {
var queryRes keyapi.QueryOneTimeKeysResponse
- keyAPI.QueryOneTimeKeys(ctx, &keyapi.QueryOneTimeKeysRequest{
+ _ = keyAPI.QueryOneTimeKeys(ctx, &keyapi.QueryOneTimeKeysRequest{
UserID: userID,
DeviceID: deviceID,
}, &queryRes)
@@ -46,7 +46,7 @@ func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.SyncKeyAPI, userID, devi
// was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST
// be already filled in with join/leave information.
func DeviceListCatchup(
- ctx context.Context, keyAPI keyapi.SyncKeyAPI, rsAPI roomserverAPI.SyncRoomserverAPI,
+ ctx context.Context, db storage.SharedUsers, keyAPI keyapi.SyncKeyAPI, rsAPI roomserverAPI.SyncRoomserverAPI,
userID string, res *types.Response, from, to types.StreamPosition,
) (newPos types.StreamPosition, hasNew bool, err error) {
@@ -73,7 +73,7 @@ func DeviceListCatchup(
offset = int64(from)
}
var queryRes keyapi.QueryKeyChangesResponse
- keyAPI.QueryKeyChanges(ctx, &keyapi.QueryKeyChangesRequest{
+ _ = keyAPI.QueryKeyChanges(ctx, &keyapi.QueryKeyChangesRequest{
Offset: offset,
ToOffset: toOffset,
}, &queryRes)
@@ -92,18 +92,13 @@ func DeviceListCatchup(
queryRes.UserIDs = append(queryRes.UserIDs, joinUserIDs...)
queryRes.UserIDs = append(queryRes.UserIDs, leaveUserIDs...)
queryRes.UserIDs = util.UniqueStrings(queryRes.UserIDs)
- var sharedUsersMap map[string]int
- sharedUsersMap, queryRes.UserIDs = filterSharedUsers(ctx, rsAPI, userID, queryRes.UserIDs)
- util.GetLogger(ctx).Debugf(
- "QueryKeyChanges request off=%d,to=%d response off=%d uids=%v",
- offset, toOffset, queryRes.Offset, queryRes.UserIDs,
- )
+ sharedUsersMap := filterSharedUsers(ctx, db, userID, queryRes.UserIDs)
userSet := make(map[string]bool)
for _, userID := range res.DeviceLists.Changed {
userSet[userID] = true
}
- for _, userID := range queryRes.UserIDs {
- if !userSet[userID] {
+ for userID, count := range sharedUsersMap {
+ if !userSet[userID] && count > 0 {
res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID)
hasNew = true
userSet[userID] = true
@@ -112,7 +107,7 @@ func DeviceListCatchup(
// Finally, add in users who have joined or left.
// TODO: This is sub-optimal because we will add users to `changed` even if we already shared a room with them.
for _, userID := range joinUserIDs {
- if !userSet[userID] {
+ if !userSet[userID] && sharedUsersMap[userID] > 0 {
res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID)
hasNew = true
userSet[userID] = true
@@ -125,6 +120,13 @@ func DeviceListCatchup(
}
}
+ util.GetLogger(ctx).WithFields(logrus.Fields{
+ "user_id": userID,
+ "from": offset,
+ "to": toOffset,
+ "response_offset": queryRes.Offset,
+ }).Debugf("QueryKeyChanges request result: %+v", res.DeviceLists)
+
return types.StreamPosition(queryRes.Offset), hasNew, nil
}
@@ -215,30 +217,31 @@ func TrackChangedUsers(
return changed, left, nil
}
+// filterSharedUsers takes a list of remote users whose keys have changed and filters
+// it down to include only users who the requesting user shares a room with.
func filterSharedUsers(
- ctx context.Context, rsAPI roomserverAPI.SyncRoomserverAPI, userID string, usersWithChangedKeys []string,
-) (map[string]int, []string) {
- var result []string
- var sharedUsersRes roomserverAPI.QuerySharedUsersResponse
- err := rsAPI.QuerySharedUsers(ctx, &roomserverAPI.QuerySharedUsersRequest{
- UserID: userID,
- OtherUserIDs: usersWithChangedKeys,
- }, &sharedUsersRes)
- if err != nil {
- // default to all users so we do needless queries rather than miss some important device update
- return nil, usersWithChangedKeys
- }
- // We forcibly put ourselves in this list because we should be notified about our own device updates
- // and if we are in 0 rooms then we don't technically share any room with ourselves so we wouldn't
- // be notified about key changes.
- sharedUsersRes.UserIDsToCount[userID] = 1
-
- for _, uid := range usersWithChangedKeys {
- if sharedUsersRes.UserIDsToCount[uid] > 0 {
- result = append(result, uid)
+ ctx context.Context, db storage.SharedUsers, userID string, usersWithChangedKeys []string,
+) map[string]int {
+ sharedUsersMap := make(map[string]int, len(usersWithChangedKeys))
+ for _, changedUserID := range usersWithChangedKeys {
+ sharedUsersMap[changedUserID] = 0
+ if changedUserID == userID {
+ // We forcibly put ourselves in this list because we should be notified about our own device updates
+ // and if we are in 0 rooms then we don't technically share any room with ourselves so we wouldn't
+ // be notified about key changes.
+ sharedUsersMap[userID] = 1
}
}
- return sharedUsersRes.UserIDsToCount, result
+ sharedUsers, err := db.SharedUsers(ctx, userID, usersWithChangedKeys)
+ if err != nil {
+ util.GetLogger(ctx).WithError(err).Errorf("db.SharedUsers failed: %s", err)
+ // default to all users so we do needless queries rather than miss some important device update
+ return sharedUsersMap
+ }
+ for _, userID := range sharedUsers {
+ sharedUsersMap[userID]++
+ }
+ return sharedUsersMap
}
func joinedRooms(res *types.Response, userID string) []string {
diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go
index 219b35e2c..80d2811be 100644
--- a/syncapi/internal/keychange_test.go
+++ b/syncapi/internal/keychange_test.go
@@ -6,11 +6,13 @@ import (
"sort"
"testing"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/util"
+
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/gomatrixserverlib"
)
var (
@@ -20,31 +22,41 @@ var (
type mockKeyAPI struct{}
-func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *keyapi.PerformUploadKeysRequest, res *keyapi.PerformUploadKeysResponse) {
+func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *keyapi.PerformUploadKeysRequest, res *keyapi.PerformUploadKeysResponse) error {
+ return nil
}
func (k *mockKeyAPI) SetUserAPI(i userapi.UserInternalAPI) {}
// PerformClaimKeys claims one-time keys for use in pre-key messages
-func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *keyapi.PerformClaimKeysRequest, res *keyapi.PerformClaimKeysResponse) {
+func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *keyapi.PerformClaimKeysRequest, res *keyapi.PerformClaimKeysResponse) error {
+ return nil
}
-func (k *mockKeyAPI) PerformDeleteKeys(ctx context.Context, req *keyapi.PerformDeleteKeysRequest, res *keyapi.PerformDeleteKeysResponse) {
+func (k *mockKeyAPI) PerformDeleteKeys(ctx context.Context, req *keyapi.PerformDeleteKeysRequest, res *keyapi.PerformDeleteKeysResponse) error {
+ return nil
}
-func (k *mockKeyAPI) PerformUploadDeviceKeys(ctx context.Context, req *keyapi.PerformUploadDeviceKeysRequest, res *keyapi.PerformUploadDeviceKeysResponse) {
+func (k *mockKeyAPI) PerformUploadDeviceKeys(ctx context.Context, req *keyapi.PerformUploadDeviceKeysRequest, res *keyapi.PerformUploadDeviceKeysResponse) error {
+ return nil
}
-func (k *mockKeyAPI) PerformUploadDeviceSignatures(ctx context.Context, req *keyapi.PerformUploadDeviceSignaturesRequest, res *keyapi.PerformUploadDeviceSignaturesResponse) {
+func (k *mockKeyAPI) PerformUploadDeviceSignatures(ctx context.Context, req *keyapi.PerformUploadDeviceSignaturesRequest, res *keyapi.PerformUploadDeviceSignaturesResponse) error {
+ return nil
}
-func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *keyapi.QueryKeysRequest, res *keyapi.QueryKeysResponse) {
+func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *keyapi.QueryKeysRequest, res *keyapi.QueryKeysResponse) error {
+ return nil
}
-func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) {
+func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) error {
+ return nil
}
-func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) {
+func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) error {
+ return nil
}
-func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *keyapi.QueryDeviceMessagesRequest, res *keyapi.QueryDeviceMessagesResponse) {
+func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *keyapi.QueryDeviceMessagesRequest, res *keyapi.QueryDeviceMessagesResponse) error {
+ return nil
}
-func (k *mockKeyAPI) QuerySignatures(ctx context.Context, req *keyapi.QuerySignaturesRequest, res *keyapi.QuerySignaturesResponse) {
+func (k *mockKeyAPI) QuerySignatures(ctx context.Context, req *keyapi.QuerySignaturesRequest, res *keyapi.QuerySignaturesResponse) error {
+ return nil
}
type mockRoomserverAPI struct {
@@ -105,6 +117,22 @@ func (s *mockRoomserverAPI) QuerySharedUsers(ctx context.Context, req *api.Query
return nil
}
+// This is actually a database function, but seeing as we track the state inside the
+// *mockRoomserverAPI, we'll just comply with the interface here instead.
+func (s *mockRoomserverAPI) SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error) {
+ commonUsers := []string{}
+ for _, members := range s.roomIDToJoinedMembers {
+ for _, member := range members {
+ for _, userID := range otherUserIDs {
+ if member == userID {
+ commonUsers = append(commonUsers, userID)
+ }
+ }
+ }
+ }
+ return util.UniqueStrings(commonUsers), nil
+}
+
type wantCatchup struct {
hasNew bool
changed []string
@@ -112,6 +140,7 @@ type wantCatchup struct {
}
func assertCatchup(t *testing.T, hasNew bool, syncResponse *types.Response, want wantCatchup) {
+ t.Helper()
if hasNew != want.hasNew {
t.Errorf("got hasNew=%v want %v", hasNew, want.hasNew)
}
@@ -178,7 +207,7 @@ func TestKeyChangeCatchupOnJoinShareNewUser(t *testing.T) {
"!another:room": {syncingUser},
},
}
- _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
+ _, hasNew, err := DeviceListCatchup(context.Background(), rsAPI, &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
if err != nil {
t.Fatalf("DeviceListCatchup returned an error: %s", err)
}
@@ -201,7 +230,7 @@ func TestKeyChangeCatchupOnLeaveShareLeftUser(t *testing.T) {
"!another:room": {syncingUser},
},
}
- _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
+ _, hasNew, err := DeviceListCatchup(context.Background(), rsAPI, &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
if err != nil {
t.Fatalf("DeviceListCatchup returned an error: %s", err)
}
@@ -224,7 +253,7 @@ func TestKeyChangeCatchupOnJoinShareNoNewUsers(t *testing.T) {
"!another:room": {syncingUser, existingUser},
},
}
- _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
+ _, hasNew, err := DeviceListCatchup(context.Background(), rsAPI, &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
if err != nil {
t.Fatalf("Catchup returned an error: %s", err)
}
@@ -246,7 +275,7 @@ func TestKeyChangeCatchupOnLeaveShareNoUsers(t *testing.T) {
"!another:room": {syncingUser, existingUser},
},
}
- _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
+ _, hasNew, err := DeviceListCatchup(context.Background(), rsAPI, &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
if err != nil {
t.Fatalf("DeviceListCatchup returned an error: %s", err)
}
@@ -305,7 +334,7 @@ func TestKeyChangeCatchupNoNewJoinsButMessages(t *testing.T) {
roomID: {syncingUser, existingUser},
},
}
- _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
+ _, hasNew, err := DeviceListCatchup(context.Background(), rsAPI, &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
if err != nil {
t.Fatalf("DeviceListCatchup returned an error: %s", err)
}
@@ -333,7 +362,7 @@ func TestKeyChangeCatchupChangeAndLeft(t *testing.T) {
"!another:room": {syncingUser},
},
}
- _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
+ _, hasNew, err := DeviceListCatchup(context.Background(), rsAPI, &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
if err != nil {
t.Fatalf("Catchup returned an error: %s", err)
}
@@ -346,13 +375,14 @@ func TestKeyChangeCatchupChangeAndLeft(t *testing.T) {
// tests that joining/leaving the SAME room puts users in `left` if the final state is leave.
// NB: Consider the case:
-// - Alice and Bob are in a room.
-// - Alice goes offline, Charlie joins, sends encrypted messages then leaves the room.
-// - Alice comes back online. Technically nothing has changed in the set of users between those two points in time,
-// it's still just (Alice,Bob) but then we won't be tracking Charlie -- is this okay though? It's device keys
-// which are only relevant when actively sending events I think? And if Alice does need the keys she knows
-// charlie's (user_id, device_id) so can just hit /keys/query - no need to keep updated about it because she
-// doesn't share any rooms with him.
+// - Alice and Bob are in a room.
+// - Alice goes offline, Charlie joins, sends encrypted messages then leaves the room.
+// - Alice comes back online. Technically nothing has changed in the set of users between those two points in time,
+// it's still just (Alice,Bob) but then we won't be tracking Charlie -- is this okay though? It's device keys
+// which are only relevant when actively sending events I think? And if Alice does need the keys she knows
+// charlie's (user_id, device_id) so can just hit /keys/query - no need to keep updated about it because she
+// doesn't share any rooms with him.
+//
// Ergo, we put them in `left` as it is simpler.
func TestKeyChangeCatchupChangeAndLeftSameRoom(t *testing.T) {
newShareUser := "@berta:localhost"
@@ -419,7 +449,7 @@ func TestKeyChangeCatchupChangeAndLeftSameRoom(t *testing.T) {
},
}
_, hasNew, err := DeviceListCatchup(
- context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken,
+ context.Background(), rsAPI, &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken,
)
if err != nil {
t.Fatalf("DeviceListCatchup returned an error: %s", err)
diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go
index f6b4d15e0..13c4e9d89 100644
--- a/syncapi/routing/context.go
+++ b/syncapi/routing/context.go
@@ -21,10 +21,12 @@ import (
"fmt"
"net/http"
"strconv"
+ "time"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/caching"
roomserver "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/syncapi/internal"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
@@ -95,24 +97,6 @@ func Context(
ContainsURL: filter.ContainsURL,
}
- // TODO: Get the actual state at the last event returned by SelectContextAfterEvent
- state, _ := syncDB.CurrentState(ctx, roomID, &stateFilter, nil)
- // verify the user is allowed to see the context for this room/event
- for _, x := range state {
- var hisVis gomatrixserverlib.HistoryVisibility
- hisVis, err = x.HistoryVisibility()
- if err != nil {
- continue
- }
- allowed := hisVis == gomatrixserverlib.WorldReadable || membershipRes.Membership == gomatrixserverlib.Join
- if !allowed {
- return util.JSONResponse{
- Code: http.StatusForbidden,
- JSON: jsonerror.Forbidden("User is not allowed to query context"),
- }
- }
- }
-
id, requestedEvent, err := syncDB.SelectContextEvent(ctx, roomID, eventID)
if err != nil {
if err == sql.ErrNoRows {
@@ -125,6 +109,24 @@ func Context(
return jsonerror.InternalServerError()
}
+ // verify the user is allowed to see the context for this room/event
+ startTime := time.Now()
+ filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, syncDB, rsAPI, []*gomatrixserverlib.HeaderedEvent{&requestedEvent}, nil, device.UserID, "context")
+ if err != nil {
+ logrus.WithError(err).Error("unable to apply history visibility filter")
+ return jsonerror.InternalServerError()
+ }
+ logrus.WithFields(logrus.Fields{
+ "duration": time.Since(startTime),
+ "room_id": roomID,
+ }).Debug("applied history visibility (context)")
+ if len(filteredEvents) == 0 {
+ return util.JSONResponse{
+ Code: http.StatusForbidden,
+ JSON: jsonerror.Forbidden("User is not allowed to query context"),
+ }
+ }
+
eventsBefore, err := syncDB.SelectContextBeforeEvent(ctx, id, roomID, filter)
if err != nil && err != sql.ErrNoRows {
logrus.WithError(err).Error("unable to fetch before events")
@@ -137,8 +139,27 @@ func Context(
return jsonerror.InternalServerError()
}
- eventsBeforeClient := gomatrixserverlib.HeaderedToClientEvents(eventsBefore, gomatrixserverlib.FormatAll)
- eventsAfterClient := gomatrixserverlib.HeaderedToClientEvents(eventsAfter, gomatrixserverlib.FormatAll)
+ startTime = time.Now()
+ eventsBeforeFiltered, eventsAfterFiltered, err := applyHistoryVisibilityOnContextEvents(ctx, syncDB, rsAPI, eventsBefore, eventsAfter, device.UserID)
+ if err != nil {
+ logrus.WithError(err).Error("unable to apply history visibility filter")
+ return jsonerror.InternalServerError()
+ }
+
+ logrus.WithFields(logrus.Fields{
+ "duration": time.Since(startTime),
+ "room_id": roomID,
+ }).Debug("applied history visibility (context eventsBefore/eventsAfter)")
+
+ // TODO: Get the actual state at the last event returned by SelectContextAfterEvent
+ state, err := syncDB.CurrentState(ctx, roomID, &stateFilter, nil)
+ if err != nil {
+ logrus.WithError(err).Error("unable to fetch current room state")
+ return jsonerror.InternalServerError()
+ }
+
+ eventsBeforeClient := gomatrixserverlib.HeaderedToClientEvents(eventsBeforeFiltered, gomatrixserverlib.FormatAll)
+ eventsAfterClient := gomatrixserverlib.HeaderedToClientEvents(eventsAfterFiltered, gomatrixserverlib.FormatAll)
newState := applyLazyLoadMembers(device, filter, eventsAfterClient, eventsBeforeClient, state, lazyLoadCache)
response := ContextRespsonse{
@@ -162,6 +183,44 @@ func Context(
}
}
+// applyHistoryVisibilityOnContextEvents is a helper function to avoid roundtrips to the roomserver
+// by combining the events before and after the context event. Returns the filtered events,
+// and an error, if any.
+func applyHistoryVisibilityOnContextEvents(
+ ctx context.Context, syncDB storage.Database, rsAPI roomserver.SyncRoomserverAPI,
+ eventsBefore, eventsAfter []*gomatrixserverlib.HeaderedEvent,
+ userID string,
+) (filteredBefore, filteredAfter []*gomatrixserverlib.HeaderedEvent, err error) {
+ eventIDsBefore := make(map[string]struct{}, len(eventsBefore))
+ eventIDsAfter := make(map[string]struct{}, len(eventsAfter))
+
+ // Remember before/after eventIDs, so we can restore them
+ // after applying history visibility checks
+ for _, ev := range eventsBefore {
+ eventIDsBefore[ev.EventID()] = struct{}{}
+ }
+ for _, ev := range eventsAfter {
+ eventIDsAfter[ev.EventID()] = struct{}{}
+ }
+
+ allEvents := append(eventsBefore, eventsAfter...)
+ filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, syncDB, rsAPI, allEvents, nil, userID, "context")
+ if err != nil {
+ return nil, nil, err
+ }
+
+ // "Restore" events in the correct context
+ for _, ev := range filteredEvents {
+ if _, ok := eventIDsBefore[ev.EventID()]; ok {
+ filteredBefore = append(filteredBefore, ev)
+ }
+ if _, ok := eventIDsAfter[ev.EventID()]; ok {
+ filteredAfter = append(filteredAfter, ev)
+ }
+ }
+ return filteredBefore, filteredAfter, nil
+}
+
func getStartEnd(ctx context.Context, syncDB storage.Database, startEvents, endEvents []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) {
if len(startEvents) > 0 {
start, err = syncDB.EventPositionInTopology(ctx, startEvents[0].EventID())
diff --git a/syncapi/routing/filter.go b/syncapi/routing/filter.go
index 1a10bd649..f5acdbde3 100644
--- a/syncapi/routing/filter.go
+++ b/syncapi/routing/filter.go
@@ -16,16 +16,17 @@ package routing
import (
"encoding/json"
- "io/ioutil"
+ "io"
"net/http"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/util"
+ "github.com/tidwall/gjson"
+
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/sync"
"github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/gomatrixserverlib"
- "github.com/matrix-org/util"
- "github.com/tidwall/gjson"
)
// GetFilter implements GET /_matrix/client/r0/user/{userId}/filter/{filterId}
@@ -65,7 +66,9 @@ type filterResponse struct {
FilterID string `json:"filter_id"`
}
-//PutFilter implements POST /_matrix/client/r0/user/{userId}/filter
+// PutFilter implements
+//
+// POST /_matrix/client/r0/user/{userId}/filter
func PutFilter(
req *http.Request, device *api.Device, syncDB storage.Database, userID string,
) util.JSONResponse {
@@ -85,7 +88,7 @@ func PutFilter(
var filter gomatrixserverlib.Filter
defer req.Body.Close() // nolint:errcheck
- body, err := ioutil.ReadAll(req.Body)
+ body, err := io.ReadAll(req.Body)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go
index 24745cd55..9db3d8e17 100644
--- a/syncapi/routing/messages.go
+++ b/syncapi/routing/messages.go
@@ -19,18 +19,21 @@ import (
"fmt"
"net/http"
"sort"
+ "time"
+
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/util"
+ "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/dendrite/syncapi/internal"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/sync"
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/gomatrixserverlib"
- "github.com/matrix-org/util"
- "github.com/sirupsen/logrus"
)
type messagesReq struct {
@@ -262,7 +265,7 @@ func (m *messagesResp) applyLazyLoadMembers(
}
}
for _, evt := range membershipToUser {
- m.State = append(m.State, gomatrixserverlib.HeaderedToClientEvent(evt, gomatrixserverlib.FormatSync))
+ m.State = append(m.State, gomatrixserverlib.HeaderedToClientEvent(evt, gomatrixserverlib.FormatAll))
}
}
@@ -323,6 +326,9 @@ func (r *messagesReq) retrieveEvents() (
// reliable way to define it), it would be easier and less troublesome to
// only have to change it in one place, i.e. the database.
start, end, err = r.getStartEnd(events)
+ if err != nil {
+ return []gomatrixserverlib.ClientEvent{}, *r.from, *r.to, err
+ }
// Sort the events to ensure we send them in the right order.
if r.backwardOrdering {
@@ -336,97 +342,18 @@ func (r *messagesReq) retrieveEvents() (
}
events = reversed(events)
}
- events = r.filterHistoryVisible(events)
if len(events) == 0 {
return []gomatrixserverlib.ClientEvent{}, *r.from, *r.to, nil
}
- // Convert all of the events into client events.
- clientEvents = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatAll)
- return clientEvents, start, end, err
-}
-
-func (r *messagesReq) filterHistoryVisible(events []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent {
- // TODO FIXME: We don't fully implement history visibility yet. To avoid leaking events which the
- // user shouldn't see, we check the recent events and remove any prior to the join event of the user
- // which is equiv to history_visibility: joined
- joinEventIndex := -1
- for i, ev := range events {
- if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(r.device.UserID) {
- membership, _ := ev.Membership()
- if membership == "join" {
- joinEventIndex = i
- break
- }
- }
- }
-
- var result []*gomatrixserverlib.HeaderedEvent
- var eventsToCheck []*gomatrixserverlib.HeaderedEvent
- if joinEventIndex != -1 {
- if r.backwardOrdering {
- result = events[:joinEventIndex+1]
- eventsToCheck = append(eventsToCheck, result[0])
- } else {
- result = events[joinEventIndex:]
- eventsToCheck = append(eventsToCheck, result[len(result)-1])
- }
- } else {
- eventsToCheck = []*gomatrixserverlib.HeaderedEvent{events[0], events[len(events)-1]}
- result = events
- }
- // make sure the user was in the room for both the earliest and latest events, we need this because
- // some backpagination results will not have the join event (e.g if they hit /messages at the join event itself)
- wasJoined := true
- for _, ev := range eventsToCheck {
- var queryRes api.QueryStateAfterEventsResponse
- err := r.rsAPI.QueryStateAfterEvents(r.ctx, &api.QueryStateAfterEventsRequest{
- RoomID: ev.RoomID(),
- PrevEventIDs: ev.PrevEventIDs(),
- StateToFetch: []gomatrixserverlib.StateKeyTuple{
- {EventType: gomatrixserverlib.MRoomMember, StateKey: r.device.UserID},
- {EventType: gomatrixserverlib.MRoomHistoryVisibility, StateKey: ""},
- },
- }, &queryRes)
- if err != nil {
- wasJoined = false
- break
- }
- var hisVisEvent, membershipEvent *gomatrixserverlib.HeaderedEvent
- for i := range queryRes.StateEvents {
- switch queryRes.StateEvents[i].Type() {
- case gomatrixserverlib.MRoomMember:
- membershipEvent = queryRes.StateEvents[i]
- case gomatrixserverlib.MRoomHistoryVisibility:
- hisVisEvent = queryRes.StateEvents[i]
- }
- }
- if hisVisEvent == nil {
- return events // apply no filtering as it defaults to Shared.
- }
- hisVis, _ := hisVisEvent.HistoryVisibility()
- if hisVis == "shared" || hisVis == "world_readable" {
- return events // apply no filtering
- }
- if membershipEvent == nil {
- wasJoined = false
- break
- }
- membership, err := membershipEvent.Membership()
- if err != nil {
- wasJoined = false
- break
- }
- if membership != "join" {
- wasJoined = false
- break
- }
- }
- if !wasJoined {
- util.GetLogger(r.ctx).WithField("num_events", len(events)).Warnf("%s was not joined to room during these events, omitting them", r.device.UserID)
- return []*gomatrixserverlib.HeaderedEvent{}
- }
- return result
+ // Apply room history visibility filter
+ startTime := time.Now()
+ filteredEvents, err := internal.ApplyHistoryVisibilityFilter(r.ctx, r.db, r.rsAPI, events, nil, r.device.UserID, "messages")
+ logrus.WithFields(logrus.Fields{
+ "duration": time.Since(startTime),
+ "room_id": r.roomID,
+ }).Debug("applied history visibility (messages)")
+ return gomatrixserverlib.HeaderedToClientEvents(filteredEvents, gomatrixserverlib.FormatAll), start, end, err
}
func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) {
@@ -594,6 +521,7 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][]
[]string{},
[]string{},
nil, true,
+ gomatrixserverlib.HistoryVisibilityShared,
)
if err != nil {
return nil, err
diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go
index 5a036d889..43a75da95 100644
--- a/syncapi/storage/interface.go
+++ b/syncapi/storage/interface.go
@@ -27,6 +27,8 @@ import (
type Database interface {
Presence
+ SharedUsers
+
MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error)
@@ -67,7 +69,9 @@ type Database interface {
// when generating the sync stream position for this event. Returns the sync stream position for the inserted event.
// Returns an error if there was a problem inserting this event.
WriteEvent(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, addStateEvents []*gomatrixserverlib.HeaderedEvent,
- addStateEventIDs []string, removeStateEventIDs []string, transactionID *api.TransactionID, excludeFromSync bool) (types.StreamPosition, error)
+ addStateEventIDs []string, removeStateEventIDs []string, transactionID *api.TransactionID, excludeFromSync bool,
+ historyVisibility gomatrixserverlib.HistoryVisibility,
+ ) (types.StreamPosition, error)
// PurgeRoomState completely purges room state from the sync API. This is done when
// receiving an output event that completely resets the state.
PurgeRoomState(ctx context.Context, roomID string) error
@@ -157,6 +161,10 @@ type Database interface {
IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error)
UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error
+ // SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found
+ // returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty
+ // string as the membership.
+ SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error)
}
type Presence interface {
@@ -165,3 +173,8 @@ type Presence interface {
PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error)
MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error)
}
+
+type SharedUsers interface {
+ // SharedUsers returns a subset of otherUserIDs that share a room with userID.
+ SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error)
+}
diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go
index 8ee387b39..58f404511 100644
--- a/syncapi/storage/postgres/current_room_state_table.go
+++ b/syncapi/storage/postgres/current_room_state_table.go
@@ -23,6 +23,7 @@ import (
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
@@ -51,6 +52,7 @@ CREATE TABLE IF NOT EXISTS syncapi_current_room_state (
-- The serial ID of the output_room_events table when this event became
-- part of the current state of the room.
added_at BIGINT,
+ history_visibility SMALLINT NOT NULL DEFAULT 2,
-- Clobber based on 3-uple of room_id, type and state_key
CONSTRAINT syncapi_room_state_unique UNIQUE (room_id, type, state_key)
);
@@ -63,8 +65,8 @@ CREATE UNIQUE INDEX IF NOT EXISTS syncapi_current_room_state_eventid_idx ON sync
`
const upsertRoomStateSQL = "" +
- "INSERT INTO syncapi_current_room_state (room_id, event_id, type, sender, contains_url, state_key, headered_event_json, membership, added_at)" +
- " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" +
+ "INSERT INTO syncapi_current_room_state (room_id, event_id, type, sender, contains_url, state_key, headered_event_json, membership, added_at, history_visibility)" +
+ " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)" +
" ON CONFLICT ON CONSTRAINT syncapi_room_state_unique" +
" DO UPDATE SET event_id = $2, sender=$4, contains_url=$5, headered_event_json = $7, membership = $8, added_at = $9"
@@ -100,13 +102,18 @@ const selectStateEventSQL = "" +
"SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3"
const selectEventsWithEventIDsSQL = "" +
- // TODO: The session_id and transaction_id blanks are here because otherwise
- // the rowsToStreamEvents expects there to be exactly six columns. We need to
+ // TODO: The session_id and transaction_id blanks are here because
+ // the rowsToStreamEvents expects there to be exactly seven columns. We need to
// figure out if these really need to be in the DB, and if so, we need a
// better permanent fix for this. - neilalexander, 2 Jan 2020
- "SELECT event_id, added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" +
+ "SELECT event_id, added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id, history_visibility" +
" FROM syncapi_current_room_state WHERE event_id = ANY($1)"
+const selectSharedUsersSQL = "" +
+ "SELECT state_key FROM syncapi_current_room_state WHERE room_id = ANY(" +
+ " SELECT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" +
+ ") AND state_key = ANY($2) AND membership IN ('join', 'invite');"
+
type currentRoomStateStatements struct {
upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt *sql.Stmt
@@ -118,6 +125,7 @@ type currentRoomStateStatements struct {
selectJoinedUsersInRoomStmt *sql.Stmt
selectEventsWithEventIDsStmt *sql.Stmt
selectStateEventStmt *sql.Stmt
+ selectSharedUsersStmt *sql.Stmt
}
func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
@@ -126,6 +134,17 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro
if err != nil {
return nil, err
}
+
+ m := sqlutil.NewMigrator(db)
+ m.AddMigrations(sqlutil.Migration{
+ Version: "syncapi: add history visibility column (current_room_state)",
+ Up: deltas.UpAddHistoryVisibilityColumnCurrentRoomState,
+ })
+ err = m.Up(context.Background())
+ if err != nil {
+ return nil, err
+ }
+
if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil {
return nil, err
}
@@ -156,6 +175,9 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro
if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil {
return nil, err
}
+ if s.selectSharedUsersStmt, err = db.Prepare(selectSharedUsersSQL); err != nil {
+ return nil, err
+ }
return s, nil
}
@@ -327,6 +349,7 @@ func (s *currentRoomStateStatements) UpsertRoomState(
headeredJSON,
membership,
addedAt,
+ event.Visibility,
)
return err
}
@@ -379,3 +402,24 @@ func (s *currentRoomStateStatements) SelectStateEvent(
}
return &ev, err
}
+
+func (s *currentRoomStateStatements) SelectSharedUsers(
+ ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string,
+) ([]string, error) {
+ stmt := sqlutil.TxStmt(txn, s.selectSharedUsersStmt)
+ rows, err := stmt.QueryContext(ctx, userID, pq.Array(otherUserIDs))
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectSharedUsersStmt: rows.close() failed")
+
+ var stateKey string
+ result := make([]string, 0, len(otherUserIDs))
+ for rows.Next() {
+ if err := rows.Scan(&stateKey); err != nil {
+ return nil, err
+ }
+ result = append(result, stateKey)
+ }
+ return result, rows.Err()
+}
diff --git a/syncapi/storage/postgres/deltas/20201211125500_sequences.go b/syncapi/storage/postgres/deltas/20201211125500_sequences.go
index 7db524da5..6303c9472 100644
--- a/syncapi/storage/postgres/deltas/20201211125500_sequences.go
+++ b/syncapi/storage/postgres/deltas/20201211125500_sequences.go
@@ -15,24 +15,13 @@
package deltas
import (
+ "context"
"database/sql"
"fmt"
-
- "github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/pressly/goose"
)
-func LoadFromGoose() {
- goose.AddMigration(UpFixSequences, DownFixSequences)
- goose.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn)
-}
-
-func LoadFixSequences(m *sqlutil.Migrations) {
- m.AddMigration(UpFixSequences, DownFixSequences)
-}
-
-func UpFixSequences(tx *sql.Tx) error {
- _, err := tx.Exec(`
+func UpFixSequences(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
-- We need to delete all of the existing receipts because the indexes
-- will be wrong, and we'll get primary key violations if we try to
-- reuse existing stream IDs from a different sequence.
@@ -49,8 +38,8 @@ func UpFixSequences(tx *sql.Tx) error {
return nil
}
-func DownFixSequences(tx *sql.Tx) error {
- _, err := tx.Exec(`
+func DownFixSequences(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
-- We need to delete all of the existing receipts because the indexes
-- will be wrong, and we'll get primary key violations if we try to
-- reuse existing stream IDs from a different sequence.
diff --git a/syncapi/storage/postgres/deltas/20210112130000_sendtodevice_sentcolumn.go b/syncapi/storage/postgres/deltas/20210112130000_sendtodevice_sentcolumn.go
index 3690eca8e..77b083ae2 100644
--- a/syncapi/storage/postgres/deltas/20210112130000_sendtodevice_sentcolumn.go
+++ b/syncapi/storage/postgres/deltas/20210112130000_sendtodevice_sentcolumn.go
@@ -15,18 +15,13 @@
package deltas
import (
+ "context"
"database/sql"
"fmt"
-
- "github.com/matrix-org/dendrite/internal/sqlutil"
)
-func LoadRemoveSendToDeviceSentColumn(m *sqlutil.Migrations) {
- m.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn)
-}
-
-func UpRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
- _, err := tx.Exec(`
+func UpRemoveSendToDeviceSentColumn(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
ALTER TABLE syncapi_send_to_device
DROP COLUMN IF EXISTS sent_by_token;
`)
@@ -36,8 +31,8 @@ func UpRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
return nil
}
-func DownRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
- _, err := tx.Exec(`
+func DownRemoveSendToDeviceSentColumn(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
ALTER TABLE syncapi_send_to_device
ADD COLUMN IF NOT EXISTS sent_by_token TEXT;
`)
diff --git a/syncapi/storage/postgres/deltas/2022061412000000_history_visibility_column.go b/syncapi/storage/postgres/deltas/2022061412000000_history_visibility_column.go
new file mode 100644
index 000000000..d68ed8d5f
--- /dev/null
+++ b/syncapi/storage/postgres/deltas/2022061412000000_history_visibility_column.go
@@ -0,0 +1,109 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package deltas
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+ "fmt"
+
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+func UpAddHistoryVisibilityColumnOutputRoomEvents(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
+ ALTER TABLE syncapi_output_room_events ADD COLUMN IF NOT EXISTS history_visibility SMALLINT NOT NULL DEFAULT 2;
+ UPDATE syncapi_output_room_events SET history_visibility = 4 WHERE type IN ('m.room.message', 'm.room.encrypted');
+ `)
+ if err != nil {
+ return fmt.Errorf("failed to execute upgrade: %w", err)
+ }
+ return nil
+}
+
+// UpSetHistoryVisibility sets the history visibility for already stored events.
+// Requires current_room_state and output_room_events to be created.
+func UpSetHistoryVisibility(ctx context.Context, tx *sql.Tx) error {
+ // get the current room history visibilities
+ historyVisibilities, err := currentHistoryVisibilities(ctx, tx)
+ if err != nil {
+ return err
+ }
+
+ // update the history visibility
+ for roomID, hisVis := range historyVisibilities {
+ _, err = tx.ExecContext(ctx, `UPDATE syncapi_output_room_events SET history_visibility = $1
+ WHERE type IN ('m.room.message', 'm.room.encrypted') AND room_id = $2 AND history_visibility <> $1`, hisVis, roomID)
+ if err != nil {
+ return fmt.Errorf("failed to update history visibility: %w", err)
+ }
+ }
+
+ return nil
+}
+
+func UpAddHistoryVisibilityColumnCurrentRoomState(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
+ ALTER TABLE syncapi_current_room_state ADD COLUMN IF NOT EXISTS history_visibility SMALLINT NOT NULL DEFAULT 2;
+ UPDATE syncapi_current_room_state SET history_visibility = 4 WHERE type IN ('m.room.message', 'm.room.encrypted');
+ `)
+ if err != nil {
+ return fmt.Errorf("failed to execute upgrade: %w", err)
+ }
+
+ return nil
+}
+
+// currentHistoryVisibilities returns a map from roomID to current history visibility.
+// If the history visibility was changed after room creation, defaults to joined.
+func currentHistoryVisibilities(ctx context.Context, tx *sql.Tx) (map[string]gomatrixserverlib.HistoryVisibility, error) {
+ rows, err := tx.QueryContext(ctx, `SELECT DISTINCT room_id, headered_event_json FROM syncapi_current_room_state
+ WHERE type = 'm.room.history_visibility' AND state_key = '';
+`)
+ if err != nil {
+ return nil, fmt.Errorf("failed to query current room state: %w", err)
+ }
+ defer rows.Close() // nolint: errcheck
+ var eventBytes []byte
+ var roomID string
+ var event gomatrixserverlib.HeaderedEvent
+ var hisVis gomatrixserverlib.HistoryVisibility
+ historyVisibilities := make(map[string]gomatrixserverlib.HistoryVisibility)
+ for rows.Next() {
+ if err = rows.Scan(&roomID, &eventBytes); err != nil {
+ return nil, fmt.Errorf("failed to scan row: %w", err)
+ }
+ if err = json.Unmarshal(eventBytes, &event); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal event: %w", err)
+ }
+ historyVisibilities[roomID] = gomatrixserverlib.HistoryVisibilityJoined
+ if hisVis, err = event.HistoryVisibility(); err == nil && event.Depth() < 10 {
+ historyVisibilities[roomID] = hisVis
+ }
+ }
+ return historyVisibilities, nil
+}
+
+func DownAddHistoryVisibilityColumn(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
+ ALTER TABLE syncapi_output_room_events DROP COLUMN IF EXISTS history_visibility;
+ ALTER TABLE syncapi_current_room_state DROP COLUMN IF EXISTS history_visibility;
+ `)
+ if err != nil {
+ return fmt.Errorf("failed to execute downgrade: %w", err)
+ }
+ return nil
+}
diff --git a/syncapi/storage/postgres/memberships_table.go b/syncapi/storage/postgres/memberships_table.go
index 00223c57a..939d6b3f5 100644
--- a/syncapi/storage/postgres/memberships_table.go
+++ b/syncapi/storage/postgres/memberships_table.go
@@ -66,10 +66,14 @@ const selectMembershipCountSQL = "" +
const selectHeroesSQL = "" +
"SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership = ANY($3) LIMIT 5"
+const selectMembershipBeforeSQL = "" +
+ "SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1"
+
type membershipsStatements struct {
- upsertMembershipStmt *sql.Stmt
- selectMembershipCountStmt *sql.Stmt
- selectHeroesStmt *sql.Stmt
+ upsertMembershipStmt *sql.Stmt
+ selectMembershipCountStmt *sql.Stmt
+ selectHeroesStmt *sql.Stmt
+ selectMembershipForUserStmt *sql.Stmt
}
func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) {
@@ -82,6 +86,7 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) {
{&s.upsertMembershipStmt, upsertMembershipSQL},
{&s.selectMembershipCountStmt, selectMembershipCountSQL},
{&s.selectHeroesStmt, selectHeroesSQL},
+ {&s.selectMembershipForUserStmt, selectMembershipBeforeSQL},
}.Prepare(db)
}
@@ -132,3 +137,20 @@ func (s *membershipsStatements) SelectHeroes(
}
return heroes, rows.Err()
}
+
+// SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found
+// returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty
+// string as the membership.
+func (s *membershipsStatements) SelectMembershipForUser(
+ ctx context.Context, txn *sql.Tx, roomID, userID string, pos int64,
+) (membership string, topologyPos int, err error) {
+ stmt := sqlutil.TxStmt(txn, s.selectMembershipForUserStmt)
+ err = stmt.QueryRowContext(ctx, roomID, userID, pos).Scan(&membership, &topologyPos)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return "leave", 0, nil
+ }
+ return "", 0, err
+ }
+ return membership, topologyPos, nil
+}
diff --git a/syncapi/storage/postgres/notification_data_table.go b/syncapi/storage/postgres/notification_data_table.go
index f3fc4451f..9cd8b7362 100644
--- a/syncapi/storage/postgres/notification_data_table.go
+++ b/syncapi/storage/postgres/notification_data_table.go
@@ -58,7 +58,7 @@ const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_
(user_id, room_id, notification_count, highlight_count)
VALUES ($1, $2, $3, $4)
ON CONFLICT (user_id, room_id)
- DO UPDATE SET notification_count = $3, highlight_count = $4
+ DO UPDATE SET id = nextval('syncapi_notification_data_id_seq'), notification_count = $3, highlight_count = $4
RETURNING id`
const selectUserUnreadNotificationCountsSQL = `SELECT
diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go
index d84d0cfa2..8f633640e 100644
--- a/syncapi/storage/postgres/output_room_events_table.go
+++ b/syncapi/storage/postgres/output_room_events_table.go
@@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
@@ -67,7 +68,9 @@ CREATE TABLE IF NOT EXISTS syncapi_output_room_events (
-- events retrieved through backfilling that have a position in the stream
-- that relates to the moment these were retrieved rather than the moment these
-- were emitted.
- exclude_from_sync BOOL DEFAULT FALSE
+ exclude_from_sync BOOL DEFAULT FALSE,
+ -- The history visibility before this event (1 - world_readable; 2 - shared; 3 - invited; 4 - joined)
+ history_visibility SMALLINT NOT NULL DEFAULT 2
);
CREATE INDEX IF NOT EXISTS syncapi_output_room_events_type_idx ON syncapi_output_room_events (type);
@@ -78,16 +81,16 @@ CREATE INDEX IF NOT EXISTS syncapi_output_room_events_exclude_from_sync_idx ON s
const insertEventSQL = "" +
"INSERT INTO syncapi_output_room_events (" +
- "room_id, event_id, headered_event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id, exclude_from_sync" +
- ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) " +
+ "room_id, event_id, headered_event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id, exclude_from_sync, history_visibility" +
+ ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) " +
"ON CONFLICT ON CONSTRAINT syncapi_event_id_idx DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $11) " +
"RETURNING id"
const selectEventsSQL = "" +
- "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)"
+ "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility FROM syncapi_output_room_events WHERE event_id = ANY($1)"
const selectEventsWithFilterSQL = "" +
- "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)" +
+ "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility FROM syncapi_output_room_events WHERE event_id = ANY($1)" +
" AND ( $2::text[] IS NULL OR sender = ANY($2) )" +
" AND ( $3::text[] IS NULL OR NOT(sender = ANY($3)) )" +
" AND ( $4::text[] IS NULL OR type LIKE ANY($4) )" +
@@ -96,7 +99,7 @@ const selectEventsWithFilterSQL = "" +
" LIMIT $7"
const selectRecentEventsSQL = "" +
- "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
+ "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" +
" AND ( $4::text[] IS NULL OR sender = ANY($4) )" +
" AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) )" +
@@ -105,7 +108,7 @@ const selectRecentEventsSQL = "" +
" ORDER BY id DESC LIMIT $8"
const selectRecentEventsForSyncSQL = "" +
- "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
+ "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" +
" AND ( $4::text[] IS NULL OR sender = ANY($4) )" +
" AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) )" +
@@ -114,7 +117,7 @@ const selectRecentEventsForSyncSQL = "" +
" ORDER BY id DESC LIMIT $8"
const selectEarlyEventsSQL = "" +
- "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
+ "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" +
" AND ( $4::text[] IS NULL OR sender = ANY($4) )" +
" AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) )" +
@@ -130,7 +133,7 @@ const updateEventJSONSQL = "" +
// In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id).
const selectStateInRangeSQL = "" +
- "SELECT event_id, id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids" +
+ "SELECT event_id, id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids, history_visibility" +
" FROM syncapi_output_room_events" +
" WHERE (id > $1 AND id <= $2) AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" +
" AND room_id = ANY($3)" +
@@ -146,10 +149,10 @@ const deleteEventsForRoomSQL = "" +
"DELETE FROM syncapi_output_room_events WHERE room_id = $1"
const selectContextEventSQL = "" +
- "SELECT id, headered_event_json FROM syncapi_output_room_events WHERE room_id = $1 AND event_id = $2"
+ "SELECT id, headered_event_json, history_visibility FROM syncapi_output_room_events WHERE room_id = $1 AND event_id = $2"
const selectContextBeforeEventSQL = "" +
- "SELECT headered_event_json FROM syncapi_output_room_events WHERE room_id = $1 AND id < $2" +
+ "SELECT headered_event_json, history_visibility FROM syncapi_output_room_events WHERE room_id = $1 AND id < $2" +
" AND ( $4::text[] IS NULL OR sender = ANY($4) )" +
" AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) )" +
" AND ( $6::text[] IS NULL OR type LIKE ANY($6) )" +
@@ -157,7 +160,7 @@ const selectContextBeforeEventSQL = "" +
" ORDER BY id DESC LIMIT $3"
const selectContextAfterEventSQL = "" +
- "SELECT id, headered_event_json FROM syncapi_output_room_events WHERE room_id = $1 AND id > $2" +
+ "SELECT id, headered_event_json, history_visibility FROM syncapi_output_room_events WHERE room_id = $1 AND id > $2" +
" AND ( $4::text[] IS NULL OR sender = ANY($4) )" +
" AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) )" +
" AND ( $6::text[] IS NULL OR type LIKE ANY($6) )" +
@@ -186,6 +189,19 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
if err != nil {
return nil, err
}
+
+ m := sqlutil.NewMigrator(db)
+ m.AddMigrations(
+ sqlutil.Migration{
+ Version: "syncapi: add history visibility column (output_room_events)",
+ Up: deltas.UpAddHistoryVisibilityColumnOutputRoomEvents,
+ },
+ )
+ err = m.Up(context.Background())
+ if err != nil {
+ return nil, err
+ }
+
return s, sqlutil.StatementList{
{&s.insertEventStmt, insertEventSQL},
{&s.selectEventsStmt, selectEventsSQL},
@@ -246,14 +262,15 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
for rows.Next() {
var (
- eventID string
- streamPos types.StreamPosition
- eventBytes []byte
- excludeFromSync bool
- addIDs pq.StringArray
- delIDs pq.StringArray
+ eventID string
+ streamPos types.StreamPosition
+ eventBytes []byte
+ excludeFromSync bool
+ addIDs pq.StringArray
+ delIDs pq.StringArray
+ historyVisibility gomatrixserverlib.HistoryVisibility
)
- if err := rows.Scan(&eventID, &streamPos, &eventBytes, &excludeFromSync, &addIDs, &delIDs); err != nil {
+ if err := rows.Scan(&eventID, &streamPos, &eventBytes, &excludeFromSync, &addIDs, &delIDs, &historyVisibility); err != nil {
return nil, nil, err
}
// Sanity check for deleted state and whine if we see it. We don't need to do anything
@@ -283,6 +300,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
needSet[id] = true
}
stateNeeded[ev.RoomID()] = needSet
+ ev.Visibility = historyVisibility
eventIDToEvent[eventID] = types.StreamEvent{
HeaderedEvent: &ev,
@@ -314,7 +332,7 @@ func (s *outputRoomEventsStatements) SelectMaxEventID(
func (s *outputRoomEventsStatements) InsertEvent(
ctx context.Context, txn *sql.Tx,
event *gomatrixserverlib.HeaderedEvent, addState, removeState []string,
- transactionID *api.TransactionID, excludeFromSync bool,
+ transactionID *api.TransactionID, excludeFromSync bool, historyVisibility gomatrixserverlib.HistoryVisibility,
) (streamPos types.StreamPosition, err error) {
var txnID *string
var sessionID *int64
@@ -351,6 +369,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
sessionID,
txnID,
excludeFromSync,
+ historyVisibility,
).Scan(&streamPos)
return
}
@@ -504,13 +523,15 @@ func (s *outputRoomEventsStatements) SelectContextEvent(ctx context.Context, txn
row := sqlutil.TxStmt(txn, s.selectContextEventStmt).QueryRowContext(ctx, roomID, eventID)
var eventAsString string
- if err = row.Scan(&id, &eventAsString); err != nil {
+ var historyVisibility gomatrixserverlib.HistoryVisibility
+ if err = row.Scan(&id, &eventAsString, &historyVisibility); err != nil {
return 0, evt, err
}
if err = json.Unmarshal([]byte(eventAsString), &evt); err != nil {
return 0, evt, err
}
+ evt.Visibility = historyVisibility
return id, evt, nil
}
@@ -532,15 +553,17 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent(
for rows.Next() {
var (
- eventBytes []byte
- evt *gomatrixserverlib.HeaderedEvent
+ eventBytes []byte
+ evt *gomatrixserverlib.HeaderedEvent
+ historyVisibility gomatrixserverlib.HistoryVisibility
)
- if err = rows.Scan(&eventBytes); err != nil {
+ if err = rows.Scan(&eventBytes, &historyVisibility); err != nil {
return evts, err
}
if err = json.Unmarshal(eventBytes, &evt); err != nil {
return evts, err
}
+ evt.Visibility = historyVisibility
evts = append(evts, evt)
}
@@ -565,15 +588,17 @@ func (s *outputRoomEventsStatements) SelectContextAfterEvent(
for rows.Next() {
var (
- eventBytes []byte
- evt *gomatrixserverlib.HeaderedEvent
+ eventBytes []byte
+ evt *gomatrixserverlib.HeaderedEvent
+ historyVisibility gomatrixserverlib.HistoryVisibility
)
- if err = rows.Scan(&lastID, &eventBytes); err != nil {
+ if err = rows.Scan(&lastID, &eventBytes, &historyVisibility); err != nil {
return 0, evts, err
}
if err = json.Unmarshal(eventBytes, &evt); err != nil {
return 0, evts, err
}
+ evt.Visibility = historyVisibility
evts = append(evts, evt)
}
@@ -584,15 +609,16 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) {
var result []types.StreamEvent
for rows.Next() {
var (
- eventID string
- streamPos types.StreamPosition
- eventBytes []byte
- excludeFromSync bool
- sessionID *int64
- txnID *string
- transactionID *api.TransactionID
+ eventID string
+ streamPos types.StreamPosition
+ eventBytes []byte
+ excludeFromSync bool
+ sessionID *int64
+ txnID *string
+ transactionID *api.TransactionID
+ historyVisibility gomatrixserverlib.HistoryVisibility
)
- if err := rows.Scan(&eventID, &streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID); err != nil {
+ if err := rows.Scan(&eventID, &streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID, &historyVisibility); err != nil {
return nil, err
}
// TODO: Handle redacted events
@@ -607,7 +633,7 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) {
TransactionID: *txnID,
}
}
-
+ ev.Visibility = historyVisibility
result = append(result, types.StreamEvent{
HeaderedEvent: &ev,
StreamPosition: streamPos,
diff --git a/syncapi/storage/postgres/receipt_table.go b/syncapi/storage/postgres/receipt_table.go
index 2a42ffd74..bbddaa939 100644
--- a/syncapi/storage/postgres/receipt_table.go
+++ b/syncapi/storage/postgres/receipt_table.go
@@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
@@ -73,6 +74,15 @@ func NewPostgresReceiptsTable(db *sql.DB) (tables.Receipts, error) {
if err != nil {
return nil, err
}
+ m := sqlutil.NewMigrator(db)
+ m.AddMigrations(sqlutil.Migration{
+ Version: "syncapi: fix sequences",
+ Up: deltas.UpFixSequences,
+ })
+ err = m.Up(context.Background())
+ if err != nil {
+ return nil, err
+ }
r := &receiptStatements{
db: db,
}
diff --git a/syncapi/storage/postgres/send_to_device_table.go b/syncapi/storage/postgres/send_to_device_table.go
index 47c1cdaed..fd0c1c56b 100644
--- a/syncapi/storage/postgres/send_to_device_table.go
+++ b/syncapi/storage/postgres/send_to_device_table.go
@@ -21,8 +21,10 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
+ "github.com/sirupsen/logrus"
)
const sendToDeviceSchema = `
@@ -51,12 +53,12 @@ const selectSendToDeviceMessagesSQL = `
SELECT id, user_id, device_id, content
FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4
- ORDER BY id DESC
+ ORDER BY id ASC
`
const deleteSendToDeviceMessagesSQL = `
DELETE FROM syncapi_send_to_device
- WHERE user_id = $1 AND device_id = $2 AND id < $3
+ WHERE user_id = $1 AND device_id = $2 AND id <= $3
`
const selectMaxSendToDeviceIDSQL = "" +
@@ -75,6 +77,15 @@ func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
if err != nil {
return nil, err
}
+ m := sqlutil.NewMigrator(db)
+ m.AddMigrations(sqlutil.Migration{
+ Version: "syncapi: drop sent_by_token",
+ Up: deltas.UpRemoveSendToDeviceSentColumn,
+ })
+ err = m.Up(context.Background())
+ if err != nil {
+ return nil, err
+ }
if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil {
return nil, err
}
@@ -112,17 +123,18 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
if err = rows.Scan(&id, &userID, &deviceID, &content); err != nil {
return
}
- if id > lastPos {
- lastPos = id
- }
event := types.SendToDeviceEvent{
ID: id,
UserID: userID,
DeviceID: deviceID,
}
if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil {
+ logrus.WithError(err).Errorf("Failed to unmarshal send-to-device message")
continue
}
+ if id > lastPos {
+ lastPos = id
+ }
events = append(events, event)
}
if lastPos == 0 {
diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go
index 9cfe7c070..979ff6647 100644
--- a/syncapi/storage/postgres/syncserver.go
+++ b/syncapi/storage/postgres/syncserver.go
@@ -98,12 +98,20 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions)
if err != nil {
return nil, err
}
- m := sqlutil.NewMigrations()
- deltas.LoadFixSequences(m)
- deltas.LoadRemoveSendToDeviceSentColumn(m)
- if err = m.RunDeltas(d.db, dbProperties); err != nil {
+
+ // apply migrations which need multiple tables
+ m := sqlutil.NewMigrator(d.db)
+ m.AddMigrations(
+ sqlutil.Migration{
+ Version: "syncapi: set history visibility for existing events",
+ Up: deltas.UpSetHistoryVisibility, // Requires current_room_state and output_room_events to be created.
+ },
+ )
+ err = m.Up(base.Context())
+ if err != nil {
return nil, err
}
+
d.Database = shared.Database{
DB: d.db,
Writer: d.writer,
diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go
index 76114aff8..a46e55256 100644
--- a/syncapi/storage/shared/syncserver.go
+++ b/syncapi/storage/shared/syncserver.go
@@ -176,6 +176,10 @@ func (d *Database) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]t
return d.Peeks.SelectPeekingDevices(ctx)
}
+func (d *Database) SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error) {
+ return d.CurrentRoomState.SelectSharedUsers(ctx, nil, userID, otherUserIDs)
+}
+
func (d *Database) GetStateEvent(
ctx context.Context, roomID, evType, stateKey string,
) (*gomatrixserverlib.HeaderedEvent, error) {
@@ -227,7 +231,7 @@ func (d *Database) AddPeek(
return
}
-// DeletePeeks tracks the fact that a user has stopped peeking from the specified
+// DeletePeek tracks the fact that a user has stopped peeking from the specified
// device. If the peeks was successfully deleted this returns the stream ID it was
// stored at. Returns an error if there was a problem communicating with the database.
func (d *Database) DeletePeek(
@@ -364,11 +368,13 @@ func (d *Database) WriteEvent(
addStateEvents []*gomatrixserverlib.HeaderedEvent,
addStateEventIDs, removeStateEventIDs []string,
transactionID *api.TransactionID, excludeFromSync bool,
+ historyVisibility gomatrixserverlib.HistoryVisibility,
) (pduPosition types.StreamPosition, returnErr error) {
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var err error
+ ev.Visibility = historyVisibility
pos, err := d.OutputEvents.InsertEvent(
- ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync,
+ ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, historyVisibility,
)
if err != nil {
return fmt.Errorf("d.OutputEvents.InsertEvent: %w", err)
@@ -387,7 +393,9 @@ func (d *Database) WriteEvent(
// Nothing to do, the event may have just been a message event.
return nil
}
-
+ for i := range addStateEvents {
+ addStateEvents[i].Visibility = historyVisibility
+ }
return d.updateRoomState(ctx, txn, removeStateEventIDs, addStateEvents, pduPosition, topoPosition)
})
@@ -556,7 +564,7 @@ func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, reda
return err
}
-// Retrieve the backward topology position, i.e. the position of the
+// GetBackwardTopologyPos retrieves the backward topology position, i.e. the position of the
// oldest event in the room's topology.
func (d *Database) GetBackwardTopologyPos(
ctx context.Context,
@@ -667,7 +675,7 @@ func (d *Database) fetchMissingStateEvents(
return events, nil
}
-// getStateDeltas returns the state deltas between fromPos and toPos,
+// GetStateDeltas returns the state deltas between fromPos and toPos,
// exclusive of oldPos, inclusive of newPos, for the rooms in which
// the user has new membership events.
// A list of joined room IDs is also returned in case the caller needs it.
@@ -805,7 +813,7 @@ func (d *Database) GetStateDeltas(
return deltas, joinedRoomIDs, nil
}
-// getStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync
+// GetStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync
// requests with full_state=true.
// Fetches full state for all joined rooms and uses selectStateInRange to get
// updates for other rooms.
@@ -1032,37 +1040,41 @@ func (d *Database) GetUserUnreadNotificationCounts(ctx context.Context, userID s
return d.NotificationData.SelectUserUnreadCounts(ctx, userID, from, to)
}
-func (s *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) {
- return s.OutputEvents.SelectContextEvent(ctx, nil, roomID, eventID)
+func (d *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) {
+ return d.OutputEvents.SelectContextEvent(ctx, nil, roomID, eventID)
}
-func (s *Database) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) {
- return s.OutputEvents.SelectContextBeforeEvent(ctx, nil, id, roomID, filter)
+func (d *Database) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) {
+ return d.OutputEvents.SelectContextBeforeEvent(ctx, nil, id, roomID, filter)
}
-func (s *Database) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) {
- return s.OutputEvents.SelectContextAfterEvent(ctx, nil, id, roomID, filter)
+func (d *Database) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) {
+ return d.OutputEvents.SelectContextAfterEvent(ctx, nil, id, roomID, filter)
}
-func (s *Database) IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error) {
- return s.Ignores.SelectIgnores(ctx, userID)
+func (d *Database) IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error) {
+ return d.Ignores.SelectIgnores(ctx, userID)
}
-func (s *Database) UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error {
- return s.Ignores.UpsertIgnores(ctx, userID, ignores)
+func (d *Database) UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error {
+ return d.Ignores.UpsertIgnores(ctx, userID, ignores)
}
-func (s *Database) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) {
- return s.Presence.UpsertPresence(ctx, nil, userID, statusMsg, presence, lastActiveTS, fromSync)
+func (d *Database) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) {
+ return d.Presence.UpsertPresence(ctx, nil, userID, statusMsg, presence, lastActiveTS, fromSync)
}
-func (s *Database) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) {
- return s.Presence.GetPresenceForUser(ctx, nil, userID)
+func (d *Database) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) {
+ return d.Presence.GetPresenceForUser(ctx, nil, userID)
}
-func (s *Database) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) {
- return s.Presence.GetPresenceAfter(ctx, nil, after, filter)
+func (d *Database) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) {
+ return d.Presence.GetPresenceAfter(ctx, nil, after, filter)
}
-func (s *Database) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) {
- return s.Presence.GetMaxPresenceID(ctx, nil)
+func (d *Database) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) {
+ return d.Presence.GetMaxPresenceID(ctx, nil)
+}
+
+func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) {
+ return d.Memberships.SelectMembershipForUser(ctx, nil, roomID, userID, pos)
}
diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go
index f0a1c7bb7..3a10b2325 100644
--- a/syncapi/storage/sqlite3/current_room_state_table.go
+++ b/syncapi/storage/sqlite3/current_room_state_table.go
@@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
@@ -41,6 +42,7 @@ CREATE TABLE IF NOT EXISTS syncapi_current_room_state (
headered_event_json TEXT NOT NULL,
membership TEXT,
added_at BIGINT,
+ history_visibility SMALLINT NOT NULL DEFAULT 2, -- The history visibility before this event (1 - world_readable; 2 - shared; 3 - invited; 4 - joined)
UNIQUE (room_id, type, state_key)
);
-- for event deletion
@@ -52,8 +54,8 @@ CREATE UNIQUE INDEX IF NOT EXISTS syncapi_current_room_state_eventid_idx ON sync
`
const upsertRoomStateSQL = "" +
- "INSERT INTO syncapi_current_room_state (room_id, event_id, type, sender, contains_url, state_key, headered_event_json, membership, added_at)" +
- " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" +
+ "INSERT INTO syncapi_current_room_state (room_id, event_id, type, sender, contains_url, state_key, headered_event_json, membership, added_at, history_visibility)" +
+ " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)" +
" ON CONFLICT (room_id, type, state_key)" +
" DO UPDATE SET event_id = $2, sender=$4, contains_url=$5, headered_event_json = $7, membership = $8, added_at = $9"
@@ -84,13 +86,18 @@ const selectStateEventSQL = "" +
"SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3"
const selectEventsWithEventIDsSQL = "" +
- // TODO: The session_id and transaction_id blanks are here because otherwise
- // the rowsToStreamEvents expects there to be exactly six columns. We need to
+ // TODO: The session_id and transaction_id blanks are here because
+ // the rowsToStreamEvents expects there to be exactly seven columns. We need to
// figure out if these really need to be in the DB, and if so, we need a
// better permanent fix for this. - neilalexander, 2 Jan 2020
- "SELECT event_id, added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" +
+ "SELECT event_id, added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id, history_visibility" +
" FROM syncapi_current_room_state WHERE event_id IN ($1)"
+const selectSharedUsersSQL = "" +
+ "SELECT state_key FROM syncapi_current_room_state WHERE room_id IN(" +
+ " SELECT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" +
+ ") AND state_key IN ($2) AND membership IN ('join', 'invite');"
+
type currentRoomStateStatements struct {
db *sql.DB
streamIDStatements *StreamIDStatements
@@ -100,8 +107,9 @@ type currentRoomStateStatements struct {
selectRoomIDsWithMembershipStmt *sql.Stmt
selectRoomIDsWithAnyMembershipStmt *sql.Stmt
selectJoinedUsersStmt *sql.Stmt
- //selectJoinedUsersInRoomStmt *sql.Stmt - prepared at runtime due to variadic
+ //selectJoinedUsersInRoomStmt *sql.Stmt - prepared at runtime due to variadic
selectStateEventStmt *sql.Stmt
+ //selectSharedUsersSQL *sql.Stmt - prepared at runtime due to variadic
}
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) {
@@ -113,6 +121,17 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (t
if err != nil {
return nil, err
}
+
+ m := sqlutil.NewMigrator(db)
+ m.AddMigrations(sqlutil.Migration{
+ Version: "syncapi: add history visibility column (current_room_state)",
+ Up: deltas.UpAddHistoryVisibilityColumnCurrentRoomState,
+ })
+ err = m.Up(context.Background())
+ if err != nil {
+ return nil, err
+ }
+
if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil {
return nil, err
}
@@ -322,6 +341,7 @@ func (s *currentRoomStateStatements) UpsertRoomState(
headeredJSON,
membership,
addedAt,
+ event.Visibility,
)
return err
}
@@ -396,3 +416,32 @@ func (s *currentRoomStateStatements) SelectStateEvent(
}
return &ev, err
}
+
+func (s *currentRoomStateStatements) SelectSharedUsers(
+ ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string,
+) ([]string, error) {
+
+ params := make([]interface{}, len(otherUserIDs)+1)
+ params[0] = userID
+ for k, v := range otherUserIDs {
+ params[k+1] = v
+ }
+
+ result := make([]string, 0, len(otherUserIDs))
+ query := strings.Replace(selectSharedUsersSQL, "($2)", sqlutil.QueryVariadicOffset(len(otherUserIDs), 1), 1)
+ err := sqlutil.RunLimitedVariablesQuery(
+ ctx, query, s.db, params, sqlutil.SQLite3MaxVariables,
+ func(rows *sql.Rows) error {
+ var stateKey string
+ for rows.Next() {
+ if err := rows.Scan(&stateKey); err != nil {
+ return err
+ }
+ result = append(result, stateKey)
+ }
+ return nil
+ },
+ )
+
+ return result, err
+}
diff --git a/syncapi/storage/sqlite3/deltas/20201211125500_sequences.go b/syncapi/storage/sqlite3/deltas/20201211125500_sequences.go
index 8e7ebff86..f476335d5 100644
--- a/syncapi/storage/sqlite3/deltas/20201211125500_sequences.go
+++ b/syncapi/storage/sqlite3/deltas/20201211125500_sequences.go
@@ -15,24 +15,13 @@
package deltas
import (
+ "context"
"database/sql"
"fmt"
-
- "github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/pressly/goose"
)
-func LoadFromGoose() {
- goose.AddMigration(UpFixSequences, DownFixSequences)
- goose.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn)
-}
-
-func LoadFixSequences(m *sqlutil.Migrations) {
- m.AddMigration(UpFixSequences, DownFixSequences)
-}
-
-func UpFixSequences(tx *sql.Tx) error {
- _, err := tx.Exec(`
+func UpFixSequences(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
-- We need to delete all of the existing receipts because the indexes
-- will be wrong, and we'll get primary key violations if we try to
-- reuse existing stream IDs from a different sequence.
@@ -45,8 +34,8 @@ func UpFixSequences(tx *sql.Tx) error {
return nil
}
-func DownFixSequences(tx *sql.Tx) error {
- _, err := tx.Exec(`
+func DownFixSequences(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
-- We need to delete all of the existing receipts because the indexes
-- will be wrong, and we'll get primary key violations if we try to
-- reuse existing stream IDs from a different sequence.
diff --git a/syncapi/storage/sqlite3/deltas/20210112130000_sendtodevice_sentcolumn.go b/syncapi/storage/sqlite3/deltas/20210112130000_sendtodevice_sentcolumn.go
index e0c514102..34cae2241 100644
--- a/syncapi/storage/sqlite3/deltas/20210112130000_sendtodevice_sentcolumn.go
+++ b/syncapi/storage/sqlite3/deltas/20210112130000_sendtodevice_sentcolumn.go
@@ -15,18 +15,13 @@
package deltas
import (
+ "context"
"database/sql"
"fmt"
-
- "github.com/matrix-org/dendrite/internal/sqlutil"
)
-func LoadRemoveSendToDeviceSentColumn(m *sqlutil.Migrations) {
- m.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn)
-}
-
-func UpRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
- _, err := tx.Exec(`
+func UpRemoveSendToDeviceSentColumn(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
CREATE TEMPORARY TABLE syncapi_send_to_device_backup(id, user_id, device_id, content);
INSERT INTO syncapi_send_to_device_backup SELECT id, user_id, device_id, content FROM syncapi_send_to_device;
DROP TABLE syncapi_send_to_device;
@@ -45,8 +40,8 @@ func UpRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
return nil
}
-func DownRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
- _, err := tx.Exec(`
+func DownRemoveSendToDeviceSentColumn(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
CREATE TEMPORARY TABLE syncapi_send_to_device_backup(id, user_id, device_id, content);
INSERT INTO syncapi_send_to_device_backup SELECT id, user_id, device_id, content FROM syncapi_send_to_device;
DROP TABLE syncapi_send_to_device;
diff --git a/syncapi/storage/sqlite3/deltas/2022061412000000_history_visibility_column.go b/syncapi/storage/sqlite3/deltas/2022061412000000_history_visibility_column.go
new file mode 100644
index 000000000..d23f07566
--- /dev/null
+++ b/syncapi/storage/sqlite3/deltas/2022061412000000_history_visibility_column.go
@@ -0,0 +1,137 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package deltas
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+ "fmt"
+
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+func UpAddHistoryVisibilityColumnOutputRoomEvents(ctx context.Context, tx *sql.Tx) error {
+ // SQLite doesn't have "if exists", so check if the column exists. If the query doesn't return an error, it already exists.
+ // Required for unit tests, as otherwise a duplicate column error will show up.
+ _, err := tx.QueryContext(ctx, "SELECT history_visibility FROM syncapi_output_room_events LIMIT 1")
+ if err == nil {
+ return nil
+ }
+ _, err = tx.ExecContext(ctx, `
+ ALTER TABLE syncapi_output_room_events ADD COLUMN history_visibility SMALLINT NOT NULL DEFAULT 2;
+ UPDATE syncapi_output_room_events SET history_visibility = 4 WHERE type IN ('m.room.message', 'm.room.encrypted');
+ `)
+ if err != nil {
+ return fmt.Errorf("failed to execute upgrade: %w", err)
+ }
+ return nil
+}
+
+// UpSetHistoryVisibility sets the history visibility for already stored events.
+// Requires current_room_state and output_room_events to be created.
+func UpSetHistoryVisibility(ctx context.Context, tx *sql.Tx) error {
+ // get the current room history visibilities
+ historyVisibilities, err := currentHistoryVisibilities(ctx, tx)
+ if err != nil {
+ return err
+ }
+
+ // update the history visibility
+ for roomID, hisVis := range historyVisibilities {
+ _, err = tx.ExecContext(ctx, `UPDATE syncapi_output_room_events SET history_visibility = $1
+ WHERE type IN ('m.room.message', 'm.room.encrypted') AND room_id = $2 AND history_visibility <> $1`, hisVis, roomID)
+ if err != nil {
+ return fmt.Errorf("failed to update history visibility: %w", err)
+ }
+ }
+
+ return nil
+}
+
+func UpAddHistoryVisibilityColumnCurrentRoomState(ctx context.Context, tx *sql.Tx) error {
+ // SQLite doesn't have "if exists", so check if the column exists. If the query doesn't return an error, it already exists.
+ // Required for unit tests, as otherwise a duplicate column error will show up.
+ _, err := tx.QueryContext(ctx, "SELECT history_visibility FROM syncapi_current_room_state LIMIT 1")
+ if err == nil {
+ return nil
+ }
+ _, err = tx.ExecContext(ctx, `
+ ALTER TABLE syncapi_current_room_state ADD COLUMN history_visibility SMALLINT NOT NULL DEFAULT 2;
+ UPDATE syncapi_current_room_state SET history_visibility = 4 WHERE type IN ('m.room.message', 'm.room.encrypted');
+ `)
+ if err != nil {
+ return fmt.Errorf("failed to execute upgrade: %w", err)
+ }
+
+ return nil
+}
+
+// currentHistoryVisibilities returns a map from roomID to current history visibility.
+// If the history visibility was changed after room creation, defaults to joined.
+func currentHistoryVisibilities(ctx context.Context, tx *sql.Tx) (map[string]gomatrixserverlib.HistoryVisibility, error) {
+ rows, err := tx.QueryContext(ctx, `SELECT DISTINCT room_id, headered_event_json FROM syncapi_current_room_state
+ WHERE type = 'm.room.history_visibility' AND state_key = '';
+`)
+ if err != nil {
+ return nil, fmt.Errorf("failed to query current room state: %w", err)
+ }
+ defer rows.Close() // nolint: errcheck
+ var eventBytes []byte
+ var roomID string
+ var event gomatrixserverlib.HeaderedEvent
+ var hisVis gomatrixserverlib.HistoryVisibility
+ historyVisibilities := make(map[string]gomatrixserverlib.HistoryVisibility)
+ for rows.Next() {
+ if err = rows.Scan(&roomID, &eventBytes); err != nil {
+ return nil, fmt.Errorf("failed to scan row: %w", err)
+ }
+ if err = json.Unmarshal(eventBytes, &event); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal event: %w", err)
+ }
+ historyVisibilities[roomID] = gomatrixserverlib.HistoryVisibilityJoined
+ if hisVis, err = event.HistoryVisibility(); err == nil && event.Depth() < 10 {
+ historyVisibilities[roomID] = hisVis
+ }
+ }
+ return historyVisibilities, nil
+}
+
+func DownAddHistoryVisibilityColumn(ctx context.Context, tx *sql.Tx) error {
+ // SQLite doesn't have "if exists", so check if the column exists.
+ _, err := tx.QueryContext(ctx, "SELECT history_visibility FROM syncapi_output_room_events LIMIT 1")
+ if err != nil {
+ // The column probably doesn't exist
+ return nil
+ }
+ _, err = tx.ExecContext(ctx, `
+ ALTER TABLE syncapi_output_room_events DROP COLUMN history_visibility;
+ `)
+ if err != nil {
+ return fmt.Errorf("failed to execute downgrade: %w", err)
+ }
+ _, err = tx.QueryContext(ctx, "SELECT history_visibility FROM syncapi_current_room_state LIMIT 1")
+ if err != nil {
+ // The column probably doesn't exist
+ return nil
+ }
+ _, err = tx.ExecContext(ctx, `
+ ALTER TABLE syncapi_current_room_state DROP COLUMN history_visibility;
+ `)
+ if err != nil {
+ return fmt.Errorf("failed to execute downgrade: %w", err)
+ }
+ return nil
+}
diff --git a/syncapi/storage/sqlite3/memberships_table.go b/syncapi/storage/sqlite3/memberships_table.go
index e4daa99c1..0c966fca0 100644
--- a/syncapi/storage/sqlite3/memberships_table.go
+++ b/syncapi/storage/sqlite3/memberships_table.go
@@ -66,11 +66,15 @@ const selectMembershipCountSQL = "" +
const selectHeroesSQL = "" +
"SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership IN ($3) LIMIT 5"
+const selectMembershipBeforeSQL = "" +
+ "SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1"
+
type membershipsStatements struct {
db *sql.DB
upsertMembershipStmt *sql.Stmt
selectMembershipCountStmt *sql.Stmt
//selectHeroesStmt *sql.Stmt - prepared at runtime due to variadic
+ selectMembershipForUserStmt *sql.Stmt
}
func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) {
@@ -84,6 +88,7 @@ func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) {
return s, sqlutil.StatementList{
{&s.upsertMembershipStmt, upsertMembershipSQL},
{&s.selectMembershipCountStmt, selectMembershipCountSQL},
+ {&s.selectMembershipForUserStmt, selectMembershipBeforeSQL},
// {&s.selectHeroesStmt, selectHeroesSQL}, - prepared at runtime due to variadic
}.Prepare(db)
}
@@ -148,3 +153,20 @@ func (s *membershipsStatements) SelectHeroes(
}
return heroes, rows.Err()
}
+
+// SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found
+// returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty
+// string as the membership.
+func (s *membershipsStatements) SelectMembershipForUser(
+ ctx context.Context, txn *sql.Tx, roomID, userID string, pos int64,
+) (membership string, topologyPos int, err error) {
+ stmt := sqlutil.TxStmt(txn, s.selectMembershipForUserStmt)
+ err = stmt.QueryRowContext(ctx, roomID, userID, pos).Scan(&membership, &topologyPos)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return "leave", 0, nil
+ }
+ return "", 0, err
+ }
+ return membership, topologyPos, nil
+}
diff --git a/syncapi/storage/sqlite3/notification_data_table.go b/syncapi/storage/sqlite3/notification_data_table.go
index 4b3f074db..eaa11a8c0 100644
--- a/syncapi/storage/sqlite3/notification_data_table.go
+++ b/syncapi/storage/sqlite3/notification_data_table.go
@@ -25,12 +25,14 @@ import (
"github.com/matrix-org/dendrite/syncapi/types"
)
-func NewSqliteNotificationDataTable(db *sql.DB) (tables.NotificationData, error) {
+func NewSqliteNotificationDataTable(db *sql.DB, streamID *StreamIDStatements) (tables.NotificationData, error) {
_, err := db.Exec(notificationDataSchema)
if err != nil {
return nil, err
}
- r := ¬ificationDataStatements{}
+ r := ¬ificationDataStatements{
+ streamIDStatements: streamID,
+ }
return r, sqlutil.StatementList{
{&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL},
{&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL},
@@ -39,6 +41,7 @@ func NewSqliteNotificationDataTable(db *sql.DB) (tables.NotificationData, error)
}
type notificationDataStatements struct {
+ streamIDStatements *StreamIDStatements
upsertRoomUnreadCounts *sql.Stmt
selectUserUnreadCounts *sql.Stmt
selectMaxID *sql.Stmt
@@ -58,8 +61,7 @@ const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_
(user_id, room_id, notification_count, highlight_count)
VALUES ($1, $2, $3, $4)
ON CONFLICT (user_id, room_id)
- DO UPDATE SET notification_count = $3, highlight_count = $4
- RETURNING id`
+ DO UPDATE SET id = $5, notification_count = $6, highlight_count = $7`
const selectUserUnreadNotificationCountsSQL = `SELECT
id, room_id, notification_count, highlight_count
@@ -71,7 +73,11 @@ const selectUserUnreadNotificationCountsSQL = `SELECT
const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data`
func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) {
- err = r.upsertRoomUnreadCounts.QueryRowContext(ctx, userID, roomID, notificationCount, highlightCount).Scan(&pos)
+ pos, err = r.streamIDStatements.nextNotificationID(ctx, nil)
+ if err != nil {
+ return
+ }
+ _, err = r.upsertRoomUnreadCounts.ExecContext(ctx, userID, roomID, notificationCount, highlightCount, pos, notificationCount, highlightCount)
return
}
diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go
index f9961a9d1..91fd35b5b 100644
--- a/syncapi/storage/sqlite3/output_room_events_table.go
+++ b/syncapi/storage/sqlite3/output_room_events_table.go
@@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
@@ -47,7 +48,8 @@ CREATE TABLE IF NOT EXISTS syncapi_output_room_events (
remove_state_ids TEXT, -- JSON encoded string array
session_id BIGINT,
transaction_id TEXT,
- exclude_from_sync BOOL NOT NULL DEFAULT FALSE
+ exclude_from_sync BOOL NOT NULL DEFAULT FALSE,
+ history_visibility SMALLINT NOT NULL DEFAULT 2 -- The history visibility before this event (1 - world_readable; 2 - shared; 3 - invited; 4 - joined)
);
CREATE INDEX IF NOT EXISTS syncapi_output_room_events_type_idx ON syncapi_output_room_events (type);
@@ -58,27 +60,27 @@ CREATE INDEX IF NOT EXISTS syncapi_output_room_events_exclude_from_sync_idx ON s
const insertEventSQL = "" +
"INSERT INTO syncapi_output_room_events (" +
- "id, room_id, event_id, headered_event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id, exclude_from_sync" +
- ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) " +
- "ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $13)"
+ "id, room_id, event_id, headered_event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id, exclude_from_sync, history_visibility" +
+ ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) " +
+ "ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $14)"
const selectEventsSQL = "" +
- "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id IN ($1)"
+ "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility FROM syncapi_output_room_events WHERE event_id IN ($1)"
const selectRecentEventsSQL = "" +
- "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
+ "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3"
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const selectRecentEventsForSyncSQL = "" +
- "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
+ "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE"
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const selectEarlyEventsSQL = "" +
- "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
+ "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3"
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
@@ -90,7 +92,7 @@ const updateEventJSONSQL = "" +
"UPDATE syncapi_output_room_events SET headered_event_json=$1 WHERE event_id=$2"
const selectStateInRangeSQL = "" +
- "SELECT event_id, id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids" +
+ "SELECT event_id, id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids, history_visibility" +
" FROM syncapi_output_room_events" +
" WHERE (id > $1 AND id <= $2)" +
" AND room_id IN ($3)" +
@@ -102,15 +104,15 @@ const deleteEventsForRoomSQL = "" +
"DELETE FROM syncapi_output_room_events WHERE room_id = $1"
const selectContextEventSQL = "" +
- "SELECT id, headered_event_json FROM syncapi_output_room_events WHERE room_id = $1 AND event_id = $2"
+ "SELECT id, headered_event_json, history_visibility FROM syncapi_output_room_events WHERE room_id = $1 AND event_id = $2"
const selectContextBeforeEventSQL = "" +
- "SELECT headered_event_json FROM syncapi_output_room_events WHERE room_id = $1 AND id < $2"
+ "SELECT headered_event_json, history_visibility FROM syncapi_output_room_events WHERE room_id = $1 AND id < $2"
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const selectContextAfterEventSQL = "" +
- "SELECT id, headered_event_json FROM syncapi_output_room_events WHERE room_id = $1 AND id > $2"
+ "SELECT id, headered_event_json, history_visibility FROM syncapi_output_room_events WHERE room_id = $1 AND id > $2"
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
@@ -135,6 +137,19 @@ func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Even
if err != nil {
return nil, err
}
+
+ m := sqlutil.NewMigrator(db)
+ m.AddMigrations(
+ sqlutil.Migration{
+ Version: "syncapi: add history visibility column (output_room_events)",
+ Up: deltas.UpAddHistoryVisibilityColumnOutputRoomEvents,
+ },
+ )
+ err = m.Up(context.Background())
+ if err != nil {
+ return nil, err
+ }
+
return s, sqlutil.StatementList{
{&s.insertEventStmt, insertEventSQL},
{&s.selectMaxEventIDStmt, selectMaxEventIDSQL},
@@ -196,14 +211,15 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
for rows.Next() {
var (
- eventID string
- streamPos types.StreamPosition
- eventBytes []byte
- excludeFromSync bool
- addIDsJSON string
- delIDsJSON string
+ eventID string
+ streamPos types.StreamPosition
+ eventBytes []byte
+ excludeFromSync bool
+ addIDsJSON string
+ delIDsJSON string
+ historyVisibility gomatrixserverlib.HistoryVisibility
)
- if err := rows.Scan(&eventID, &streamPos, &eventBytes, &excludeFromSync, &addIDsJSON, &delIDsJSON); err != nil {
+ if err := rows.Scan(&eventID, &streamPos, &eventBytes, &excludeFromSync, &addIDsJSON, &delIDsJSON, &historyVisibility); err != nil {
return nil, nil, err
}
@@ -239,6 +255,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
needSet[id] = true
}
stateNeeded[ev.RoomID()] = needSet
+ ev.Visibility = historyVisibility
eventIDToEvent[eventID] = types.StreamEvent{
HeaderedEvent: &ev,
@@ -270,7 +287,7 @@ func (s *outputRoomEventsStatements) SelectMaxEventID(
func (s *outputRoomEventsStatements) InsertEvent(
ctx context.Context, txn *sql.Tx,
event *gomatrixserverlib.HeaderedEvent, addState, removeState []string,
- transactionID *api.TransactionID, excludeFromSync bool,
+ transactionID *api.TransactionID, excludeFromSync bool, historyVisibility gomatrixserverlib.HistoryVisibility,
) (types.StreamPosition, error) {
var txnID *string
var sessionID *int64
@@ -326,6 +343,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
sessionID,
txnID,
excludeFromSync,
+ historyVisibility,
excludeFromSync,
)
return streamPos, err
@@ -481,15 +499,16 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) {
var result []types.StreamEvent
for rows.Next() {
var (
- eventID string
- streamPos types.StreamPosition
- eventBytes []byte
- excludeFromSync bool
- sessionID *int64
- txnID *string
- transactionID *api.TransactionID
+ eventID string
+ streamPos types.StreamPosition
+ eventBytes []byte
+ excludeFromSync bool
+ sessionID *int64
+ txnID *string
+ transactionID *api.TransactionID
+ historyVisibility gomatrixserverlib.HistoryVisibility
)
- if err := rows.Scan(&eventID, &streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID); err != nil {
+ if err := rows.Scan(&eventID, &streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID, &historyVisibility); err != nil {
return nil, err
}
// TODO: Handle redacted events
@@ -505,6 +524,8 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) {
}
}
+ ev.Visibility = historyVisibility
+
result = append(result, types.StreamEvent{
HeaderedEvent: &ev,
StreamPosition: streamPos,
@@ -519,13 +540,15 @@ func (s *outputRoomEventsStatements) SelectContextEvent(
) (id int, evt gomatrixserverlib.HeaderedEvent, err error) {
row := sqlutil.TxStmt(txn, s.selectContextEventStmt).QueryRowContext(ctx, roomID, eventID)
var eventAsString string
- if err = row.Scan(&id, &eventAsString); err != nil {
+ var historyVisibility gomatrixserverlib.HistoryVisibility
+ if err = row.Scan(&id, &eventAsString, &historyVisibility); err != nil {
return 0, evt, err
}
if err = json.Unmarshal([]byte(eventAsString), &evt); err != nil {
return 0, evt, err
}
+ evt.Visibility = historyVisibility
return id, evt, nil
}
@@ -550,15 +573,17 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent(
for rows.Next() {
var (
- eventBytes []byte
- evt *gomatrixserverlib.HeaderedEvent
+ eventBytes []byte
+ evt *gomatrixserverlib.HeaderedEvent
+ historyVisibility gomatrixserverlib.HistoryVisibility
)
- if err = rows.Scan(&eventBytes); err != nil {
+ if err = rows.Scan(&eventBytes, &historyVisibility); err != nil {
return evts, err
}
if err = json.Unmarshal(eventBytes, &evt); err != nil {
return evts, err
}
+ evt.Visibility = historyVisibility
evts = append(evts, evt)
}
@@ -586,15 +611,17 @@ func (s *outputRoomEventsStatements) SelectContextAfterEvent(
for rows.Next() {
var (
- eventBytes []byte
- evt *gomatrixserverlib.HeaderedEvent
+ eventBytes []byte
+ evt *gomatrixserverlib.HeaderedEvent
+ historyVisibility gomatrixserverlib.HistoryVisibility
)
- if err = rows.Scan(&lastID, &eventBytes); err != nil {
+ if err = rows.Scan(&lastID, &eventBytes, &historyVisibility); err != nil {
return 0, evts, err
}
if err = json.Unmarshal(eventBytes, &evt); err != nil {
return 0, evts, err
}
+ evt.Visibility = historyVisibility
evts = append(evts, evt)
}
return lastID, evts, rows.Err()
diff --git a/syncapi/storage/sqlite3/receipt_table.go b/syncapi/storage/sqlite3/receipt_table.go
index bd778bf3c..31adb005b 100644
--- a/syncapi/storage/sqlite3/receipt_table.go
+++ b/syncapi/storage/sqlite3/receipt_table.go
@@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
@@ -70,6 +71,15 @@ func NewSqliteReceiptsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Re
if err != nil {
return nil, err
}
+ m := sqlutil.NewMigrator(db)
+ m.AddMigrations(sqlutil.Migration{
+ Version: "syncapi: fix sequences",
+ Up: deltas.UpFixSequences,
+ })
+ err = m.Up(context.Background())
+ if err != nil {
+ return nil, err
+ }
r := &receiptStatements{
db: db,
streamIDStatements: streamID,
diff --git a/syncapi/storage/sqlite3/send_to_device_table.go b/syncapi/storage/sqlite3/send_to_device_table.go
index 0b1d5bbf2..e3aa1b7a1 100644
--- a/syncapi/storage/sqlite3/send_to_device_table.go
+++ b/syncapi/storage/sqlite3/send_to_device_table.go
@@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/sirupsen/logrus"
@@ -49,12 +50,12 @@ const selectSendToDeviceMessagesSQL = `
SELECT id, user_id, device_id, content
FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4
- ORDER BY id DESC
+ ORDER BY id ASC
`
const deleteSendToDeviceMessagesSQL = `
DELETE FROM syncapi_send_to_device
- WHERE user_id = $1 AND device_id = $2 AND id < $3
+ WHERE user_id = $1 AND device_id = $2 AND id <= $3
`
const selectMaxSendToDeviceIDSQL = "" +
@@ -76,6 +77,15 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
if err != nil {
return nil, err
}
+ m := sqlutil.NewMigrator(db)
+ m.AddMigrations(sqlutil.Migration{
+ Version: "syncapi: drop sent_by_token",
+ Up: deltas.UpRemoveSendToDeviceSentColumn,
+ })
+ err = m.Up(context.Background())
+ if err != nil {
+ return nil, err
+ }
if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil {
return nil, err
}
@@ -120,9 +130,6 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
logrus.WithError(err).Errorf("Failed to retrieve send-to-device message")
return
}
- if id > lastPos {
- lastPos = id
- }
event := types.SendToDeviceEvent{
ID: id,
UserID: userID,
@@ -132,6 +139,9 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
logrus.WithError(err).Errorf("Failed to unmarshal send-to-device message")
continue
}
+ if id > lastPos {
+ lastPos = id
+ }
events = append(events, event)
}
if lastPos == 0 {
diff --git a/syncapi/storage/sqlite3/stream_id_table.go b/syncapi/storage/sqlite3/stream_id_table.go
index 71980b806..1160a437e 100644
--- a/syncapi/storage/sqlite3/stream_id_table.go
+++ b/syncapi/storage/sqlite3/stream_id_table.go
@@ -26,6 +26,8 @@ INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("invite", 0)
ON CONFLICT DO NOTHING;
INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("presence", 0)
ON CONFLICT DO NOTHING;
+INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("notification", 0)
+ ON CONFLICT DO NOTHING;
`
const increaseStreamIDStmt = "" +
@@ -78,3 +80,9 @@ func (s *StreamIDStatements) nextPresenceID(ctx context.Context, txn *sql.Tx) (p
err = increaseStmt.QueryRowContext(ctx, "presence").Scan(&pos)
return
}
+
+func (s *StreamIDStatements) nextNotificationID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
+ increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
+ err = increaseStmt.QueryRowContext(ctx, "notification").Scan(&pos)
+ return
+}
diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go
index e08a0ba82..a84e2bd16 100644
--- a/syncapi/storage/sqlite3/syncserver.go
+++ b/syncapi/storage/sqlite3/syncserver.go
@@ -16,6 +16,7 @@
package sqlite3
import (
+ "context"
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
@@ -42,13 +43,13 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions)
if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil {
return nil, err
}
- if err = d.prepare(dbProperties); err != nil {
+ if err = d.prepare(base.Context()); err != nil {
return nil, err
}
return &d, nil
}
-func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (err error) {
+func (d *SyncServerDatasource) prepare(ctx context.Context) (err error) {
if err = d.streamID.Prepare(d.db); err != nil {
return err
}
@@ -96,7 +97,7 @@ func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (er
if err != nil {
return err
}
- notificationData, err := NewSqliteNotificationDataTable(d.db)
+ notificationData, err := NewSqliteNotificationDataTable(d.db, &d.streamID)
if err != nil {
return err
}
@@ -108,10 +109,17 @@ func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (er
if err != nil {
return err
}
- m := sqlutil.NewMigrations()
- deltas.LoadFixSequences(m)
- deltas.LoadRemoveSendToDeviceSentColumn(m)
- if err = m.RunDeltas(d.db, dbProperties); err != nil {
+
+ // apply migrations which need multiple tables
+ m := sqlutil.NewMigrator(d.db)
+ m.AddMigrations(
+ sqlutil.Migration{
+ Version: "syncapi: set history visibility for existing events",
+ Up: deltas.UpSetHistoryVisibility, // Requires current_room_state and output_room_events to be created.
+ },
+ )
+ err = m.Up(ctx)
+ if err != nil {
return err
}
d.Database = shared.Database{
diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go
index 563c92e34..a62818e9b 100644
--- a/syncapi/storage/storage_test.go
+++ b/syncapi/storage/storage_test.go
@@ -1,7 +1,9 @@
package storage_test
import (
+ "bytes"
"context"
+ "encoding/json"
"fmt"
"reflect"
"testing"
@@ -10,20 +12,22 @@ import (
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/test"
+ "github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/gomatrixserverlib"
)
var ctx = context.Background()
-func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
+func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func(), func()) {
connStr, close := test.PrepareDBConnectionString(t, dbType)
- db, err := storage.NewSyncServerDatasource(nil, &config.DatabaseOptions{
+ base, closeBase := testrig.CreateBaseDendrite(t, dbType)
+ db, err := storage.NewSyncServerDatasource(base, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
})
if err != nil {
t.Fatalf("NewSyncServerDatasource returned %s", err)
}
- return db, close
+ return db, close, closeBase
}
func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserverlib.HeaderedEvent) (positions []types.StreamPosition) {
@@ -35,7 +39,7 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver
addStateEvents = append(addStateEvents, ev)
addStateEventIDs = append(addStateEventIDs, ev.EventID())
}
- pos, err := db.WriteEvent(ctx, ev, addStateEvents, addStateEventIDs, removeStateEventIDs, nil, false)
+ pos, err := db.WriteEvent(ctx, ev, addStateEvents, addStateEventIDs, removeStateEventIDs, nil, false, gomatrixserverlib.HistoryVisibilityShared)
if err != nil {
t.Fatalf("WriteEvent failed: %s", err)
}
@@ -49,8 +53,9 @@ func TestWriteEvents(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
alice := test.NewUser(t)
r := test.NewRoom(t, alice)
- db, close := MustCreateDatabase(t, dbType)
+ db, close, closeBase := MustCreateDatabase(t, dbType)
defer close()
+ defer closeBase()
MustWriteEvents(t, db, r.Events())
})
}
@@ -58,8 +63,9 @@ func TestWriteEvents(t *testing.T) {
// These tests assert basic functionality of RecentEvents for PDUs
func TestRecentEventsPDU(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, close := MustCreateDatabase(t, dbType)
+ db, close, closeBase := MustCreateDatabase(t, dbType)
defer close()
+ defer closeBase()
alice := test.NewUser(t)
// dummy room to make sure SQL queries are filtering on room ID
MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
@@ -161,8 +167,9 @@ func TestRecentEventsPDU(t *testing.T) {
// The purpose of this test is to ensure that backfill does indeed go backwards, using a topology token
func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, close := MustCreateDatabase(t, dbType)
+ db, close, closeBase := MustCreateDatabase(t, dbType)
defer close()
+ defer closeBase()
alice := test.NewUser(t)
r := test.NewRoom(t, alice)
for i := 0; i < 10; i++ {
@@ -394,90 +401,113 @@ func TestGetEventsInRangeWithEventsInsertedLikeBackfill(t *testing.T) {
from = topologyTokenBefore(t, db, paginatedEvents[len(paginatedEvents)-1].EventID())
}
}
+*/
func TestSendToDeviceBehaviour(t *testing.T) {
- //t.Parallel()
- db := MustCreateDatabase(t)
+ t.Parallel()
+ alice := test.NewUser(t)
+ bob := test.NewUser(t)
+ deviceID := "one"
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, close, closeBase := MustCreateDatabase(t, dbType)
+ defer close()
+ defer closeBase()
+ // At this point there should be no messages. We haven't sent anything
+ // yet.
+ _, events, err := db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, 100)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(events) != 0 {
+ t.Fatal("first call should have no updates")
+ }
- // At this point there should be no messages. We haven't sent anything
- // yet.
- _, events, updates, deletions, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{})
- if err != nil {
- t.Fatal(err)
- }
- if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 {
- t.Fatal("first call should have no updates")
- }
- err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.StreamingToken{})
- if err != nil {
- return
- }
+ // Try sending a message.
+ streamPos, err := db.StoreNewSendForDeviceMessage(ctx, alice.ID, deviceID, gomatrixserverlib.SendToDeviceEvent{
+ Sender: bob.ID,
+ Type: "m.type",
+ Content: json.RawMessage("{}"),
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
- // Try sending a message.
- streamPos, err := db.StoreNewSendForDeviceMessage(ctx, "alice", "one", gomatrixserverlib.SendToDeviceEvent{
- Sender: "bob",
- Type: "m.type",
- Content: json.RawMessage("{}"),
+ // At this point we should get exactly one message. We're sending the sync position
+ // that we were given from the update and the send-to-device update will be updated
+ // in the database to reflect that this was the sync position we sent the message at.
+ streamPos, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if count := len(events); count != 1 {
+ t.Fatalf("second call should have one update, got %d", count)
+ }
+
+ // At this point we should still have one message because we haven't progressed the
+ // sync position yet. This is equivalent to the client failing to /sync and retrying
+ // with the same position.
+ streamPos, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(events) != 1 {
+ t.Fatal("third call should have one update still")
+ }
+ err = db.CleanSendToDeviceUpdates(context.Background(), alice.ID, deviceID, streamPos)
+ if err != nil {
+ return
+ }
+
+ // At this point we should now have no updates, because we've progressed the sync
+ // position. Therefore the update from before will not be sent again.
+ _, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(events) != 0 {
+ t.Fatal("fourth call should have no updates")
+ }
+
+ // At this point we should still have no updates, because no new updates have been
+ // sent.
+ _, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(events) != 0 {
+ t.Fatal("fifth call should have no updates")
+ }
+
+ // Send some more messages and verify the ordering is correct ("in order of arrival")
+ var lastPos types.StreamPosition = 0
+ for i := 0; i < 10; i++ {
+ streamPos, err = db.StoreNewSendForDeviceMessage(ctx, alice.ID, deviceID, gomatrixserverlib.SendToDeviceEvent{
+ Sender: bob.ID,
+ Type: "m.type",
+ Content: json.RawMessage(fmt.Sprintf(`{"count":%d}`, i)),
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ lastPos = streamPos
+ }
+
+ _, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, lastPos)
+ if err != nil {
+ t.Fatalf("unable to get events: %v", err)
+ }
+
+ for i := 0; i < 10; i++ {
+ want := json.RawMessage(fmt.Sprintf(`{"count":%d}`, i))
+ got := events[i].Content
+ if !bytes.Equal(got, want) {
+ t.Fatalf("messages are out of order\nwant: %s\ngot: %s", string(want), string(got))
+ }
+ }
})
- if err != nil {
- t.Fatal(err)
- }
-
- // At this point we should get exactly one message. We're sending the sync position
- // that we were given from the update and the send-to-device update will be updated
- // in the database to reflect that this was the sync position we sent the message at.
- _, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos})
- if err != nil {
- t.Fatal(err)
- }
- if len(events) != 1 || len(updates) != 1 || len(deletions) != 0 {
- t.Fatal("second call should have one update")
- }
- err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.StreamingToken{SendToDevicePosition: streamPos})
- if err != nil {
- return
- }
-
- // At this point we should still have one message because we haven't progressed the
- // sync position yet. This is equivalent to the client failing to /sync and retrying
- // with the same position.
- _, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos})
- if err != nil {
- t.Fatal(err)
- }
- if len(events) != 1 || len(updates) != 0 || len(deletions) != 0 {
- t.Fatal("third call should have one update still")
- }
- err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.StreamingToken{SendToDevicePosition: streamPos})
- if err != nil {
- return
- }
-
- // At this point we should now have no updates, because we've progressed the sync
- // position. Therefore the update from before will not be sent again.
- _, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos + 1})
- if err != nil {
- t.Fatal(err)
- }
- if len(events) != 0 || len(updates) != 0 || len(deletions) != 1 {
- t.Fatal("fourth call should have no updates")
- }
- err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.StreamingToken{SendToDevicePosition: streamPos + 1})
- if err != nil {
- return
- }
-
- // At this point we should still have no updates, because no new updates have been
- // sent.
- _, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos + 2})
- if err != nil {
- t.Fatal(err)
- }
- if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 {
- t.Fatal("fifth call should have no updates")
- }
}
+/*
func TestInviteBehaviour(t *testing.T) {
db := MustCreateDatabase(t)
inviteRoom1 := "!inviteRoom1:somewhere"
diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go
index ccdebfdbd..468d26aca 100644
--- a/syncapi/storage/tables/interface.go
+++ b/syncapi/storage/tables/interface.go
@@ -18,10 +18,11 @@ import (
"context"
"database/sql"
+ "github.com/matrix-org/gomatrixserverlib"
+
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/types"
- "github.com/matrix-org/gomatrixserverlib"
)
type AccountData interface {
@@ -52,7 +53,14 @@ type Peeks interface {
type Events interface {
SelectStateInRange(ctx context.Context, txn *sql.Tx, r types.Range, stateFilter *gomatrixserverlib.StateFilter, roomIDs []string) (map[string]map[string]bool, map[string]types.StreamEvent, error)
SelectMaxEventID(ctx context.Context, txn *sql.Tx) (id int64, err error)
- InsertEvent(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, addState, removeState []string, transactionID *api.TransactionID, excludeFromSync bool) (streamPos types.StreamPosition, err error)
+ InsertEvent(
+ ctx context.Context, txn *sql.Tx,
+ event *gomatrixserverlib.HeaderedEvent,
+ addState, removeState []string,
+ transactionID *api.TransactionID,
+ excludeFromSync bool,
+ historyVisibility gomatrixserverlib.HistoryVisibility,
+ ) (streamPos types.StreamPosition, err error)
// SelectRecentEvents returns events between the two stream positions: exclusive of low and inclusive of high.
// If onlySyncEvents has a value of true, only returns the events that aren't marked as to exclude from sync.
// Returns up to `limit` events. Returns `limited=true` if there are more events in this range but we hit the `limit`.
@@ -104,6 +112,8 @@ type CurrentRoomState interface {
SelectJoinedUsers(ctx context.Context) (map[string][]string, error)
// SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room.
SelectJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error)
+ // SelectSharedUsers returns a subset of otherUserIDs that share a room with userID.
+ SelectSharedUsers(ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string) ([]string, error)
}
// BackwardsExtremities keeps track of backwards extremities for a room.
@@ -113,12 +123,14 @@ type CurrentRoomState interface {
//
// We persist the previous event IDs as well, one per row, so when we do fetch even
// earlier events we can simply delete rows which referenced it. Consider the graph:
-// A
-// | Event C has 1 prev_event ID: A.
-// B C
-// |___| Event D has 2 prev_event IDs: B and C.
-// |
-// D
+//
+// A
+// | Event C has 1 prev_event ID: A.
+// B C
+// |___| Event D has 2 prev_event IDs: B and C.
+// |
+// D
+//
// The earliest known event we have is D, so this table has 2 rows.
// A backfill request gives us C but not B. We delete rows where prev_event=C. This
// still means that D is a backwards extremity as we do not have event B. However, event
@@ -173,6 +185,7 @@ type Memberships interface {
UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error
SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error)
SelectHeroes(ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string) (heroes []string, err error)
+ SelectMembershipForUser(ctx context.Context, txn *sql.Tx, roomID, userID string, pos int64) (membership string, topologicalPos int, err error)
}
type NotificationData interface {
diff --git a/syncapi/storage/tables/output_room_events_test.go b/syncapi/storage/tables/output_room_events_test.go
index 69bbd04c9..bdb17ae20 100644
--- a/syncapi/storage/tables/output_room_events_test.go
+++ b/syncapi/storage/tables/output_room_events_test.go
@@ -53,7 +53,7 @@ func TestOutputRoomEventsTable(t *testing.T) {
events := room.Events()
err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error {
for _, ev := range events {
- _, err := tab.InsertEvent(ctx, txn, ev, nil, nil, nil, false)
+ _, err := tab.InsertEvent(ctx, txn, ev, nil, nil, nil, false, gomatrixserverlib.HistoryVisibilityShared)
if err != nil {
return fmt.Errorf("failed to InsertEvent: %s", err)
}
@@ -79,7 +79,7 @@ func TestOutputRoomEventsTable(t *testing.T) {
"body": "test.txt",
"url": "mxc://test.txt",
})
- if _, err = tab.InsertEvent(ctx, txn, urlEv, nil, nil, nil, false); err != nil {
+ if _, err = tab.InsertEvent(ctx, txn, urlEv, nil, nil, nil, false, gomatrixserverlib.HistoryVisibilityShared); err != nil {
return fmt.Errorf("failed to InsertEvent: %s", err)
}
wantEventID := []string{urlEv.EventID()}
diff --git a/syncapi/streams/stream_devicelist.go b/syncapi/streams/stream_devicelist.go
index f42099510..5448ee5bd 100644
--- a/syncapi/streams/stream_devicelist.go
+++ b/syncapi/streams/stream_devicelist.go
@@ -28,7 +28,7 @@ func (p *DeviceListStreamProvider) IncrementalSync(
from, to types.StreamPosition,
) types.StreamPosition {
var err error
- to, _, err = internal.DeviceListCatchup(context.Background(), p.keyAPI, p.rsAPI, req.Device.UserID, req.Response, from, to)
+ to, _, err = internal.DeviceListCatchup(context.Background(), p.DB, p.keyAPI, p.rsAPI, req.Device.UserID, req.Response, from, to)
if err != nil {
req.Log.WithError(err).Error("internal.DeviceListCatchup failed")
return from
diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go
index 00b3dfe3b..136cbea5a 100644
--- a/syncapi/streams/stream_pdu.go
+++ b/syncapi/streams/stream_pdu.go
@@ -10,11 +10,17 @@ import (
"github.com/matrix-org/dendrite/internal/caching"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/syncapi/internal"
+ "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
+
"github.com/matrix-org/gomatrixserverlib"
+ "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"go.uber.org/atomic"
+
+ "github.com/matrix-org/dendrite/syncapi/notifier"
)
// The max number of per-room goroutines to have running.
@@ -34,6 +40,7 @@ type PDUStreamProvider struct {
// userID+deviceID -> lazy loading cache
lazyLoadCache caching.LazyLoadCache
rsAPI roomserverAPI.SyncRoomserverAPI
+ notifier *notifier.Notifier
}
func (p *PDUStreamProvider) worker() {
@@ -100,6 +107,15 @@ func (p *PDUStreamProvider) CompleteSync(
req.Log.WithError(err).Error("unable to update event filter with ignored users")
}
+ // Invalidate the lazyLoadCache, otherwise we end up with missing displaynames/avatars
+ // TODO: This might be inefficient, when joined to many and/or large rooms.
+ for _, roomID := range joinedRoomIDs {
+ joinedUsers := p.notifier.JoinedUsers(roomID)
+ for _, sharedUser := range joinedUsers {
+ p.lazyLoadCache.InvalidateLazyLoadedUser(req.Device, roomID, sharedUser)
+ }
+ }
+
// Build up a /sync response. Add joined rooms.
var reqMutex sync.Mutex
var reqWaitGroup sync.WaitGroup
@@ -109,12 +125,11 @@ func (p *PDUStreamProvider) CompleteSync(
p.queue(func() {
defer reqWaitGroup.Done()
- var jr *types.JoinResponse
- jr, err = p.getJoinResponseForCompleteSync(
- ctx, roomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device,
+ jr, jerr := p.getJoinResponseForCompleteSync(
+ ctx, roomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, false,
)
- if err != nil {
- req.Log.WithError(err).Error("p.getJoinResponseForCompleteSync failed")
+ if jerr != nil {
+ req.Log.WithError(jerr).Error("p.getJoinResponseForCompleteSync failed")
return
}
@@ -137,7 +152,7 @@ func (p *PDUStreamProvider) CompleteSync(
if !peek.Deleted {
var jr *types.JoinResponse
jr, err = p.getJoinResponseForCompleteSync(
- ctx, peek.RoomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device,
+ ctx, peek.RoomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, true,
)
if err != nil {
req.Log.WithError(err).Error("p.getJoinResponseForCompleteSync failed")
@@ -262,19 +277,13 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
var pos types.StreamPosition
if _, pos, err = p.DB.PositionInTopology(ctx, mostRecentEventID); err == nil {
switch {
- case r.Backwards && pos > latestPosition:
+ case r.Backwards && pos < latestPosition:
fallthrough
- case !r.Backwards && pos < latestPosition:
+ case !r.Backwards && pos > latestPosition:
latestPosition = pos
}
}
}
- if len(recentEvents) > 0 {
- updateLatestPosition(recentEvents[len(recentEvents)-1].EventID())
- }
- if len(delta.StateEvents) > 0 {
- updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID())
- }
if stateFilter.LazyLoadMembers {
delta.StateEvents, err = p.lazyLoadMembers(
@@ -294,6 +303,19 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
}
}
+ // Applies the history visibility rules
+ events, err := applyHistoryVisibilityFilter(ctx, p.DB, p.rsAPI, delta.RoomID, device.UserID, eventFilter.Limit, recentEvents)
+ if err != nil {
+ logrus.WithError(err).Error("unable to apply history visibility filter")
+ }
+
+ if len(events) > 0 {
+ updateLatestPosition(events[len(events)-1].EventID())
+ }
+ if len(delta.StateEvents) > 0 {
+ updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID())
+ }
+
switch delta.Membership {
case gomatrixserverlib.Join:
jr := types.NewJoinResponse()
@@ -301,14 +323,17 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
p.addRoomSummary(ctx, jr, delta.RoomID, device.UserID, latestPosition)
}
jr.Timeline.PrevBatch = &prevBatch
- jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
- jr.Timeline.Limited = limited
+ jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync)
+ // If we are limited by the filter AND the history visibility filter
+ // didn't "remove" events, return that the response is limited.
+ jr.Timeline.Limited = limited && len(events) == len(recentEvents)
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync)
res.Rooms.Join[delta.RoomID] = *jr
case gomatrixserverlib.Peek:
jr := types.NewJoinResponse()
jr.Timeline.PrevBatch = &prevBatch
+ // TODO: Apply history visibility on peeked rooms
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = limited
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync)
@@ -318,12 +343,12 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
fallthrough // transitions to leave are the same as ban
case gomatrixserverlib.Ban:
- // TODO: recentEvents may contain events that this user is not allowed to see because they are
- // no longer in the room.
lr := types.NewLeaveResponse()
lr.Timeline.PrevBatch = &prevBatch
- lr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
- lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true
+ lr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync)
+ // If we are limited by the filter AND the history visibility filter
+ // didn't "remove" events, return that the response is limited.
+ lr.Timeline.Limited = limited && len(events) == len(recentEvents)
lr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync)
res.Rooms.Leave[delta.RoomID] = *lr
}
@@ -331,6 +356,41 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
return latestPosition, nil
}
+// applyHistoryVisibilityFilter gets the current room state and supplies it to ApplyHistoryVisibilityFilter, to make
+// sure we always return the required events in the timeline.
+func applyHistoryVisibilityFilter(
+ ctx context.Context,
+ db storage.Database,
+ rsAPI roomserverAPI.SyncRoomserverAPI,
+ roomID, userID string,
+ limit int,
+ recentEvents []*gomatrixserverlib.HeaderedEvent,
+) ([]*gomatrixserverlib.HeaderedEvent, error) {
+ // We need to make sure we always include the latest states events, if they are in the timeline.
+ // We grep at least limit * 2 events, to ensure we really get the needed events.
+ stateEvents, err := db.CurrentState(ctx, roomID, &gomatrixserverlib.StateFilter{Limit: limit * 2}, nil)
+ if err != nil {
+ // Not a fatal error, we can continue without the stateEvents,
+ // they are only needed if there are state events in the timeline.
+ logrus.WithError(err).Warnf("failed to get current room state")
+ }
+ alwaysIncludeIDs := make(map[string]struct{}, len(stateEvents))
+ for _, ev := range stateEvents {
+ alwaysIncludeIDs[ev.EventID()] = struct{}{}
+ }
+ startTime := time.Now()
+ events, err := internal.ApplyHistoryVisibilityFilter(ctx, db, rsAPI, recentEvents, alwaysIncludeIDs, userID, "sync")
+ if err != nil {
+
+ return nil, err
+ }
+ logrus.WithFields(logrus.Fields{
+ "duration": time.Since(startTime),
+ "room_id": roomID,
+ }).Debug("applied history visibility (sync)")
+ return events, nil
+}
+
func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, jr *types.JoinResponse, roomID, userID string, latestPosition types.StreamPosition) {
// Work out how many members are in the room.
joinedCount, _ := p.DB.MembershipCount(ctx, roomID, gomatrixserverlib.Join, latestPosition)
@@ -378,6 +438,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
eventFilter *gomatrixserverlib.RoomEventFilter,
wantFullState bool,
device *userapi.Device,
+ isPeek bool,
) (jr *types.JoinResponse, err error) {
jr = types.NewJoinResponse()
// TODO: When filters are added, we may need to call this multiple times to get enough events.
@@ -392,33 +453,6 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
return
}
- // TODO FIXME: We don't fully implement history visibility yet. To avoid leaking events which the
- // user shouldn't see, we check the recent events and remove any prior to the join event of the user
- // which is equiv to history_visibility: joined
- joinEventIndex := -1
- for i := len(recentStreamEvents) - 1; i >= 0; i-- {
- ev := recentStreamEvents[i]
- if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(device.UserID) {
- membership, _ := ev.Membership()
- if membership == "join" {
- joinEventIndex = i
- if i > 0 {
- // the create event happens before the first join, so we should cut it at that point instead
- if recentStreamEvents[i-1].Type() == gomatrixserverlib.MRoomCreate && recentStreamEvents[i-1].StateKeyEquals("") {
- joinEventIndex = i - 1
- break
- }
- }
- break
- }
- }
- }
- if joinEventIndex != -1 {
- // cut all events earlier than the join (but not the join itself)
- recentStreamEvents = recentStreamEvents[joinEventIndex:]
- limited = false // so clients know not to try to backpaginate
- }
-
// Work our way through the timeline events and pick out the event IDs
// of any state events that appear in the timeline. We'll specifically
// exclude them at the next step, so that we don't get duplicate state
@@ -462,6 +496,19 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents)
stateEvents = removeDuplicates(stateEvents, recentEvents)
+ events := recentEvents
+ // Only apply history visibility checks if the response is for joined rooms
+ if !isPeek {
+ events, err = applyHistoryVisibilityFilter(ctx, p.DB, p.rsAPI, roomID, device.UserID, eventFilter.Limit, recentEvents)
+ if err != nil {
+ logrus.WithError(err).Error("unable to apply history visibility filter")
+ }
+ }
+
+ // If we are limited by the filter AND the history visibility filter
+ // didn't "remove" events, return that the response is limited.
+ limited = limited && len(events) == len(recentEvents)
+
if stateFilter.LazyLoadMembers {
if err != nil {
return nil, err
@@ -476,8 +523,10 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
}
jr.Timeline.PrevBatch = prevBatch
- jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
- jr.Timeline.Limited = limited
+ jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync)
+ // If we are limited by the filter AND the history visibility filter
+ // didn't "remove" events, return that the response is limited.
+ jr.Timeline.Limited = limited && len(events) == len(recentEvents)
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(stateEvents, gomatrixserverlib.FormatSync)
return jr, nil
}
diff --git a/syncapi/streams/stream_sendtodevice.go b/syncapi/streams/stream_sendtodevice.go
index 6a18df506..31c6187cb 100644
--- a/syncapi/streams/stream_sendtodevice.go
+++ b/syncapi/streams/stream_sendtodevice.go
@@ -39,21 +39,13 @@ func (p *SendToDeviceStreamProvider) IncrementalSync(
return from
}
- if len(events) > 0 {
- // Clean up old send-to-device messages from before this stream position.
- if err := p.DB.CleanSendToDeviceUpdates(req.Context, req.Device.UserID, req.Device.ID, from); err != nil {
- req.Log.WithError(err).Error("p.DB.CleanSendToDeviceUpdates failed")
- return from
- }
-
- // Add the updates into the sync response.
- for _, event := range events {
- // skip ignored user events
- if _, ok := req.IgnoredUsers.List[event.Sender]; ok {
- continue
- }
- req.Response.ToDevice.Events = append(req.Response.ToDevice.Events, event.SendToDeviceEvent)
+ // Add the updates into the sync response.
+ for _, event := range events {
+ // skip ignored user events
+ if _, ok := req.IgnoredUsers.List[event.Sender]; ok {
+ continue
}
+ req.Response.ToDevice.Events = append(req.Response.ToDevice.Events, event.SendToDeviceEvent)
}
return lastPos
diff --git a/syncapi/streams/streams.go b/syncapi/streams/streams.go
index 1ca4ee8c3..dbc053bd8 100644
--- a/syncapi/streams/streams.go
+++ b/syncapi/streams/streams.go
@@ -34,6 +34,7 @@ func NewSyncStreamProviders(
StreamProvider: StreamProvider{DB: d},
lazyLoadCache: lazyLoadCache,
rsAPI: rsAPI,
+ notifier: notifier,
},
TypingStreamProvider: &TypingStreamProvider{
StreamProvider: StreamProvider{DB: d},
diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go
index 6f0849e08..d908a9629 100644
--- a/syncapi/sync/requestpool.go
+++ b/syncapi/sync/requestpool.go
@@ -25,6 +25,11 @@ import (
"sync"
"time"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/util"
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/sirupsen/logrus"
+
"github.com/matrix-org/dendrite/clientapi/jsonerror"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
@@ -35,10 +40,6 @@ import (
"github.com/matrix-org/dendrite/syncapi/streams"
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/gomatrixserverlib"
- "github.com/matrix-org/util"
- "github.com/prometheus/client_golang/prometheus"
- "github.com/sirupsen/logrus"
)
// RequestPool manages HTTP long-poll connections for /sync
@@ -251,6 +252,12 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
waitingSyncRequests.Inc()
defer waitingSyncRequests.Dec()
+ // Clean up old send-to-device messages from before this stream position.
+ // This is needed to avoid sending the same message multiple times
+ if err = rp.db.CleanSendToDeviceUpdates(syncReq.Context, syncReq.Device.UserID, syncReq.Device.ID, syncReq.Since.SendToDevicePosition); err != nil {
+ syncReq.Log.WithError(err).Error("p.DB.CleanSendToDeviceUpdates failed")
+ }
+
// loop until we get some data
for {
startTime := time.Now()
@@ -429,7 +436,7 @@ func (rp *RequestPool) OnIncomingKeyChangeRequest(req *http.Request, device *use
}
rp.streams.PDUStreamProvider.IncrementalSync(req.Context(), syncReq, fromToken.PDUPosition, toToken.PDUPosition)
_, _, err = internal.DeviceListCatchup(
- req.Context(), rp.keyAPI, rp.rsAPI, syncReq.Device.UserID,
+ req.Context(), rp.db, rp.keyAPI, rp.rsAPI, syncReq.Device.UserID,
syncReq.Response, fromToken.DeviceListPosition, toToken.DeviceListPosition,
)
if err != nil {
diff --git a/syncapi/sync/requestpool_test.go b/syncapi/sync/requestpool_test.go
index 48e6c6c7a..3e5769d8c 100644
--- a/syncapi/sync/requestpool_test.go
+++ b/syncapi/sync/requestpool_test.go
@@ -12,10 +12,13 @@ import (
)
type dummyPublisher struct {
+ lock sync.Mutex
count int
}
func (d *dummyPublisher) SendPresence(userID string, presence types.Presence, statusMsg *string) error {
+ d.lock.Lock()
+ defer d.lock.Unlock()
d.count++
return nil
}
@@ -125,11 +128,15 @@ func TestRequestPool_updatePresence(t *testing.T) {
go rp.cleanPresence(db, time.Millisecond*50)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
+ publisher.lock.Lock()
beforeCount := publisher.count
+ publisher.lock.Unlock()
rp.updatePresence(db, tt.args.presence, tt.args.userID)
+ publisher.lock.Lock()
if tt.wantIncrease && publisher.count <= beforeCount {
t.Fatalf("expected count to increase: %d <= %d", publisher.count, beforeCount)
}
+ publisher.lock.Unlock()
time.Sleep(tt.args.sleep)
})
}
diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go
index 3ce7c64b7..dc073a16e 100644
--- a/syncapi/syncapi_test.go
+++ b/syncapi/syncapi_test.go
@@ -3,12 +3,16 @@ package syncapi
import (
"context"
"encoding/json"
+ "fmt"
"net/http"
"net/http/httptest"
+ "reflect"
"testing"
"time"
+ "github.com/matrix-org/dendrite/clientapi/producers"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
+ "github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/roomserver/api"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base"
@@ -51,6 +55,16 @@ func (s *syncRoomserverAPI) QueryBulkStateContent(ctx context.Context, req *rsap
return nil
}
+func (s *syncRoomserverAPI) QueryMembershipForUser(ctx context.Context, req *rsapi.QueryMembershipForUserRequest, res *rsapi.QueryMembershipForUserResponse) error {
+ res.IsRoomForgotten = false
+ res.RoomExists = true
+ return nil
+}
+
+func (s *syncRoomserverAPI) QueryMembershipAtEvent(ctx context.Context, req *rsapi.QueryMembershipAtEventRequest, res *rsapi.QueryMembershipAtEventResponse) error {
+ return nil
+}
+
type syncUserAPI struct {
userapi.SyncUserAPI
accounts []userapi.Device
@@ -75,10 +89,11 @@ type syncKeyAPI struct {
keyapi.SyncKeyAPI
}
-func (s *syncKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) {
+func (s *syncKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) error {
+ return nil
}
-func (s *syncKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) {
-
+func (s *syncKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) error {
+ return nil
}
func TestSyncAPIAccessTokens(t *testing.T) {
@@ -103,7 +118,7 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) {
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
- msgs := toNATSMsgs(t, base, room.Events())
+ msgs := toNATSMsgs(t, base, room.Events()...)
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{})
testrig.MustPublishMsgs(t, jsctx, msgs...)
@@ -196,7 +211,7 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) {
// m.room.power_levels
// m.room.join_rules
// m.room.history_visibility
- msgs := toNATSMsgs(t, base, room.Events())
+ msgs := toNATSMsgs(t, base, room.Events()...)
sinceTokens := make([]string, len(msgs))
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{})
for i, msg := range msgs {
@@ -311,7 +326,308 @@ func testSyncAPIUpdatePresenceImmediately(t *testing.T, dbType test.DBType) {
}
-func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input []*gomatrixserverlib.HeaderedEvent) []*nats.Msg {
+// This is mainly what Sytest is doing in "test_history_visibility"
+func TestMessageHistoryVisibility(t *testing.T) {
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ testHistoryVisibility(t, dbType)
+ })
+}
+
+func testHistoryVisibility(t *testing.T, dbType test.DBType) {
+ type result struct {
+ seeWithoutJoin bool
+ seeBeforeJoin bool
+ seeAfterInvite bool
+ }
+
+ // create the users
+ alice := test.NewUser(t)
+ bob := test.NewUser(t)
+
+ bobDev := userapi.Device{
+ ID: "BOBID",
+ UserID: bob.ID,
+ AccessToken: "BOD_BEARER_TOKEN",
+ DisplayName: "BOB",
+ }
+
+ ctx := context.Background()
+ // check guest and normal user accounts
+ for _, accType := range []userapi.AccountType{userapi.AccountTypeGuest, userapi.AccountTypeUser} {
+ testCases := []struct {
+ historyVisibility gomatrixserverlib.HistoryVisibility
+ wantResult result
+ }{
+ {
+ historyVisibility: gomatrixserverlib.HistoryVisibilityWorldReadable,
+ wantResult: result{
+ seeWithoutJoin: true,
+ seeBeforeJoin: true,
+ seeAfterInvite: true,
+ },
+ },
+ {
+ historyVisibility: gomatrixserverlib.HistoryVisibilityShared,
+ wantResult: result{
+ seeWithoutJoin: false,
+ seeBeforeJoin: true,
+ seeAfterInvite: true,
+ },
+ },
+ {
+ historyVisibility: gomatrixserverlib.HistoryVisibilityInvited,
+ wantResult: result{
+ seeWithoutJoin: false,
+ seeBeforeJoin: false,
+ seeAfterInvite: true,
+ },
+ },
+ {
+ historyVisibility: gomatrixserverlib.HistoryVisibilityJoined,
+ wantResult: result{
+ seeWithoutJoin: false,
+ seeBeforeJoin: false,
+ seeAfterInvite: false,
+ },
+ },
+ }
+
+ bobDev.AccountType = accType
+ userType := "guest"
+ if accType == userapi.AccountTypeUser {
+ userType = "real user"
+ }
+
+ base, close := testrig.CreateBaseDendrite(t, dbType)
+ defer close()
+
+ jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
+ defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
+
+ // Use the actual internal roomserver API
+ rsAPI := roomserver.NewInternalAPI(base)
+ rsAPI.SetFederationAPI(nil, nil)
+
+ AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{bobDev}}, rsAPI, &syncKeyAPI{})
+
+ for _, tc := range testCases {
+ testname := fmt.Sprintf("%s - %s", tc.historyVisibility, userType)
+ t.Run(testname, func(t *testing.T) {
+ // create a room with the given visibility
+ room := test.NewRoom(t, alice, test.RoomHistoryVisibility(tc.historyVisibility))
+
+ // send the events/messages to NATS to create the rooms
+ beforeJoinEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("Before invite in a %s room", tc.historyVisibility)})
+ eventsToSend := append(room.Events(), beforeJoinEv)
+ if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", nil, false); err != nil {
+ t.Fatalf("failed to send events: %v", err)
+ }
+
+ // There is only one event, we expect only to be able to see this, if the room is world_readable
+ w := httptest.NewRecorder()
+ base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/messages", room.ID), test.WithQueryParams(map[string]string{
+ "access_token": bobDev.AccessToken,
+ "dir": "b",
+ })))
+ if w.Code != 200 {
+ t.Logf("%s", w.Body.String())
+ t.Fatalf("got HTTP %d want %d", w.Code, 200)
+ }
+ // We only care about the returned events at this point
+ var res struct {
+ Chunk []gomatrixserverlib.ClientEvent `json:"chunk"`
+ }
+ if err := json.NewDecoder(w.Body).Decode(&res); err != nil {
+ t.Errorf("failed to decode response body: %s", err)
+ }
+
+ verifyEventVisible(t, tc.wantResult.seeWithoutJoin, beforeJoinEv, res.Chunk)
+
+ // Create invite, a message, join the room and create another message.
+ inviteEv := room.CreateAndInsert(t, alice, "m.room.member", map[string]interface{}{"membership": "invite"}, test.WithStateKey(bob.ID))
+ afterInviteEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("After invite in a %s room", tc.historyVisibility)})
+ joinEv := room.CreateAndInsert(t, bob, "m.room.member", map[string]interface{}{"membership": "join"}, test.WithStateKey(bob.ID))
+ msgEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("After join in a %s room", tc.historyVisibility)})
+
+ eventsToSend = append([]*gomatrixserverlib.HeaderedEvent{}, inviteEv, afterInviteEv, joinEv, msgEv)
+
+ if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", nil, false); err != nil {
+ t.Fatalf("failed to send events: %v", err)
+ }
+
+ // Verify the messages after/before invite are visible or not
+ w = httptest.NewRecorder()
+ base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/messages", room.ID), test.WithQueryParams(map[string]string{
+ "access_token": bobDev.AccessToken,
+ "dir": "b",
+ })))
+ if w.Code != 200 {
+ t.Logf("%s", w.Body.String())
+ t.Fatalf("got HTTP %d want %d", w.Code, 200)
+ }
+ if err := json.NewDecoder(w.Body).Decode(&res); err != nil {
+ t.Errorf("failed to decode response body: %s", err)
+ }
+ // verify results
+ verifyEventVisible(t, tc.wantResult.seeBeforeJoin, beforeJoinEv, res.Chunk)
+ verifyEventVisible(t, tc.wantResult.seeAfterInvite, afterInviteEv, res.Chunk)
+ })
+ }
+ }
+}
+
+func verifyEventVisible(t *testing.T, wantVisible bool, wantVisibleEvent *gomatrixserverlib.HeaderedEvent, chunk []gomatrixserverlib.ClientEvent) {
+ t.Helper()
+ if wantVisible {
+ for _, ev := range chunk {
+ if ev.EventID == wantVisibleEvent.EventID() {
+ return
+ }
+ }
+ t.Fatalf("expected to see event %s but didn't: %+v", wantVisibleEvent.EventID(), chunk)
+ } else {
+ for _, ev := range chunk {
+ if ev.EventID == wantVisibleEvent.EventID() {
+ t.Fatalf("expected not to see event %s: %+v", wantVisibleEvent.EventID(), string(ev.Content))
+ }
+ }
+ }
+}
+
+func TestSendToDevice(t *testing.T) {
+ test.WithAllDatabases(t, testSendToDevice)
+}
+
+func testSendToDevice(t *testing.T, dbType test.DBType) {
+ user := test.NewUser(t)
+ alice := userapi.Device{
+ ID: "ALICEID",
+ UserID: user.ID,
+ AccessToken: "ALICE_BEARER_TOKEN",
+ DisplayName: "Alice",
+ AccountType: userapi.AccountTypeUser,
+ }
+
+ base, close := testrig.CreateBaseDendrite(t, dbType)
+ defer close()
+
+ jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
+ defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
+
+ AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, &syncKeyAPI{})
+
+ producer := producers.SyncAPIProducer{
+ TopicSendToDeviceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent),
+ JetStream: jsctx,
+ }
+
+ msgCounter := 0
+
+ testCases := []struct {
+ name string
+ since string
+ want []string
+ sendMessagesCount int
+ }{
+ {
+ name: "initial sync, no messages",
+ want: []string{},
+ },
+ {
+ name: "initial sync, one new message",
+ sendMessagesCount: 1,
+ want: []string{
+ "message 1",
+ },
+ },
+ {
+ name: "initial sync, two new messages", // we didn't advance the since token, so we'll receive two messages
+ sendMessagesCount: 1,
+ want: []string{
+ "message 1",
+ "message 2",
+ },
+ },
+ {
+ name: "incremental sync, one message", // this deletes message 1, as we advanced the since token
+ since: types.StreamingToken{SendToDevicePosition: 1}.String(),
+ want: []string{
+ "message 2",
+ },
+ },
+ {
+ name: "failed incremental sync, one message", // didn't advance since, so still the same message
+ since: types.StreamingToken{SendToDevicePosition: 1}.String(),
+ want: []string{
+ "message 2",
+ },
+ },
+ {
+ name: "incremental sync, no message", // this should delete message 2
+ since: types.StreamingToken{SendToDevicePosition: 2}.String(), // next_batch from previous sync
+ want: []string{},
+ },
+ {
+ name: "incremental sync, three new messages",
+ since: types.StreamingToken{SendToDevicePosition: 2}.String(),
+ sendMessagesCount: 3,
+ want: []string{
+ "message 3", // message 2 was deleted in the previous test
+ "message 4",
+ "message 5",
+ },
+ },
+ {
+ name: "initial sync, three messages", // we expect three messages, as we didn't go beyond "2"
+ want: []string{
+ "message 3",
+ "message 4",
+ "message 5",
+ },
+ },
+ {
+ name: "incremental sync, no messages", // advance the sync token, no new messages
+ since: types.StreamingToken{SendToDevicePosition: 5}.String(),
+ want: []string{},
+ },
+ }
+
+ ctx := context.Background()
+ for _, tc := range testCases {
+ // Send to-device messages of type "m.dendrite.test" with content `{"dummy":"message $counter"}`
+ for i := 0; i < tc.sendMessagesCount; i++ {
+ msgCounter++
+ msg := map[string]string{
+ "dummy": fmt.Sprintf("message %d", msgCounter),
+ }
+ if err := producer.SendToDevice(ctx, user.ID, user.ID, alice.ID, "m.dendrite.test", msg); err != nil {
+ t.Fatalf("unable to send to device message: %v", err)
+ }
+ }
+ time.Sleep((time.Millisecond * 15) * time.Duration(tc.sendMessagesCount)) // wait a bit, so the messages can be processed
+ // Execute a /sync request, recording the response
+ w := httptest.NewRecorder()
+ base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{
+ "access_token": alice.AccessToken,
+ "since": tc.since,
+ })))
+
+ // Extract the to_device.events, # gets all values of an array, in this case a string slice with "message $counter" entries
+ events := gjson.Get(w.Body.String(), "to_device.events.#.content.dummy").Array()
+ got := make([]string, len(events))
+ for i := range events {
+ got[i] = events[i].String()
+ }
+
+ // Ensure the messages we received are as we expect them to be
+ if !reflect.DeepEqual(got, tc.want) {
+ t.Logf("[%s|since=%s]: Sync: %s", tc.name, tc.since, w.Body.String())
+ t.Fatalf("[%s|since=%s]: got: %+v, want: %+v", tc.name, tc.since, got, tc.want)
+ }
+ }
+}
+
+func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input ...*gomatrixserverlib.HeaderedEvent) []*nats.Msg {
result := make([]*nats.Msg, len(input))
for i, ev := range input {
var addsStateIDs []string
@@ -323,6 +639,7 @@ func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input []*gomatrixserverli
NewRoomEvent: &rsapi.OutputNewRoomEvent{
Event: ev,
AddsStateEventIDs: addsStateIDs,
+ HistoryVisibility: ev.Visibility,
},
})
}
diff --git a/syncapi/types/types.go b/syncapi/types/types.go
index 159fa08b6..39b085d9c 100644
--- a/syncapi/types/types.go
+++ b/syncapi/types/types.go
@@ -21,9 +21,10 @@ import (
"strconv"
"strings"
- "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/tidwall/gjson"
+
+ "github.com/matrix-org/dendrite/roomserver/api"
)
var (
@@ -330,23 +331,23 @@ type Response struct {
NextBatch StreamingToken `json:"next_batch"`
AccountData struct {
Events []gomatrixserverlib.ClientEvent `json:"events,omitempty"`
- } `json:"account_data"`
+ } `json:"account_data,omitempty"`
Presence struct {
Events []gomatrixserverlib.ClientEvent `json:"events,omitempty"`
- } `json:"presence"`
+ } `json:"presence,omitempty"`
Rooms struct {
- Join map[string]JoinResponse `json:"join"`
- Peek map[string]JoinResponse `json:"peek"`
- Invite map[string]InviteResponse `json:"invite"`
- Leave map[string]LeaveResponse `json:"leave"`
- } `json:"rooms"`
+ Join map[string]JoinResponse `json:"join,omitempty"`
+ Peek map[string]JoinResponse `json:"peek,omitempty"`
+ Invite map[string]InviteResponse `json:"invite,omitempty"`
+ Leave map[string]LeaveResponse `json:"leave,omitempty"`
+ } `json:"rooms,omitempty"`
ToDevice struct {
- Events []gomatrixserverlib.SendToDeviceEvent `json:"events"`
- } `json:"to_device"`
+ Events []gomatrixserverlib.SendToDeviceEvent `json:"events,omitempty"`
+ } `json:"to_device,omitempty"`
DeviceLists struct {
Changed []string `json:"changed,omitempty"`
Left []string `json:"left,omitempty"`
- } `json:"device_lists"`
+ } `json:"device_lists,omitempty"`
DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"`
}
diff --git a/sytest-whitelist b/sytest-whitelist
index ea25c75d0..88dfe920b 100644
--- a/sytest-whitelist
+++ b/sytest-whitelist
@@ -110,8 +110,6 @@ Newly joined room is included in an incremental sync
User is offline if they set_presence=offline in their sync
Changes to state are included in an incremental sync
A change to displayname should appear in incremental /sync
-Current state appears in timeline in private history
-Current state appears in timeline in private history with many messages before
Rooms a user is invited to appear in an initial sync
Rooms a user is invited to appear in an incremental sync
Sync can be polled for updates
@@ -458,7 +456,6 @@ After changing password, a different session no longer works by default
Read markers appear in incremental v2 /sync
Read markers appear in initial v2 /sync
Read markers can be updated
-Local users can peek into world_readable rooms by room ID
We can't peek into rooms with shared history_visibility
We can't peek into rooms with invited history_visibility
We can't peek into rooms with joined history_visibility
@@ -719,4 +716,27 @@ registration is idempotent, with username specified
Setting state twice is idempotent
Joining room twice is idempotent
Inbound federation can return missing events for shared visibility
-Inbound federation ignores redactions from invalid servers room > v3
\ No newline at end of file
+Inbound federation ignores redactions from invalid servers room > v3
+Joining room twice is idempotent
+Getting messages going forward is limited for a departed room (SPEC-216)
+m.room.history_visibility == "shared" allows/forbids appropriately for Guest users
+m.room.history_visibility == "invited" allows/forbids appropriately for Guest users
+m.room.history_visibility == "default" allows/forbids appropriately for Guest users
+m.room.history_visibility == "shared" allows/forbids appropriately for Real users
+m.room.history_visibility == "invited" allows/forbids appropriately for Real users
+m.room.history_visibility == "default" allows/forbids appropriately for Real users
+Guest users can sync from world_readable guest_access rooms if joined
+Guest users can sync from shared guest_access rooms if joined
+Guest users can sync from invited guest_access rooms if joined
+Guest users can sync from joined guest_access rooms if joined
+Guest users can sync from default guest_access rooms if joined
+Real users can sync from world_readable guest_access rooms if joined
+Real users can sync from shared guest_access rooms if joined
+Real users can sync from invited guest_access rooms if joined
+Real users can sync from joined guest_access rooms if joined
+Real users can sync from default guest_access rooms if joined
+Only see history_visibility changes on boundaries
+Current state appears in timeline in private history
+Current state appears in timeline in private history with many messages before
+Local users can peek into world_readable rooms by room ID
+Newly joined room includes presence in incremental sync
\ No newline at end of file
diff --git a/test/keys.go b/test/keys.go
index 75e3800e0..327c6ed7b 100644
--- a/test/keys.go
+++ b/test/keys.go
@@ -22,7 +22,6 @@ import (
"encoding/pem"
"errors"
"fmt"
- "io/ioutil"
"math/big"
"os"
"strings"
@@ -144,7 +143,7 @@ func NewTLSKeyWithAuthority(serverName, tlsKeyPath, tlsCertPath, authorityKeyPat
}
// load the authority key
- dat, err := ioutil.ReadFile(authorityKeyPath)
+ dat, err := os.ReadFile(authorityKeyPath)
if err != nil {
return err
}
@@ -158,7 +157,7 @@ func NewTLSKeyWithAuthority(serverName, tlsKeyPath, tlsCertPath, authorityKeyPat
}
// load the authority certificate
- dat, err = ioutil.ReadFile(authorityCertPath)
+ dat, err = os.ReadFile(authorityCertPath)
if err != nil {
return err
}
diff --git a/test/room.go b/test/room.go
index 6ae403b3f..94eb51bbe 100644
--- a/test/room.go
+++ b/test/room.go
@@ -37,10 +37,11 @@ var (
)
type Room struct {
- ID string
- Version gomatrixserverlib.RoomVersion
- preset Preset
- creator *User
+ ID string
+ Version gomatrixserverlib.RoomVersion
+ preset Preset
+ visibility gomatrixserverlib.HistoryVisibility
+ creator *User
authEvents gomatrixserverlib.AuthEvents
currentState map[string]*gomatrixserverlib.HeaderedEvent
@@ -61,6 +62,7 @@ func NewRoom(t *testing.T, creator *User, modifiers ...roomModifier) *Room {
preset: PresetPublicChat,
Version: gomatrixserverlib.RoomVersionV9,
currentState: make(map[string]*gomatrixserverlib.HeaderedEvent),
+ visibility: gomatrixserverlib.HistoryVisibilityShared,
}
for _, m := range modifiers {
m(t, r)
@@ -97,10 +99,14 @@ func (r *Room) insertCreateEvents(t *testing.T) {
fallthrough
case PresetPrivateChat:
joinRule.JoinRule = "invite"
- hisVis.HistoryVisibility = "shared"
+ hisVis.HistoryVisibility = gomatrixserverlib.HistoryVisibilityShared
case PresetPublicChat:
joinRule.JoinRule = "public"
- hisVis.HistoryVisibility = "shared"
+ hisVis.HistoryVisibility = gomatrixserverlib.HistoryVisibilityShared
+ }
+
+ if r.visibility != "" {
+ hisVis.HistoryVisibility = r.visibility
}
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomCreate, map[string]interface{}{
@@ -183,7 +189,9 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten
if err = gomatrixserverlib.Allowed(ev, &r.authEvents); err != nil {
t.Fatalf("CreateEvent[%s]: failed to verify event was allowed: %s", eventType, err)
}
- return ev.Headered(r.Version)
+ headeredEvent := ev.Headered(r.Version)
+ headeredEvent.Visibility = r.visibility
+ return headeredEvent
}
// Add a new event to this room DAG. Not thread-safe.
@@ -242,6 +250,12 @@ func RoomPreset(p Preset) roomModifier {
}
}
+func RoomHistoryVisibility(vis gomatrixserverlib.HistoryVisibility) roomModifier {
+ return func(t *testing.T, r *Room) {
+ r.visibility = vis
+ }
+}
+
func RoomVersion(ver gomatrixserverlib.RoomVersion) roomModifier {
return func(t *testing.T, r *Room) {
r.Version = ver
diff --git a/test/testrig/base.go b/test/testrig/base.go
index facb49f3e..d13c43129 100644
--- a/test/testrig/base.go
+++ b/test/testrig/base.go
@@ -32,11 +32,11 @@ func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, f
var cfg config.Dendrite
cfg.Defaults(false)
cfg.Global.JetStream.InMemory = true
-
switch dbType {
case test.DBTypePostgres:
cfg.Global.Defaults(true) // autogen a signing key
cfg.MediaAPI.Defaults(true) // autogen a media path
+ cfg.Global.ServerName = "test"
// use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use
// the file system event with InMemory=true :(
cfg.Global.JetStream.TopicPrefix = fmt.Sprintf("Test_%d_", dbType)
@@ -50,6 +50,7 @@ func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, f
return base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics), close
case test.DBTypeSQLite:
cfg.Defaults(true) // sets a sqlite db per component
+ cfg.Global.ServerName = "test"
// use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use
// the file system event with InMemory=true :(
cfg.Global.JetStream.TopicPrefix = fmt.Sprintf("Test_%d_", dbType)
diff --git a/test/user.go b/test/user.go
index 0020098a5..692eae351 100644
--- a/test/user.go
+++ b/test/user.go
@@ -20,6 +20,7 @@ import (
"sync/atomic"
"testing"
+ "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
)
@@ -45,7 +46,8 @@ var (
)
type User struct {
- ID string
+ ID string
+ accountType api.AccountType
// key ID and private key of the server who has this user, if known.
keyID gomatrixserverlib.KeyID
privKey ed25519.PrivateKey
@@ -62,6 +64,12 @@ func WithSigningServer(srvName gomatrixserverlib.ServerName, keyID gomatrixserve
}
}
+func WithAccountType(accountType api.AccountType) UserOpt {
+ return func(u *User) {
+ u.accountType = accountType
+ }
+}
+
func NewUser(t *testing.T, opts ...UserOpt) *User {
counter := atomic.AddInt64(&userIDCounter, 1)
var u User
diff --git a/userapi/api/api.go b/userapi/api/api.go
index 63daae914..f6a276707 100644
--- a/userapi/api/api.go
+++ b/userapi/api/api.go
@@ -101,7 +101,7 @@ type ClientUserAPI interface {
QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error
- QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse)
+ QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) error
QueryThreePIDsForLocalpart(ctx context.Context, req *QueryThreePIDsForLocalpartRequest, res *QueryThreePIDsForLocalpartResponse) error
QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error
diff --git a/userapi/api/api_trace.go b/userapi/api/api_trace.go
index 256365ea5..3ce304138 100644
--- a/userapi/api/api_trace.go
+++ b/userapi/api/api_trace.go
@@ -94,9 +94,10 @@ func (t *UserInternalAPITrace) PerformPushRulesPut(ctx context.Context, req *Per
util.GetLogger(ctx).Infof("PerformPushRulesPut req=%+v res=%+v", js(req), js(res))
return err
}
-func (t *UserInternalAPITrace) QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) {
- t.Impl.QueryKeyBackup(ctx, req, res)
+func (t *UserInternalAPITrace) QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) error {
+ err := t.Impl.QueryKeyBackup(ctx, req, res)
util.GetLogger(ctx).Infof("QueryKeyBackup req=%+v res=%+v", js(req), js(res))
+ return err
}
func (t *UserInternalAPITrace) QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error {
err := t.Impl.QueryProfile(ctx, req, res)
diff --git a/userapi/internal/api.go b/userapi/internal/api.go
index fedf2752f..dd0cc78a4 100644
--- a/userapi/internal/api.go
+++ b/userapi/internal/api.go
@@ -30,11 +30,13 @@ import (
"github.com/matrix-org/dendrite/appservice/types"
"github.com/matrix-org/dendrite/clientapi/userutil"
+ "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/dendrite/internal/sqlutil"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
+ synctypes "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/producers"
"github.com/matrix-org/dendrite/userapi/storage"
@@ -64,7 +66,24 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
if req.DataType == "" {
return fmt.Errorf("data type must not be empty")
}
- return a.DB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData)
+ if err := a.DB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData); err != nil {
+ util.GetLogger(ctx).WithError(err).Error("a.DB.SaveAccountData failed")
+ return fmt.Errorf("failed to save account data: %w", err)
+ }
+ var ignoredUsers *synctypes.IgnoredUsers
+ if req.DataType == "m.ignored_user_list" {
+ ignoredUsers = &synctypes.IgnoredUsers{}
+ _ = json.Unmarshal(req.AccountData, ignoredUsers)
+ }
+ if err := a.SyncProducer.SendAccountData(req.UserID, eventutil.AccountData{
+ RoomID: req.RoomID,
+ Type: req.DataType,
+ IgnoredUsers: ignoredUsers,
+ }); err != nil {
+ util.GetLogger(ctx).WithError(err).Error("a.SyncProducer.SendAccountData failed")
+ return fmt.Errorf("failed to send account data to output: %w", err)
+ }
+ return nil
}
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
@@ -93,7 +112,9 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
}
// Inform the SyncAPI about the newly created push_rules
- if err = a.SyncProducer.SendAccountData(acc.UserID, "", "m.push_rules"); err != nil {
+ if err = a.SyncProducer.SendAccountData(acc.UserID, eventutil.AccountData{
+ Type: "m.push_rules",
+ }); err != nil {
util.GetLogger(ctx).WithFields(logrus.Fields{
"user_id": acc.UserID,
}).WithError(err).Warn("failed to send account data to the SyncAPI")
@@ -171,7 +192,9 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
deleteReq.KeyIDs = append(deleteReq.KeyIDs, gomatrixserverlib.KeyID(keyID))
}
deleteRes := &keyapi.PerformDeleteKeysResponse{}
- a.KeyAPI.PerformDeleteKeys(ctx, deleteReq, deleteRes)
+ if err := a.KeyAPI.PerformDeleteKeys(ctx, deleteReq, deleteRes); err != nil {
+ return err
+ }
if err := deleteRes.Error; err != nil {
return fmt.Errorf("a.KeyAPI.PerformDeleteKeys: %w", err)
}
@@ -190,10 +213,12 @@ func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) er
}
var uploadRes keyapi.PerformUploadKeysResponse
- a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{
+ if err := a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{
UserID: userID,
DeviceKeys: deviceKeys,
- }, &uploadRes)
+ }, &uploadRes); err != nil {
+ return err
+ }
if uploadRes.Error != nil {
return fmt.Errorf("failed to delete device keys: %v", uploadRes.Error)
}
@@ -247,7 +272,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
if req.DisplayName != nil && dev.DisplayName != *req.DisplayName {
// display name has changed: update the device key
var uploadRes keyapi.PerformUploadKeysResponse
- a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{
+ if err := a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{
UserID: req.RequestingUserID,
DeviceKeys: []keyapi.DeviceKeys{
{
@@ -258,7 +283,9 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
},
},
OnlyDisplayNameUpdates: true,
- }, &uploadRes)
+ }, &uploadRes); err != nil {
+ return err
+ }
if uploadRes.Error != nil {
return fmt.Errorf("failed to update device key display name: %v", uploadRes.Error)
}
@@ -458,7 +485,9 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName),
}
evacuateRes := &rsapi.PerformAdminEvacuateUserResponse{}
- a.RSAPI.PerformAdminEvacuateUser(ctx, evacuateReq, evacuateRes)
+ if err := a.RSAPI.PerformAdminEvacuateUser(ctx, evacuateReq, evacuateRes); err != nil {
+ return err
+ }
if err := evacuateRes.Error; err != nil {
logrus.WithError(err).Errorf("Failed to evacuate user after account deactivation")
}
@@ -517,9 +546,6 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
if req.Version == "" {
res.BadInput = true
res.Error = "must specify a version to delete"
- if res.Error != "" {
- return fmt.Errorf(res.Error)
- }
return nil
}
exists, err := a.DB.DeleteKeyBackup(ctx, req.UserID, req.Version)
@@ -528,9 +554,6 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
}
res.Exists = exists
res.Version = req.Version
- if res.Error != "" {
- return fmt.Errorf(res.Error)
- }
return nil
}
// Create metadata
@@ -541,9 +564,6 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
}
res.Exists = err == nil
res.Version = version
- if res.Error != "" {
- return fmt.Errorf(res.Error)
- }
return nil
}
// Update metadata
@@ -554,16 +574,10 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
}
res.Exists = err == nil
res.Version = req.Version
- if res.Error != "" {
- return fmt.Errorf(res.Error)
- }
return nil
}
// Upload Keys for a specific version metadata
a.uploadBackupKeys(ctx, req, res)
- if res.Error != "" {
- return fmt.Errorf(res.Error)
- }
return nil
}
@@ -606,16 +620,16 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform
res.KeyETag = etag
}
-func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) {
+func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) error {
version, algorithm, authData, etag, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, req.Version)
res.Version = version
if err != nil {
if err == sql.ErrNoRows {
res.Exists = false
- return
+ return nil
}
res.Error = fmt.Sprintf("failed to query key backup: %s", err)
- return
+ return nil
}
res.Algorithm = algorithm
res.AuthData = authData
@@ -627,15 +641,16 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB
if err != nil {
res.Error = fmt.Sprintf("failed to count keys: %s", err)
}
- return
+ return nil
}
result, err := a.DB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID)
if err != nil {
res.Error = fmt.Sprintf("failed to query keys: %s", err)
- return
+ return nil
}
res.Keys = result
+ return nil
}
func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error {
@@ -732,11 +747,11 @@ func (a *UserInternalAPI) PerformPushRulesPut(
if err := a.InputAccountData(ctx, &userReq, &userRes); err != nil {
return err
}
-
- if err := a.SyncProducer.SendAccountData(req.UserID, "" /* roomID */, pushRulesAccountDataType); err != nil {
+ if err := a.SyncProducer.SendAccountData(req.UserID, eventutil.AccountData{
+ Type: pushRulesAccountDataType,
+ }); err != nil {
util.GetLogger(ctx).WithError(err).Errorf("syncProducer.SendData failed")
}
-
return nil
}
diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go
index 9f11f2353..f86b5a896 100644
--- a/userapi/inthttp/client.go
+++ b/userapi/inthttp/client.go
@@ -21,7 +21,6 @@ import (
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/userapi/api"
- "github.com/opentracing/opentracing-go"
)
// HTTP paths for the internal HTTP APIs
@@ -85,11 +84,10 @@ type httpUserInternalAPI struct {
}
func (h *httpUserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "InputAccountData")
- defer span.Finish()
-
- apiURL := h.apiURL + InputAccountDataPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+ return httputil.CallInternalRPCAPI(
+ "InputAccountData", h.apiURL+InputAccountDataPath,
+ h.httpClient, ctx, req, res,
+ )
}
func (h *httpUserInternalAPI) PerformAccountCreation(
@@ -97,11 +95,10 @@ func (h *httpUserInternalAPI) PerformAccountCreation(
request *api.PerformAccountCreationRequest,
response *api.PerformAccountCreationResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformAccountCreation")
- defer span.Finish()
-
- apiURL := h.apiURL + PerformAccountCreationPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "PerformAccountCreation", h.apiURL+PerformAccountCreationPath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpUserInternalAPI) PerformPasswordUpdate(
@@ -109,11 +106,10 @@ func (h *httpUserInternalAPI) PerformPasswordUpdate(
request *api.PerformPasswordUpdateRequest,
response *api.PerformPasswordUpdateResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPasswordUpdate")
- defer span.Finish()
-
- apiURL := h.apiURL + PerformPasswordUpdatePath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "PerformPasswordUpdate", h.apiURL+PerformPasswordUpdatePath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpUserInternalAPI) PerformDeviceCreation(
@@ -121,11 +117,10 @@ func (h *httpUserInternalAPI) PerformDeviceCreation(
request *api.PerformDeviceCreationRequest,
response *api.PerformDeviceCreationResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDeviceCreation")
- defer span.Finish()
-
- apiURL := h.apiURL + PerformDeviceCreationPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "PerformDeviceCreation", h.apiURL+PerformDeviceCreationPath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpUserInternalAPI) PerformDeviceDeletion(
@@ -133,47 +128,54 @@ func (h *httpUserInternalAPI) PerformDeviceDeletion(
request *api.PerformDeviceDeletionRequest,
response *api.PerformDeviceDeletionResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDeviceDeletion")
- defer span.Finish()
-
- apiURL := h.apiURL + PerformDeviceDeletionPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "PerformDeviceDeletion", h.apiURL+PerformDeviceDeletionPath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpUserInternalAPI) PerformLastSeenUpdate(
ctx context.Context,
- req *api.PerformLastSeenUpdateRequest,
- res *api.PerformLastSeenUpdateResponse,
+ request *api.PerformLastSeenUpdateRequest,
+ response *api.PerformLastSeenUpdateResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLastSeen")
- defer span.Finish()
-
- apiURL := h.apiURL + PerformLastSeenUpdatePath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+ return httputil.CallInternalRPCAPI(
+ "PerformLastSeen", h.apiURL+PerformLastSeenUpdatePath,
+ h.httpClient, ctx, request, response,
+ )
}
-func (h *httpUserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDeviceUpdate")
- defer span.Finish()
-
- apiURL := h.apiURL + PerformDeviceUpdatePath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+func (h *httpUserInternalAPI) PerformDeviceUpdate(
+ ctx context.Context,
+ request *api.PerformDeviceUpdateRequest,
+ response *api.PerformDeviceUpdateResponse,
+) error {
+ return httputil.CallInternalRPCAPI(
+ "PerformDeviceUpdate", h.apiURL+PerformDeviceUpdatePath,
+ h.httpClient, ctx, request, response,
+ )
}
-func (h *httpUserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformAccountDeactivation")
- defer span.Finish()
-
- apiURL := h.apiURL + PerformAccountDeactivationPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+func (h *httpUserInternalAPI) PerformAccountDeactivation(
+ ctx context.Context,
+ request *api.PerformAccountDeactivationRequest,
+ response *api.PerformAccountDeactivationResponse,
+) error {
+ return httputil.CallInternalRPCAPI(
+ "PerformAccountDeactivation", h.apiURL+PerformAccountDeactivationPath,
+ h.httpClient, ctx, request, response,
+ )
}
-func (h *httpUserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, request *api.PerformOpenIDTokenCreationRequest, response *api.PerformOpenIDTokenCreationResponse) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformOpenIDTokenCreation")
- defer span.Finish()
-
- apiURL := h.apiURL + PerformOpenIDTokenCreationPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+func (h *httpUserInternalAPI) PerformOpenIDTokenCreation(
+ ctx context.Context,
+ request *api.PerformOpenIDTokenCreationRequest,
+ response *api.PerformOpenIDTokenCreationResponse,
+) error {
+ return httputil.CallInternalRPCAPI(
+ "PerformOpenIDTokenCreation", h.apiURL+PerformOpenIDTokenCreationPath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpUserInternalAPI) QueryProfile(
@@ -181,11 +183,10 @@ func (h *httpUserInternalAPI) QueryProfile(
request *api.QueryProfileRequest,
response *api.QueryProfileResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryProfile")
- defer span.Finish()
-
- apiURL := h.apiURL + QueryProfilePath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "QueryProfile", h.apiURL+QueryProfilePath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpUserInternalAPI) QueryDeviceInfos(
@@ -193,11 +194,10 @@ func (h *httpUserInternalAPI) QueryDeviceInfos(
request *api.QueryDeviceInfosRequest,
response *api.QueryDeviceInfosResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryDeviceInfos")
- defer span.Finish()
-
- apiURL := h.apiURL + QueryDeviceInfosPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "QueryDeviceInfos", h.apiURL+QueryDeviceInfosPath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpUserInternalAPI) QueryAccessToken(
@@ -205,72 +205,87 @@ func (h *httpUserInternalAPI) QueryAccessToken(
request *api.QueryAccessTokenRequest,
response *api.QueryAccessTokenResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryAccessToken")
- defer span.Finish()
-
- apiURL := h.apiURL + QueryAccessTokenPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "QueryAccessToken", h.apiURL+QueryAccessTokenPath,
+ h.httpClient, ctx, request, response,
+ )
}
-func (h *httpUserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevicesRequest, res *api.QueryDevicesResponse) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryDevices")
- defer span.Finish()
-
- apiURL := h.apiURL + QueryDevicesPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+func (h *httpUserInternalAPI) QueryDevices(
+ ctx context.Context,
+ request *api.QueryDevicesRequest,
+ response *api.QueryDevicesResponse,
+) error {
+ return httputil.CallInternalRPCAPI(
+ "QueryDevices", h.apiURL+QueryDevicesPath,
+ h.httpClient, ctx, request, response,
+ )
}
-func (h *httpUserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAccountDataRequest, res *api.QueryAccountDataResponse) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryAccountData")
- defer span.Finish()
-
- apiURL := h.apiURL + QueryAccountDataPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+func (h *httpUserInternalAPI) QueryAccountData(
+ ctx context.Context,
+ request *api.QueryAccountDataRequest,
+ response *api.QueryAccountDataResponse,
+) error {
+ return httputil.CallInternalRPCAPI(
+ "QueryAccountData", h.apiURL+QueryAccountDataPath,
+ h.httpClient, ctx, request, response,
+ )
}
-func (h *httpUserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySearchProfiles")
- defer span.Finish()
-
- apiURL := h.apiURL + QuerySearchProfilesPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+func (h *httpUserInternalAPI) QuerySearchProfiles(
+ ctx context.Context,
+ request *api.QuerySearchProfilesRequest,
+ response *api.QuerySearchProfilesResponse,
+) error {
+ return httputil.CallInternalRPCAPI(
+ "QuerySearchProfiles", h.apiURL+QuerySearchProfilesPath,
+ h.httpClient, ctx, request, response,
+ )
}
-func (h *httpUserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryOpenIDToken")
- defer span.Finish()
-
- apiURL := h.apiURL + QueryOpenIDTokenPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+func (h *httpUserInternalAPI) QueryOpenIDToken(
+ ctx context.Context,
+ request *api.QueryOpenIDTokenRequest,
+ response *api.QueryOpenIDTokenResponse,
+) error {
+ return httputil.CallInternalRPCAPI(
+ "QueryOpenIDToken", h.apiURL+QueryOpenIDTokenPath,
+ h.httpClient, ctx, request, response,
+ )
}
-func (h *httpUserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformKeyBackup")
- defer span.Finish()
-
- apiURL := h.apiURL + PerformKeyBackupPath
- err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
- if err != nil {
- res.Error = err.Error()
- }
- return nil
-}
-func (h *httpUserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyBackup")
- defer span.Finish()
-
- apiURL := h.apiURL + QueryKeyBackupPath
- err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
- if err != nil {
- res.Error = err.Error()
- }
+func (h *httpUserInternalAPI) PerformKeyBackup(
+ ctx context.Context,
+ request *api.PerformKeyBackupRequest,
+ response *api.PerformKeyBackupResponse,
+) error {
+ return httputil.CallInternalRPCAPI(
+ "PerformKeyBackup", h.apiURL+PerformKeyBackupPath,
+ h.httpClient, ctx, request, response,
+ )
}
-func (h *httpUserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryNotifications")
- defer span.Finish()
+func (h *httpUserInternalAPI) QueryKeyBackup(
+ ctx context.Context,
+ request *api.QueryKeyBackupRequest,
+ response *api.QueryKeyBackupResponse,
+) error {
+ return httputil.CallInternalRPCAPI(
+ "QueryKeyBackup", h.apiURL+QueryKeyBackupPath,
+ h.httpClient, ctx, request, response,
+ )
+}
- return httputil.PostJSON(ctx, span, h.httpClient, h.apiURL+QueryNotificationsPath, req, res)
+func (h *httpUserInternalAPI) QueryNotifications(
+ ctx context.Context,
+ request *api.QueryNotificationsRequest,
+ response *api.QueryNotificationsResponse,
+) error {
+ return httputil.CallInternalRPCAPI(
+ "QueryNotifications", h.apiURL+QueryNotificationsPath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpUserInternalAPI) PerformPusherSet(
@@ -278,27 +293,32 @@ func (h *httpUserInternalAPI) PerformPusherSet(
request *api.PerformPusherSetRequest,
response *struct{},
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPusherSet")
- defer span.Finish()
-
- apiURL := h.apiURL + PerformPusherSetPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "PerformPusherSet", h.apiURL+PerformPusherSetPath,
+ h.httpClient, ctx, request, response,
+ )
}
-func (h *httpUserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPusherDeletion")
- defer span.Finish()
-
- apiURL := h.apiURL + PerformPusherDeletionPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+func (h *httpUserInternalAPI) PerformPusherDeletion(
+ ctx context.Context,
+ request *api.PerformPusherDeletionRequest,
+ response *struct{},
+) error {
+ return httputil.CallInternalRPCAPI(
+ "PerformPusherDeletion", h.apiURL+PerformPusherDeletionPath,
+ h.httpClient, ctx, request, response,
+ )
}
-func (h *httpUserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPushers")
- defer span.Finish()
-
- apiURL := h.apiURL + QueryPushersPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+func (h *httpUserInternalAPI) QueryPushers(
+ ctx context.Context,
+ request *api.QueryPushersRequest,
+ response *api.QueryPushersResponse,
+) error {
+ return httputil.CallInternalRPCAPI(
+ "QueryPushers", h.apiURL+QueryPushersPath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpUserInternalAPI) PerformPushRulesPut(
@@ -306,91 +326,119 @@ func (h *httpUserInternalAPI) PerformPushRulesPut(
request *api.PerformPushRulesPutRequest,
response *struct{},
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPushRulesPut")
- defer span.Finish()
-
- apiURL := h.apiURL + PerformPushRulesPutPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "PerformPushRulesPut", h.apiURL+PerformPushRulesPutPath,
+ h.httpClient, ctx, request, response,
+ )
}
-func (h *httpUserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPushRules")
- defer span.Finish()
-
- apiURL := h.apiURL + QueryPushRulesPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+func (h *httpUserInternalAPI) QueryPushRules(
+ ctx context.Context,
+ request *api.QueryPushRulesRequest,
+ response *api.QueryPushRulesResponse,
+) error {
+ return httputil.CallInternalRPCAPI(
+ "QueryPushRules", h.apiURL+QueryPushRulesPath,
+ h.httpClient, ctx, request, response,
+ )
}
-func (h *httpUserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, PerformSetAvatarURLPath)
- defer span.Finish()
-
- apiURL := h.apiURL + PerformSetAvatarURLPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+func (h *httpUserInternalAPI) SetAvatarURL(
+ ctx context.Context,
+ request *api.PerformSetAvatarURLRequest,
+ response *api.PerformSetAvatarURLResponse,
+) error {
+ return httputil.CallInternalRPCAPI(
+ "SetAvatarURL", h.apiURL+PerformSetAvatarURLPath,
+ h.httpClient, ctx, request, response,
+ )
}
-func (h *httpUserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, QueryNumericLocalpartPath)
- defer span.Finish()
-
- apiURL := h.apiURL + QueryNumericLocalpartPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, struct{}{}, res)
+func (h *httpUserInternalAPI) QueryNumericLocalpart(
+ ctx context.Context,
+ response *api.QueryNumericLocalpartResponse,
+) error {
+ return httputil.CallInternalRPCAPI(
+ "QueryNumericLocalpart", h.apiURL+QueryNumericLocalpartPath,
+ h.httpClient, ctx, &struct{}{}, response,
+ )
}
-func (h *httpUserInternalAPI) QueryAccountAvailability(ctx context.Context, req *api.QueryAccountAvailabilityRequest, res *api.QueryAccountAvailabilityResponse) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, QueryAccountAvailabilityPath)
- defer span.Finish()
-
- apiURL := h.apiURL + QueryAccountAvailabilityPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+func (h *httpUserInternalAPI) QueryAccountAvailability(
+ ctx context.Context,
+ request *api.QueryAccountAvailabilityRequest,
+ response *api.QueryAccountAvailabilityResponse,
+) error {
+ return httputil.CallInternalRPCAPI(
+ "QueryAccountAvailability", h.apiURL+QueryAccountAvailabilityPath,
+ h.httpClient, ctx, request, response,
+ )
}
-func (h *httpUserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, QueryAccountByPasswordPath)
- defer span.Finish()
-
- apiURL := h.apiURL + QueryAccountByPasswordPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+func (h *httpUserInternalAPI) QueryAccountByPassword(
+ ctx context.Context,
+ request *api.QueryAccountByPasswordRequest,
+ response *api.QueryAccountByPasswordResponse,
+) error {
+ return httputil.CallInternalRPCAPI(
+ "QueryAccountByPassword", h.apiURL+QueryAccountByPasswordPath,
+ h.httpClient, ctx, request, response,
+ )
}
-func (h *httpUserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *struct{}) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, PerformSetDisplayNamePath)
- defer span.Finish()
-
- apiURL := h.apiURL + PerformSetDisplayNamePath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+func (h *httpUserInternalAPI) SetDisplayName(
+ ctx context.Context,
+ request *api.PerformUpdateDisplayNameRequest,
+ response *struct{},
+) error {
+ return httputil.CallInternalRPCAPI(
+ "SetDisplayName", h.apiURL+PerformSetDisplayNamePath,
+ h.httpClient, ctx, request, response,
+ )
}
-func (h *httpUserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, QueryLocalpartForThreePIDPath)
- defer span.Finish()
-
- apiURL := h.apiURL + QueryLocalpartForThreePIDPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+func (h *httpUserInternalAPI) QueryLocalpartForThreePID(
+ ctx context.Context,
+ request *api.QueryLocalpartForThreePIDRequest,
+ response *api.QueryLocalpartForThreePIDResponse,
+) error {
+ return httputil.CallInternalRPCAPI(
+ "QueryLocalpartForThreePID", h.apiURL+QueryLocalpartForThreePIDPath,
+ h.httpClient, ctx, request, response,
+ )
}
-func (h *httpUserInternalAPI) QueryThreePIDsForLocalpart(ctx context.Context, req *api.QueryThreePIDsForLocalpartRequest, res *api.QueryThreePIDsForLocalpartResponse) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, QueryThreePIDsForLocalpartPath)
- defer span.Finish()
-
- apiURL := h.apiURL + QueryThreePIDsForLocalpartPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+func (h *httpUserInternalAPI) QueryThreePIDsForLocalpart(
+ ctx context.Context,
+ request *api.QueryThreePIDsForLocalpartRequest,
+ response *api.QueryThreePIDsForLocalpartResponse,
+) error {
+ return httputil.CallInternalRPCAPI(
+ "QueryThreePIDsForLocalpart", h.apiURL+QueryThreePIDsForLocalpartPath,
+ h.httpClient, ctx, request, response,
+ )
}
-func (h *httpUserInternalAPI) PerformForgetThreePID(ctx context.Context, req *api.PerformForgetThreePIDRequest, res *struct{}) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, PerformForgetThreePIDPath)
- defer span.Finish()
-
- apiURL := h.apiURL + PerformForgetThreePIDPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+func (h *httpUserInternalAPI) PerformForgetThreePID(
+ ctx context.Context,
+ request *api.PerformForgetThreePIDRequest,
+ response *struct{},
+) error {
+ return httputil.CallInternalRPCAPI(
+ "PerformForgetThreePID", h.apiURL+PerformForgetThreePIDPath,
+ h.httpClient, ctx, request, response,
+ )
}
-func (h *httpUserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, req *api.PerformSaveThreePIDAssociationRequest, res *struct{}) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, PerformSaveThreePIDAssociationPath)
- defer span.Finish()
-
- apiURL := h.apiURL + PerformSaveThreePIDAssociationPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+func (h *httpUserInternalAPI) PerformSaveThreePIDAssociation(
+ ctx context.Context,
+ request *api.PerformSaveThreePIDAssociationRequest,
+ response *struct{},
+) error {
+ return httputil.CallInternalRPCAPI(
+ "PerformSaveThreePIDAssociation", h.apiURL+PerformSaveThreePIDAssociationPath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpUserInternalAPI) DeleteProfile(ctx context.Context, req *api.PerformDeleteProfileRequest, res *struct{}) error {
diff --git a/userapi/inthttp/client_logintoken.go b/userapi/inthttp/client_logintoken.go
index 366a97099..211b1b7a1 100644
--- a/userapi/inthttp/client_logintoken.go
+++ b/userapi/inthttp/client_logintoken.go
@@ -19,7 +19,6 @@ import (
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/userapi/api"
- "github.com/opentracing/opentracing-go"
)
const (
@@ -33,11 +32,10 @@ func (h *httpUserInternalAPI) PerformLoginTokenCreation(
request *api.PerformLoginTokenCreationRequest,
response *api.PerformLoginTokenCreationResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLoginTokenCreation")
- defer span.Finish()
-
- apiURL := h.apiURL + PerformLoginTokenCreationPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "PerformLoginTokenCreation", h.apiURL+PerformLoginTokenCreationPath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpUserInternalAPI) PerformLoginTokenDeletion(
@@ -45,11 +43,10 @@ func (h *httpUserInternalAPI) PerformLoginTokenDeletion(
request *api.PerformLoginTokenDeletionRequest,
response *api.PerformLoginTokenDeletionResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLoginTokenDeletion")
- defer span.Finish()
-
- apiURL := h.apiURL + PerformLoginTokenDeletionPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "PerformLoginTokenDeletion", h.apiURL+PerformLoginTokenDeletionPath,
+ h.httpClient, ctx, request, response,
+ )
}
func (h *httpUserInternalAPI) QueryLoginToken(
@@ -57,9 +54,8 @@ func (h *httpUserInternalAPI) QueryLoginToken(
request *api.QueryLoginTokenRequest,
response *api.QueryLoginTokenResponse,
) error {
- span, ctx := opentracing.StartSpanFromContext(ctx, "QueryLoginToken")
- defer span.Finish()
-
- apiURL := h.apiURL + QueryLoginTokenPath
- return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+ return httputil.CallInternalRPCAPI(
+ "QueryLoginToken", h.apiURL+QueryLoginTokenPath,
+ h.httpClient, ctx, request, response,
+ )
}
diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go
index ee9429198..11971d80b 100644
--- a/userapi/inthttp/server.go
+++ b/userapi/inthttp/server.go
@@ -15,8 +15,6 @@
package inthttp
import (
- "encoding/json"
- "fmt"
"net/http"
"github.com/gorilla/mux"
@@ -29,339 +27,134 @@ import (
func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
addRoutesLoginToken(internalAPIMux, s)
- internalAPIMux.Handle(PerformAccountCreationPath,
- httputil.MakeInternalAPI("performAccountCreation", func(req *http.Request) util.JSONResponse {
- request := api.PerformAccountCreationRequest{}
- response := api.PerformAccountCreationResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := s.PerformAccountCreation(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
- )
- internalAPIMux.Handle(PerformPasswordUpdatePath,
- httputil.MakeInternalAPI("performPasswordUpdate", func(req *http.Request) util.JSONResponse {
- request := api.PerformPasswordUpdateRequest{}
- response := api.PerformPasswordUpdateResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := s.PerformPasswordUpdate(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
- )
- internalAPIMux.Handle(PerformDeviceCreationPath,
- httputil.MakeInternalAPI("performDeviceCreation", func(req *http.Request) util.JSONResponse {
- request := api.PerformDeviceCreationRequest{}
- response := api.PerformDeviceCreationResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := s.PerformDeviceCreation(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
- )
- internalAPIMux.Handle(PerformLastSeenUpdatePath,
- httputil.MakeInternalAPI("performLastSeenUpdate", func(req *http.Request) util.JSONResponse {
- request := api.PerformLastSeenUpdateRequest{}
- response := api.PerformLastSeenUpdateResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := s.PerformLastSeenUpdate(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
- )
- internalAPIMux.Handle(PerformDeviceUpdatePath,
- httputil.MakeInternalAPI("performDeviceUpdate", func(req *http.Request) util.JSONResponse {
- request := api.PerformDeviceUpdateRequest{}
- response := api.PerformDeviceUpdateResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := s.PerformDeviceUpdate(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
- )
- internalAPIMux.Handle(PerformDeviceDeletionPath,
- httputil.MakeInternalAPI("performDeviceDeletion", func(req *http.Request) util.JSONResponse {
- request := api.PerformDeviceDeletionRequest{}
- response := api.PerformDeviceDeletionResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := s.PerformDeviceDeletion(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
- )
- internalAPIMux.Handle(PerformAccountDeactivationPath,
- httputil.MakeInternalAPI("performAccountDeactivation", func(req *http.Request) util.JSONResponse {
- request := api.PerformAccountDeactivationRequest{}
- response := api.PerformAccountDeactivationResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := s.PerformAccountDeactivation(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
- )
- internalAPIMux.Handle(PerformOpenIDTokenCreationPath,
- httputil.MakeInternalAPI("performOpenIDTokenCreation", func(req *http.Request) util.JSONResponse {
- request := api.PerformOpenIDTokenCreationRequest{}
- response := api.PerformOpenIDTokenCreationResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := s.PerformOpenIDTokenCreation(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
- )
- internalAPIMux.Handle(QueryProfilePath,
- httputil.MakeInternalAPI("queryProfile", func(req *http.Request) util.JSONResponse {
- request := api.QueryProfileRequest{}
- response := api.QueryProfileResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := s.QueryProfile(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
- )
- internalAPIMux.Handle(QueryAccessTokenPath,
- httputil.MakeInternalAPI("queryAccessToken", func(req *http.Request) util.JSONResponse {
- request := api.QueryAccessTokenRequest{}
- response := api.QueryAccessTokenResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := s.QueryAccessToken(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
- )
- internalAPIMux.Handle(QueryDevicesPath,
- httputil.MakeInternalAPI("queryDevices", func(req *http.Request) util.JSONResponse {
- request := api.QueryDevicesRequest{}
- response := api.QueryDevicesResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := s.QueryDevices(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
- )
- internalAPIMux.Handle(QueryAccountDataPath,
- httputil.MakeInternalAPI("queryAccountData", func(req *http.Request) util.JSONResponse {
- request := api.QueryAccountDataRequest{}
- response := api.QueryAccountDataResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := s.QueryAccountData(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
- )
- internalAPIMux.Handle(QueryDeviceInfosPath,
- httputil.MakeInternalAPI("queryDeviceInfos", func(req *http.Request) util.JSONResponse {
- request := api.QueryDeviceInfosRequest{}
- response := api.QueryDeviceInfosResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := s.QueryDeviceInfos(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
- )
- internalAPIMux.Handle(QuerySearchProfilesPath,
- httputil.MakeInternalAPI("querySearchProfiles", func(req *http.Request) util.JSONResponse {
- request := api.QuerySearchProfilesRequest{}
- response := api.QuerySearchProfilesResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := s.QuerySearchProfiles(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
- )
- internalAPIMux.Handle(QueryOpenIDTokenPath,
- httputil.MakeInternalAPI("queryOpenIDToken", func(req *http.Request) util.JSONResponse {
- request := api.QueryOpenIDTokenRequest{}
- response := api.QueryOpenIDTokenResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := s.QueryOpenIDToken(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
- )
- internalAPIMux.Handle(InputAccountDataPath,
- httputil.MakeInternalAPI("inputAccountDataPath", func(req *http.Request) util.JSONResponse {
- request := api.InputAccountDataRequest{}
- response := api.InputAccountDataResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := s.InputAccountData(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
- )
- internalAPIMux.Handle(QueryKeyBackupPath,
- httputil.MakeInternalAPI("queryKeyBackup", func(req *http.Request) util.JSONResponse {
- request := api.QueryKeyBackupRequest{}
- response := api.QueryKeyBackupResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- s.QueryKeyBackup(req.Context(), &request, &response)
- if response.Error != "" {
- return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", response.Error))
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
- )
- internalAPIMux.Handle(PerformKeyBackupPath,
- httputil.MakeInternalAPI("performKeyBackup", func(req *http.Request) util.JSONResponse {
- request := api.PerformKeyBackupRequest{}
- response := api.PerformKeyBackupResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- err := s.PerformKeyBackup(req.Context(), &request, &response)
- if err != nil {
- return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response}
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
- )
- internalAPIMux.Handle(QueryNotificationsPath,
- httputil.MakeInternalAPI("queryNotifications", func(req *http.Request) util.JSONResponse {
- var request api.QueryNotificationsRequest
- var response api.QueryNotificationsResponse
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := s.QueryNotifications(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+ internalAPIMux.Handle(
+ PerformAccountCreationPath,
+ httputil.MakeInternalRPCAPI("UserAPIPerformAccountCreation", s.PerformAccountCreation),
)
- internalAPIMux.Handle(PerformPusherSetPath,
- httputil.MakeInternalAPI("performPusherSet", func(req *http.Request) util.JSONResponse {
- request := api.PerformPusherSetRequest{}
- response := struct{}{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := s.PerformPusherSet(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
- )
- internalAPIMux.Handle(PerformPusherDeletionPath,
- httputil.MakeInternalAPI("performPusherDeletion", func(req *http.Request) util.JSONResponse {
- request := api.PerformPusherDeletionRequest{}
- response := struct{}{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := s.PerformPusherDeletion(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+ internalAPIMux.Handle(
+ PerformPasswordUpdatePath,
+ httputil.MakeInternalRPCAPI("UserAPIPerformPasswordUpdate", s.PerformPasswordUpdate),
)
- internalAPIMux.Handle(QueryPushersPath,
- httputil.MakeInternalAPI("queryPushers", func(req *http.Request) util.JSONResponse {
- request := api.QueryPushersRequest{}
- response := api.QueryPushersResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := s.QueryPushers(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+ internalAPIMux.Handle(
+ PerformDeviceCreationPath,
+ httputil.MakeInternalRPCAPI("UserAPIPerformDeviceCreation", s.PerformDeviceCreation),
)
- internalAPIMux.Handle(PerformPushRulesPutPath,
- httputil.MakeInternalAPI("performPushRulesPut", func(req *http.Request) util.JSONResponse {
- request := api.PerformPushRulesPutRequest{}
- response := struct{}{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := s.PerformPushRulesPut(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+ internalAPIMux.Handle(
+ PerformLastSeenUpdatePath,
+ httputil.MakeInternalRPCAPI("UserAPIPerformLastSeenUpdate", s.PerformLastSeenUpdate),
)
- internalAPIMux.Handle(QueryPushRulesPath,
- httputil.MakeInternalAPI("queryPushRules", func(req *http.Request) util.JSONResponse {
- request := api.QueryPushRulesRequest{}
- response := api.QueryPushRulesResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := s.QueryPushRules(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+ internalAPIMux.Handle(
+ PerformDeviceUpdatePath,
+ httputil.MakeInternalRPCAPI("UserAPIPerformDeviceUpdate", s.PerformDeviceUpdate),
)
- internalAPIMux.Handle(PerformSetAvatarURLPath,
- httputil.MakeInternalAPI("performSetAvatarURL", func(req *http.Request) util.JSONResponse {
- request := api.PerformSetAvatarURLRequest{}
- response := api.PerformSetAvatarURLResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := s.SetAvatarURL(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+
+ internalAPIMux.Handle(
+ PerformDeviceDeletionPath,
+ httputil.MakeInternalRPCAPI("UserAPIPerformDeviceDeletion", s.PerformDeviceDeletion),
)
+
+ internalAPIMux.Handle(
+ PerformAccountDeactivationPath,
+ httputil.MakeInternalRPCAPI("UserAPIPerformAccountDeactivation", s.PerformAccountDeactivation),
+ )
+
+ internalAPIMux.Handle(
+ PerformOpenIDTokenCreationPath,
+ httputil.MakeInternalRPCAPI("UserAPIPerformOpenIDTokenCreation", s.PerformOpenIDTokenCreation),
+ )
+
+ internalAPIMux.Handle(
+ QueryProfilePath,
+ httputil.MakeInternalRPCAPI("UserAPIQueryProfile", s.QueryProfile),
+ )
+
+ internalAPIMux.Handle(
+ QueryAccessTokenPath,
+ httputil.MakeInternalRPCAPI("UserAPIQueryAccessToken", s.QueryAccessToken),
+ )
+
+ internalAPIMux.Handle(
+ QueryDevicesPath,
+ httputil.MakeInternalRPCAPI("UserAPIQueryDevices", s.QueryDevices),
+ )
+
+ internalAPIMux.Handle(
+ QueryAccountDataPath,
+ httputil.MakeInternalRPCAPI("UserAPIQueryAccountData", s.QueryAccountData),
+ )
+
+ internalAPIMux.Handle(
+ QueryDeviceInfosPath,
+ httputil.MakeInternalRPCAPI("UserAPIQueryDeviceInfos", s.QueryDeviceInfos),
+ )
+
+ internalAPIMux.Handle(
+ QuerySearchProfilesPath,
+ httputil.MakeInternalRPCAPI("UserAPIQuerySearchProfiles", s.QuerySearchProfiles),
+ )
+
+ internalAPIMux.Handle(
+ QueryOpenIDTokenPath,
+ httputil.MakeInternalRPCAPI("UserAPIQueryOpenIDToken", s.QueryOpenIDToken),
+ )
+
+ internalAPIMux.Handle(
+ InputAccountDataPath,
+ httputil.MakeInternalRPCAPI("UserAPIInputAccountData", s.InputAccountData),
+ )
+
+ internalAPIMux.Handle(
+ QueryKeyBackupPath,
+ httputil.MakeInternalRPCAPI("UserAPIQueryKeyBackup", s.QueryKeyBackup),
+ )
+
+ internalAPIMux.Handle(
+ PerformKeyBackupPath,
+ httputil.MakeInternalRPCAPI("UserAPIPerformKeyBackup", s.PerformKeyBackup),
+ )
+
+ internalAPIMux.Handle(
+ QueryNotificationsPath,
+ httputil.MakeInternalRPCAPI("UserAPIQueryNotifications", s.QueryNotifications),
+ )
+
+ internalAPIMux.Handle(
+ PerformPusherSetPath,
+ httputil.MakeInternalRPCAPI("UserAPIPerformPusherSet", s.PerformPusherSet),
+ )
+
+ internalAPIMux.Handle(
+ PerformPusherDeletionPath,
+ httputil.MakeInternalRPCAPI("UserAPIPerformPusherDeletion", s.PerformPusherDeletion),
+ )
+
+ internalAPIMux.Handle(
+ QueryPushersPath,
+ httputil.MakeInternalRPCAPI("UserAPIQueryPushers", s.QueryPushers),
+ )
+
+ internalAPIMux.Handle(
+ PerformPushRulesPutPath,
+ httputil.MakeInternalRPCAPI("UserAPIPerformPushRulesPut", s.PerformPushRulesPut),
+ )
+
+ internalAPIMux.Handle(
+ QueryPushRulesPath,
+ httputil.MakeInternalRPCAPI("UserAPIQueryPushRules", s.QueryPushRules),
+ )
+
+ internalAPIMux.Handle(
+ PerformSetAvatarURLPath,
+ httputil.MakeInternalRPCAPI("UserAPIPerformSetAvatarURL", s.SetAvatarURL),
+ )
+
+ // TODO: Look at the shape of this
internalAPIMux.Handle(QueryNumericLocalpartPath,
- httputil.MakeInternalAPI("queryNumericLocalpart", func(req *http.Request) util.JSONResponse {
+ httputil.MakeInternalAPI("UserAPIQueryNumericLocalpart", func(req *http.Request) util.JSONResponse {
response := api.QueryNumericLocalpartResponse{}
if err := s.QueryNumericLocalpart(req.Context(), &response); err != nil {
return util.ErrorResponse(err)
@@ -369,93 +162,40 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
- internalAPIMux.Handle(QueryAccountAvailabilityPath,
- httputil.MakeInternalAPI("queryAccountAvailability", func(req *http.Request) util.JSONResponse {
- request := api.QueryAccountAvailabilityRequest{}
- response := api.QueryAccountAvailabilityResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := s.QueryAccountAvailability(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+
+ internalAPIMux.Handle(
+ QueryAccountAvailabilityPath,
+ httputil.MakeInternalRPCAPI("UserAPIQueryAccountAvailability", s.QueryAccountAvailability),
)
- internalAPIMux.Handle(QueryAccountByPasswordPath,
- httputil.MakeInternalAPI("queryAccountByPassword", func(req *http.Request) util.JSONResponse {
- request := api.QueryAccountByPasswordRequest{}
- response := api.QueryAccountByPasswordResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := s.QueryAccountByPassword(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+
+ internalAPIMux.Handle(
+ QueryAccountByPasswordPath,
+ httputil.MakeInternalRPCAPI("UserAPIQueryAccountByPassword", s.QueryAccountByPassword),
)
- internalAPIMux.Handle(PerformSetDisplayNamePath,
- httputil.MakeInternalAPI("performSetDisplayName", func(req *http.Request) util.JSONResponse {
- request := api.PerformUpdateDisplayNameRequest{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := s.SetDisplayName(req.Context(), &request, &struct{}{}); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &struct{}{}}
- }),
+
+ internalAPIMux.Handle(
+ PerformSetDisplayNamePath,
+ httputil.MakeInternalRPCAPI("UserAPISetDisplayName", s.SetDisplayName),
)
- internalAPIMux.Handle(QueryLocalpartForThreePIDPath,
- httputil.MakeInternalAPI("queryLocalpartForThreePID", func(req *http.Request) util.JSONResponse {
- request := api.QueryLocalpartForThreePIDRequest{}
- response := api.QueryLocalpartForThreePIDResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := s.QueryLocalpartForThreePID(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+
+ internalAPIMux.Handle(
+ QueryLocalpartForThreePIDPath,
+ httputil.MakeInternalRPCAPI("UserAPIQueryLocalpartForThreePID", s.QueryLocalpartForThreePID),
)
- internalAPIMux.Handle(QueryThreePIDsForLocalpartPath,
- httputil.MakeInternalAPI("queryThreePIDsForLocalpart", func(req *http.Request) util.JSONResponse {
- request := api.QueryThreePIDsForLocalpartRequest{}
- response := api.QueryThreePIDsForLocalpartResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := s.QueryThreePIDsForLocalpart(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+
+ internalAPIMux.Handle(
+ QueryThreePIDsForLocalpartPath,
+ httputil.MakeInternalRPCAPI("UserAPIQueryThreePIDsForLocalpart", s.QueryThreePIDsForLocalpart),
)
- internalAPIMux.Handle(PerformForgetThreePIDPath,
- httputil.MakeInternalAPI("performForgetThreePID", func(req *http.Request) util.JSONResponse {
- request := api.PerformForgetThreePIDRequest{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := s.PerformForgetThreePID(req.Context(), &request, &struct{}{}); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &struct{}{}}
- }),
+
+ internalAPIMux.Handle(
+ PerformForgetThreePIDPath,
+ httputil.MakeInternalRPCAPI("UserAPIPerformForgetThreePID", s.PerformForgetThreePID),
)
- internalAPIMux.Handle(PerformSaveThreePIDAssociationPath,
- httputil.MakeInternalAPI("performSaveThreePIDAssociation", func(req *http.Request) util.JSONResponse {
- request := api.PerformSaveThreePIDAssociationRequest{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := s.PerformSaveThreePIDAssociation(req.Context(), &request, &struct{}{}); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &struct{}{}}
- }),
+
+ internalAPIMux.Handle(
+ PerformSaveThreePIDAssociationPath,
+ httputil.MakeInternalRPCAPI("UserAPIPerformSaveThreePIDAssociation", s.PerformSaveThreePIDAssociation),
)
internalAPIMux.Handle(PerformDeleteUserProfilePath,
httputil.MakeInternalAPI("performDeleteUserProfilePath", func(req *http.Request) util.JSONResponse {
diff --git a/userapi/inthttp/server_logintoken.go b/userapi/inthttp/server_logintoken.go
index 1f2eb34b9..b57348413 100644
--- a/userapi/inthttp/server_logintoken.go
+++ b/userapi/inthttp/server_logintoken.go
@@ -15,54 +15,25 @@
package inthttp
import (
- "encoding/json"
- "net/http"
-
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/util"
)
// addRoutesLoginToken adds routes for all login token API calls.
func addRoutesLoginToken(internalAPIMux *mux.Router, s api.UserInternalAPI) {
- internalAPIMux.Handle(PerformLoginTokenCreationPath,
- httputil.MakeInternalAPI("performLoginTokenCreation", func(req *http.Request) util.JSONResponse {
- request := api.PerformLoginTokenCreationRequest{}
- response := api.PerformLoginTokenCreationResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := s.PerformLoginTokenCreation(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+ internalAPIMux.Handle(
+ PerformLoginTokenCreationPath,
+ httputil.MakeInternalRPCAPI("UserAPIPerformLoginTokenCreation", s.PerformLoginTokenCreation),
)
- internalAPIMux.Handle(PerformLoginTokenDeletionPath,
- httputil.MakeInternalAPI("performLoginTokenDeletion", func(req *http.Request) util.JSONResponse {
- request := api.PerformLoginTokenDeletionRequest{}
- response := api.PerformLoginTokenDeletionResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := s.PerformLoginTokenDeletion(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+
+ internalAPIMux.Handle(
+ PerformLoginTokenDeletionPath,
+ httputil.MakeInternalRPCAPI("UserAPIPerformLoginTokenDeletion", s.PerformLoginTokenDeletion),
)
- internalAPIMux.Handle(QueryLoginTokenPath,
- httputil.MakeInternalAPI("queryLoginToken", func(req *http.Request) util.JSONResponse {
- request := api.QueryLoginTokenRequest{}
- response := api.QueryLoginTokenResponse{}
- if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
- return util.MessageResponse(http.StatusBadRequest, err.Error())
- }
- if err := s.QueryLoginToken(req.Context(), &request, &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+
+ internalAPIMux.Handle(
+ QueryLoginTokenPath,
+ httputil.MakeInternalRPCAPI("UserAPIQueryLoginToken", s.QueryLoginToken),
)
}
diff --git a/userapi/producers/syncapi.go b/userapi/producers/syncapi.go
index 4a206f333..27cfc2848 100644
--- a/userapi/producers/syncapi.go
+++ b/userapi/producers/syncapi.go
@@ -34,7 +34,7 @@ func NewSyncAPI(db storage.Database, js JetStreamPublisher, clientDataTopic stri
}
// SendAccountData sends account data to the Sync API server.
-func (p *SyncAPI) SendAccountData(userID string, roomID string, dataType string) error {
+func (p *SyncAPI) SendAccountData(userID string, data eventutil.AccountData) error {
m := &nats.Msg{
Subject: p.clientDataTopic,
Header: nats.Header{},
@@ -42,18 +42,15 @@ func (p *SyncAPI) SendAccountData(userID string, roomID string, dataType string)
m.Header.Set(jetstream.UserID, userID)
var err error
- m.Data, err = json.Marshal(eventutil.AccountData{
- RoomID: roomID,
- Type: dataType,
- })
+ m.Data, err = json.Marshal(data)
if err != nil {
return err
}
log.WithFields(log.Fields{
"user_id": userID,
- "room_id": roomID,
- "data_type": dataType,
+ "room_id": data.RoomID,
+ "data_type": data.Type,
}).Tracef("Producing to topic '%s'", p.clientDataTopic)
_, err = p.producer.PublishMsg(m)
diff --git a/userapi/storage/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go
index e3cab56ee..33fb6dd42 100644
--- a/userapi/storage/postgres/accounts_table.go
+++ b/userapi/storage/postgres/accounts_table.go
@@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/userapi/storage/tables"
log "github.com/sirupsen/logrus"
@@ -85,6 +86,23 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam
if err != nil {
return nil, err
}
+ m := sqlutil.NewMigrator(db)
+ m.AddMigrations([]sqlutil.Migration{
+ {
+ Version: "userapi: add is active",
+ Up: deltas.UpIsActive,
+ Down: deltas.DownIsActive,
+ },
+ {
+ Version: "userapi: add account type",
+ Up: deltas.UpAddAccountType,
+ Down: deltas.DownAddAccountType,
+ },
+ }...)
+ err = m.Up(context.Background())
+ if err != nil {
+ return nil, err
+ }
return s, sqlutil.StatementList{
{&s.insertAccountStmt, insertAccountSQL},
{&s.updatePasswordStmt, updatePasswordSQL},
diff --git a/userapi/storage/postgres/deltas/20200929203058_is_active.go b/userapi/storage/postgres/deltas/20200929203058_is_active.go
index 32d3235be..24f87e073 100644
--- a/userapi/storage/postgres/deltas/20200929203058_is_active.go
+++ b/userapi/storage/postgres/deltas/20200929203058_is_active.go
@@ -1,33 +1,21 @@
package deltas
import (
+ "context"
"database/sql"
"fmt"
-
- "github.com/pressly/goose"
-
- "github.com/matrix-org/dendrite/internal/sqlutil"
)
-func LoadFromGoose() {
- goose.AddMigration(UpIsActive, DownIsActive)
- goose.AddMigration(UpAddAccountType, DownAddAccountType)
-}
-
-func LoadIsActive(m *sqlutil.Migrations) {
- m.AddMigration(UpIsActive, DownIsActive)
-}
-
-func UpIsActive(tx *sql.Tx) error {
- _, err := tx.Exec("ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS is_deactivated BOOLEAN DEFAULT FALSE;")
+func UpIsActive(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, "ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS is_deactivated BOOLEAN DEFAULT FALSE;")
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
-func DownIsActive(tx *sql.Tx) error {
- _, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN is_deactivated;")
+func DownIsActive(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, "ALTER TABLE account_accounts DROP COLUMN is_deactivated;")
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
diff --git a/userapi/storage/postgres/deltas/20201001204705_last_seen_ts_ip.go b/userapi/storage/postgres/deltas/20201001204705_last_seen_ts_ip.go
index 1bbb0a9d3..edd3353f0 100644
--- a/userapi/storage/postgres/deltas/20201001204705_last_seen_ts_ip.go
+++ b/userapi/storage/postgres/deltas/20201001204705_last_seen_ts_ip.go
@@ -1,18 +1,13 @@
package deltas
import (
+ "context"
"database/sql"
"fmt"
-
- "github.com/matrix-org/dendrite/internal/sqlutil"
)
-func LoadLastSeenTSIP(m *sqlutil.Migrations) {
- m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
-}
-
-func UpLastSeenTSIP(tx *sql.Tx) error {
- _, err := tx.Exec(`
+func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS last_seen_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM CURRENT_TIMESTAMP)*1000;
ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS ip TEXT;
ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS user_agent TEXT;`)
@@ -22,8 +17,8 @@ ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS user_agent TEXT;`)
return nil
}
-func DownLastSeenTSIP(tx *sql.Tx) error {
- _, err := tx.Exec(`
+func DownLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
ALTER TABLE device_devices DROP COLUMN last_seen_ts;
ALTER TABLE device_devices DROP COLUMN ip;
ALTER TABLE device_devices DROP COLUMN user_agent;`)
diff --git a/userapi/storage/postgres/deltas/2022021013023800_add_account_type.go b/userapi/storage/postgres/deltas/2022021013023800_add_account_type.go
index 2fae00cb9..eb7c3a958 100644
--- a/userapi/storage/postgres/deltas/2022021013023800_add_account_type.go
+++ b/userapi/storage/postgres/deltas/2022021013023800_add_account_type.go
@@ -1,20 +1,15 @@
package deltas
import (
+ "context"
"database/sql"
"fmt"
-
- "github.com/matrix-org/dendrite/internal/sqlutil"
)
-func LoadAddAccountType(m *sqlutil.Migrations) {
- m.AddMigration(UpAddAccountType, DownAddAccountType)
-}
-
-func UpAddAccountType(tx *sql.Tx) error {
+func UpAddAccountType(ctx context.Context, tx *sql.Tx) error {
// initially set every account to useraccount, change appservice and guest accounts afterwards
// (user = 1, guest = 2, admin = 3, appservice = 4)
- _, err := tx.Exec(`ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS account_type SMALLINT NOT NULL DEFAULT 1;
+ _, err := tx.ExecContext(ctx, `ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS account_type SMALLINT NOT NULL DEFAULT 1;
UPDATE account_accounts SET account_type = 4 WHERE appservice_id <> '';
UPDATE account_accounts SET account_type = 2 WHERE localpart ~ '^[0-9]+$';
ALTER TABLE account_accounts ALTER COLUMN account_type DROP DEFAULT;`,
@@ -25,8 +20,8 @@ ALTER TABLE account_accounts ALTER COLUMN account_type DROP DEFAULT;`,
return nil
}
-func DownAddAccountType(tx *sql.Tx) error {
- _, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN account_type;")
+func DownAddAccountType(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, "ALTER TABLE account_accounts DROP COLUMN account_type;")
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go
index ccb776672..f65681aae 100644
--- a/userapi/storage/postgres/devices_table.go
+++ b/userapi/storage/postgres/devices_table.go
@@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
@@ -120,6 +121,15 @@ func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName
if err != nil {
return nil, err
}
+ m := sqlutil.NewMigrator(db)
+ m.AddMigrations(sqlutil.Migration{
+ Version: "userapi: add last_seen_ts",
+ Up: deltas.UpLastSeenTSIP,
+ })
+ err = m.Up(context.Background())
+ if err != nil {
+ return nil, err
+ }
return s, sqlutil.StatementList{
{&s.insertDeviceStmt, insertDeviceSQL},
{&s.selectDeviceByTokenStmt, selectDeviceByTokenSQL},
diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go
index c70122d65..7d3b9b6a5 100644
--- a/userapi/storage/postgres/storage.go
+++ b/userapi/storage/postgres/storage.go
@@ -23,7 +23,6 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
- "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/userapi/storage/shared"
// Import the postgres database driver.
@@ -37,23 +36,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
return nil, err
}
- m := sqlutil.NewMigrations()
- if _, err = db.Exec(accountsSchema); err != nil {
- // do this so that the migration can and we don't fail on
- // preparing statements for columns that don't exist yet
- return nil, err
- }
- if _, err = db.Exec(profilesSchema); err != nil {
- return nil, err
- }
- deltas.LoadIsActive(m)
- //deltas.LoadLastSeenTSIP(m)
- deltas.LoadAddAccountType(m)
- deltas.LoadProfilePrimaryKey(m, serverName)
- if err = m.RunDeltas(db, dbProperties); err != nil {
- return nil, err
- }
-
accountDataTable, err := NewPostgresAccountDataTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err)
diff --git a/userapi/storage/sqlite3/accounts_table.go b/userapi/storage/sqlite3/accounts_table.go
index 6c5fe3071..484e90056 100644
--- a/userapi/storage/sqlite3/accounts_table.go
+++ b/userapi/storage/sqlite3/accounts_table.go
@@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/userapi/storage/tables"
log "github.com/sirupsen/logrus"
@@ -87,6 +88,23 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
if err != nil {
return nil, err
}
+ m := sqlutil.NewMigrator(db)
+ m.AddMigrations([]sqlutil.Migration{
+ {
+ Version: "userapi: add is active",
+ Up: deltas.UpIsActive,
+ Down: deltas.DownIsActive,
+ },
+ {
+ Version: "userapi: add account type",
+ Up: deltas.UpAddAccountType,
+ Down: deltas.DownAddAccountType,
+ },
+ }...)
+ err = m.Up(context.Background())
+ if err != nil {
+ return nil, err
+ }
return s, sqlutil.StatementList{
{&s.insertAccountStmt, insertAccountSQL},
{&s.updatePasswordStmt, updatePasswordSQL},
diff --git a/userapi/storage/sqlite3/deltas/20200929203058_is_active.go b/userapi/storage/sqlite3/deltas/20200929203058_is_active.go
index c69614e83..e25efc695 100644
--- a/userapi/storage/sqlite3/deltas/20200929203058_is_active.go
+++ b/userapi/storage/sqlite3/deltas/20200929203058_is_active.go
@@ -1,25 +1,13 @@
package deltas
import (
+ "context"
"database/sql"
"fmt"
-
- "github.com/pressly/goose"
-
- "github.com/matrix-org/dendrite/internal/sqlutil"
)
-func LoadFromGoose() {
- goose.AddMigration(UpIsActive, DownIsActive)
- goose.AddMigration(UpAddAccountType, DownAddAccountType)
-}
-
-func LoadIsActive(m *sqlutil.Migrations) {
- m.AddMigration(UpIsActive, DownIsActive)
-}
-
-func UpIsActive(tx *sql.Tx) error {
- _, err := tx.Exec(`
+func UpIsActive(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
ALTER TABLE account_accounts RENAME TO account_accounts_tmp;
CREATE TABLE account_accounts (
localpart TEXT NOT NULL PRIMARY KEY,
@@ -42,8 +30,8 @@ DROP TABLE account_accounts_tmp;`)
return nil
}
-func DownIsActive(tx *sql.Tx) error {
- _, err := tx.Exec(`
+func DownIsActive(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
ALTER TABLE account_accounts RENAME TO account_accounts_tmp;
CREATE TABLE account_accounts (
localpart TEXT NOT NULL PRIMARY KEY,
diff --git a/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go b/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go
index ebf908001..7f7e95d2d 100644
--- a/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go
+++ b/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go
@@ -1,18 +1,13 @@
package deltas
import (
+ "context"
"database/sql"
"fmt"
-
- "github.com/matrix-org/dendrite/internal/sqlutil"
)
-func LoadLastSeenTSIP(m *sqlutil.Migrations) {
- m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
-}
-
-func UpLastSeenTSIP(tx *sql.Tx) error {
- _, err := tx.Exec(`
+func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
ALTER TABLE device_devices RENAME TO device_devices_tmp;
CREATE TABLE device_devices (
access_token TEXT PRIMARY KEY,
@@ -39,8 +34,8 @@ func UpLastSeenTSIP(tx *sql.Tx) error {
return nil
}
-func DownLastSeenTSIP(tx *sql.Tx) error {
- _, err := tx.Exec(`
+func DownLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
ALTER TABLE device_devices RENAME TO device_devices_tmp;
CREATE TABLE IF NOT EXISTS device_devices (
access_token TEXT PRIMARY KEY,
diff --git a/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go b/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go
index 9b058dedd..46532698c 100644
--- a/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go
+++ b/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go
@@ -1,26 +1,15 @@
package deltas
import (
+ "context"
"database/sql"
"fmt"
-
- "github.com/pressly/goose"
-
- "github.com/matrix-org/dendrite/internal/sqlutil"
)
-func init() {
- goose.AddMigration(UpAddAccountType, DownAddAccountType)
-}
-
-func LoadAddAccountType(m *sqlutil.Migrations) {
- m.AddMigration(UpAddAccountType, DownAddAccountType)
-}
-
-func UpAddAccountType(tx *sql.Tx) error {
+func UpAddAccountType(ctx context.Context, tx *sql.Tx) error {
// initially set every account to useraccount, change appservice and guest accounts afterwards
// (user = 1, guest = 2, admin = 3, appservice = 4)
- _, err := tx.Exec(`ALTER TABLE account_accounts RENAME TO account_accounts_tmp;
+ _, err := tx.ExecContext(ctx, `ALTER TABLE account_accounts RENAME TO account_accounts_tmp;
CREATE TABLE account_accounts (
localpart TEXT NOT NULL PRIMARY KEY,
created_ts BIGINT NOT NULL,
@@ -45,8 +34,8 @@ DROP TABLE account_accounts_tmp;`)
return nil
}
-func DownAddAccountType(tx *sql.Tx) error {
- _, err := tx.Exec(`ALTER TABLE account_accounts DROP COLUMN account_type;`)
+func DownAddAccountType(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `ALTER TABLE account_accounts DROP COLUMN account_type;`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go
index 93291e6ad..27a7524d6 100644
--- a/userapi/storage/sqlite3/devices_table.go
+++ b/userapi/storage/sqlite3/devices_table.go
@@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/dendrite/clientapi/userutil"
@@ -107,6 +108,15 @@ func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
if err != nil {
return nil, err
}
+ m := sqlutil.NewMigrator(db)
+ m.AddMigrations(sqlutil.Migration{
+ Version: "userapi: add last_seen_ts",
+ Up: deltas.UpLastSeenTSIP,
+ })
+ if err = m.Up(context.Background()); err != nil {
+ return nil, err
+ }
+
return s, sqlutil.StatementList{
{&s.insertDeviceStmt, insertDeviceSQL},
{&s.selectDevicesCountStmt, selectDevicesCountSQL},
diff --git a/userapi/storage/sqlite3/stats_table.go b/userapi/storage/sqlite3/stats_table.go
index e00ed417b..8aa1746c5 100644
--- a/userapi/storage/sqlite3/stats_table.go
+++ b/userapi/storage/sqlite3/stats_table.go
@@ -20,13 +20,14 @@ import (
"strings"
"time"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/sirupsen/logrus"
+
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/dendrite/userapi/types"
- "github.com/matrix-org/gomatrixserverlib"
- "github.com/sirupsen/logrus"
)
const userDailyVisitsSchema = `
@@ -297,11 +298,10 @@ func (s *statsStatements) monthlyUsers(ctx context.Context, txn *sql.Tx) (result
return
}
-/* R30Users counts the number of 30 day retained users, defined as:
-- Users who have created their accounts more than 30 days ago
-- Where last seen at most 30 days ago
-- Where account creation and last_seen are > 30 days apart
-*/
+// R30Users counts the number of 30 day retained users, defined as:
+// - Users who have created their accounts more than 30 days ago
+// - Where last seen at most 30 days ago
+// - Where account creation and last_seen are > 30 days apart
func (s *statsStatements) r30Users(ctx context.Context, txn *sql.Tx) (map[string]int64, error) {
stmt := sqlutil.TxStmt(txn, s.countR30UsersStmt)
lastSeenAfter := time.Now().AddDate(0, 0, -30)
@@ -334,7 +334,8 @@ func (s *statsStatements) r30Users(ctx context.Context, txn *sql.Tx) (map[string
return result, rows.Err()
}
-/* R30UsersV2 counts the number of 30 day retained users, defined as users that:
+/*
+R30UsersV2 counts the number of 30 day retained users, defined as users that:
- Appear more than once in the past 60 days
- Have more than 30 days between the most and least recent appearances that occurred in the past 60 days.
*/
diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go
index cacf7e1b5..78b7ce588 100644
--- a/userapi/storage/sqlite3/storage.go
+++ b/userapi/storage/sqlite3/storage.go
@@ -25,10 +25,6 @@ import (
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/shared"
- "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
-
- // Import the postgres database driver.
- _ "github.com/lib/pq"
)
// NewDatabase creates a new accounts and profiles database
@@ -38,23 +34,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
return nil, err
}
- m := sqlutil.NewMigrations()
- if _, err = db.Exec(accountsSchema); err != nil {
- // do this so that the migration can and we don't fail on
- // preparing statements for columns that don't exist yet
- return nil, err
- }
- if _, err = db.Exec(profilesSchema); err != nil {
- return nil, err
- }
- deltas.LoadIsActive(m)
- //deltas.LoadLastSeenTSIP(m)
- deltas.LoadAddAccountType(m)
- deltas.LoadProfilePrimaryKey(m, serverName)
- if err = m.RunDeltas(db, dbProperties); err != nil {
- return nil, err
- }
-
accountDataTable, err := NewSQLiteAccountDataTable(db)
if err != nil {
return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err)
diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go
index 1dc78f063..cb9a668fa 100644
--- a/userapi/userapi_test.go
+++ b/userapi/userapi_test.go
@@ -117,16 +117,20 @@ func TestQueryProfile(t *testing.T) {
},
}
- runCases := func(testAPI api.UserInternalAPI) {
+ runCases := func(testAPI api.UserInternalAPI, http bool) {
+ mode := "monolith"
+ if http {
+ mode = "HTTP"
+ }
for _, tc := range testCases {
var gotRes api.QueryProfileResponse
gotErr := testAPI.QueryProfile(context.TODO(), &tc.req, &gotRes)
if tc.wantErr == nil && gotErr != nil || tc.wantErr != nil && gotErr == nil {
- t.Errorf("QueryProfile error, got %s want %s", gotErr, tc.wantErr)
+ t.Errorf("QueryProfile %s error, got %s want %s", mode, gotErr, tc.wantErr)
continue
}
if !reflect.DeepEqual(tc.wantRes, gotRes) {
- t.Errorf("QueryProfile response got %+v want %+v", gotRes, tc.wantRes)
+ t.Errorf("QueryProfile %s response got %+v want %+v", mode, gotRes, tc.wantRes)
}
}
}
@@ -140,10 +144,10 @@ func TestQueryProfile(t *testing.T) {
if err != nil {
t.Fatalf("failed to create HTTP client")
}
- runCases(httpAPI)
+ runCases(httpAPI, true)
})
t.Run("Monolith", func(t *testing.T) {
- runCases(userAPI)
+ runCases(userAPI, false)
})
}
diff --git a/userapi/util/phonehomestats.go b/userapi/util/phonehomestats.go
index ad93a50e3..b17f62060 100644
--- a/userapi/util/phonehomestats.go
+++ b/userapi/util/phonehomestats.go
@@ -139,7 +139,7 @@ func (p *phoneHomeStats) collect() {
output := bytes.Buffer{}
if err = json.NewEncoder(&output).Encode(p.stats); err != nil {
- logrus.WithError(err).Error("unable to encode anonymous stats")
+ logrus.WithError(err).Error("Unable to encode phone-home statistics")
return
}
@@ -147,14 +147,14 @@ func (p *phoneHomeStats) collect() {
request, err := http.NewRequestWithContext(ctx, http.MethodPost, p.cfg.Global.ReportStats.Endpoint, &output)
if err != nil {
- logrus.WithError(err).Error("unable to create anonymous stats request")
+ logrus.WithError(err).Error("Unable to create phone-home statistics request")
return
}
request.Header.Set("User-Agent", "Dendrite/"+internal.VersionString())
_, err = p.client.Do(request)
if err != nil {
- logrus.WithError(err).Error("unable to send anonymous stats")
+ logrus.WithError(err).Error("Unable to send phone-home statistics")
return
}
}