diff --git a/CHANGES.md b/CHANGES.md index d05b871ac..effa5fdd1 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,49 @@ # Changelog +## Dendrite 0.3.1 (2020-11-20) + +### Features + +* Memory optimisation by reference passing, significantly reducing the number of allocations and duplication in memory +* A hook API has been added for experimental MSCs, with an early implementation of MSC2836 +* The last seen timestamp and IP address are now updated automatically when calling `/sync` +* The last seen timestamp and IP address are now reported in `/_matrix/client/r0/devices` (contributed by [alexkursell](https://github.com/alexkursell)) +* An optional configuration option `sync_api.real_ip_header` has been added for specifying which HTTP header contains the real client IP address (for if Dendrite is running behind a reverse HTTP proxy) +* Partial implementation of `/_matrix/client/r0/admin/whois` (contributed by [DavidSpenler](https://github.com/DavidSpenler)) + +### Fixes + +* A concurrency bug has been fixed in the federation API that could cause Dendrite to crash +* The error when registering a username with invalid characters has been corrected (contributed by [bodqhrohro](https://github.com/bodqhrohro)) + +## Dendrite 0.3.0 (2020-11-16) + +### Features + +* Read receipts (both inbound and outbound) are now supported (contributed by [S7evinK](https://github.com/S7evinK)) +* Forgetting rooms is now supported (contributed by [S7evinK](https://github.com/S7evinK)) +* The `-version` command line flag has been added (contributed by [S7evinK](https://github.com/S7evinK)) + +### Fixes + +* User accounts that contain the `=` character can now be registered +* Backfilling should now work properly on rooms with world-readable history visibility (contributed by [MayeulC](https://github.com/MayeulC)) +* The `gjson` dependency has been updated for correct JSON integer ranges +* Some more client event fields have been marked as omit-when-empty (contributed by [S7evinK](https://github.com/S7evinK)) +* The `build.sh` script has been updated to work properly on all POSIX platforms (contributed by [felix](https://github.com/felix)) + +## Dendrite 0.2.1 (2020-10-22) + +### Fixes + +* Forward extremities are now calculated using only references from other extremities, rather than including outliers, which should fix cases where state can become corrupted ([#1556](https://github.com/matrix-org/dendrite/pull/1556)) +* Old state events will no longer be processed by the sync API as new, which should fix some cases where clients incorrectly believe they have joined or left rooms ([#1548](https://github.com/matrix-org/dendrite/pull/1548)) +* More SQLite database locking issues have been resolved in the latest events updater ([#1554](https://github.com/matrix-org/dendrite/pull/1554)) +* Internal HTTP API calls are now made using H2C (HTTP/2) in polylith mode, mitigating some potential head-of-line blocking issues ([#1541](https://github.com/matrix-org/dendrite/pull/1541)) +* Roomserver output events no longer incorrectly flag state rewrites ([#1557](https://github.com/matrix-org/dendrite/pull/1557)) +* Notification levels are now parsed correctly in power level events ([gomatrixserverlib#228](https://github.com/matrix-org/gomatrixserverlib/pull/228), contributed by [Pestdoktor](https://github.com/Pestdoktor)) +* Invalid UTF-8 is now correctly rejected when making federation requests ([gomatrixserverlib#229](https://github.com/matrix-org/gomatrixserverlib/pull/229), contributed by [Pestdoktor](https://github.com/Pestdoktor)) + ## Dendrite 0.2.0 (2020-10-20) ### Important diff --git a/README.md b/README.md index ea61dac13..40fd69ead 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ As of October 2020, Dendrite has now entered **beta** which means: This does not mean: - Dendrite is bug-free. It has not yet been battle-tested in the real world and so will be error prone initially. - All of the CS/Federation APIs are implemented. We are tracking progress via a script called 'Are We Synapse Yet?'. In particular, - read receipts, presence and push notifications are entirely missing from Dendrite. See [CHANGES.md](CHANGES.md) for updates. + presence and push notifications are entirely missing from Dendrite. See [CHANGES.md](CHANGES.md) for updates. - Dendrite is ready for massive homeserver deployments. You cannot shard each microservice, only run each one on a different machine. Currently, we expect Dendrite to function well for small (10s/100s of users) homeserver deployments as well as P2P Matrix nodes in-browser or on mobile devices. @@ -54,31 +54,32 @@ The following instructions are enough to get Dendrite started as a non-federatin ```bash $ git clone https://github.com/matrix-org/dendrite $ cd dendrite +$ ./build.sh -# generate self-signed certificate and an event signing key for federation -$ go build ./cmd/generate-keys -$ ./generate-keys --private-key matrix_key.pem --tls-cert server.crt --tls-key server.key +# Generate a Matrix signing key for federation (required) +$ ./bin/generate-keys --private-key matrix_key.pem -# Copy and modify the config file: -# you'll need to set a server name and paths to the keys at the very least, along with setting -# up the database filenames +# Generate a self-signed certificate (optional, but a valid TLS certificate is normally +# needed for Matrix federation/clients to work properly!) +$ ./bin/generate-keys --tls-cert server.crt --tls-key server.key + +# Copy and modify the config file - you'll need to set a server name and paths to the keys +# at the very least, along with setting up the database connection strings. $ cp dendrite-config.yaml dendrite.yaml -# build and run the server -$ go build ./cmd/dendrite-monolith-server -$ ./dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml +# Build and run the server: +$ ./bin/dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml ``` -Then point your favourite Matrix client at `http://localhost:8008`. +Then point your favourite Matrix client at `http://localhost:8008` or `https://localhost:8448`. ## Progress We use a script called Are We Synapse Yet which checks Sytest compliance rates. Sytest is a black-box homeserver test rig with around 900 tests. The script works out how many of these tests are passing on Dendrite and it -updates with CI. As of October 2020 we're at around 57% CS API coverage and 81% Federation coverage, though check +updates with CI. As of November 2020 we're at around 58% CS API coverage and 83% Federation coverage, though check CI for the latest numbers. In practice, this means you can communicate locally and via federation with Synapse servers such as matrix.org reasonably well. There's a long list of features that are not implemented, notably: - - Receipts - Push - Search and Context - User Directory @@ -98,6 +99,7 @@ This means Dendrite supports amongst others: - Redaction - Tagging - E2E keys and device lists + - Receipts ## Contributing diff --git a/appservice/consumers/roomserver.go b/appservice/consumers/roomserver.go index 560cd2373..4f6b4b4d4 100644 --- a/appservice/consumers/roomserver.go +++ b/appservice/consumers/roomserver.go @@ -88,7 +88,7 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error { return nil } - events := []gomatrixserverlib.HeaderedEvent{output.NewRoomEvent.Event} + events := []*gomatrixserverlib.HeaderedEvent{output.NewRoomEvent.Event} events = append(events, output.NewRoomEvent.AddStateEvents...) // Send event to any relevant application services @@ -102,14 +102,14 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error { // application service. func (s *OutputRoomEventConsumer) filterRoomserverEvents( ctx context.Context, - events []gomatrixserverlib.HeaderedEvent, + events []*gomatrixserverlib.HeaderedEvent, ) error { for _, ws := range s.workerStates { for _, event := range events { // Check if this event is interesting to this application service if s.appserviceIsInterestedInEvent(ctx, event, ws.AppService) { // Queue this event to be sent off to the application service - if err := s.asDB.StoreEvent(ctx, ws.AppService.ID, &event); err != nil { + if err := s.asDB.StoreEvent(ctx, ws.AppService.ID, event); err != nil { log.WithError(err).Warn("failed to insert incoming event into appservices database") } else { // Tell our worker to send out new messages by updating remaining message @@ -125,7 +125,7 @@ func (s *OutputRoomEventConsumer) filterRoomserverEvents( // appserviceIsInterestedInEvent returns a boolean depending on whether a given // event falls within one of a given application service's namespaces. -func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event gomatrixserverlib.HeaderedEvent, appservice config.ApplicationService) bool { +func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, appservice config.ApplicationService) bool { // No reason to queue events if they'll never be sent to the application // service if appservice.URL == "" { diff --git a/appservice/workers/transaction_scheduler.go b/appservice/workers/transaction_scheduler.go index b1735841d..cc9907481 100644 --- a/appservice/workers/transaction_scheduler.go +++ b/appservice/workers/transaction_scheduler.go @@ -185,7 +185,7 @@ func createTransaction( } } - var ev []gomatrixserverlib.Event + var ev []*gomatrixserverlib.Event for _, e := range events { ev = append(ev, e.Event) } diff --git a/build-dendritejs.sh b/build-dendritejs.sh index cd42a6bee..83ec3699c 100755 --- a/build-dendritejs.sh +++ b/build-dendritejs.sh @@ -1,4 +1,4 @@ -#!/bin/bash -eu +#!/bin/sh -eu export GIT_COMMIT=$(git rev-list -1 HEAD) && \ -GOOS=js GOARCH=wasm go build -ldflags "-X main.GitCommit=$GIT_COMMIT" -o main.wasm ./cmd/dendritejs \ No newline at end of file +GOOS=js GOARCH=wasm go build -ldflags "-X main.GitCommit=$GIT_COMMIT" -o main.wasm ./cmd/dendritejs diff --git a/build.sh b/build.sh index 31e0519f5..494d97eda 100755 --- a/build.sh +++ b/build.sh @@ -1,4 +1,4 @@ -#!/bin/bash -eu +#!/bin/sh -eu # Put installed packages into ./bin export GOBIN=$PWD/`dirname $0`/bin @@ -7,7 +7,7 @@ if [ -d ".git" ] then export BUILD=`git rev-parse --short HEAD || ""` export BRANCH=`(git symbolic-ref --short HEAD | tr -d \/ ) || ""` - if [[ $BRANCH == "master" ]] + if [ "$BRANCH" = master ] then export BRANCH="" fi diff --git a/build/docker/README.md b/build/docker/README.md index 7bf72e156..0e46e637a 100644 --- a/build/docker/README.md +++ b/build/docker/README.md @@ -2,6 +2,11 @@ These are Docker images for Dendrite! +They can be found on Docker Hub: + +- [matrixdotorg/dendrite-monolith](https://hub.docker.com/repository/docker/matrixdotorg/dendrite-monolith) for monolith deployments +- [matrixdotorg/dendrite-polylith](https://hub.docker.com/repository/docker/matrixdotorg/dendrite-polylith) for polylith deployments + ## Dockerfiles The `Dockerfile` builds the base image which contains all of the Dendrite diff --git a/build/docker/docker-compose.deps.yml b/build/docker/docker-compose.deps.yml index 1a27ffac0..454fddc29 100644 --- a/build/docker/docker-compose.deps.yml +++ b/build/docker/docker-compose.deps.yml @@ -1,5 +1,6 @@ version: "3.4" services: + # PostgreSQL is needed for both polylith and monolith modes. postgres: hostname: postgres image: postgres:9.6 @@ -15,12 +16,14 @@ services: networks: - internal + # Zookeeper is only needed for polylith mode! zookeeper: hostname: zookeeper image: zookeeper networks: - internal + # Kafka is only needed for polylith mode! kafka: container_name: dendrite_kafka hostname: kafka @@ -29,8 +32,6 @@ services: KAFKA_ADVERTISED_HOST_NAME: "kafka" KAFKA_DELETE_TOPIC_ENABLE: "true" KAFKA_ZOOKEEPER_CONNECT: "zookeeper:2181" - ports: - - 9092:9092 depends_on: - zookeeper networks: diff --git a/build/docker/docker-compose.monolith.yml b/build/docker/docker-compose.monolith.yml index 8fb798343..024183aa6 100644 --- a/build/docker/docker-compose.monolith.yml +++ b/build/docker/docker-compose.monolith.yml @@ -7,6 +7,9 @@ services: "--tls-cert=server.crt", "--tls-key=server.key" ] + ports: + - 8008:8008 + - 8448:8448 volumes: - ./config:/etc/dendrite networks: diff --git a/build/gobind/monolith.go b/build/gobind/monolith.go index fd010809c..fdb148775 100644 --- a/build/gobind/monolith.go +++ b/build/gobind/monolith.go @@ -112,22 +112,24 @@ func (m *DendriteMonolith) Start() { serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() - keyAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation) - userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI) - keyAPI.SetUserAPI(userAPI) rsAPI := roomserver.NewInternalAPI( base, keyRing, ) + fsAPI := federationsender.NewInternalAPI( + base, federation, rsAPI, keyRing, + ) + + keyAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation) + userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI) + keyAPI.SetUserAPI(userAPI) + eduInputAPI := eduserver.NewInternalAPI( base, cache.New(), userAPI, ) asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) - fsAPI := federationsender.NewInternalAPI( - base, federation, rsAPI, keyRing, - ) ygg.SetSessionFunc(func(address string) { req := &api.PerformServersAliveRequest{ diff --git a/clientapi/routing/account_data.go b/clientapi/routing/account_data.go index 48303c97f..22e635139 100644 --- a/clientapi/routing/account_data.go +++ b/clientapi/routing/account_data.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" + eduserverAPI "github.com/matrix-org/dendrite/eduserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/userapi/api" @@ -148,7 +149,8 @@ type fullyReadEvent struct { // SaveReadMarker implements POST /rooms/{roomId}/read_markers func SaveReadMarker( - req *http.Request, userAPI api.UserInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI, + req *http.Request, + userAPI api.UserInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI, eduAPI eduserverAPI.EDUServerInputAPI, syncProducer *producers.SyncAPIProducer, device *api.Device, roomID string, ) util.JSONResponse { // Verify that the user is a member of this room @@ -192,8 +194,10 @@ func SaveReadMarker( return jsonerror.InternalServerError() } - // TODO handle the read receipt that may be included in the read marker - // See https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-rooms-roomid-read-markers + // Handle the read receipt that may be included in the read marker + if r.Read != "" { + return SetReceipt(req, eduAPI, device, roomID, "m.read", r.Read) + } return util.JSONResponse{ Code: http.StatusOK, diff --git a/clientapi/routing/admin_whois.go b/clientapi/routing/admin_whois.go new file mode 100644 index 000000000..b448791c3 --- /dev/null +++ b/clientapi/routing/admin_whois.go @@ -0,0 +1,88 @@ +// Copyright 2020 David Spenler +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "net/http" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/userapi/api" + + "github.com/matrix-org/util" +) + +type adminWhoisResponse struct { + UserID string `json:"user_id"` + Devices map[string]deviceInfo `json:"devices"` +} + +type deviceInfo struct { + Sessions []sessionInfo `json:"sessions"` +} + +type sessionInfo struct { + Connections []connectionInfo `json:"connections"` +} + +type connectionInfo struct { + IP string `json:"ip"` + LastSeen int64 `json:"last_seen"` + UserAgent string `json:"user_agent"` +} + +// GetAdminWhois implements GET /admin/whois/{userId} +func GetAdminWhois( + req *http.Request, userAPI api.UserInternalAPI, device *api.Device, + userID string, +) util.JSONResponse { + if userID != device.UserID { + // TODO: Still allow if user is admin + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("userID does not match the current user"), + } + } + + var queryRes api.QueryDevicesResponse + err := userAPI.QueryDevices(req.Context(), &api.QueryDevicesRequest{ + UserID: userID, + }, &queryRes) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("GetAdminWhois failed to query user devices") + return jsonerror.InternalServerError() + } + + devices := make(map[string]deviceInfo) + for _, device := range queryRes.Devices { + connInfo := connectionInfo{ + IP: device.LastSeenIP, + LastSeen: device.LastSeenTS, + UserAgent: device.UserAgent, + } + dev, ok := devices[device.ID] + if !ok { + dev.Sessions = []sessionInfo{{}} + } + dev.Sessions[0].Connections = append(dev.Sessions[0].Connections, connInfo) + devices[device.ID] = dev + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: adminWhoisResponse{ + UserID: userID, + Devices: devices, + }, + } +} diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go index cff3c9813..8f29fbe74 100644 --- a/clientapi/routing/createroom.go +++ b/clientapi/routing/createroom.go @@ -255,7 +255,7 @@ func createRoom( historyVisibility = historyVisibilityShared } - var builtEvents []gomatrixserverlib.HeaderedEvent + var builtEvents []*gomatrixserverlib.HeaderedEvent // send events into the room in order of: // 1- m.room.create @@ -327,13 +327,13 @@ func createRoom( return jsonerror.InternalServerError() } - if err = gomatrixserverlib.Allowed(*ev, &authEvents); err != nil { + if err = gomatrixserverlib.Allowed(ev, &authEvents); err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.Allowed failed") return jsonerror.InternalServerError() } // Add the event to the list of auth events - builtEvents = append(builtEvents, (*ev).Headered(roomVersion)) + builtEvents = append(builtEvents, ev.Headered(roomVersion)) err = authEvents.AddEvent(ev) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("authEvents.AddEvent failed") @@ -397,7 +397,7 @@ func createRoom( ev := event.Event globalStrippedState = append( globalStrippedState, - gomatrixserverlib.NewInviteV2StrippedState(&ev), + gomatrixserverlib.NewInviteV2StrippedState(ev), ) } } @@ -415,7 +415,7 @@ func createRoom( } inviteStrippedState := append( globalStrippedState, - gomatrixserverlib.NewInviteV2StrippedState(&inviteEvent.Event), + gomatrixserverlib.NewInviteV2StrippedState(inviteEvent.Event), ) // Send the invite event to the roomserver. err = roomserverAPI.SendInvite( @@ -488,5 +488,5 @@ func buildEvent( if err != nil { return nil, fmt.Errorf("cannot build event %s : Builder failed to build. %w", builder.Type, err) } - return &event, nil + return event, nil } diff --git a/clientapi/routing/device.go b/clientapi/routing/device.go index d50c73b35..6adaa7694 100644 --- a/clientapi/routing/device.go +++ b/clientapi/routing/device.go @@ -16,6 +16,7 @@ package routing import ( "io/ioutil" + "net" "net/http" "github.com/matrix-org/dendrite/clientapi/auth" @@ -32,7 +33,7 @@ type deviceJSON struct { DeviceID string `json:"device_id"` DisplayName string `json:"display_name"` LastSeenIP string `json:"last_seen_ip"` - LastSeenTS uint64 `json:"last_seen_ts"` + LastSeenTS int64 `json:"last_seen_ts"` } type devicesJSON struct { @@ -79,6 +80,8 @@ func GetDeviceByID( JSON: deviceJSON{ DeviceID: targetDevice.ID, DisplayName: targetDevice.DisplayName, + LastSeenIP: stripIPPort(targetDevice.LastSeenIP), + LastSeenTS: targetDevice.LastSeenTS, }, } } @@ -102,6 +105,8 @@ func GetDevicesByLocalpart( res.Devices = append(res.Devices, deviceJSON{ DeviceID: dev.ID, DisplayName: dev.DisplayName, + LastSeenIP: stripIPPort(dev.LastSeenIP), + LastSeenTS: dev.LastSeenTS, }) } @@ -230,3 +235,20 @@ func DeleteDevices( JSON: struct{}{}, } } + +// stripIPPort converts strings like "[::1]:12345" to "::1" +func stripIPPort(addr string) string { + ip := net.ParseIP(addr) + if ip != nil { + return addr + } + host, _, err := net.SplitHostPort(addr) + if err != nil { + return "" + } + ip = net.ParseIP(host) + if ip != nil { + return host + } + return "" +} diff --git a/clientapi/routing/getevent.go b/clientapi/routing/getevent.go index 18b96e1ef..cccab1ee7 100644 --- a/clientapi/routing/getevent.go +++ b/clientapi/routing/getevent.go @@ -32,7 +32,7 @@ type getEventRequest struct { eventID string cfg *config.ClientAPI federation *gomatrixserverlib.FederationClient - requestedEvent gomatrixserverlib.Event + requestedEvent *gomatrixserverlib.Event } // GetEvent implements GET /_matrix/client/r0/rooms/{roomId}/event/{eventId} diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index fe0795577..5ca202b66 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -77,7 +77,7 @@ func sendMembership(ctx context.Context, accountDB accounts.Database, device *us if err = roomserverAPI.SendEvents( ctx, rsAPI, api.KindNew, - []gomatrixserverlib.HeaderedEvent{event.Event.Headered(roomVer)}, + []*gomatrixserverlib.HeaderedEvent{event.Event.Headered(roomVer)}, cfg.Matrix.ServerName, nil, ); err != nil { @@ -214,7 +214,7 @@ func SendInvite( err = roomserverAPI.SendInvite( req.Context(), rsAPI, - *event, + event, nil, // ask the roomserver to draw up invite room state for us cfg.Matrix.ServerName, nil, @@ -407,3 +407,47 @@ func checkMemberInRoom(ctx context.Context, rsAPI api.RoomserverInternalAPI, use } return nil } + +func SendForget( + req *http.Request, device *userapi.Device, + roomID string, rsAPI roomserverAPI.RoomserverInternalAPI, +) util.JSONResponse { + ctx := req.Context() + logger := util.GetLogger(ctx).WithField("roomID", roomID).WithField("userID", device.UserID) + var membershipRes api.QueryMembershipForUserResponse + membershipReq := api.QueryMembershipForUserRequest{ + RoomID: roomID, + UserID: device.UserID, + } + err := rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes) + if err != nil { + logger.WithError(err).Error("QueryMembershipForUser: could not query membership for user") + return jsonerror.InternalServerError() + } + if membershipRes.IsInRoom { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.Forbidden("user is still a member of the room"), + } + } + if !membershipRes.HasBeenInRoom { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.Forbidden("user did not belong to room"), + } + } + + request := api.PerformForgetRequest{ + RoomID: roomID, + UserID: device.UserID, + } + response := api.PerformForgetResponse{} + if err := rsAPI.PerformForget(ctx, &request, &response); err != nil { + logger.WithError(err).Error("PerformForget: unable to forget room") + return jsonerror.InternalServerError() + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } +} diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index bbe35facd..971f23689 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -346,14 +346,14 @@ func buildMembershipEvents( roomIDs []string, newProfile authtypes.Profile, userID string, cfg *config.ClientAPI, evTime time.Time, rsAPI api.RoomserverInternalAPI, -) ([]gomatrixserverlib.HeaderedEvent, error) { - evs := []gomatrixserverlib.HeaderedEvent{} +) ([]*gomatrixserverlib.HeaderedEvent, error) { + evs := []*gomatrixserverlib.HeaderedEvent{} for _, roomID := range roomIDs { verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} verRes := api.QueryRoomVersionForRoomResponse{} if err := rsAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil { - return []gomatrixserverlib.HeaderedEvent{}, err + return nil, err } builder := gomatrixserverlib.EventBuilder{ @@ -379,7 +379,7 @@ func buildMembershipEvents( return nil, err } - evs = append(evs, (*event).Headered(verRes.RoomVersion)) + evs = append(evs, event.Headered(verRes.RoomVersion)) } return evs, nil diff --git a/clientapi/routing/receipt.go b/clientapi/routing/receipt.go new file mode 100644 index 000000000..fe8fe765d --- /dev/null +++ b/clientapi/routing/receipt.go @@ -0,0 +1,54 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "fmt" + "net/http" + "time" + + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/eduserver/api" + + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" +) + +func SetReceipt(req *http.Request, eduAPI api.EDUServerInputAPI, device *userapi.Device, roomId, receiptType, eventId string) util.JSONResponse { + timestamp := gomatrixserverlib.AsTimestamp(time.Now()) + logrus.WithFields(logrus.Fields{ + "roomId": roomId, + "receiptType": receiptType, + "eventId": eventId, + "userId": device.UserID, + "timestamp": timestamp, + }).Debug("Setting receipt") + + // currently only m.read is accepted + if receiptType != "m.read" { + return util.MessageResponse(400, fmt.Sprintf("receipt type must be m.read not '%s'", receiptType)) + } + + if err := api.SendReceipt(req.Context(), eduAPI, device.UserID, roomId, eventId, receiptType, timestamp); err != nil { + return util.ErrorResponse(err) + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } +} diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index 266c0aff2..7696bec0f 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -121,7 +121,7 @@ func SendRedaction( JSON: jsonerror.NotFound("Room does not exist"), } } - if err = roomserverAPI.SendEvents(context.Background(), rsAPI, api.KindNew, []gomatrixserverlib.HeaderedEvent{*e}, cfg.Matrix.ServerName, nil); err != nil { + if err = roomserverAPI.SendEvents(context.Background(), rsAPI, api.KindNew, []*gomatrixserverlib.HeaderedEvent{e}, cfg.Matrix.ServerName, nil); err != nil { util.GetLogger(req.Context()).WithError(err).Errorf("failed to SendEvents") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 756eafe2f..528537ef4 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -113,7 +113,7 @@ var ( // TODO: Remove old sessions. Need to do so on a session-specific timeout. // sessions stores the completed flow stages for all sessions. Referenced using their sessionID. sessions = newSessionsDict() - validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-./]+$`) + validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`) ) // registerRequest represents the submitted registration request. @@ -209,7 +209,7 @@ func validateUsername(username string) *util.JSONResponse { } else if !validUsernameRegex.MatchString(username) { return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./'"), + JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"), } } else if username[0] == '_' { // Regex checks its not a zero length string return &util.JSONResponse{ @@ -230,7 +230,7 @@ func validateApplicationServiceUsername(username string) *util.JSONResponse { } else if !validUsernameRegex.MatchString(username) { return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./'"), + JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"), } } return nil diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 4f99237f5..65b622b3a 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -651,6 +651,16 @@ func Setup( }), ).Methods(http.MethodGet) + r0mux.Handle("/admin/whois/{userID}", + httputil.MakeAuthAPI("admin_whois", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return GetAdminWhois(req, userAPI, device, vars["userID"]) + }), + ).Methods(http.MethodGet) + r0mux.Handle("/user_directory/search", httputil.MakeAuthAPI("userdirectory_search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.rateLimit(req); r != nil { @@ -705,7 +715,20 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return SaveReadMarker(req, userAPI, rsAPI, syncProducer, device, vars["roomID"]) + return SaveReadMarker(req, userAPI, rsAPI, eduAPI, syncProducer, device, vars["roomID"]) + }), + ).Methods(http.MethodPost, http.MethodOptions) + + r0mux.Handle("/rooms/{roomID}/forget", + httputil.MakeAuthAPI("rooms_forget", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return SendForget(req, device, vars["roomID"], rsAPI) }), ).Methods(http.MethodPost, http.MethodOptions) @@ -830,4 +853,17 @@ func Setup( return ClaimKeys(req, keyAPI) }), ).Methods(http.MethodPost, http.MethodOptions) + r0mux.Handle("/rooms/{roomId}/receipt/{receiptType}/{eventId}", + httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + + return SetReceipt(req, eduAPI, device, vars["roomId"], vars["receiptType"], vars["eventId"]) + }), + ).Methods(http.MethodPost, http.MethodOptions) } diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 1303663ff..a4240cf34 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -93,7 +93,7 @@ func SendEvent( if err := api.SendEvents( req.Context(), rsAPI, api.KindNew, - []gomatrixserverlib.HeaderedEvent{ + []*gomatrixserverlib.HeaderedEvent{ e.Headered(verRes.RoomVersion), }, cfg.Matrix.ServerName, @@ -189,7 +189,7 @@ func generateSendEvent( // check to see if this user can perform this operation stateEvents := make([]*gomatrixserverlib.Event, len(queryRes.StateEvents)) for i := range queryRes.StateEvents { - stateEvents[i] = &queryRes.StateEvents[i].Event + stateEvents[i] = queryRes.StateEvents[i].Event } provider := gomatrixserverlib.NewAuthEvents(stateEvents) if err = gomatrixserverlib.Allowed(e.Event, &provider); err != nil { @@ -198,5 +198,5 @@ func generateSendEvent( JSON: jsonerror.Forbidden(err.Error()), // TODO: Is this error string comprehensible to the client? } } - return &e.Event, nil + return e.Event, nil } diff --git a/clientapi/routing/state.go b/clientapi/routing/state.go index f69b54bbc..57014bc3b 100644 --- a/clientapi/routing/state.go +++ b/clientapi/routing/state.go @@ -267,7 +267,7 @@ func OnIncomingStateTypeRequest( // to find the state event, if provided. for _, ev := range stateRes.StateEvents { if ev.Type() == evType && ev.StateKeyEquals(stateKey) { - event = &ev + event = ev break } } @@ -290,7 +290,7 @@ func OnIncomingStateTypeRequest( return jsonerror.InternalServerError() } if len(stateAfterRes.StateEvents) > 0 { - event = &stateAfterRes.StateEvents[0] + event = stateAfterRes.StateEvents[0] } } @@ -304,7 +304,7 @@ func OnIncomingStateTypeRequest( } stateEvent := stateEventInStateResp{ - ClientEvent: gomatrixserverlib.HeaderedToClientEvent(*event, gomatrixserverlib.FormatAll), + ClientEvent: gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatAll), } var res interface{} diff --git a/clientapi/threepid/invites.go b/clientapi/threepid/invites.go index 272d3407d..f49796fde 100644 --- a/clientapi/threepid/invites.go +++ b/clientapi/threepid/invites.go @@ -362,8 +362,8 @@ func emit3PIDInviteEvent( return api.SendEvents( ctx, rsAPI, api.KindNew, - []gomatrixserverlib.HeaderedEvent{ - (*event).Headered(queryRes.RoomVersion), + []*gomatrixserverlib.HeaderedEvent{ + event.Headered(queryRes.RoomVersion), }, cfg.Matrix.ServerName, nil, diff --git a/cmd/create-room-events/main.go b/cmd/create-room-events/main.go index afe974643..23b44193a 100644 --- a/cmd/create-room-events/main.go +++ b/cmd/create-room-events/main.go @@ -123,7 +123,7 @@ func buildAndOutput() gomatrixserverlib.EventReference { } // Write an event to the output. -func writeEvent(event gomatrixserverlib.Event) { +func writeEvent(event *gomatrixserverlib.Event) { encoder := json.NewEncoder(os.Stdout) if *format == "InputRoomEvent" { var ire api.InputRoomEvent diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go index e935805f6..70b81bbc8 100644 --- a/cmd/dendrite-monolith-server/main.go +++ b/cmd/dendrite-monolith-server/main.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/eduserver/cache" "github.com/matrix-org/dendrite/federationsender" "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/mscs" "github.com/matrix-org/dendrite/internal/setup" "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/roomserver" @@ -148,6 +149,12 @@ func main() { base.PublicMediaAPIMux, ) + if len(base.Cfg.MSCs.MSCs) > 0 { + if err := mscs.Enable(base, &monolith); err != nil { + logrus.WithError(err).Fatalf("Failed to enable MSCs") + } + } + // Expose the matrix APIs directly rather than putting them under a /api path. go func() { base.SetupAndServeHTTP( diff --git a/cmd/dendrite-polylith-multi/personalities/keyserver.go b/cmd/dendrite-polylith-multi/personalities/keyserver.go index 8c159ad06..6a7907602 100644 --- a/cmd/dendrite-polylith-multi/personalities/keyserver.go +++ b/cmd/dendrite-polylith-multi/personalities/keyserver.go @@ -21,7 +21,8 @@ import ( ) func KeyServer(base *setup.BaseDendrite, cfg *config.Dendrite) { - intAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, base.CreateFederationClient()) + fsAPI := base.FederationSenderHTTPClient() + intAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, fsAPI) intAPI.SetUserAPI(base.UserAPIClient()) keyserver.AddInternalRoutes(base.InternalAPIMux, intAPI) diff --git a/cmd/resolve-state/main.go b/cmd/resolve-state/main.go index 9fb14f056..a622fbf28 100644 --- a/cmd/resolve-state/main.go +++ b/cmd/resolve-state/main.go @@ -80,9 +80,9 @@ func main() { } authEventIDMap := make(map[string]struct{}) - eventPtrs := make([]*gomatrixserverlib.Event, len(eventEntries)) + events := make([]*gomatrixserverlib.Event, len(eventEntries)) for i := range eventEntries { - eventPtrs[i] = &eventEntries[i].Event + events[i] = eventEntries[i].Event for _, authEventID := range eventEntries[i].AuthEventIDs() { authEventIDMap[authEventID] = struct{}{} } @@ -99,18 +99,9 @@ func main() { panic(err) } - authEventPtrs := make([]*gomatrixserverlib.Event, len(authEventEntries)) + authEvents := make([]*gomatrixserverlib.Event, len(authEventEntries)) for i := range authEventEntries { - authEventPtrs[i] = &authEventEntries[i].Event - } - - events := make([]gomatrixserverlib.Event, len(eventEntries)) - authEvents := make([]gomatrixserverlib.Event, len(authEventEntries)) - for i, ptr := range eventPtrs { - events[i] = *ptr - } - for i, ptr := range authEventPtrs { - authEvents[i] = *ptr + authEvents[i] = authEventEntries[i].Event } fmt.Println("Resolving state") diff --git a/cmd/syncserver-integration-tests/main.go b/cmd/syncserver-integration-tests/main.go index a11dd2a01..cbbcaa7e0 100644 --- a/cmd/syncserver-integration-tests/main.go +++ b/cmd/syncserver-integration-tests/main.go @@ -103,7 +103,7 @@ func clientEventJSONForOutputRoomEvent(outputRoomEvent string) string { if err := json.Unmarshal([]byte(outputRoomEvent), &out); err != nil { panic("failed to unmarshal output room event: " + err.Error()) } - clientEvs := gomatrixserverlib.ToClientEvents([]gomatrixserverlib.Event{ + clientEvs := gomatrixserverlib.ToClientEvents([]*gomatrixserverlib.Event{ out.NewRoomEvent.Event.Event, }, gomatrixserverlib.FormatSync) b, err := json.Marshal(clientEvs[0]) diff --git a/dendrite-config.yaml b/dendrite-config.yaml index 6e87bc709..2a8650db6 100644 --- a/dendrite-config.yaml +++ b/dendrite-config.yaml @@ -36,6 +36,8 @@ global: server_name: localhost # The path to the signing private key file, used to sign requests and events. + # Note that this is NOT the same private key as used for TLS! To generate a + # signing key, use "./bin/generate-keys --private-key matrix_key.pem". private_key: matrix_key.pem # The paths and expiry timestamps (as a UNIX timestamp in millisecond precision) @@ -74,6 +76,12 @@ global: # Kafka. use_naffka: true + # The max size a Kafka message is allowed to use. + # You only need to change this value, if you encounter issues with too large messages. + # Must be less than/equal to "max.message.bytes" configured in Kafka. + # Defaults to 8388608 bytes. + # max_message_bytes: 8388608 + # Naffka database options. Not required when using Kafka. naffka_database: connection_string: file:naffka.db @@ -292,6 +300,11 @@ sync_api: max_idle_conns: 2 conn_max_lifetime: -1 + # This option controls which HTTP header to inspect to find the real remote IP + # address of the client. This is likely required if Dendrite is running behind + # a reverse proxy server. + # real_ip_header: X-Real-IP + # Configuration for the User API. user_api: internal_api: diff --git a/docs/INSTALL.md b/docs/INSTALL.md index f804193cd..0b3f932be 100644 --- a/docs/INSTALL.md +++ b/docs/INSTALL.md @@ -12,6 +12,8 @@ Dendrite can be run in one of two configurations: lightweight implementation called [Naffka](https://github.com/matrix-org/naffka). This will usually be the preferred model for low-volume, low-user or experimental deployments. +For most deployments, it is **recommended to run in monolith mode with PostgreSQL databases**. + Regardless of whether you are running in polylith or monolith mode, each Dendrite component that requires storage has its own database. Both Postgres and SQLite are supported and can be mixed-and-matched across components as needed in the configuration file. @@ -30,23 +32,9 @@ If you want to run a polylith deployment, you also need: * Apache Kafka 0.10.2+ -## Building up a monolith deploment +Please note that Kafka is **not required** for a monolith deployment. -Start by cloning the code: - -```bash -git clone https://github.com/matrix-org/dendrite -cd dendrite -``` - -Then build it: - -```bash -go build -o bin/dendrite-monolith-server ./cmd/dendrite-monolith-server -go build -o bin/generate-keys ./cmd/generate-keys -``` - -## Building up a polylith deployment +## Building Dendrite Start by cloning the code: @@ -61,6 +49,8 @@ Then build it: ./build.sh ``` +## Install Kafka (polylith only) + Install and start Kafka (c.f. [scripts/install-local-kafka.sh](scripts/install-local-kafka.sh)): ```bash @@ -96,9 +86,9 @@ Dendrite can use the built-in SQLite database engine for small setups. The SQLite databases do not need to be pre-built - Dendrite will create them automatically at startup. -### Postgres database setup +### PostgreSQL database setup -Assuming that Postgres 9.6 (or later) is installed: +Assuming that PostgreSQL 9.6 (or later) is installed: * Create role, choosing a new password when prompted: @@ -118,17 +108,31 @@ Assuming that Postgres 9.6 (or later) is installed: ### Server key generation -Each Dendrite server requires unique server keys. +Each Dendrite installation requires: -In order for an instance to federate correctly, you should have a valid -certificate issued by a trusted authority, and private key to match. If you -don't and just want to test locally, generate the self-signed SSL certificate -for federation and the server signing key: +- A unique Matrix signing private key +- A valid and trusted TLS certificate and private key + +To generate a Matrix signing private key: ```bash -./bin/generate-keys --private-key matrix_key.pem --tls-cert server.crt --tls-key server.key +./bin/generate-keys --private-key matrix_key.pem ``` +**Warning:** Make sure take a safe backup of this key! You will likely need it if you want to reinstall Dendrite, or +any other Matrix homeserver, on the same domain name in the future. If you lose this key, you may have trouble joining +federated rooms. + +For testing, you can generate a self-signed certificate and key, although this will not work for public federation: + +```bash +./bin/generate-keys --tls-cert server.crt --tls-key server.key +``` + +If you have server keys from an older Synapse instance, +[convert them](serverkeyformat.md#converting-synapse-keys) to Dendrite's PEM +format and configure them as `old_private_keys` in your config. + ### Configuration file Create config file, based on `dendrite-config.yaml`. Call it `dendrite.yaml`. Things that will need editing include *at least*: @@ -136,9 +140,9 @@ Create config file, based on `dendrite-config.yaml`. Call it `dendrite.yaml`. Th * The `server_name` entry to reflect the hostname of your Dendrite server * The `database` lines with an updated connection string based on your desired setup, e.g. replacing `database` with the name of the database: - * For Postgres: `postgres://dendrite:password@localhost/database` - * For SQLite on disk: `file:component.db` or `file:///path/to/component.db` - * Postgres and SQLite can be mixed and matched. + * For Postgres: `postgres://dendrite:password@localhost/database`, e.g. `postgres://dendrite:password@localhost/dendrite_userapi_account.db` + * For SQLite on disk: `file:component.db` or `file:///path/to/component.db`, e.g. `file:userapi_account.db` + * Postgres and SQLite can be mixed and matched on different components as desired. * The `use_naffka` option if using Naffka in a monolith deployment There are other options which may be useful so review them all. In particular, @@ -148,7 +152,7 @@ help to improve reliability considerably by allowing your homeserver to fetch public keys for dead homeservers from somewhere else. **WARNING:** Dendrite supports running all components from the same database in -Postgres mode, but this is **NOT** a supported configuration with SQLite. When +PostgreSQL mode, but this is **NOT** a supported configuration with SQLite. When using SQLite, all components **MUST** use their own database file. ## Starting a monolith server @@ -160,8 +164,14 @@ Be sure to update the database username and password if needed. The monolith server can be started as shown below. By default it listens for HTTP connections on port 8008, so you can configure your Matrix client to use -`http://localhost:8008` as the server. If you set `--tls-cert` and `--tls-key` -as shown below, it will also listen for HTTPS connections on port 8448. +`http://servername:8008` as the server: + +```bash +./bin/dendrite-monolith-server +``` + +If you set `--tls-cert` and `--tls-key` as shown below, it will also listen +for HTTPS connections on port 8448: ```bash ./bin/dendrite-monolith-server --tls-cert=server.crt --tls-key=server.key diff --git a/docs/hiawatha/polylith-sample.conf b/docs/hiawatha/polylith-sample.conf new file mode 100644 index 000000000..99730efdb --- /dev/null +++ b/docs/hiawatha/polylith-sample.conf @@ -0,0 +1,16 @@ +VirtualHost { + ... + # route requests to: + # /_matrix/client/.*/sync + # /_matrix/client/.*/user/{userId}/filter + # /_matrix/client/.*/user/{userId}/filter/{filterID} + # /_matrix/client/.*/keys/changes + # /_matrix/client/.*/rooms/{roomId}/messages + # to sync_api + ReverseProxy = /_matrix/client/.*?/(sync|user/.*?/filter/?.*|keys/changes|rooms/.*?/messages) http://localhost:8073 + ReverseProxy = /_matrix/client http://localhost:8071 + ReverseProxy = /_matrix/federation http://localhost:8072 + ReverseProxy = /_matrix/key http://localhost:8072 + ReverseProxy = /_matrix/media http://localhost:8074 + ... +} diff --git a/docs/nginx/monolith-sample.conf b/docs/nginx/monolith-sample.conf index 9ee5e1ac1..350e83489 100644 --- a/docs/nginx/monolith-sample.conf +++ b/docs/nginx/monolith-sample.conf @@ -1,5 +1,6 @@ server { - listen 443 ssl; + listen 443 ssl; # IPv4 + listen [::]:443; # IPv6 server_name my.hostname.com; ssl_certificate /path/to/fullchain.pem; diff --git a/docs/nginx/polylith-sample.conf b/docs/nginx/polylith-sample.conf index ab3461848..d0d3c98d5 100644 --- a/docs/nginx/polylith-sample.conf +++ b/docs/nginx/polylith-sample.conf @@ -1,5 +1,6 @@ server { - listen 443 ssl; + listen 443 ssl; # IPv4 + listen [::]:443; # IPv6 server_name my.hostname.com; ssl_certificate /path/to/fullchain.pem; diff --git a/docs/serverkeyformat.md b/docs/serverkeyformat.md new file mode 100644 index 000000000..feda93454 --- /dev/null +++ b/docs/serverkeyformat.md @@ -0,0 +1,29 @@ +# Server Key Format + +Dendrite stores the server signing key in the PEM format with the following structure. + +``` +-----BEGIN MATRIX PRIVATE KEY----- +Key-ID: ed25519: + + +-----END MATRIX PRIVATE KEY----- +``` + +## Converting Synapse Keys + +If you have signing keys from a previous synapse server, you should ideally configure them as `old_private_keys` in your Dendrite config file. Synapse stores signing keys in the following format. + +``` +ed25519 +``` + +To convert this key to Dendrite's PEM format, use the following template. **It is important to include the equals sign, as the key data needs to be padded to 32 bytes.** + +``` +-----BEGIN MATRIX PRIVATE KEY----- +Key-ID: ed25519: + += +-----END MATRIX PRIVATE KEY----- +``` \ No newline at end of file diff --git a/docs/sytest.md b/docs/sytest.md index 03954f135..0d42013ec 100644 --- a/docs/sytest.md +++ b/docs/sytest.md @@ -85,6 +85,7 @@ Set up the database: ```sh sudo -u postgres psql -c "CREATE USER dendrite PASSWORD 'itsasecret'" +sudo -u postgres psql -c "ALTER USER dendrite CREATEDB" for i in dendrite0 dendrite1 sytest_template; do sudo -u postgres psql -c "CREATE DATABASE $i OWNER dendrite;"; done mkdir -p "server-0" cat > "server-0/database.yaml" << EOF diff --git a/eduserver/api/input.go b/eduserver/api/input.go index 0d0d21f33..f8599e1cc 100644 --- a/eduserver/api/input.go +++ b/eduserver/api/input.go @@ -59,6 +59,22 @@ type InputSendToDeviceEventRequest struct { // InputSendToDeviceEventResponse is a response to InputSendToDeviceEventRequest type InputSendToDeviceEventResponse struct{} +type InputReceiptEvent struct { + UserID string `json:"user_id"` + RoomID string `json:"room_id"` + EventID string `json:"event_id"` + Type string `json:"type"` + Timestamp gomatrixserverlib.Timestamp `json:"timestamp"` +} + +// InputReceiptEventRequest is a request to EDUServerInputAPI +type InputReceiptEventRequest struct { + InputReceiptEvent InputReceiptEvent `json:"input_receipt_event"` +} + +// InputReceiptEventResponse is a response to InputReceiptEventRequest +type InputReceiptEventResponse struct{} + // EDUServerInputAPI is used to write events to the typing server. type EDUServerInputAPI interface { InputTypingEvent( @@ -72,4 +88,10 @@ type EDUServerInputAPI interface { request *InputSendToDeviceEventRequest, response *InputSendToDeviceEventResponse, ) error + + InputReceiptEvent( + ctx context.Context, + request *InputReceiptEventRequest, + response *InputReceiptEventResponse, + ) error } diff --git a/eduserver/api/output.go b/eduserver/api/output.go index e6ded8413..650458a29 100644 --- a/eduserver/api/output.go +++ b/eduserver/api/output.go @@ -49,3 +49,39 @@ type OutputSendToDeviceEvent struct { DeviceID string `json:"device_id"` gomatrixserverlib.SendToDeviceEvent } + +type ReceiptEvent struct { + UserID string `json:"user_id"` + RoomID string `json:"room_id"` + EventID string `json:"event_id"` + Type string `json:"type"` + Timestamp gomatrixserverlib.Timestamp `json:"timestamp"` +} + +// OutputReceiptEvent is an entry in the receipt output kafka log +type OutputReceiptEvent struct { + UserID string `json:"user_id"` + RoomID string `json:"room_id"` + EventID string `json:"event_id"` + Type string `json:"type"` + Timestamp gomatrixserverlib.Timestamp `json:"timestamp"` +} + +// Helper structs for receipts json creation +type ReceiptMRead struct { + User map[string]ReceiptTS `json:"m.read"` +} + +type ReceiptTS struct { + TS gomatrixserverlib.Timestamp `json:"ts"` +} + +// FederationSender output +type FederationReceiptMRead struct { + User map[string]FederationReceiptData `json:"m.read"` +} + +type FederationReceiptData struct { + Data ReceiptTS `json:"data"` + EventIDs []string `json:"event_ids"` +} diff --git a/eduserver/api/wrapper.go b/eduserver/api/wrapper.go index c2c4596de..7907f4d39 100644 --- a/eduserver/api/wrapper.go +++ b/eduserver/api/wrapper.go @@ -67,3 +67,22 @@ func SendToDevice( response := InputSendToDeviceEventResponse{} return eduAPI.InputSendToDeviceEvent(ctx, &request, &response) } + +// SendReceipt sends a receipt event to EDU Server +func SendReceipt( + ctx context.Context, + eduAPI EDUServerInputAPI, userID, roomID, eventID, receiptType string, + timestamp gomatrixserverlib.Timestamp, +) error { + request := InputReceiptEventRequest{ + InputReceiptEvent: InputReceiptEvent{ + UserID: userID, + RoomID: roomID, + EventID: eventID, + Type: receiptType, + Timestamp: timestamp, + }, + } + response := InputReceiptEventResponse{} + return eduAPI.InputReceiptEvent(ctx, &request, &response) +} diff --git a/eduserver/eduserver.go b/eduserver/eduserver.go index 098ac0248..d5ab36818 100644 --- a/eduserver/eduserver.go +++ b/eduserver/eduserver.go @@ -49,8 +49,9 @@ func NewInternalAPI( Cache: eduCache, UserAPI: userAPI, Producer: producer, - OutputTypingEventTopic: string(cfg.Matrix.Kafka.TopicFor(config.TopicOutputTypingEvent)), - OutputSendToDeviceEventTopic: string(cfg.Matrix.Kafka.TopicFor(config.TopicOutputSendToDeviceEvent)), + OutputTypingEventTopic: cfg.Matrix.Kafka.TopicFor(config.TopicOutputTypingEvent), + OutputSendToDeviceEventTopic: cfg.Matrix.Kafka.TopicFor(config.TopicOutputSendToDeviceEvent), + OutputReceiptEventTopic: cfg.Matrix.Kafka.TopicFor(config.TopicOutputReceiptEvent), ServerName: cfg.Matrix.ServerName, } } diff --git a/eduserver/input/input.go b/eduserver/input/input.go index e3d2c55e3..c54fb9de8 100644 --- a/eduserver/input/input.go +++ b/eduserver/input/input.go @@ -37,6 +37,8 @@ type EDUServerInputAPI struct { OutputTypingEventTopic string // The kafka topic to output new send to device events to. OutputSendToDeviceEventTopic string + // The kafka topic to output new receipt events to + OutputReceiptEventTopic string // kafka producer Producer sarama.SyncProducer // Internal user query API @@ -173,3 +175,31 @@ func (t *EDUServerInputAPI) sendToDeviceEvent(ise *api.InputSendToDeviceEvent) e return nil } + +// InputReceiptEvent implements api.EDUServerInputAPI +// TODO: Intelligently batch requests sent by the same user (e.g wait a few milliseconds before emitting output events) +func (t *EDUServerInputAPI) InputReceiptEvent( + ctx context.Context, + request *api.InputReceiptEventRequest, + response *api.InputReceiptEventResponse, +) error { + logrus.WithFields(logrus.Fields{}).Infof("Producing to topic '%s'", t.OutputReceiptEventTopic) + output := &api.OutputReceiptEvent{ + UserID: request.InputReceiptEvent.UserID, + RoomID: request.InputReceiptEvent.RoomID, + EventID: request.InputReceiptEvent.EventID, + Type: request.InputReceiptEvent.Type, + Timestamp: request.InputReceiptEvent.Timestamp, + } + js, err := json.Marshal(output) + if err != nil { + return err + } + m := &sarama.ProducerMessage{ + Topic: t.OutputReceiptEventTopic, + Key: sarama.StringEncoder(request.InputReceiptEvent.RoomID + ":" + request.InputReceiptEvent.UserID), + Value: sarama.ByteEncoder(js), + } + _, _, err = t.Producer.SendMessage(m) + return err +} diff --git a/eduserver/inthttp/client.go b/eduserver/inthttp/client.go index 7d0bc1603..0690ed827 100644 --- a/eduserver/inthttp/client.go +++ b/eduserver/inthttp/client.go @@ -14,6 +14,7 @@ import ( const ( EDUServerInputTypingEventPath = "/eduserver/input" EDUServerInputSendToDeviceEventPath = "/eduserver/sendToDevice" + EDUServerInputReceiptEventPath = "/eduserver/receipt" ) // NewEDUServerClient creates a EDUServerInputAPI implemented by talking to a HTTP POST API. @@ -54,3 +55,16 @@ func (h *httpEDUServerInputAPI) InputSendToDeviceEvent( apiURL := h.eduServerURL + EDUServerInputSendToDeviceEventPath return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } + +// InputSendToDeviceEvent implements EDUServerInputAPI +func (h *httpEDUServerInputAPI) InputReceiptEvent( + ctx context.Context, + request *api.InputReceiptEventRequest, + response *api.InputReceiptEventResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "InputReceiptEventPath") + defer span.Finish() + + apiURL := h.eduServerURL + EDUServerInputReceiptEventPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} diff --git a/eduserver/inthttp/server.go b/eduserver/inthttp/server.go index e374513a3..a34943750 100644 --- a/eduserver/inthttp/server.go +++ b/eduserver/inthttp/server.go @@ -38,4 +38,17 @@ func AddRoutes(t api.EDUServerInputAPI, internalAPIMux *mux.Router) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(EDUServerInputReceiptEventPath, + httputil.MakeInternalAPI("inputReceiptEvent", func(req *http.Request) util.JSONResponse { + var request api.InputReceiptEventRequest + var response api.InputReceiptEventResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := t.InputReceiptEvent(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) } diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index 3c2e5bbb0..00101fe95 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -72,7 +72,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) { t.Errorf("failed to parse event: %s", err) } he := ev.Headered(tc.roomVer) - invReq, err := gomatrixserverlib.NewInviteV2Request(&he, nil) + invReq, err := gomatrixserverlib.NewInviteV2Request(he, nil) if err != nil { t.Errorf("failed to create invite v2 request: %s", err) continue diff --git a/federationapi/routing/backfill.go b/federationapi/routing/backfill.go index ea77c947f..717289a2a 100644 --- a/federationapi/routing/backfill.go +++ b/federationapi/routing/backfill.go @@ -93,9 +93,9 @@ func Backfill( } // Filter any event that's not from the requested room out. - evs := make([]gomatrixserverlib.Event, 0) + evs := make([]*gomatrixserverlib.Event, 0) - var ev gomatrixserverlib.HeaderedEvent + var ev *gomatrixserverlib.HeaderedEvent for _, ev = range res.Events { if ev.RoomID() == roomID { evs = append(evs, ev.Event) diff --git a/federationapi/routing/events.go b/federationapi/routing/events.go index 6fa28f69d..312ef9f8e 100644 --- a/federationapi/routing/events.go +++ b/federationapi/routing/events.go @@ -98,5 +98,5 @@ func fetchEvent(ctx context.Context, rsAPI api.RoomserverInternalAPI, eventID st return nil, &util.JSONResponse{Code: http.StatusNotFound, JSON: nil} } - return &eventsResponse.Events[0].Event, nil + return eventsResponse.Events[0].Event, nil } diff --git a/federationapi/routing/invite.go b/federationapi/routing/invite.go index 16c0441b9..54f9e6848 100644 --- a/federationapi/routing/invite.go +++ b/federationapi/routing/invite.go @@ -97,7 +97,7 @@ func InviteV1( func processInvite( ctx context.Context, isInviteV2 bool, - event gomatrixserverlib.Event, + event *gomatrixserverlib.Event, roomVer gomatrixserverlib.RoomVersion, strippedState []gomatrixserverlib.InviteV2StrippedState, roomID string, @@ -171,12 +171,12 @@ func processInvite( if isInviteV2 { return util.JSONResponse{ Code: http.StatusOK, - JSON: gomatrixserverlib.RespInviteV2{Event: signedEvent}, + JSON: gomatrixserverlib.RespInviteV2{Event: &signedEvent}, } } else { return util.JSONResponse{ Code: http.StatusOK, - JSON: gomatrixserverlib.RespInvite{Event: signedEvent}, + JSON: gomatrixserverlib.RespInvite{Event: &signedEvent}, } } default: diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index 12f205366..c8e7e627b 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -138,7 +138,7 @@ func MakeJoin( // Check that the join is allowed or not stateEvents := make([]*gomatrixserverlib.Event, len(queryRes.StateEvents)) for i := range queryRes.StateEvents { - stateEvents[i] = &queryRes.StateEvents[i].Event + stateEvents[i] = queryRes.StateEvents[i].Event } provider := gomatrixserverlib.NewAuthEvents(stateEvents) @@ -291,7 +291,7 @@ func SendJoin( if err = api.SendEvents( httpReq.Context(), rsAPI, api.KindNew, - []gomatrixserverlib.HeaderedEvent{ + []*gomatrixserverlib.HeaderedEvent{ event.Headered(stateAndAuthChainResponse.RoomVersion), }, cfg.Matrix.ServerName, @@ -319,7 +319,7 @@ func SendJoin( } } -type eventsByDepth []gomatrixserverlib.HeaderedEvent +type eventsByDepth []*gomatrixserverlib.HeaderedEvent func (e eventsByDepth) Len() int { return len(e) diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index fb81d9319..498532de0 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -98,7 +98,7 @@ func MakeLeave( // Check that the leave is allowed or not stateEvents := make([]*gomatrixserverlib.Event, len(queryRes.StateEvents)) for i := range queryRes.StateEvents { - stateEvents[i] = &queryRes.StateEvents[i].Event + stateEvents[i] = queryRes.StateEvents[i].Event } provider := gomatrixserverlib.NewAuthEvents(stateEvents) if err = gomatrixserverlib.Allowed(event.Event, &provider); err != nil { @@ -257,7 +257,7 @@ func SendLeave( if err = api.SendEvents( httpReq.Context(), rsAPI, api.KindNew, - []gomatrixserverlib.HeaderedEvent{ + []*gomatrixserverlib.HeaderedEvent{ event.Headered(verRes.RoomVersion), }, cfg.Matrix.ServerName, diff --git a/federationapi/routing/missingevents.go b/federationapi/routing/missingevents.go index 5118b34e5..f79a2d2d8 100644 --- a/federationapi/routing/missingevents.go +++ b/federationapi/routing/missingevents.go @@ -73,8 +73,8 @@ func GetMissingEvents( // filterEvents returns only those events with matching roomID func filterEvents( - events []gomatrixserverlib.HeaderedEvent, roomID string, -) []gomatrixserverlib.HeaderedEvent { + events []*gomatrixserverlib.HeaderedEvent, roomID string, +) []*gomatrixserverlib.HeaderedEvent { ref := events[:0] for _, ev := range events { if ev.RoomID() == roomID { diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 76dc3a2ee..a3011b500 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -111,7 +111,8 @@ type txnReq struct { // which the roomserver is unaware of. haveEvents map[string]*gomatrixserverlib.HeaderedEvent // new events which the roomserver does not know about - newEvents map[string]bool + newEvents map[string]bool + newEventsMutex sync.RWMutex } // A subset of FederationClient functionality that txn requires. Useful for testing. @@ -128,7 +129,7 @@ type txnFederationClient interface { func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.RespSend, *util.JSONResponse) { results := make(map[string]gomatrixserverlib.PDUResult) - pdus := []gomatrixserverlib.HeaderedEvent{} + pdus := []*gomatrixserverlib.HeaderedEvent{} for _, pdu := range t.PDUs { var header struct { RoomID string `json:"room_id"` @@ -171,7 +172,7 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res } continue } - if err = gomatrixserverlib.VerifyAllEventSignatures(ctx, []gomatrixserverlib.Event{event}, t.keys); err != nil { + if err = gomatrixserverlib.VerifyAllEventSignatures(ctx, []*gomatrixserverlib.Event{event}, t.keys); err != nil { util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID()) results[event.EventID()] = gomatrixserverlib.PDUResult{ Error: err.Error(), @@ -264,6 +265,8 @@ func (e missingPrevEventsError) Error() string { } func (t *txnReq) haveEventIDs() map[string]bool { + t.newEventsMutex.RLock() + defer t.newEventsMutex.RUnlock() result := make(map[string]bool, len(t.haveEvents)) for eventID := range t.haveEvents { if t.newEvents[eventID] { @@ -322,12 +325,69 @@ func (t *txnReq) processEDUs(ctx context.Context) { } case gomatrixserverlib.MDeviceListUpdate: t.processDeviceListUpdate(ctx, e) + case gomatrixserverlib.MReceipt: + // https://matrix.org/docs/spec/server_server/r0.1.4#receipts + payload := map[string]eduserverAPI.FederationReceiptMRead{} + + if err := json.Unmarshal(e.Content, &payload); err != nil { + util.GetLogger(ctx).WithError(err).Error("Failed to unmarshal receipt event") + continue + } + + for roomID, receipt := range payload { + for userID, mread := range receipt.User { + _, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("Failed to split domain from receipt event sender") + continue + } + if t.Origin != domain { + util.GetLogger(ctx).Warnf("Dropping receipt event where sender domain (%q) doesn't match origin (%q)", domain, t.Origin) + continue + } + if err := t.processReceiptEvent(ctx, userID, roomID, "m.read", mread.Data.TS, mread.EventIDs); err != nil { + util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{ + "sender": t.Origin, + "user_id": userID, + "room_id": roomID, + "events": mread.EventIDs, + }).Error("Failed to send receipt event to edu server") + continue + } + } + } default: util.GetLogger(ctx).WithField("type", e.Type).Debug("Unhandled EDU") } } } +// processReceiptEvent sends receipt events to the edu server +func (t *txnReq) processReceiptEvent(ctx context.Context, + userID, roomID, receiptType string, + timestamp gomatrixserverlib.Timestamp, + eventIDs []string, +) error { + // store every event + for _, eventID := range eventIDs { + req := eduserverAPI.InputReceiptEventRequest{ + InputReceiptEvent: eduserverAPI.InputReceiptEvent{ + UserID: userID, + RoomID: roomID, + EventID: eventID, + Type: receiptType, + Timestamp: timestamp, + }, + } + resp := eduserverAPI.InputReceiptEventResponse{} + if err := t.eduAPI.InputReceiptEvent(ctx, &req, &resp); err != nil { + return fmt.Errorf("unable to set receipt event: %w", err) + } + } + + return nil +} + func (t *txnReq) processDeviceListUpdate(ctx context.Context, e gomatrixserverlib.EDU) { var payload gomatrixserverlib.DeviceListUpdateEvent if err := json.Unmarshal(e.Content, &payload); err != nil { @@ -356,7 +416,7 @@ func (t *txnReq) getServers(ctx context.Context, roomID string) []gomatrixserver return servers } -func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event) error { +func (t *txnReq) processEvent(ctx context.Context, e *gomatrixserverlib.Event) error { logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) // Work out if the roomserver knows everything it needs to know to auth @@ -404,7 +464,7 @@ func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event) er context.Background(), t.rsAPI, api.KindNew, - []gomatrixserverlib.HeaderedEvent{ + []*gomatrixserverlib.HeaderedEvent{ e.Headered(stateResp.RoomVersion), }, api.DoNotSendToOtherServers, @@ -413,7 +473,7 @@ func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event) er } func (t *txnReq) retrieveMissingAuthEvents( - ctx context.Context, e gomatrixserverlib.Event, stateResp *api.QueryMissingAuthPrevEventsResponse, + ctx context.Context, e *gomatrixserverlib.Event, stateResp *api.QueryMissingAuthPrevEventsResponse, ) error { logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) @@ -466,10 +526,10 @@ withNextEvent: return nil } -func checkAllowedByState(e gomatrixserverlib.Event, stateEvents []gomatrixserverlib.Event) error { +func checkAllowedByState(e *gomatrixserverlib.Event, stateEvents []*gomatrixserverlib.Event) error { authUsingState := gomatrixserverlib.NewAuthEvents(nil) for i := range stateEvents { - err := authUsingState.AddEvent(&stateEvents[i]) + err := authUsingState.AddEvent(stateEvents[i]) if err != nil { return err } @@ -478,7 +538,7 @@ func checkAllowedByState(e gomatrixserverlib.Event, stateEvents []gomatrixserver } // nolint:gocyclo -func (t *txnReq) processEventWithMissingState(ctx context.Context, e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) error { +func (t *txnReq) processEventWithMissingState(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) error { // Do this with a fresh context, so that we keep working even if the // original request times out. With any luck, by the time the remote // side retries, we'll have fetched the missing state. @@ -512,7 +572,7 @@ func (t *txnReq) processEventWithMissingState(ctx context.Context, e gomatrixser return nil } - backwardsExtremity := &newEvents[0] + backwardsExtremity := newEvents[0] newEvents = newEvents[1:] type respState struct { @@ -600,7 +660,7 @@ func (t *txnReq) processEventWithMissingState(ctx context.Context, e gomatrixser // than the backward extremity, into the roomserver without state. This way // they will automatically fast-forward based on the room state at the // extremity in the last step. - headeredNewEvents := make([]gomatrixserverlib.HeaderedEvent, len(newEvents)) + headeredNewEvents := make([]*gomatrixserverlib.HeaderedEvent, len(newEvents)) for i, newEvent := range newEvents { headeredNewEvents[i] = newEvent.Headered(roomVersion) } @@ -677,9 +737,9 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event return nil } for i, ev := range res.StateEvents { - t.haveEvents[ev.EventID()] = &res.StateEvents[i] + t.haveEvents[ev.EventID()] = res.StateEvents[i] } - var authEvents []gomatrixserverlib.Event + var authEvents []*gomatrixserverlib.Event missingAuthEvents := make(map[string]bool) for _, ev := range res.StateEvents { for _, ae := range ev.AuthEventIDs() { @@ -707,7 +767,7 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event } for i := range queryRes.Events { evID := queryRes.Events[i].EventID() - t.haveEvents[evID] = &queryRes.Events[i] + t.haveEvents[evID] = queryRes.Events[i] authEvents = append(authEvents, queryRes.Events[i].Unwrap()) } @@ -730,8 +790,8 @@ func (t *txnReq) lookupStateBeforeEvent(ctx context.Context, roomVersion gomatri } func (t *txnReq) resolveStatesAndCheck(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, states []*gomatrixserverlib.RespState, backwardsExtremity *gomatrixserverlib.Event) (*gomatrixserverlib.RespState, error) { - var authEventList []gomatrixserverlib.Event - var stateEventList []gomatrixserverlib.Event + var authEventList []*gomatrixserverlib.Event + var stateEventList []*gomatrixserverlib.Event for _, state := range states { authEventList = append(authEventList, state.AuthEvents...) stateEventList = append(stateEventList, state.StateEvents...) @@ -742,7 +802,7 @@ func (t *txnReq) resolveStatesAndCheck(ctx context.Context, roomVersion gomatrix } // apply the current event retryAllowedState: - if err = checkAllowedByState(*backwardsExtremity, resolvedStateEvents); err != nil { + if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents); err != nil { switch missing := err.(type) { case gomatrixserverlib.MissingAuthEventError: servers := t.getServers(ctx, backwardsExtremity.RoomID()) @@ -779,9 +839,9 @@ retryAllowedState: // This function recursively calls txnReq.processEvent with the missing events, which will be processed before this function returns. // This means that we may recursively call this function, as we spider back up prev_events. // nolint:gocyclo -func (t *txnReq) getMissingEvents(ctx context.Context, e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (newEvents []gomatrixserverlib.Event, err error) { +func (t *txnReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (newEvents []*gomatrixserverlib.Event, err error) { logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) - needed := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{e}) + needed := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{e}) // query latest events (our trusted forward extremities) req := api.QueryLatestEventsAndStateRequest{ RoomID: e.RoomID(), @@ -922,7 +982,7 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even } for i := range queryRes.Events { evID := queryRes.Events[i].EventID() - t.haveEvents[evID] = &queryRes.Events[i] + t.haveEvents[evID] = queryRes.Events[i] if missing[evID] { delete(missing, evID) } @@ -1059,10 +1119,10 @@ func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib. if err := t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { util.GetLogger(ctx).Warnf("Failed to query roomserver for missing event %s: %s - falling back to remote", missingEventID, err) } else if len(queryRes.Events) == 1 { - return &queryRes.Events[0], nil + return queryRes.Events[0], nil } } - var event gomatrixserverlib.Event + var event *gomatrixserverlib.Event found := false for _, serverName := range servers { txn, err := t.federation.GetEvent(ctx, serverName, missingEventID) @@ -1082,11 +1142,13 @@ func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib. util.GetLogger(ctx).WithField("event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(servers)) return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(servers)) } - if err := gomatrixserverlib.VerifyAllEventSignatures(ctx, []gomatrixserverlib.Event{event}, t.keys); err != nil { + if err := gomatrixserverlib.VerifyAllEventSignatures(ctx, []*gomatrixserverlib.Event{event}, t.keys); err != nil { util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID()) return nil, verifySigError{event.EventID(), err} } h := event.Headered(roomVersion) + t.newEventsMutex.Lock() t.newEvents[h.EventID()] = true - return &h, nil + t.newEventsMutex.Unlock() + return h, nil } diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index 0a462433c..9b9db873b 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -33,8 +33,8 @@ var ( []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":6,"event_id":"$MYSbs8m4rEbsCWXD:kaer.morhen","hashes":{"sha256":"kgbYM7v4Ud2YaBsjBTolM4ySg6rHcJNYI6nWhMSdFUA"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$gl2T9l3qm0kUbiIJ:kaer.morhen",{"sha256":"C/rD04h9wGxRdN2G/IBfrgoE1UovzLZ+uskwaKZ37/Q"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"x0UoKh968jj/F5l1/R7Ew0T6CTKuew3PLNHASNxqck/bkNe8yYQiDHXRr+kZxObeqPZZTpaF1+EI+bLU9W8GDQ"}},"type":"m.room.message"}`), []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":7,"event_id":"$N5x9WJkl9ClPrAEg:kaer.morhen","hashes":{"sha256":"FWM8oz4yquTunRZ67qlW2gzPDzdWfBP6RPHXhK1I/x8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$MYSbs8m4rEbsCWXD:kaer.morhen",{"sha256":"fatqgW+SE8mb2wFn3UN+drmluoD4UJ/EcSrL6Ur9q1M"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"Y+LX/xcyufoXMOIoqQBNOzy6lZfUGB1ffgXIrSugk6obMiyAsiRejHQN/pciZXsHKxMJLYRFAz4zSJoS/LGPAA"}},"type":"m.room.message"}`), } - testEvents = []gomatrixserverlib.HeaderedEvent{} - testStateEvents = make(map[gomatrixserverlib.StateKeyTuple]gomatrixserverlib.HeaderedEvent) + testEvents = []*gomatrixserverlib.HeaderedEvent{} + testStateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent) ) func init() { @@ -76,6 +76,14 @@ func (p *testEDUProducer) InputSendToDeviceEvent( return nil } +func (o *testEDUProducer) InputReceiptEvent( + ctx context.Context, + request *eduAPI.InputReceiptEventRequest, + response *eduAPI.InputReceiptEventResponse, +) error { + return nil +} + type testRoomserverAPI struct { inputRoomEvents []api.InputRoomEvent queryMissingAuthPrevEvents func(*api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse @@ -84,6 +92,10 @@ type testRoomserverAPI struct { queryLatestEventsAndState func(*api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse } +func (t *testRoomserverAPI) PerformForget(ctx context.Context, req *api.PerformForgetRequest, resp *api.PerformForgetResponse) error { + return nil +} + func (t *testRoomserverAPI) SetFederationSenderAPI(fsAPI fsAPI.FederationSenderInternalAPI) {} func (t *testRoomserverAPI) InputRoomEvents( @@ -433,7 +445,7 @@ NextPDU: } } -func fromStateTuples(tuples []gomatrixserverlib.StateKeyTuple, omitTuples []gomatrixserverlib.StateKeyTuple) (result []gomatrixserverlib.HeaderedEvent) { +func fromStateTuples(tuples []gomatrixserverlib.StateKeyTuple, omitTuples []gomatrixserverlib.StateKeyTuple) (result []*gomatrixserverlib.HeaderedEvent) { NextTuple: for _, t := range tuples { for _, o := range omitTuples { @@ -449,7 +461,7 @@ NextTuple: return } -func assertInputRoomEvents(t *testing.T, got []api.InputRoomEvent, want []gomatrixserverlib.HeaderedEvent) { +func assertInputRoomEvents(t *testing.T, got []api.InputRoomEvent, want []*gomatrixserverlib.HeaderedEvent) { for _, g := range got { fmt.Println("GOT ", g.Event.EventID()) } @@ -481,7 +493,7 @@ func TestBasicTransaction(t *testing.T) { } txn := mustCreateTransaction(rsAPI, &txnFedClient{}, pdus) mustProcessTransaction(t, txn, nil) - assertInputRoomEvents(t, rsAPI.inputRoomEvents, []gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) + assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) } // The purpose of this test is to check that if the event received fails auth checks the event is still sent to the roomserver @@ -502,7 +514,7 @@ func TestTransactionFailAuthChecks(t *testing.T) { txn := mustCreateTransaction(rsAPI, &txnFedClient{}, pdus) mustProcessTransaction(t, txn, []string{}) // expect message to be sent to the roomserver - assertInputRoomEvents(t, rsAPI.inputRoomEvents, []gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) + assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) } // The purpose of this test is to make sure that when an event is received for which we do not know the prev_events, @@ -574,7 +586,7 @@ func TestTransactionFetchMissingPrevEvents(t *testing.T) { t.Errorf("call to /get_missing_events wrong latest events: got %v want %v", missing.LatestEvents, inputEvent.EventID()) } return gomatrixserverlib.RespMissingEvents{ - Events: []gomatrixserverlib.Event{ + Events: []*gomatrixserverlib.Event{ prevEvent.Unwrap(), }, }, nil @@ -586,7 +598,7 @@ func TestTransactionFetchMissingPrevEvents(t *testing.T) { } txn := mustCreateTransaction(rsAPI, cli, pdus) mustProcessTransaction(t, txn, nil) - assertInputRoomEvents(t, rsAPI.inputRoomEvents, []gomatrixserverlib.HeaderedEvent{prevEvent, inputEvent}) + assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{prevEvent, inputEvent}) } // The purpose of this test is to check that when there are missing prev_events and we still haven't been able to fill @@ -641,7 +653,7 @@ func TestTransactionFetchMissingStateByStateIDs(t *testing.T) { } else if askingForEvent == eventB.EventID() { prevEventExists = haveEventB } - var stateEvents []gomatrixserverlib.HeaderedEvent + var stateEvents []*gomatrixserverlib.HeaderedEvent if prevEventExists { stateEvents = fromStateTuples(req.StateToFetch, omitTuples) } @@ -759,7 +771,7 @@ func TestTransactionFetchMissingStateByStateIDs(t *testing.T) { } // just return event C, not event B so /state_ids logic kicks in as there will STILL be missing prev_events return gomatrixserverlib.RespMissingEvents{ - Events: []gomatrixserverlib.Event{ + Events: []*gomatrixserverlib.Event{ eventC.Unwrap(), }, }, nil @@ -771,5 +783,5 @@ func TestTransactionFetchMissingStateByStateIDs(t *testing.T) { } txn := mustCreateTransaction(rsAPI, cli, pdus) mustProcessTransaction(t, txn, nil) - assertInputRoomEvents(t, rsAPI.inputRoomEvents, []gomatrixserverlib.HeaderedEvent{eventB, eventC, eventD}) + assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{eventB, eventC, eventD}) } diff --git a/federationapi/routing/state.go b/federationapi/routing/state.go index 28dfad846..128df6187 100644 --- a/federationapi/routing/state.go +++ b/federationapi/routing/state.go @@ -136,7 +136,7 @@ func getState( }, nil } -func getIDsFromEvent(events []gomatrixserverlib.Event) []string { +func getIDsFromEvent(events []*gomatrixserverlib.Event) []string { IDs := make([]string, len(events)) for i := range events { IDs[i] = events[i].EventID() diff --git a/federationapi/routing/threepid.go b/federationapi/routing/threepid.go index 4db5273af..ed6af526e 100644 --- a/federationapi/routing/threepid.go +++ b/federationapi/routing/threepid.go @@ -65,7 +65,7 @@ func CreateInvitesFrom3PIDInvites( return *reqErr } - evs := []gomatrixserverlib.HeaderedEvent{} + evs := []*gomatrixserverlib.HeaderedEvent{} for _, inv := range body.Invites { verReq := api.QueryRoomVersionForRoomRequest{RoomID: inv.RoomID} verRes := api.QueryRoomVersionForRoomResponse{} @@ -84,7 +84,7 @@ func CreateInvitesFrom3PIDInvites( return jsonerror.InternalServerError() } if event != nil { - evs = append(evs, (*event).Headered(verRes.RoomVersion)) + evs = append(evs, event.Headered(verRes.RoomVersion)) } } @@ -165,7 +165,7 @@ func ExchangeThirdPartyInvite( // Ask the requesting server to sign the newly created event so we know it // acknowledged it - signedEvent, err := federation.SendInvite(httpReq.Context(), request.Origin(), *event) + signedEvent, err := federation.SendInvite(httpReq.Context(), request.Origin(), event) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed") return jsonerror.InternalServerError() @@ -175,7 +175,7 @@ func ExchangeThirdPartyInvite( if err = api.SendEvents( httpReq.Context(), rsAPI, api.KindNew, - []gomatrixserverlib.HeaderedEvent{ + []*gomatrixserverlib.HeaderedEvent{ signedEvent.Event.Headered(verRes.RoomVersion), }, cfg.Matrix.ServerName, @@ -297,7 +297,7 @@ func buildMembershipEvent( authEvents := gomatrixserverlib.NewAuthEvents(nil) for i := range queryRes.StateEvents { - err = authEvents.AddEvent(&queryRes.StateEvents[i].Event) + err = authEvents.AddEvent(queryRes.StateEvents[i].Event) if err != nil { return nil, err } @@ -318,7 +318,7 @@ func buildMembershipEvent( cfg.Matrix.PrivateKey, queryRes.RoomVersion, ) - return &event, err + return event, err } // sendToRemoteServer uses federation to send an invite provided by an identity diff --git a/federationsender/api/api.go b/federationsender/api/api.go index 15bd857ae..ee5e375fb 100644 --- a/federationsender/api/api.go +++ b/federationsender/api/api.go @@ -48,6 +48,7 @@ type FederationSenderInternalAPI interface { // Query the server names of the joined hosts in a room. // Unlike QueryJoinedHostsInRoom, this function returns a de-duplicated slice // containing only the server names (without information for membership events). + // The response will include this server if they are joined to the room. QueryJoinedHostServerNamesInRoom( ctx context.Context, request *QueryJoinedHostServerNamesInRoomRequest, @@ -110,6 +111,7 @@ type PerformJoinRequest struct { } type PerformJoinResponse struct { + JoinedVia gomatrixserverlib.ServerName LastError *gomatrix.HTTPError } @@ -134,12 +136,12 @@ type PerformLeaveResponse struct { type PerformInviteRequest struct { RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` - Event gomatrixserverlib.HeaderedEvent `json:"event"` + Event *gomatrixserverlib.HeaderedEvent `json:"event"` InviteRoomState []gomatrixserverlib.InviteV2StrippedState `json:"invite_room_state"` } type PerformInviteResponse struct { - Event gomatrixserverlib.HeaderedEvent `json:"event"` + Event *gomatrixserverlib.HeaderedEvent `json:"event"` } type PerformServersAliveRequest struct { diff --git a/federationsender/consumers/eduserver.go b/federationsender/consumers/eduserver.go index d9ac41b3b..9d7574e68 100644 --- a/federationsender/consumers/eduserver.go +++ b/federationsender/consumers/eduserver.go @@ -34,6 +34,7 @@ import ( type OutputEDUConsumer struct { typingConsumer *internal.ContinualConsumer sendToDeviceConsumer *internal.ContinualConsumer + receiptConsumer *internal.ContinualConsumer db storage.Database queues *queue.OutgoingQueues ServerName gomatrixserverlib.ServerName @@ -51,24 +52,31 @@ func NewOutputEDUConsumer( c := &OutputEDUConsumer{ typingConsumer: &internal.ContinualConsumer{ ComponentName: "eduserver/typing", - Topic: string(cfg.Matrix.Kafka.TopicFor(config.TopicOutputTypingEvent)), + Topic: cfg.Matrix.Kafka.TopicFor(config.TopicOutputTypingEvent), Consumer: kafkaConsumer, PartitionStore: store, }, sendToDeviceConsumer: &internal.ContinualConsumer{ ComponentName: "eduserver/sendtodevice", - Topic: string(cfg.Matrix.Kafka.TopicFor(config.TopicOutputSendToDeviceEvent)), + Topic: cfg.Matrix.Kafka.TopicFor(config.TopicOutputSendToDeviceEvent), + Consumer: kafkaConsumer, + PartitionStore: store, + }, + receiptConsumer: &internal.ContinualConsumer{ + ComponentName: "eduserver/receipt", + Topic: cfg.Matrix.Kafka.TopicFor(config.TopicOutputReceiptEvent), Consumer: kafkaConsumer, PartitionStore: store, }, queues: queues, db: store, ServerName: cfg.Matrix.ServerName, - TypingTopic: string(cfg.Matrix.Kafka.TopicFor(config.TopicOutputTypingEvent)), - SendToDeviceTopic: string(cfg.Matrix.Kafka.TopicFor(config.TopicOutputSendToDeviceEvent)), + TypingTopic: cfg.Matrix.Kafka.TopicFor(config.TopicOutputTypingEvent), + SendToDeviceTopic: cfg.Matrix.Kafka.TopicFor(config.TopicOutputSendToDeviceEvent), } c.typingConsumer.ProcessMessage = c.onTypingEvent c.sendToDeviceConsumer.ProcessMessage = c.onSendToDeviceEvent + c.receiptConsumer.ProcessMessage = c.onReceiptEvent return c } @@ -81,6 +89,9 @@ func (t *OutputEDUConsumer) Start() error { if err := t.sendToDeviceConsumer.Start(); err != nil { return fmt.Errorf("t.sendToDeviceConsumer.Start: %w", err) } + if err := t.receiptConsumer.Start(); err != nil { + return fmt.Errorf("t.receiptConsumer.Start: %w", err) + } return nil } @@ -177,3 +188,58 @@ func (t *OutputEDUConsumer) onTypingEvent(msg *sarama.ConsumerMessage) error { return t.queues.SendEDU(edu, t.ServerName, names) } + +// onReceiptEvent is called in response to a message received on the receipt +// events topic from the EDU server. +func (t *OutputEDUConsumer) onReceiptEvent(msg *sarama.ConsumerMessage) error { + // Extract the typing event from msg. + var receipt api.OutputReceiptEvent + if err := json.Unmarshal(msg.Value, &receipt); err != nil { + // Skip this msg but continue processing messages. + log.WithError(err).Errorf("eduserver output log: message parse failed (expected receipt)") + return nil + } + + // only send receipt events which originated from us + _, receiptServerName, err := gomatrixserverlib.SplitID('@', receipt.UserID) + if err != nil { + log.WithError(err).WithField("user_id", receipt.UserID).Error("Failed to extract domain from receipt sender") + return nil + } + if receiptServerName != t.ServerName { + log.WithField("other_server", receiptServerName).Info("Suppressing receipt notif: originated elsewhere") + return nil + } + + joined, err := t.db.GetJoinedHosts(context.TODO(), receipt.RoomID) + if err != nil { + return err + } + + names := make([]gomatrixserverlib.ServerName, len(joined)) + for i := range joined { + names[i] = joined[i].ServerName + } + + content := map[string]api.FederationReceiptMRead{} + content[receipt.RoomID] = api.FederationReceiptMRead{ + User: map[string]api.FederationReceiptData{ + receipt.UserID: { + Data: api.ReceiptTS{ + TS: receipt.Timestamp, + }, + EventIDs: []string{receipt.EventID}, + }, + }, + } + + edu := &gomatrixserverlib.EDU{ + Type: gomatrixserverlib.MReceipt, + Origin: string(t.ServerName), + } + if edu.Content, err = json.Marshal(content); err != nil { + return err + } + + return t.queues.SendEDU(edu, t.ServerName, names) +} diff --git a/federationsender/consumers/roomserver.go b/federationsender/consumers/roomserver.go index 3f35925a1..4b0741cfd 100644 --- a/federationsender/consumers/roomserver.go +++ b/federationsender/consumers/roomserver.go @@ -85,7 +85,7 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error { switch output.Type { case api.OutputTypeNewRoomEvent: - ev := &output.NewRoomEvent.Event + ev := output.NewRoomEvent.Event if output.NewRoomEvent.RewritesState { if err := s.db.PurgeRoomState(context.TODO(), ev.RoomID()); err != nil { @@ -187,7 +187,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) err // Send the event. return s.queues.SendEvent( - &ore.Event, gomatrixserverlib.ServerName(ore.SendAsServer), joinedHostsAtEvent, + ore.Event, gomatrixserverlib.ServerName(ore.SendAsServer), joinedHostsAtEvent, ) } @@ -264,7 +264,7 @@ func (s *OutputRoomEventConsumer) joinedHostsAtEvent( // joinedHostsFromEvents turns a list of state events into a list of joined hosts. // This errors if one of the events was invalid. // It should be impossible for an invalid event to get this far in the pipeline. -func joinedHostsFromEvents(evs []gomatrixserverlib.Event) ([]types.JoinedHost, error) { +func joinedHostsFromEvents(evs []*gomatrixserverlib.Event) ([]types.JoinedHost, error) { var joinedHosts []types.JoinedHost for _, ev := range evs { if ev.Type() != "m.room.member" || ev.StateKey() == nil { @@ -329,8 +329,8 @@ func combineDeltas(adds1, removes1, adds2, removes2 []string) (adds, removes []s // lookupStateEvents looks up the state events that are added by a new event. func (s *OutputRoomEventConsumer) lookupStateEvents( - addsStateEventIDs []string, event gomatrixserverlib.Event, -) ([]gomatrixserverlib.Event, error) { + addsStateEventIDs []string, event *gomatrixserverlib.Event, +) ([]*gomatrixserverlib.Event, error) { // Fast path if there aren't any new state events. if len(addsStateEventIDs) == 0 { return nil, nil @@ -338,11 +338,11 @@ func (s *OutputRoomEventConsumer) lookupStateEvents( // Fast path if the only state event added is the event itself. if len(addsStateEventIDs) == 1 && addsStateEventIDs[0] == event.EventID() { - return []gomatrixserverlib.Event{event}, nil + return []*gomatrixserverlib.Event{event}, nil } missing := addsStateEventIDs - var result []gomatrixserverlib.Event + var result []*gomatrixserverlib.Event // Check if event itself is being added. for _, eventID := range missing { @@ -381,7 +381,7 @@ func (s *OutputRoomEventConsumer) lookupStateEvents( return result, nil } -func missingEventsFrom(events []gomatrixserverlib.Event, required []string) []string { +func missingEventsFrom(events []*gomatrixserverlib.Event, required []string) []string { have := map[string]bool{} for _, event := range events { have[event.EventID()] = true diff --git a/federationsender/internal/perform.go b/federationsender/internal/perform.go index 4d706c3e4..02ea32df5 100644 --- a/federationsender/internal/perform.go +++ b/federationsender/internal/perform.go @@ -105,6 +105,7 @@ func (r *FederationSenderInternalAPI) PerformJoin( } // We're all good. + response.JoinedVia = serverName return } @@ -534,7 +535,7 @@ func (r *FederationSenderInternalAPI) PerformInvite( "destination": destination, }).Info("Sending invite") - inviteReq, err := gomatrixserverlib.NewInviteV2Request(&request.Event, request.InviteRoomState) + inviteReq, err := gomatrixserverlib.NewInviteV2Request(request.Event, request.InviteRoomState) if err != nil { return fmt.Errorf("gomatrixserverlib.NewInviteV2Request: %w", err) } diff --git a/federationsender/internal/perform/join.go b/federationsender/internal/perform/join.go index d03a57df6..c23f6fa3e 100644 --- a/federationsender/internal/perform/join.go +++ b/federationsender/internal/perform/join.go @@ -40,19 +40,19 @@ func JoinContext(f *gomatrixserverlib.FederationClient, k *gomatrixserverlib.Key // and that the join is allowed by the supplied state. func (r joinContext) CheckSendJoinResponse( ctx context.Context, - event gomatrixserverlib.Event, + event *gomatrixserverlib.Event, server gomatrixserverlib.ServerName, respSendJoin gomatrixserverlib.RespSendJoin, ) (*gomatrixserverlib.RespState, error) { // A list of events that we have retried, if they were not included in // the auth events supplied in the send_join. - retries := map[string][]gomatrixserverlib.Event{} + retries := map[string][]*gomatrixserverlib.Event{} // Define a function which we can pass to Check to retrieve missing // auth events inline. This greatly increases our chances of not having // to repeat the entire set of checks just for a missing event or two. - missingAuth := func(roomVersion gomatrixserverlib.RoomVersion, eventIDs []string) ([]gomatrixserverlib.Event, error) { - returning := []gomatrixserverlib.Event{} + missingAuth := func(roomVersion gomatrixserverlib.RoomVersion, eventIDs []string) ([]*gomatrixserverlib.Event, error) { + returning := []*gomatrixserverlib.Event{} // See if we have retry entries for each of the supplied event IDs. for _, eventID := range eventIDs { @@ -88,7 +88,7 @@ func (r joinContext) CheckSendJoinResponse( } // Check the signatures of the event. - if res, err := gomatrixserverlib.VerifyEventSignatures(ctx, []gomatrixserverlib.Event{ev}, r.keyRing); err != nil { + if res, err := gomatrixserverlib.VerifyEventSignatures(ctx, []*gomatrixserverlib.Event{ev}, r.keyRing); err != nil { return nil, fmt.Errorf("missingAuth VerifyEventSignatures: %w", err) } else { for _, err := range res { diff --git a/federationsender/internal/query.go b/federationsender/internal/query.go index 253400a2d..8ba228d1b 100644 --- a/federationsender/internal/query.go +++ b/federationsender/internal/query.go @@ -4,7 +4,6 @@ import ( "context" "github.com/matrix-org/dendrite/federationsender/api" - "github.com/matrix-org/gomatrixserverlib" ) // QueryJoinedHostServerNamesInRoom implements api.FederationSenderInternalAPI @@ -13,17 +12,11 @@ func (f *FederationSenderInternalAPI) QueryJoinedHostServerNamesInRoom( request *api.QueryJoinedHostServerNamesInRoomRequest, response *api.QueryJoinedHostServerNamesInRoomResponse, ) (err error) { - joinedHosts, err := f.db.GetJoinedHosts(ctx, request.RoomID) + joinedHosts, err := f.db.GetJoinedHostsForRooms(ctx, []string{request.RoomID}) if err != nil { return } - - response.ServerNames = make([]gomatrixserverlib.ServerName, 0, len(joinedHosts)) - for _, host := range joinedHosts { - response.ServerNames = append(response.ServerNames, host.ServerName) - } - - // TODO: remove duplicates? + response.ServerNames = joinedHosts return } diff --git a/go.mod b/go.mod index f785dd391..42eeaf04e 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3 github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd - github.com/matrix-org/gomatrixserverlib v0.0.0-20201020162226-22169fe9cda7 + github.com/matrix-org/gomatrixserverlib v0.0.0-20201116151724-6e7b24e4956c github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.2 @@ -33,7 +33,7 @@ require ( github.com/pressly/goose v2.7.0-rc5+incompatible github.com/prometheus/client_golang v1.7.1 github.com/sirupsen/logrus v1.6.0 - github.com/tidwall/gjson v1.6.1 + github.com/tidwall/gjson v1.6.3 github.com/tidwall/sjson v1.1.1 github.com/uber/jaeger-client-go v2.25.0+incompatible github.com/uber/jaeger-lib v2.2.0+incompatible diff --git a/go.sum b/go.sum index 7c24516d2..1341894df 100644 --- a/go.sum +++ b/go.sum @@ -569,8 +569,8 @@ github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bh github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd h1:xVrqJK3xHREMNjwjljkAUaadalWc0rRbmVuQatzmgwg= github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20201020162226-22169fe9cda7 h1:YPuewGCKaJh08NslYAhyGiLw2tg6ew9LtkW7Xr+4uTU= -github.com/matrix-org/gomatrixserverlib v0.0.0-20201020162226-22169fe9cda7/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20201116151724-6e7b24e4956c h1:iiloytJig9EmlKwuSulIbNvoPz1BFZ1QdyPWpuy85XM= +github.com/matrix-org/gomatrixserverlib v0.0.0-20201116151724-6e7b24e4956c/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 h1:HJ6U3S3ljJqNffYMcIeAncp5qT/i+ZMiJ2JC2F0aXP4= github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91/go.mod h1:sjyPyRxKM5uw1nD2cJ6O2OxI6GOqyVBfNXqKjBZTBZE= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo= @@ -812,8 +812,8 @@ github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpP github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= github.com/tidwall/gjson v1.6.0 h1:9VEQWz6LLMUsUl6PueE49ir4Ka6CzLymOAZDxpFsTDc= github.com/tidwall/gjson v1.6.0/go.mod h1:P256ACg0Mn+j1RXIDXoss50DeIABTYK1PULOJHhxOls= -github.com/tidwall/gjson v1.6.1 h1:LRbvNuNuvAiISWg6gxLEFuCe72UKy5hDqhxW/8183ws= -github.com/tidwall/gjson v1.6.1/go.mod h1:BaHyNc5bjzYkPqgLq7mdVzeiRtULKULXLgZFKsxEHI0= +github.com/tidwall/gjson v1.6.3 h1:aHoiiem0dr7GHkW001T1SMTJ7X5PvyekH5WX0whWGnI= +github.com/tidwall/gjson v1.6.3/go.mod h1:BaHyNc5bjzYkPqgLq7mdVzeiRtULKULXLgZFKsxEHI0= github.com/tidwall/match v1.0.1 h1:PnKP62LPNxHKTwvHHZZzdOAOCtsJTjo6dZLCwpKm5xc= github.com/tidwall/match v1.0.1/go.mod h1:LujAq0jyVjBy028G1WhWfIzbpQfMO8bBZ6Tyb0+pL9E= github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= diff --git a/internal/caching/cache_roomservernids.go b/internal/caching/cache_roomservernids.go index 7cb312c95..cac595494 100644 --- a/internal/caching/cache_roomservernids.go +++ b/internal/caching/cache_roomservernids.go @@ -1,6 +1,8 @@ package caching import ( + "strconv" + "github.com/matrix-org/dendrite/roomserver/types" ) @@ -83,11 +85,11 @@ func (c Caches) GetRoomServerRoomNID(roomID string) (types.RoomNID, bool) { func (c Caches) StoreRoomServerRoomNID(roomID string, roomNID types.RoomNID) { c.RoomServerRoomNIDs.Set(roomID, roomNID) - c.RoomServerRoomIDs.Set(string(roomNID), roomID) + c.RoomServerRoomIDs.Set(strconv.Itoa(int(roomNID)), roomID) } func (c Caches) GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) { - val, found := c.RoomServerRoomIDs.Get(string(roomNID)) + val, found := c.RoomServerRoomIDs.Get(strconv.Itoa(int(roomNID))) if found && val != nil { if roomID, ok := val.(string); ok { return roomID, true diff --git a/internal/config/config.go b/internal/config/config.go index 9d9e2414f..b8b12d0c1 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -66,6 +66,8 @@ type Dendrite struct { SyncAPI SyncAPI `yaml:"sync_api"` UserAPI UserAPI `yaml:"user_api"` + MSCs MSCs `yaml:"mscs"` + // The config for tracing the dendrite servers. Tracing struct { // Set to true to enable tracer hooks. If false, no tracing is set up. @@ -306,6 +308,7 @@ func (c *Dendrite) Defaults() { c.SyncAPI.Defaults() c.UserAPI.Defaults() c.AppServiceAPI.Defaults() + c.MSCs.Defaults() c.Wiring() } @@ -319,7 +322,7 @@ func (c *Dendrite) Verify(configErrs *ConfigErrors, isMonolith bool) { &c.EDUServer, &c.FederationAPI, &c.FederationSender, &c.KeyServer, &c.MediaAPI, &c.RoomServer, &c.SigningKeyServer, &c.SyncAPI, &c.UserAPI, - &c.AppServiceAPI, + &c.AppServiceAPI, &c.MSCs, } { c.Verify(configErrs, isMonolith) } @@ -337,6 +340,7 @@ func (c *Dendrite) Wiring() { c.SyncAPI.Matrix = &c.Global c.UserAPI.Matrix = &c.Global c.AppServiceAPI.Matrix = &c.Global + c.MSCs.Matrix = &c.Global c.ClientAPI.Derived = &c.Derived c.AppServiceAPI.Derived = &c.Derived diff --git a/internal/config/config_kafka.go b/internal/config/config_kafka.go index e2bd6538e..aa91e5589 100644 --- a/internal/config/config_kafka.go +++ b/internal/config/config_kafka.go @@ -9,6 +9,7 @@ const ( TopicOutputKeyChangeEvent = "OutputKeyChangeEvent" TopicOutputRoomEvent = "OutputRoomEvent" TopicOutputClientData = "OutputClientData" + TopicOutputReceiptEvent = "OutputReceiptEvent" ) type Kafka struct { @@ -24,6 +25,9 @@ type Kafka struct { UseNaffka bool `yaml:"use_naffka"` // The Naffka database is used internally by the naffka library, if used. Database DatabaseOptions `yaml:"naffka_database"` + // The max size a Kafka message passed between consumer/producer can have + // Equals roughly max.message.bytes / fetch.message.max.bytes in Kafka + MaxMessageBytes *int `yaml:"max_message_bytes"` } func (k *Kafka) TopicFor(name string) string { @@ -36,6 +40,9 @@ func (c *Kafka) Defaults() { c.Addresses = []string{"localhost:2181"} c.Database.ConnectionString = DataSource("file:naffka.db") c.TopicPrefix = "Dendrite" + + maxBytes := 1024 * 1024 * 8 // about 8MB + c.MaxMessageBytes = &maxBytes } func (c *Kafka) Verify(configErrs *ConfigErrors, isMonolith bool) { @@ -50,4 +57,5 @@ func (c *Kafka) Verify(configErrs *ConfigErrors, isMonolith bool) { checkNotZero(configErrs, "global.kafka.addresses", int64(len(c.Addresses))) } checkNotEmpty(configErrs, "global.kafka.topic_prefix", string(c.TopicPrefix)) + checkPositive(configErrs, "global.kafka.max_message_bytes", int64(*c.MaxMessageBytes)) } diff --git a/internal/config/config_mscs.go b/internal/config/config_mscs.go new file mode 100644 index 000000000..776d0b641 --- /dev/null +++ b/internal/config/config_mscs.go @@ -0,0 +1,19 @@ +package config + +type MSCs struct { + Matrix *Global `yaml:"-"` + + // The MSCs to enable, currently only `msc2836` is supported. + MSCs []string `yaml:"mscs"` + + Database DatabaseOptions `yaml:"database"` +} + +func (c *MSCs) Defaults() { + c.Database.Defaults() + c.Database.ConnectionString = "file:mscs.db" +} + +func (c *MSCs) Verify(configErrs *ConfigErrors, isMonolith bool) { + checkNotEmpty(configErrs, "mscs.database.connection_string", string(c.Database.ConnectionString)) +} diff --git a/internal/config/config_syncapi.go b/internal/config/config_syncapi.go index 0a96e41ca..fc08f7380 100644 --- a/internal/config/config_syncapi.go +++ b/internal/config/config_syncapi.go @@ -7,6 +7,8 @@ type SyncAPI struct { ExternalAPI ExternalAPIOptions `yaml:"external_api"` Database DatabaseOptions `yaml:"database"` + + RealIPHeader string `yaml:"real_ip_header"` } func (c *SyncAPI) Defaults() { diff --git a/internal/eventutil/events.go b/internal/eventutil/events.go index 0b878961e..a650b4b3a 100644 --- a/internal/eventutil/events.go +++ b/internal/eventutil/events.go @@ -73,8 +73,7 @@ func BuildEvent( return nil, err } - h := event.Headered(queryRes.RoomVersion) - return &h, nil + return event.Headered(queryRes.RoomVersion), nil } // queryRequiredEventsForBuilder queries the roomserver for auth/prev events needed for this builder. @@ -120,7 +119,7 @@ func addPrevEventsToEvent( authEvents := gomatrixserverlib.NewAuthEvents(nil) for i := range queryRes.StateEvents { - err = authEvents.AddEvent(&queryRes.StateEvents[i].Event) + err = authEvents.AddEvent(queryRes.StateEvents[i].Event) if err != nil { return fmt.Errorf("authEvents.AddEvent: %w", err) } @@ -186,5 +185,5 @@ func RedactEvent(redactionEvent, redactedEvent *gomatrixserverlib.Event) (*gomat if err != nil { return nil, err } - return &r, nil + return r, nil } diff --git a/internal/hooks/hooks.go b/internal/hooks/hooks.go new file mode 100644 index 000000000..223282a25 --- /dev/null +++ b/internal/hooks/hooks.go @@ -0,0 +1,74 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package hooks exposes places in Dendrite where custom code can be executed, useful for MSCs. +// Hooks can only be run in monolith mode. +package hooks + +import "sync" + +const ( + // KindNewEventPersisted is a hook which is called with *gomatrixserverlib.HeaderedEvent + // It is run when a new event is persisted in the roomserver. + // Usage: + // hooks.Attach(hooks.KindNewEventPersisted, func(headeredEvent interface{}) { ... }) + KindNewEventPersisted = "new_event_persisted" + // KindNewEventReceived is a hook which is called with *gomatrixserverlib.HeaderedEvent + // It is run before a new event is processed by the roomserver. This hook can be used + // to modify the event before it is persisted by adding data to `unsigned`. + // Usage: + // hooks.Attach(hooks.KindNewEventReceived, func(headeredEvent interface{}) { + // ev := headeredEvent.(*gomatrixserverlib.HeaderedEvent) + // _ = ev.SetUnsignedField("key", "val") + // }) + KindNewEventReceived = "new_event_received" +) + +var ( + hookMap = make(map[string][]func(interface{})) + hookMu = sync.Mutex{} + enabled = false +) + +// Enable all hooks. This may slow down the server slightly. Required for MSCs to work. +func Enable() { + enabled = true +} + +// Run any hooks +func Run(kind string, data interface{}) { + if !enabled { + return + } + cbs := callbacks(kind) + for _, cb := range cbs { + cb(data) + } +} + +// Attach a hook +func Attach(kind string, callback func(interface{})) { + if !enabled { + return + } + hookMu.Lock() + defer hookMu.Unlock() + hookMap[kind] = append(hookMap[kind], callback) +} + +func callbacks(kind string) []func(interface{}) { + hookMu.Lock() + defer hookMu.Unlock() + return hookMap[kind] +} diff --git a/internal/mscs/msc2836/msc2836.go b/internal/mscs/msc2836/msc2836.go new file mode 100644 index 000000000..865bc3111 --- /dev/null +++ b/internal/mscs/msc2836/msc2836.go @@ -0,0 +1,530 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package msc2836 'Threading' implements https://github.com/matrix-org/matrix-doc/pull/2836 +package msc2836 + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + fs "github.com/matrix-org/dendrite/federationsender/api" + "github.com/matrix-org/dendrite/internal/hooks" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/setup" + roomserver "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +const ( + constRelType = "m.reference" + constRoomIDKey = "relationship_room_id" + constRoomServers = "relationship_servers" +) + +type EventRelationshipRequest struct { + EventID string `json:"event_id"` + MaxDepth int `json:"max_depth"` + MaxBreadth int `json:"max_breadth"` + Limit int `json:"limit"` + DepthFirst bool `json:"depth_first"` + RecentFirst bool `json:"recent_first"` + IncludeParent bool `json:"include_parent"` + IncludeChildren bool `json:"include_children"` + Direction string `json:"direction"` + Batch string `json:"batch"` + AutoJoin bool `json:"auto_join"` +} + +func NewEventRelationshipRequest(body io.Reader) (*EventRelationshipRequest, error) { + var relation EventRelationshipRequest + relation.Defaults() + if err := json.NewDecoder(body).Decode(&relation); err != nil { + return nil, err + } + return &relation, nil +} + +func (r *EventRelationshipRequest) Defaults() { + r.Limit = 100 + r.MaxBreadth = 10 + r.MaxDepth = 3 + r.DepthFirst = false + r.RecentFirst = true + r.IncludeParent = false + r.IncludeChildren = false + r.Direction = "down" +} + +type EventRelationshipResponse struct { + Events []gomatrixserverlib.ClientEvent `json:"events"` + NextBatch string `json:"next_batch"` + Limited bool `json:"limited"` +} + +// Enable this MSC +// nolint:gocyclo +func Enable( + base *setup.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationSenderInternalAPI, + userAPI userapi.UserInternalAPI, keyRing gomatrixserverlib.JSONVerifier, +) error { + db, err := NewDatabase(&base.Cfg.MSCs.Database) + if err != nil { + return fmt.Errorf("Cannot enable MSC2836: %w", err) + } + hooks.Enable() + hooks.Attach(hooks.KindNewEventPersisted, func(headeredEvent interface{}) { + he := headeredEvent.(*gomatrixserverlib.HeaderedEvent) + hookErr := db.StoreRelation(context.Background(), he) + if hookErr != nil { + util.GetLogger(context.Background()).WithError(hookErr).Error( + "failed to StoreRelation", + ) + } + }) + hooks.Attach(hooks.KindNewEventReceived, func(headeredEvent interface{}) { + he := headeredEvent.(*gomatrixserverlib.HeaderedEvent) + ctx := context.Background() + // we only inject metadata for events our server sends + userID := he.Sender() + _, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return + } + if domain != base.Cfg.Global.ServerName { + return + } + // if this event has an m.relationship, add on the room_id and servers to unsigned + parent, child, relType := parentChildEventIDs(he) + if parent == "" || child == "" || relType == "" { + return + } + event, joinedToRoom := getEventIfVisible(ctx, rsAPI, parent, userID) + if !joinedToRoom { + return + } + err = he.SetUnsignedField(constRoomIDKey, event.RoomID()) + if err != nil { + util.GetLogger(context.Background()).WithError(err).Warn("Failed to SetUnsignedField") + return + } + + var servers []gomatrixserverlib.ServerName + if fsAPI != nil { + var res fs.QueryJoinedHostServerNamesInRoomResponse + err = fsAPI.QueryJoinedHostServerNamesInRoom(ctx, &fs.QueryJoinedHostServerNamesInRoomRequest{ + RoomID: event.RoomID(), + }, &res) + if err != nil { + util.GetLogger(context.Background()).WithError(err).Warn("Failed to QueryJoinedHostServerNamesInRoom") + return + } + servers = res.ServerNames + } else { + servers = []gomatrixserverlib.ServerName{ + base.Cfg.Global.ServerName, + } + } + err = he.SetUnsignedField(constRoomServers, servers) + if err != nil { + util.GetLogger(context.Background()).WithError(err).Warn("Failed to SetUnsignedField") + return + } + }) + + base.PublicClientAPIMux.Handle("/unstable/event_relationships", + httputil.MakeAuthAPI("eventRelationships", userAPI, eventRelationshipHandler(db, rsAPI)), + ).Methods(http.MethodPost, http.MethodOptions) + + base.PublicFederationAPIMux.Handle("/unstable/event_relationships", httputil.MakeExternalAPI( + "msc2836_event_relationships", func(req *http.Request) util.JSONResponse { + fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( + req, time.Now(), base.Cfg.Global.ServerName, keyRing, + ) + if fedReq == nil { + return errResp + } + return federatedEventRelationship(req.Context(), fedReq, db, rsAPI) + }, + )).Methods(http.MethodPost, http.MethodOptions) + return nil +} + +type reqCtx struct { + ctx context.Context + rsAPI roomserver.RoomserverInternalAPI + db Database + req *EventRelationshipRequest + userID string + isFederatedRequest bool +} + +func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAPI) func(*http.Request, *userapi.Device) util.JSONResponse { + return func(req *http.Request, device *userapi.Device) util.JSONResponse { + relation, err := NewEventRelationshipRequest(req.Body) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("failed to decode HTTP request as JSON") + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.BadJSON(fmt.Sprintf("invalid json: %s", err)), + } + } + rc := reqCtx{ + ctx: req.Context(), + req: relation, + userID: device.UserID, + rsAPI: rsAPI, + isFederatedRequest: false, + db: db, + } + res, resErr := rc.process() + if resErr != nil { + return *resErr + } + + return util.JSONResponse{ + Code: 200, + JSON: res, + } + } +} + +func federatedEventRelationship(ctx context.Context, fedReq *gomatrixserverlib.FederationRequest, db Database, rsAPI roomserver.RoomserverInternalAPI) util.JSONResponse { + relation, err := NewEventRelationshipRequest(bytes.NewBuffer(fedReq.Content())) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("failed to decode HTTP request as JSON") + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.BadJSON(fmt.Sprintf("invalid json: %s", err)), + } + } + rc := reqCtx{ + ctx: ctx, + req: relation, + userID: "", + rsAPI: rsAPI, + isFederatedRequest: true, + db: db, + } + res, resErr := rc.process() + if resErr != nil { + return *resErr + } + + return util.JSONResponse{ + Code: 200, + JSON: res, + } +} + +func (rc *reqCtx) process() (*EventRelationshipResponse, *util.JSONResponse) { + var res EventRelationshipResponse + var returnEvents []*gomatrixserverlib.HeaderedEvent + // Can the user see (according to history visibility) event_id? If no, reject the request, else continue. + // We should have the event being referenced so don't give any claimed room ID / servers + event := rc.getEventIfVisible(rc.req.EventID, "", nil) + if event == nil { + return nil, &util.JSONResponse{ + Code: 403, + JSON: jsonerror.Forbidden("Event does not exist or you are not authorised to see it"), + } + } + + // Retrieve the event. Add it to response array. + returnEvents = append(returnEvents, event) + + if rc.req.IncludeParent { + if parentEvent := rc.includeParent(event); parentEvent != nil { + returnEvents = append(returnEvents, parentEvent) + } + } + + if rc.req.IncludeChildren { + remaining := rc.req.Limit - len(returnEvents) + if remaining > 0 { + children, resErr := rc.includeChildren(rc.db, event.EventID(), remaining, rc.req.RecentFirst) + if resErr != nil { + return nil, resErr + } + returnEvents = append(returnEvents, children...) + } + } + + remaining := rc.req.Limit - len(returnEvents) + var walkLimited bool + if remaining > 0 { + included := make(map[string]bool, len(returnEvents)) + for _, ev := range returnEvents { + included[ev.EventID()] = true + } + var events []*gomatrixserverlib.HeaderedEvent + events, walkLimited = walkThread( + rc.ctx, rc.db, rc, included, remaining, + ) + returnEvents = append(returnEvents, events...) + } + res.Events = make([]gomatrixserverlib.ClientEvent, len(returnEvents)) + for i, ev := range returnEvents { + res.Events[i] = gomatrixserverlib.HeaderedToClientEvent(ev, gomatrixserverlib.FormatAll) + } + res.Limited = remaining == 0 || walkLimited + return &res, nil +} + +// If include_parent: true and there is a valid m.relationship field in the event, +// retrieve the referenced event. Apply history visibility check to that event and if it passes, add it to the response array. +func (rc *reqCtx) includeParent(event *gomatrixserverlib.HeaderedEvent) (parent *gomatrixserverlib.HeaderedEvent) { + parentID, _, _ := parentChildEventIDs(event) + if parentID == "" { + return nil + } + claimedRoomID, claimedServers := roomIDAndServers(event) + return rc.getEventIfVisible(parentID, claimedRoomID, claimedServers) +} + +// If include_children: true, lookup all events which have event_id as an m.relationship +// Apply history visibility checks to all these events and add the ones which pass into the response array, +// honouring the recent_first flag and the limit. +func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recentFirst bool) ([]*gomatrixserverlib.HeaderedEvent, *util.JSONResponse) { + children, err := db.ChildrenForParent(rc.ctx, parentID, constRelType, recentFirst) + if err != nil { + util.GetLogger(rc.ctx).WithError(err).Error("failed to get ChildrenForParent") + resErr := jsonerror.InternalServerError() + return nil, &resErr + } + var childEvents []*gomatrixserverlib.HeaderedEvent + for _, child := range children { + // in order for us to even know about the children the server must be joined to those rooms, hence pass no claimed room ID or servers. + childEvent := rc.getEventIfVisible(child.EventID, "", nil) + if childEvent != nil { + childEvents = append(childEvents, childEvent) + } + } + if len(childEvents) > limit { + return childEvents[:limit], nil + } + return childEvents, nil +} + +// Begin to walk the thread DAG in the direction specified, either depth or breadth first according to the depth_first flag, +// honouring the limit, max_depth and max_breadth values according to the following rules +// nolint: unparam +func walkThread( + ctx context.Context, db Database, rc *reqCtx, included map[string]bool, limit int, +) ([]*gomatrixserverlib.HeaderedEvent, bool) { + if rc.req.Direction != "down" { + util.GetLogger(ctx).Error("not implemented: direction=up") + return nil, false + } + var result []*gomatrixserverlib.HeaderedEvent + eventWalker := walker{ + ctx: ctx, + req: rc.req, + db: db, + fn: func(wi *walkInfo) bool { + // If already processed event, skip. + if included[wi.EventID] { + return false + } + + // If the response array is >= limit, stop. + if len(result) >= limit { + return true + } + + // Process the event. + // TODO: Include edge information: room ID and servers + event := rc.getEventIfVisible(wi.EventID, "", nil) + if event != nil { + result = append(result, event) + } + included[wi.EventID] = true + return false + }, + } + limited, err := eventWalker.WalkFrom(rc.req.EventID) + if err != nil { + util.GetLogger(ctx).WithError(err).Errorf("Failed to WalkFrom %s", rc.req.EventID) + } + return result, limited +} + +func (rc *reqCtx) getEventIfVisible(eventID string, claimedRoomID string, claimedServers []string) *gomatrixserverlib.HeaderedEvent { + event, joinedToRoom := getEventIfVisible(rc.ctx, rc.rsAPI, eventID, rc.userID) + if event != nil && joinedToRoom { + return event + } + // either we don't have the event or we aren't joined to the room, regardless we should try joining if auto join is enabled + if !rc.req.AutoJoin { + return nil + } + // if we're doing this on behalf of a random server don't auto-join rooms regardless of what the request says + if rc.isFederatedRequest { + return nil + } + roomID := claimedRoomID + var servers []gomatrixserverlib.ServerName + if event != nil { + roomID = event.RoomID() + } + for _, s := range claimedServers { + servers = append(servers, gomatrixserverlib.ServerName(s)) + } + var joinRes roomserver.PerformJoinResponse + rc.rsAPI.PerformJoin(rc.ctx, &roomserver.PerformJoinRequest{ + UserID: rc.userID, + Content: map[string]interface{}{}, + RoomIDOrAlias: roomID, + ServerNames: servers, + }, &joinRes) + if joinRes.Error != nil { + util.GetLogger(rc.ctx).WithError(joinRes.Error).WithField("room_id", roomID).Error("Failed to auto-join room") + return nil + } + if event != nil { + return event + } + // TODO: hit /event_relationships on the server we joined via + util.GetLogger(rc.ctx).Infof("joined room but need to fetch event TODO") + return nil +} + +func getEventIfVisible(ctx context.Context, rsAPI roomserver.RoomserverInternalAPI, eventID, userID string) (*gomatrixserverlib.HeaderedEvent, bool) { + var queryEventsRes roomserver.QueryEventsByIDResponse + err := rsAPI.QueryEventsByID(ctx, &roomserver.QueryEventsByIDRequest{ + EventIDs: []string{eventID}, + }, &queryEventsRes) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("getEventIfVisible: failed to QueryEventsByID") + return nil, false + } + if len(queryEventsRes.Events) == 0 { + util.GetLogger(ctx).Infof("event does not exist") + return nil, false // event does not exist + } + event := queryEventsRes.Events[0] + + // Allow events if the member is in the room + // TODO: This does not honour history_visibility + // TODO: This does not honour m.room.create content + var queryMembershipRes roomserver.QueryMembershipForUserResponse + err = rsAPI.QueryMembershipForUser(ctx, &roomserver.QueryMembershipForUserRequest{ + RoomID: event.RoomID(), + UserID: userID, + }, &queryMembershipRes) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("getEventIfVisible: failed to QueryMembershipForUser") + return nil, false + } + return event, queryMembershipRes.IsInRoom +} + +type walkInfo struct { + eventInfo + SiblingNumber int + Depth int +} + +type walker struct { + ctx context.Context + req *EventRelationshipRequest + db Database + fn func(wi *walkInfo) bool // callback invoked for each event walked, return true to terminate the walk +} + +// WalkFrom the event ID given +func (w *walker) WalkFrom(eventID string) (limited bool, err error) { + children, err := w.db.ChildrenForParent(w.ctx, eventID, constRelType, w.req.RecentFirst) + if err != nil { + util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() ChildrenForParent failed, cannot walk") + return false, err + } + var next *walkInfo + toWalk := w.addChildren(nil, children, 1) + next, toWalk = w.nextChild(toWalk) + for next != nil { + stop := w.fn(next) + if stop { + return true, nil + } + // find the children's children + children, err = w.db.ChildrenForParent(w.ctx, next.EventID, constRelType, w.req.RecentFirst) + if err != nil { + util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() ChildrenForParent failed, cannot walk") + return false, err + } + toWalk = w.addChildren(toWalk, children, next.Depth+1) + next, toWalk = w.nextChild(toWalk) + } + + return false, nil +} + +// addChildren adds an event's children to the to walk data structure +func (w *walker) addChildren(toWalk []walkInfo, children []eventInfo, depthOfChildren int) []walkInfo { + // Check what number child this event is (ordered by recent_first) compared to its parent, does it exceed (greater than) max_breadth? If yes, skip. + if len(children) > w.req.MaxBreadth { + children = children[:w.req.MaxBreadth] + } + // Check how deep the event is compared to event_id, does it exceed (greater than) max_depth? If yes, skip. + if depthOfChildren > w.req.MaxDepth { + return toWalk + } + + if w.req.DepthFirst { + // the slice is a stack so push them in reverse order so we pop them in the correct order + // e.g [3,2,1] => [3,2] , 1 => [3] , 2 => [] , 3 + for i := len(children) - 1; i >= 0; i-- { + toWalk = append(toWalk, walkInfo{ + eventInfo: children[i], + SiblingNumber: i + 1, // index from 1 + Depth: depthOfChildren, + }) + } + } else { + // the slice is a queue so push them in normal order to we dequeue them in the correct order + // e.g [1,2,3] => 1, [2, 3] => 2 , [3] => 3, [] + for i := range children { + toWalk = append(toWalk, walkInfo{ + eventInfo: children[i], + SiblingNumber: i + 1, // index from 1 + Depth: depthOfChildren, + }) + } + } + return toWalk +} + +func (w *walker) nextChild(toWalk []walkInfo) (*walkInfo, []walkInfo) { + if len(toWalk) == 0 { + return nil, nil + } + var child walkInfo + if w.req.DepthFirst { + // toWalk is a stack so pop the child off + child, toWalk = toWalk[len(toWalk)-1], toWalk[:len(toWalk)-1] + return &child, toWalk + } + // toWalk is a queue so shift the child off + child, toWalk = toWalk[0], toWalk[1:] + return &child, toWalk +} diff --git a/internal/mscs/msc2836/msc2836_test.go b/internal/mscs/msc2836/msc2836_test.go new file mode 100644 index 000000000..265d6ee67 --- /dev/null +++ b/internal/mscs/msc2836/msc2836_test.go @@ -0,0 +1,577 @@ +package msc2836_test + +import ( + "bytes" + "context" + "crypto/ed25519" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "testing" + "time" + + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/hooks" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/mscs/msc2836" + "github.com/matrix-org/dendrite/internal/setup" + roomserver "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" +) + +var ( + client = &http.Client{ + Timeout: 10 * time.Second, + } +) + +// Basic sanity check of MSC2836 logic. Injects a thread that looks like: +// A +// | +// B +// / \ +// C D +// /|\ +// E F G +// | +// H +// And makes sure POST /event_relationships works with various parameters +func TestMSC2836(t *testing.T) { + alice := "@alice:localhost" + bob := "@bob:localhost" + charlie := "@charlie:localhost" + roomIDA := "!alice:localhost" + roomIDB := "!bob:localhost" + roomIDC := "!charlie:localhost" + // give access tokens to all three users + nopUserAPI := &testUserAPI{ + accessTokens: make(map[string]userapi.Device), + } + nopUserAPI.accessTokens["alice"] = userapi.Device{ + AccessToken: "alice", + DisplayName: "Alice", + UserID: alice, + } + nopUserAPI.accessTokens["bob"] = userapi.Device{ + AccessToken: "bob", + DisplayName: "Bob", + UserID: bob, + } + nopUserAPI.accessTokens["charlie"] = userapi.Device{ + AccessToken: "charlie", + DisplayName: "Charles", + UserID: charlie, + } + eventA := mustCreateEvent(t, fledglingEvent{ + RoomID: roomIDA, + Sender: alice, + Type: "m.room.message", + Content: map[string]interface{}{ + "body": "[A] Do you know shelties?", + }, + }) + eventB := mustCreateEvent(t, fledglingEvent{ + RoomID: roomIDB, + Sender: bob, + Type: "m.room.message", + Content: map[string]interface{}{ + "body": "[B] I <3 shelties", + "m.relationship": map[string]string{ + "rel_type": "m.reference", + "event_id": eventA.EventID(), + }, + }, + }) + eventC := mustCreateEvent(t, fledglingEvent{ + RoomID: roomIDB, + Sender: bob, + Type: "m.room.message", + Content: map[string]interface{}{ + "body": "[C] like so much", + "m.relationship": map[string]string{ + "rel_type": "m.reference", + "event_id": eventB.EventID(), + }, + }, + }) + eventD := mustCreateEvent(t, fledglingEvent{ + RoomID: roomIDA, + Sender: alice, + Type: "m.room.message", + Content: map[string]interface{}{ + "body": "[D] but what are shelties???", + "m.relationship": map[string]string{ + "rel_type": "m.reference", + "event_id": eventB.EventID(), + }, + }, + }) + eventE := mustCreateEvent(t, fledglingEvent{ + RoomID: roomIDB, + Sender: bob, + Type: "m.room.message", + Content: map[string]interface{}{ + "body": "[E] seriously???", + "m.relationship": map[string]string{ + "rel_type": "m.reference", + "event_id": eventD.EventID(), + }, + }, + }) + eventF := mustCreateEvent(t, fledglingEvent{ + RoomID: roomIDC, + Sender: charlie, + Type: "m.room.message", + Content: map[string]interface{}{ + "body": "[F] omg how do you not know what shelties are", + "m.relationship": map[string]string{ + "rel_type": "m.reference", + "event_id": eventD.EventID(), + }, + }, + }) + eventG := mustCreateEvent(t, fledglingEvent{ + RoomID: roomIDA, + Sender: alice, + Type: "m.room.message", + Content: map[string]interface{}{ + "body": "[G] looked it up, it's a sheltered person?", + "m.relationship": map[string]string{ + "rel_type": "m.reference", + "event_id": eventD.EventID(), + }, + }, + }) + eventH := mustCreateEvent(t, fledglingEvent{ + RoomID: roomIDB, + Sender: bob, + Type: "m.room.message", + Content: map[string]interface{}{ + "body": "[H] it's a dog!!!!!", + "m.relationship": map[string]string{ + "rel_type": "m.reference", + "event_id": eventE.EventID(), + }, + }, + }) + // make everyone joined to each other's rooms + nopRsAPI := &testRoomserverAPI{ + userToJoinedRooms: map[string][]string{ + alice: []string{roomIDA, roomIDB, roomIDC}, + bob: []string{roomIDA, roomIDB, roomIDC}, + charlie: []string{roomIDA, roomIDB, roomIDC}, + }, + events: map[string]*gomatrixserverlib.HeaderedEvent{ + eventA.EventID(): eventA, + eventB.EventID(): eventB, + eventC.EventID(): eventC, + eventD.EventID(): eventD, + eventE.EventID(): eventE, + eventF.EventID(): eventF, + eventG.EventID(): eventG, + eventH.EventID(): eventH, + }, + } + router := injectEvents(t, nopUserAPI, nopRsAPI, []*gomatrixserverlib.HeaderedEvent{ + eventA, eventB, eventC, eventD, eventE, eventF, eventG, eventH, + }) + cancel := runServer(t, router) + defer cancel() + + t.Run("returns 403 on invalid event IDs", func(t *testing.T) { + _ = postRelationships(t, 403, "alice", newReq(t, map[string]interface{}{ + "event_id": "$invalid", + })) + }) + t.Run("returns 403 if not joined to the room of specified event in request", func(t *testing.T) { + nopUserAPI.accessTokens["frank"] = userapi.Device{ + AccessToken: "frank", + DisplayName: "Frank Not In Room", + UserID: "@frank:localhost", + } + _ = postRelationships(t, 403, "frank", newReq(t, map[string]interface{}{ + "event_id": eventB.EventID(), + "limit": 1, + "include_parent": true, + })) + }) + t.Run("omits parent if not joined to the room of parent of event", func(t *testing.T) { + nopUserAPI.accessTokens["frank2"] = userapi.Device{ + AccessToken: "frank2", + DisplayName: "Frank2 Not In Room", + UserID: "@frank2:localhost", + } + // Event B is in roomB, Event A is in roomA, so make frank2 joined to roomB + nopRsAPI.userToJoinedRooms["@frank2:localhost"] = []string{roomIDB} + body := postRelationships(t, 200, "frank2", newReq(t, map[string]interface{}{ + "event_id": eventB.EventID(), + "limit": 1, + "include_parent": true, + })) + assertContains(t, body, []string{eventB.EventID()}) + }) + t.Run("returns the parent if include_parent is true", func(t *testing.T) { + body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventB.EventID(), + "include_parent": true, + "limit": 2, + })) + assertContains(t, body, []string{eventB.EventID(), eventA.EventID()}) + }) + t.Run("returns the children in the right order if include_children is true", func(t *testing.T) { + body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventD.EventID(), + "include_children": true, + "recent_first": true, + "limit": 4, + })) + assertContains(t, body, []string{eventD.EventID(), eventG.EventID(), eventF.EventID(), eventE.EventID()}) + body = postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventD.EventID(), + "include_children": true, + "recent_first": false, + "limit": 4, + })) + assertContains(t, body, []string{eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID()}) + }) + t.Run("walks the graph depth first", func(t *testing.T) { + body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventB.EventID(), + "recent_first": false, + "depth_first": true, + "limit": 6, + })) + // Oldest first so: + // A + // | + // B1 + // / \ + // C2 D3 + // /| \ + // 4E 6F G + // | + // 5H + assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventH.EventID(), eventF.EventID()}) + body = postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventB.EventID(), + "recent_first": true, + "depth_first": true, + "limit": 6, + })) + // Recent first so: + // A + // | + // B1 + // / \ + // C D2 + // /| \ + // E5 F4 G3 + // | + // H6 + assertContains(t, body, []string{eventB.EventID(), eventD.EventID(), eventG.EventID(), eventF.EventID(), eventE.EventID(), eventH.EventID()}) + }) + t.Run("walks the graph breadth first", func(t *testing.T) { + body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventB.EventID(), + "recent_first": false, + "depth_first": false, + "limit": 6, + })) + // Oldest first so: + // A + // | + // B1 + // / \ + // C2 D3 + // /| \ + // E4 F5 G6 + // | + // H + assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID()}) + body = postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventB.EventID(), + "recent_first": true, + "depth_first": false, + "limit": 6, + })) + // Recent first so: + // A + // | + // B1 + // / \ + // C3 D2 + // /| \ + // E6 F5 G4 + // | + // H + assertContains(t, body, []string{eventB.EventID(), eventD.EventID(), eventC.EventID(), eventG.EventID(), eventF.EventID(), eventE.EventID()}) + }) + t.Run("caps via max_breadth", func(t *testing.T) { + body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventB.EventID(), + "recent_first": false, + "depth_first": false, + "max_breadth": 2, + "limit": 10, + })) + // Event G gets omitted because of max_breadth + assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventH.EventID()}) + }) + t.Run("caps via max_depth", func(t *testing.T) { + body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventB.EventID(), + "recent_first": false, + "depth_first": false, + "max_depth": 2, + "limit": 10, + })) + // Event H gets omitted because of max_depth + assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID()}) + }) + t.Run("terminates when reaching the limit", func(t *testing.T) { + body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventB.EventID(), + "recent_first": false, + "depth_first": false, + "limit": 4, + })) + assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID()}) + }) + t.Run("returns all events with a high enough limit", func(t *testing.T) { + body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventB.EventID(), + "recent_first": false, + "depth_first": false, + "limit": 400, + })) + assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID(), eventH.EventID()}) + }) +} + +// TODO: TestMSC2836TerminatesLoops (short and long) +// TODO: TestMSC2836UnknownEventsSkipped +// TODO: TestMSC2836SkipEventIfNotInRoom + +func newReq(t *testing.T, jsonBody map[string]interface{}) *msc2836.EventRelationshipRequest { + t.Helper() + b, err := json.Marshal(jsonBody) + if err != nil { + t.Fatalf("Failed to marshal request: %s", err) + } + r, err := msc2836.NewEventRelationshipRequest(bytes.NewBuffer(b)) + if err != nil { + t.Fatalf("Failed to NewEventRelationshipRequest: %s", err) + } + return r +} + +func runServer(t *testing.T, router *mux.Router) func() { + t.Helper() + externalServ := &http.Server{ + Addr: string(":8009"), + WriteTimeout: 60 * time.Second, + Handler: router, + } + go func() { + externalServ.ListenAndServe() + }() + // wait to listen on the port + time.Sleep(500 * time.Millisecond) + return func() { + externalServ.Shutdown(context.TODO()) + } +} + +func postRelationships(t *testing.T, expectCode int, accessToken string, req *msc2836.EventRelationshipRequest) *msc2836.EventRelationshipResponse { + t.Helper() + var r msc2836.EventRelationshipRequest + r.Defaults() + data, err := json.Marshal(req) + if err != nil { + t.Fatalf("failed to marshal request: %s", err) + } + httpReq, err := http.NewRequest( + "POST", "http://localhost:8009/_matrix/client/unstable/event_relationships", + bytes.NewBuffer(data), + ) + httpReq.Header.Set("Authorization", "Bearer "+accessToken) + if err != nil { + t.Fatalf("failed to prepare request: %s", err) + } + res, err := client.Do(httpReq) + if err != nil { + t.Fatalf("failed to do request: %s", err) + } + if res.StatusCode != expectCode { + body, _ := ioutil.ReadAll(res.Body) + t.Fatalf("wrong response code, got %d want %d - body: %s", res.StatusCode, expectCode, string(body)) + } + if res.StatusCode == 200 { + var result msc2836.EventRelationshipResponse + if err := json.NewDecoder(res.Body).Decode(&result); err != nil { + t.Fatalf("response 200 OK but failed to deserialise JSON : %s", err) + } + return &result + } + return nil +} + +func assertContains(t *testing.T, result *msc2836.EventRelationshipResponse, wantEventIDs []string) { + t.Helper() + gotEventIDs := make([]string, len(result.Events)) + for i, ev := range result.Events { + gotEventIDs[i] = ev.EventID + } + if len(gotEventIDs) != len(wantEventIDs) { + t.Fatalf("length mismatch: got %v want %v", gotEventIDs, wantEventIDs) + } + for i := range gotEventIDs { + if gotEventIDs[i] != wantEventIDs[i] { + t.Errorf("wrong item in position %d - got %s want %s", i, gotEventIDs[i], wantEventIDs[i]) + } + } +} + +type testUserAPI struct { + accessTokens map[string]userapi.Device +} + +func (u *testUserAPI) InputAccountData(ctx context.Context, req *userapi.InputAccountDataRequest, res *userapi.InputAccountDataResponse) error { + return nil +} +func (u *testUserAPI) PerformAccountCreation(ctx context.Context, req *userapi.PerformAccountCreationRequest, res *userapi.PerformAccountCreationResponse) error { + return nil +} +func (u *testUserAPI) PerformPasswordUpdate(ctx context.Context, req *userapi.PerformPasswordUpdateRequest, res *userapi.PerformPasswordUpdateResponse) error { + return nil +} +func (u *testUserAPI) PerformDeviceCreation(ctx context.Context, req *userapi.PerformDeviceCreationRequest, res *userapi.PerformDeviceCreationResponse) error { + return nil +} +func (u *testUserAPI) PerformDeviceDeletion(ctx context.Context, req *userapi.PerformDeviceDeletionRequest, res *userapi.PerformDeviceDeletionResponse) error { + return nil +} +func (u *testUserAPI) PerformDeviceUpdate(ctx context.Context, req *userapi.PerformDeviceUpdateRequest, res *userapi.PerformDeviceUpdateResponse) error { + return nil +} +func (u *testUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.PerformLastSeenUpdateRequest, res *userapi.PerformLastSeenUpdateResponse) error { + return nil +} +func (u *testUserAPI) PerformAccountDeactivation(ctx context.Context, req *userapi.PerformAccountDeactivationRequest, res *userapi.PerformAccountDeactivationResponse) error { + return nil +} +func (u *testUserAPI) QueryProfile(ctx context.Context, req *userapi.QueryProfileRequest, res *userapi.QueryProfileResponse) error { + return nil +} +func (u *testUserAPI) QueryAccessToken(ctx context.Context, req *userapi.QueryAccessTokenRequest, res *userapi.QueryAccessTokenResponse) error { + dev, ok := u.accessTokens[req.AccessToken] + if !ok { + res.Err = fmt.Errorf("unknown token") + return nil + } + res.Device = &dev + return nil +} +func (u *testUserAPI) QueryDevices(ctx context.Context, req *userapi.QueryDevicesRequest, res *userapi.QueryDevicesResponse) error { + return nil +} +func (u *testUserAPI) QueryAccountData(ctx context.Context, req *userapi.QueryAccountDataRequest, res *userapi.QueryAccountDataResponse) error { + return nil +} +func (u *testUserAPI) QueryDeviceInfos(ctx context.Context, req *userapi.QueryDeviceInfosRequest, res *userapi.QueryDeviceInfosResponse) error { + return nil +} +func (u *testUserAPI) QuerySearchProfiles(ctx context.Context, req *userapi.QuerySearchProfilesRequest, res *userapi.QuerySearchProfilesResponse) error { + return nil +} + +type testRoomserverAPI struct { + // use a trace API as it implements method stubs so we don't need to have them here. + // We'll override the functions we care about. + roomserver.RoomserverInternalAPITrace + userToJoinedRooms map[string][]string + events map[string]*gomatrixserverlib.HeaderedEvent +} + +func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver.QueryEventsByIDRequest, res *roomserver.QueryEventsByIDResponse) error { + for _, eventID := range req.EventIDs { + ev := r.events[eventID] + if ev != nil { + res.Events = append(res.Events, ev) + } + } + return nil +} + +func (r *testRoomserverAPI) QueryMembershipForUser(ctx context.Context, req *roomserver.QueryMembershipForUserRequest, res *roomserver.QueryMembershipForUserResponse) error { + rooms := r.userToJoinedRooms[req.UserID] + for _, roomID := range rooms { + if roomID == req.RoomID { + res.IsInRoom = true + res.HasBeenInRoom = true + res.Membership = "join" + break + } + } + return nil +} + +func injectEvents(t *testing.T, userAPI userapi.UserInternalAPI, rsAPI roomserver.RoomserverInternalAPI, events []*gomatrixserverlib.HeaderedEvent) *mux.Router { + t.Helper() + cfg := &config.Dendrite{} + cfg.Defaults() + cfg.Global.ServerName = "localhost" + cfg.MSCs.Database.ConnectionString = "file:msc2836_test.db" + cfg.MSCs.MSCs = []string{"msc2836"} + base := &setup.BaseDendrite{ + Cfg: cfg, + PublicClientAPIMux: mux.NewRouter().PathPrefix(httputil.PublicClientPathPrefix).Subrouter(), + PublicFederationAPIMux: mux.NewRouter().PathPrefix(httputil.PublicFederationPathPrefix).Subrouter(), + } + + err := msc2836.Enable(base, rsAPI, nil, userAPI, nil) + if err != nil { + t.Fatalf("failed to enable MSC2836: %s", err) + } + for _, ev := range events { + hooks.Run(hooks.KindNewEventPersisted, ev) + } + return base.PublicClientAPIMux +} + +type fledglingEvent struct { + Type string + StateKey *string + Content interface{} + Sender string + RoomID string +} + +func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *gomatrixserverlib.HeaderedEvent) { + t.Helper() + roomVer := gomatrixserverlib.RoomVersionV6 + seed := make([]byte, ed25519.SeedSize) // zero seed + key := ed25519.NewKeyFromSeed(seed) + eb := gomatrixserverlib.EventBuilder{ + Sender: ev.Sender, + Depth: 999, + Type: ev.Type, + StateKey: ev.StateKey, + RoomID: ev.RoomID, + } + err := eb.SetContent(ev.Content) + if err != nil { + t.Fatalf("mustCreateEvent: failed to marshal event content %+v", ev.Content) + } + // make sure the origin_server_ts changes so we can test recency + time.Sleep(1 * time.Millisecond) + signedEvent, err := eb.Build(time.Now(), gomatrixserverlib.ServerName("localhost"), "ed25519:test", key, roomVer) + if err != nil { + t.Fatalf("mustCreateEvent: failed to sign event: %s", err) + } + h := signedEvent.Headered(roomVer) + return h +} diff --git a/internal/mscs/msc2836/storage.go b/internal/mscs/msc2836/storage.go new file mode 100644 index 000000000..f524165fa --- /dev/null +++ b/internal/mscs/msc2836/storage.go @@ -0,0 +1,226 @@ +package msc2836 + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +type eventInfo struct { + EventID string + OriginServerTS gomatrixserverlib.Timestamp + RoomID string + Servers []string +} + +type Database interface { + // StoreRelation stores the parent->child and child->parent relationship for later querying. + // Also stores the event metadata e.g timestamp + StoreRelation(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error + // ChildrenForParent returns the events who have the given `eventID` as an m.relationship with the + // provided `relType`. The returned slice is sorted by origin_server_ts according to whether + // `recentFirst` is true or false. + ChildrenForParent(ctx context.Context, eventID, relType string, recentFirst bool) ([]eventInfo, error) +} + +type DB struct { + db *sql.DB + writer sqlutil.Writer + insertEdgeStmt *sql.Stmt + insertNodeStmt *sql.Stmt + selectChildrenForParentOldestFirstStmt *sql.Stmt + selectChildrenForParentRecentFirstStmt *sql.Stmt +} + +// NewDatabase loads the database for msc2836 +func NewDatabase(dbOpts *config.DatabaseOptions) (Database, error) { + if dbOpts.ConnectionString.IsPostgres() { + return newPostgresDatabase(dbOpts) + } + return newSQLiteDatabase(dbOpts) +} + +func newPostgresDatabase(dbOpts *config.DatabaseOptions) (Database, error) { + d := DB{ + writer: sqlutil.NewDummyWriter(), + } + var err error + if d.db, err = sqlutil.Open(dbOpts); err != nil { + return nil, err + } + _, err = d.db.Exec(` + CREATE TABLE IF NOT EXISTS msc2836_edges ( + parent_event_id TEXT NOT NULL, + child_event_id TEXT NOT NULL, + rel_type TEXT NOT NULL, + parent_room_id TEXT NOT NULL, + parent_servers TEXT NOT NULL, + CONSTRAINT msc2836_edges_uniq UNIQUE (parent_event_id, child_event_id, rel_type) + ); + + CREATE TABLE IF NOT EXISTS msc2836_nodes ( + event_id TEXT PRIMARY KEY NOT NULL, + origin_server_ts BIGINT NOT NULL, + room_id TEXT NOT NULL + ); + `) + if err != nil { + return nil, err + } + if d.insertEdgeStmt, err = d.db.Prepare(` + INSERT INTO msc2836_edges(parent_event_id, child_event_id, rel_type, parent_room_id, parent_servers) VALUES($1, $2, $3, $4, $5) ON CONFLICT DO NOTHING + `); err != nil { + return nil, err + } + if d.insertNodeStmt, err = d.db.Prepare(` + INSERT INTO msc2836_nodes(event_id, origin_server_ts, room_id) VALUES($1, $2, $3) ON CONFLICT DO NOTHING + `); err != nil { + return nil, err + } + selectChildrenQuery := ` + SELECT child_event_id, origin_server_ts, room_id FROM msc2836_edges + LEFT JOIN msc2836_nodes ON msc2836_edges.child_event_id = msc2836_nodes.event_id + WHERE parent_event_id = $1 AND rel_type = $2 + ORDER BY origin_server_ts + ` + if d.selectChildrenForParentOldestFirstStmt, err = d.db.Prepare(selectChildrenQuery + "ASC"); err != nil { + return nil, err + } + if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil { + return nil, err + } + return &d, err +} + +func newSQLiteDatabase(dbOpts *config.DatabaseOptions) (Database, error) { + d := DB{ + writer: sqlutil.NewExclusiveWriter(), + } + var err error + if d.db, err = sqlutil.Open(dbOpts); err != nil { + return nil, err + } + _, err = d.db.Exec(` + CREATE TABLE IF NOT EXISTS msc2836_edges ( + parent_event_id TEXT NOT NULL, + child_event_id TEXT NOT NULL, + rel_type TEXT NOT NULL, + parent_room_id TEXT NOT NULL, + parent_servers TEXT NOT NULL, + UNIQUE (parent_event_id, child_event_id, rel_type) + ); + + CREATE TABLE IF NOT EXISTS msc2836_nodes ( + event_id TEXT PRIMARY KEY NOT NULL, + origin_server_ts BIGINT NOT NULL, + room_id TEXT NOT NULL + ); + `) + if err != nil { + return nil, err + } + if d.insertEdgeStmt, err = d.db.Prepare(` + INSERT INTO msc2836_edges(parent_event_id, child_event_id, rel_type, parent_room_id, parent_servers) VALUES($1, $2, $3, $4, $5) ON CONFLICT (parent_event_id, child_event_id, rel_type) DO NOTHING + `); err != nil { + return nil, err + } + if d.insertNodeStmt, err = d.db.Prepare(` + INSERT INTO msc2836_nodes(event_id, origin_server_ts, room_id) VALUES($1, $2, $3) ON CONFLICT DO NOTHING + `); err != nil { + return nil, err + } + selectChildrenQuery := ` + SELECT child_event_id, origin_server_ts, room_id FROM msc2836_edges + LEFT JOIN msc2836_nodes ON msc2836_edges.child_event_id = msc2836_nodes.event_id + WHERE parent_event_id = $1 AND rel_type = $2 + ORDER BY origin_server_ts + ` + if d.selectChildrenForParentOldestFirstStmt, err = d.db.Prepare(selectChildrenQuery + "ASC"); err != nil { + return nil, err + } + if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil { + return nil, err + } + return &d, nil +} + +func (p *DB) StoreRelation(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error { + parent, child, relType := parentChildEventIDs(ev) + if parent == "" || child == "" { + return nil + } + relationRoomID, relationServers := roomIDAndServers(ev) + relationServersJSON, err := json.Marshal(relationServers) + if err != nil { + return err + } + return p.writer.Do(p.db, nil, func(txn *sql.Tx) error { + _, err := txn.Stmt(p.insertEdgeStmt).ExecContext(ctx, parent, child, relType, relationRoomID, string(relationServersJSON)) + if err != nil { + return err + } + _, err = txn.Stmt(p.insertNodeStmt).ExecContext(ctx, ev.EventID(), ev.OriginServerTS(), ev.RoomID()) + return err + }) +} + +func (p *DB) ChildrenForParent(ctx context.Context, eventID, relType string, recentFirst bool) ([]eventInfo, error) { + var rows *sql.Rows + var err error + if recentFirst { + rows, err = p.selectChildrenForParentRecentFirstStmt.QueryContext(ctx, eventID, relType) + } else { + rows, err = p.selectChildrenForParentOldestFirstStmt.QueryContext(ctx, eventID, relType) + } + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + var children []eventInfo + for rows.Next() { + var evInfo eventInfo + if err := rows.Scan(&evInfo.EventID, &evInfo.OriginServerTS, &evInfo.RoomID); err != nil { + return nil, err + } + children = append(children, evInfo) + } + return children, nil +} + +func parentChildEventIDs(ev *gomatrixserverlib.HeaderedEvent) (parent, child, relType string) { + if ev == nil { + return + } + body := struct { + Relationship struct { + RelType string `json:"rel_type"` + EventID string `json:"event_id"` + } `json:"m.relationship"` + }{} + if err := json.Unmarshal(ev.Content(), &body); err != nil { + return + } + if body.Relationship.EventID == "" || body.Relationship.RelType == "" { + return + } + return body.Relationship.EventID, ev.EventID(), body.Relationship.RelType +} + +func roomIDAndServers(ev *gomatrixserverlib.HeaderedEvent) (roomID string, servers []string) { + servers = []string{} + if ev == nil { + return + } + body := struct { + RoomID string `json:"relationship_room_id"` + Servers []string `json:"relationship_servers"` + }{} + if err := json.Unmarshal(ev.Unsigned(), &body); err != nil { + return + } + return body.RoomID, body.Servers +} diff --git a/internal/mscs/mscs.go b/internal/mscs/mscs.go new file mode 100644 index 000000000..0a896ab0f --- /dev/null +++ b/internal/mscs/mscs.go @@ -0,0 +1,42 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package mscs implements Matrix Spec Changes from https://github.com/matrix-org/matrix-doc +package mscs + +import ( + "fmt" + + "github.com/matrix-org/dendrite/internal/mscs/msc2836" + "github.com/matrix-org/dendrite/internal/setup" +) + +// Enable MSCs - returns an error on unknown MSCs +func Enable(base *setup.BaseDendrite, monolith *setup.Monolith) error { + for _, msc := range base.Cfg.MSCs.MSCs { + if err := EnableMSC(base, monolith, msc); err != nil { + return err + } + } + return nil +} + +func EnableMSC(base *setup.BaseDendrite, monolith *setup.Monolith, msc string) error { + switch msc { + case "msc2836": + return msc2836.Enable(base, monolith.RoomserverAPI, monolith.FederationSenderAPI, monolith.UserAPI, monolith.KeyRing) + default: + return fmt.Errorf("EnableMSC: unknown msc '%s'", msc) + } +} diff --git a/internal/setup/flags.go b/internal/setup/flags.go index e4fc58d60..c6ecb5cd1 100644 --- a/internal/setup/flags.go +++ b/internal/setup/flags.go @@ -16,18 +16,28 @@ package setup import ( "flag" + "fmt" + "os" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/config" - "github.com/sirupsen/logrus" ) -var configPath = flag.String("config", "dendrite.yaml", "The path to the config file. For more information, see the config file in this repository.") +var ( + configPath = flag.String("config", "dendrite.yaml", "The path to the config file. For more information, see the config file in this repository.") + version = flag.Bool("version", false, "Shows the current version and exits immediately.") +) // ParseFlags parses the commandline flags and uses them to create a config. func ParseFlags(monolith bool) *config.Dendrite { flag.Parse() + if *version { + fmt.Println(internal.VersionString()) + os.Exit(0) + } + if *configPath == "" { logrus.Fatal("--config must be supplied") } diff --git a/internal/setup/kafka/kafka.go b/internal/setup/kafka/kafka.go index 9855ae156..091025ec7 100644 --- a/internal/setup/kafka/kafka.go +++ b/internal/setup/kafka/kafka.go @@ -17,12 +17,17 @@ func SetupConsumerProducer(cfg *config.Kafka) (sarama.Consumer, sarama.SyncProdu // setupKafka creates kafka consumer/producer pair from the config. func setupKafka(cfg *config.Kafka) (sarama.Consumer, sarama.SyncProducer) { - consumer, err := sarama.NewConsumer(cfg.Addresses, nil) + sCfg := sarama.NewConfig() + sCfg.Producer.MaxMessageBytes = *cfg.MaxMessageBytes + sCfg.Producer.Return.Successes = true + sCfg.Consumer.Fetch.Default = int32(*cfg.MaxMessageBytes) + + consumer, err := sarama.NewConsumer(cfg.Addresses, sCfg) if err != nil { logrus.WithError(err).Panic("failed to start kafka consumer") } - producer, err := sarama.NewSyncProducer(cfg.Addresses, nil) + producer, err := sarama.NewSyncProducer(cfg.Addresses, sCfg) if err != nil { logrus.WithError(err).Panic("failed to setup kafka producers") } diff --git a/internal/transactions/transactions_test.go b/internal/transactions/transactions_test.go index f565e4846..aa837f76c 100644 --- a/internal/transactions/transactions_test.go +++ b/internal/transactions/transactions_test.go @@ -14,6 +14,7 @@ package transactions import ( "net/http" + "strconv" "testing" "github.com/matrix-org/util" @@ -44,8 +45,8 @@ func TestCache(t *testing.T) { for i := 1; i <= 100; i++ { fakeTxnCache.AddTransaction( fakeAccessToken, - fakeTxnID+string(i), - &util.JSONResponse{Code: http.StatusOK, JSON: fakeType{ID: string(i)}}, + fakeTxnID+strconv.Itoa(i), + &util.JSONResponse{Code: http.StatusOK, JSON: fakeType{ID: strconv.Itoa(i)}}, ) } diff --git a/internal/version.go b/internal/version.go index 040ffa32a..f4356b235 100644 --- a/internal/version.go +++ b/internal/version.go @@ -16,8 +16,8 @@ var build string const ( VersionMajor = 0 - VersionMinor = 2 - VersionPatch = 0 + VersionMinor = 3 + VersionPatch = 1 VersionTag = "" // example: "rc1" ) diff --git a/roomserver/acls/acls.go b/roomserver/acls/acls.go index 775b6c73a..b18daa3de 100644 --- a/roomserver/acls/acls.go +++ b/roomserver/acls/acls.go @@ -61,7 +61,7 @@ func NewServerACLs(db ServerACLDatabase) *ServerACLs { continue } if state != nil { - acls.OnServerACLUpdate(&state.Event) + acls.OnServerACLUpdate(state.Event) } } return acls diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 732675407..42e416914 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -153,6 +153,9 @@ type RoomserverInternalAPI interface { response *PerformBackfillResponse, ) error + // PerformForget forgets a rooms history for a specific user + PerformForget(ctx context.Context, req *PerformForgetRequest, resp *PerformForgetResponse) error + // Asks for the default room version as preferred by the server. QueryRoomVersionCapabilities( ctx context.Context, diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go index 5a47a940c..868a41aa7 100644 --- a/roomserver/api/api_trace.go +++ b/roomserver/api/api_trace.go @@ -204,6 +204,16 @@ func (t *RoomserverInternalAPITrace) PerformBackfill( return err } +func (t *RoomserverInternalAPITrace) PerformForget( + ctx context.Context, + req *PerformForgetRequest, + res *PerformForgetResponse, +) error { + err := t.Impl.PerformForget(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("PerformForget req=%+v res=%+v", js(req), js(res)) + return err +} + func (t *RoomserverInternalAPITrace) QueryRoomVersionCapabilities( ctx context.Context, req *QueryRoomVersionCapabilitiesRequest, diff --git a/roomserver/api/input.go b/roomserver/api/input.go index e1a8afa00..8e6e4ac7b 100644 --- a/roomserver/api/input.go +++ b/roomserver/api/input.go @@ -53,7 +53,7 @@ type InputRoomEvent struct { // This controls how the event is processed. Kind Kind `json:"kind"` // The event JSON for the event to add. - Event gomatrixserverlib.HeaderedEvent `json:"event"` + Event *gomatrixserverlib.HeaderedEvent `json:"event"` // List of state event IDs that authenticate this event. // These are likely derived from the "auth_events" JSON key of the event. // But can be different because the "auth_events" key can be incomplete or wrong. diff --git a/roomserver/api/output.go b/roomserver/api/output.go index 1d9510bf2..27b5226f8 100644 --- a/roomserver/api/output.go +++ b/roomserver/api/output.go @@ -99,7 +99,7 @@ const ( // prev_events. type OutputNewRoomEvent struct { // The Event. - Event gomatrixserverlib.HeaderedEvent `json:"event"` + Event *gomatrixserverlib.HeaderedEvent `json:"event"` // Does the event completely rewrite the room state? If so, then AddsStateEventIDs // will contain the entire room state. RewritesState bool `json:"rewrites_state"` @@ -116,7 +116,7 @@ type OutputNewRoomEvent struct { // may decide a bunch of state events on one branch are now valid, so they will be // present in this list. This is useful when trying to maintain the current state of a room // as to do so you need to include both these events and `Event`. - AddStateEvents []gomatrixserverlib.HeaderedEvent `json:"adds_state_events"` + AddStateEvents []*gomatrixserverlib.HeaderedEvent `json:"adds_state_events"` // The state event IDs that were removed from the state of the room by this event. RemovesStateEventIDs []string `json:"removes_state_event_ids"` @@ -173,7 +173,7 @@ type OutputNewRoomEvent struct { // the original event to save space, so you cannot use that slice alone. // Instead, use this function which will add the original event if it is present // in `AddsStateEventIDs`. -func (ore *OutputNewRoomEvent) AddsState() []gomatrixserverlib.HeaderedEvent { +func (ore *OutputNewRoomEvent) AddsState() []*gomatrixserverlib.HeaderedEvent { includeOutputEvent := false for _, id := range ore.AddsStateEventIDs { if id == ore.Event.EventID() { @@ -198,7 +198,7 @@ func (ore *OutputNewRoomEvent) AddsState() []gomatrixserverlib.HeaderedEvent { // should build their current room state up from OutputNewRoomEvents only. type OutputOldRoomEvent struct { // The Event. - Event gomatrixserverlib.HeaderedEvent `json:"event"` + Event *gomatrixserverlib.HeaderedEvent `json:"event"` } // An OutputNewInviteEvent is written whenever an invite becomes active. @@ -208,7 +208,7 @@ type OutputNewInviteEvent struct { // The room version of the invited room. RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` // The "m.room.member" invite event. - Event gomatrixserverlib.HeaderedEvent `json:"event"` + Event *gomatrixserverlib.HeaderedEvent `json:"event"` } // An OutputRetireInviteEvent is written whenever an existing invite is no longer @@ -235,7 +235,7 @@ type OutputRedactedEvent struct { // The event ID that was redacted RedactedEventID string // The value of `unsigned.redacted_because` - the redaction event itself - RedactedBecause gomatrixserverlib.HeaderedEvent + RedactedBecause *gomatrixserverlib.HeaderedEvent } // An OutputNewPeek is written whenever a user starts peeking into a room diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go index 07614cdf8..3f5dcfb3e 100644 --- a/roomserver/api/perform.go +++ b/roomserver/api/perform.go @@ -83,7 +83,8 @@ type PerformJoinRequest struct { type PerformJoinResponse struct { // The room ID, populated on success. - RoomID string `json:"room_id"` + RoomID string `json:"room_id"` + JoinedVia gomatrixserverlib.ServerName // If non-nil, the join request failed. Contains more information why it failed. Error *PerformError } @@ -98,7 +99,7 @@ type PerformLeaveResponse struct { type PerformInviteRequest struct { RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` - Event gomatrixserverlib.HeaderedEvent `json:"event"` + Event *gomatrixserverlib.HeaderedEvent `json:"event"` InviteRoomState []gomatrixserverlib.InviteV2StrippedState `json:"invite_room_state"` SendAsServer string `json:"send_as_server"` TransactionID *TransactionID `json:"transaction_id"` @@ -147,7 +148,7 @@ func (r *PerformBackfillRequest) PrevEventIDs() []string { // PerformBackfillResponse is a response to PerformBackfill. type PerformBackfillResponse struct { // Missing events, arbritrary order. - Events []gomatrixserverlib.HeaderedEvent `json:"events"` + Events []*gomatrixserverlib.HeaderedEvent `json:"events"` } type PerformPublishRequest struct { @@ -181,3 +182,12 @@ type PerformInboundPeekResponse struct { // The event at which this state was captured LatestEvent gomatrixserverlib.HeaderedEvent `json:"latest_event"` } + +// PerformForgetRequest is a request to PerformForget +type PerformForgetRequest struct { + RoomID string `json:"room_id"` + UserID string `json:"user_id"` +} + +type PerformForgetResponse struct{} + diff --git a/roomserver/api/query.go b/roomserver/api/query.go index 3f606509a..9d2346f64 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -50,7 +50,7 @@ type QueryLatestEventsAndStateResponse struct { // This list will be in an arbitrary order. // These are used to set the auth_events when sending an event. // These are used to check whether the event is allowed. - StateEvents []gomatrixserverlib.HeaderedEvent `json:"state_events"` + StateEvents []*gomatrixserverlib.HeaderedEvent `json:"state_events"` // The depth of the latest events. // This is one greater than the maximum depth of the latest events. // This is used to set the depth when sending an event. @@ -80,7 +80,7 @@ type QueryStateAfterEventsResponse struct { PrevEventsExist bool `json:"prev_events_exist"` // The state events requested. // This list will be in an arbitrary order. - StateEvents []gomatrixserverlib.HeaderedEvent `json:"state_events"` + StateEvents []*gomatrixserverlib.HeaderedEvent `json:"state_events"` } type QueryMissingAuthPrevEventsRequest struct { @@ -119,7 +119,7 @@ type QueryEventsByIDResponse struct { // fails to read it from the database then it will fail // the entire request. // This list will be in an arbitrary order. - Events []gomatrixserverlib.HeaderedEvent `json:"events"` + Events []*gomatrixserverlib.HeaderedEvent `json:"events"` } // QueryMembershipForUserRequest is a request to QueryMembership @@ -140,7 +140,9 @@ type QueryMembershipForUserResponse struct { // True if the user is in room. IsInRoom bool `json:"is_in_room"` // The current membership - Membership string + Membership string `json:"membership"` + // True if the user asked to forget this room. + IsRoomForgotten bool `json:"is_room_forgotten"` } // QueryMembershipsForRoomRequest is a request to QueryMembershipsForRoom @@ -160,6 +162,8 @@ type QueryMembershipsForRoomResponse struct { // True if the user has been in room before and has either stayed in it or // left it. HasBeenInRoom bool `json:"has_been_in_room"` + // True if the user asked to forget this room. + IsRoomForgotten bool `json:"is_room_forgotten"` } // QueryServerJoinedToRoomRequest is a request to QueryServerJoinedToRoom @@ -209,7 +213,7 @@ type QueryMissingEventsRequest struct { // QueryMissingEventsResponse is a response to QueryMissingEvents type QueryMissingEventsResponse struct { // Missing events, arbritrary order. - Events []gomatrixserverlib.HeaderedEvent `json:"events"` + Events []*gomatrixserverlib.HeaderedEvent `json:"events"` } // QueryStateAndAuthChainRequest is a request to QueryStateAndAuthChain @@ -238,8 +242,8 @@ type QueryStateAndAuthChainResponse struct { PrevEventsExist bool `json:"prev_events_exist"` // The state and auth chain events that were requested. // The lists will be in an arbitrary order. - StateEvents []gomatrixserverlib.HeaderedEvent `json:"state_events"` - AuthChainEvents []gomatrixserverlib.HeaderedEvent `json:"auth_chain_events"` + StateEvents []*gomatrixserverlib.HeaderedEvent `json:"state_events"` + AuthChainEvents []*gomatrixserverlib.HeaderedEvent `json:"auth_chain_events"` } // QueryRoomVersionCapabilitiesRequest asks for the default room version diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index 80b8f837c..a6ef735ce 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -25,7 +25,7 @@ import ( // SendEvents to the roomserver The events are written with KindNew. func SendEvents( ctx context.Context, rsAPI RoomserverInternalAPI, - kind Kind, events []gomatrixserverlib.HeaderedEvent, + kind Kind, events []*gomatrixserverlib.HeaderedEvent, sendAsServer gomatrixserverlib.ServerName, txnID *TransactionID, ) error { ires := make([]InputRoomEvent, len(events)) @@ -46,7 +46,7 @@ func SendEvents( // marked as `true` in haveEventIDs. func SendEventWithState( ctx context.Context, rsAPI RoomserverInternalAPI, kind Kind, - state *gomatrixserverlib.RespState, event gomatrixserverlib.HeaderedEvent, + state *gomatrixserverlib.RespState, event *gomatrixserverlib.HeaderedEvent, haveEventIDs map[string]bool, ) error { outliers, err := state.Events() @@ -97,7 +97,7 @@ func SendInputRoomEvents( // If we are in the room then the event should be sent using the SendEvents method. func SendInvite( ctx context.Context, - rsAPI RoomserverInternalAPI, inviteEvent gomatrixserverlib.HeaderedEvent, + rsAPI RoomserverInternalAPI, inviteEvent *gomatrixserverlib.HeaderedEvent, inviteRoomState []gomatrixserverlib.InviteV2StrippedState, sendAsServer gomatrixserverlib.ServerName, txnID *TransactionID, ) error { @@ -134,7 +134,7 @@ func GetEvent(ctx context.Context, rsAPI RoomserverInternalAPI, eventID string) if len(res.Events) != 1 { return nil } - return &res.Events[0] + return res.Events[0] } // GetStateEvent returns the current state event in the room or nil. diff --git a/roomserver/auth/auth.go b/roomserver/auth/auth.go index fdcf9f062..aa1d5bc25 100644 --- a/roomserver/auth/auth.go +++ b/roomserver/auth/auth.go @@ -25,7 +25,7 @@ import ( func IsServerAllowed( serverName gomatrixserverlib.ServerName, serverCurrentlyInRoom bool, - authEvents []gomatrixserverlib.Event, + authEvents []*gomatrixserverlib.Event, ) bool { historyVisibility := HistoryVisibilityForRoom(authEvents) @@ -52,7 +52,7 @@ func IsServerAllowed( return false } -func HistoryVisibilityForRoom(authEvents []gomatrixserverlib.Event) string { +func HistoryVisibilityForRoom(authEvents []*gomatrixserverlib.Event) string { // https://matrix.org/docs/spec/client_server/r0.6.0#id87 // By default if no history_visibility is set, or if the value is not understood, the visibility is assumed to be shared. visibility := "shared" @@ -78,7 +78,7 @@ func HistoryVisibilityForRoom(authEvents []gomatrixserverlib.Event) string { return visibility } -func IsAnyUserOnServerWithMembership(serverName gomatrixserverlib.ServerName, authEvents []gomatrixserverlib.Event, wantMembership string) bool { +func IsAnyUserOnServerWithMembership(serverName gomatrixserverlib.ServerName, authEvents []*gomatrixserverlib.Event, wantMembership string) bool { for _, ev := range authEvents { membership, err := ev.Membership() if err != nil || membership != wantMembership { diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go index 3e023d2a7..97b2ddf58 100644 --- a/roomserver/internal/alias.go +++ b/roomserver/internal/alias.go @@ -229,7 +229,7 @@ func (r *RoomserverInternalAPI) sendUpdatedAliasesEvent( // Add auth events authEvents := gomatrixserverlib.NewAuthEvents(nil) for i := range res.StateEvents { - err = authEvents.AddEvent(&res.StateEvents[i].Event) + err = authEvents.AddEvent(res.StateEvents[i].Event) if err != nil { return err } diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index aae3390d1..59ccef3ee 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -27,6 +27,7 @@ type RoomserverInternalAPI struct { *perform.Leaver *perform.Publisher *perform.Backfiller + *perform.Forgetter DB storage.Database Cfg *config.RoomServer Producer sarama.SyncProducer @@ -117,6 +118,9 @@ func (r *RoomserverInternalAPI) SetFederationSenderAPI(fsAPI fsAPI.FederationSen // than trying random servers PreferServers: r.PerspectiveServerNames, } + r.Forgetter = &perform.Forgetter{ + DB: r.DB, + } } func (r *RoomserverInternalAPI) PerformInvite( @@ -148,3 +152,11 @@ func (r *RoomserverInternalAPI) PerformLeave( } return r.WriteOutputEvents(req.RoomID, outputEvents) } + +func (r *RoomserverInternalAPI) PerformForget( + ctx context.Context, + req *api.PerformForgetRequest, + resp *api.PerformForgetResponse, +) error { + return r.Forgetter.PerformForget(ctx, req, resp) +} diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index 0fa89d9c4..1f4215e74 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -31,7 +31,7 @@ import ( func CheckForSoftFail( ctx context.Context, db storage.Database, - event gomatrixserverlib.HeaderedEvent, + event *gomatrixserverlib.HeaderedEvent, stateEventIDs []string, ) (bool, error) { rewritesState := len(stateEventIDs) > 1 @@ -72,7 +72,7 @@ func CheckForSoftFail( } // Work out which of the state events we actually need. - stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event.Unwrap()}) + stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()}) // Load the actual auth events from the database. authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries) @@ -93,7 +93,7 @@ func CheckForSoftFail( func CheckAuthEvents( ctx context.Context, db storage.Database, - event gomatrixserverlib.HeaderedEvent, + event *gomatrixserverlib.HeaderedEvent, authEventIDs []string, ) ([]types.EventNID, error) { // Grab the numeric IDs for the supplied auth state events from the database. @@ -104,7 +104,7 @@ func CheckAuthEvents( authStateEntries = types.DeduplicateStateEntries(authStateEntries) // Work out which of the state events we actually need. - stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event.Unwrap()}) + stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()}) // Load the actual auth events from the database. authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries) @@ -168,7 +168,7 @@ func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID types.EventTypeNID) * if !ok { return nil } - return &event.Event + return event.Event } func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) *gomatrixserverlib.Event { @@ -187,7 +187,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) * if !ok { return nil } - return &event.Event + return event.Event } // loadAuthEvents loads the events needed for authentication from the supplied room state. diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index 4c072e44a..036c717a2 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -5,6 +5,7 @@ import ( "database/sql" "errors" "fmt" + "strings" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/auth" @@ -67,7 +68,7 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam if err != nil { return false, err } - gmslEvents := make([]gomatrixserverlib.Event, len(events)) + gmslEvents := make([]*gomatrixserverlib.Event, len(events)) for i := range events { gmslEvents[i] = events[i].Event } @@ -190,13 +191,13 @@ func StateBeforeEvent(ctx context.Context, db storage.Database, info types.RoomI func LoadEvents( ctx context.Context, db storage.Database, eventNIDs []types.EventNID, -) ([]gomatrixserverlib.Event, error) { +) ([]*gomatrixserverlib.Event, error) { stateEvents, err := db.Events(ctx, eventNIDs) if err != nil { return nil, err } - result := make([]gomatrixserverlib.Event, len(stateEvents)) + result := make([]*gomatrixserverlib.Event, len(stateEvents)) for i := range stateEvents { result[i] = stateEvents[i].Event } @@ -205,7 +206,7 @@ func LoadEvents( func LoadStateEvents( ctx context.Context, db storage.Database, stateEntries []types.StateEntry, -) ([]gomatrixserverlib.Event, error) { +) ([]*gomatrixserverlib.Event, error) { eventNIDs := make([]types.EventNID, len(stateEntries)) for i := range stateEntries { eventNIDs[i] = stateEntries[i].EventNID @@ -222,12 +223,45 @@ func CheckServerAllowedToSeeEvent( if errors.Is(err, sql.ErrNoRows) { return false, nil } - return false, err + return false, fmt.Errorf("roomState.LoadStateAtEvent: %w", err) } - // TODO: We probably want to make it so that we don't have to pull - // out all the state if possible. - stateAtEvent, err := LoadStateEvents(ctx, db, stateEntries) + // Extract all of the event state key NIDs from the room state. + var stateKeyNIDs []types.EventStateKeyNID + for _, entry := range stateEntries { + stateKeyNIDs = append(stateKeyNIDs, entry.EventStateKeyNID) + } + + // Then request those state key NIDs from the database. + stateKeys, err := db.EventStateKeys(ctx, stateKeyNIDs) + if err != nil { + return false, fmt.Errorf("db.EventStateKeys: %w", err) + } + + // If the event state key doesn't match the given servername + // then we'll filter it out. This does preserve state keys that + // are "" since these will contain history visibility etc. + for nid, key := range stateKeys { + if key != "" && !strings.HasSuffix(key, ":"+string(serverName)) { + delete(stateKeys, nid) + } + } + + // Now filter through all of the state events for the room. + // If the state key NID appears in the list of valid state + // keys then we'll add it to the list of filtered entries. + var filteredEntries []types.StateEntry + for _, entry := range stateEntries { + if _, ok := stateKeys[entry.EventStateKeyNID]; ok { + filteredEntries = append(filteredEntries, entry) + } + } + + if len(filteredEntries) == 0 { + return false, nil + } + + stateAtEvent, err := LoadStateEvents(ctx, db, filteredEntries) if err != nil { return false, err } diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index d340ac218..79dc2fe14 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -22,6 +22,7 @@ import ( "time" "github.com/Shopify/sarama" + "github.com/matrix-org/dendrite/internal/hooks" "github.com/matrix-org/dendrite/roomserver/acls" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/storage" @@ -61,7 +62,11 @@ func (w *inputWorker) start() { for { select { case task := <-w.input: + hooks.Run(hooks.KindNewEventReceived, &task.event.Event) _, task.err = w.r.processRoomEvent(task.ctx, task.event) + if task.err == nil { + hooks.Run(hooks.KindNewEventPersisted, &task.event.Event) + } task.wg.Done() case <-time.After(time.Second * 5): return @@ -92,7 +97,7 @@ func (r *Inputer) WriteOutputEvents(roomID string, updates []api.OutputEvent) er }) if updates[i].NewRoomEvent.Event.Type() == "m.room.server_acl" && updates[i].NewRoomEvent.Event.StateKeyEquals("") { ev := updates[i].NewRoomEvent.Event.Unwrap() - defer r.ACLs.OnServerACLUpdate(&ev) + defer r.ACLs.OnServerACLUpdate(ev) } } logger.Infof("Producing to topic '%s'", r.OutputRoomEventTopic) @@ -102,7 +107,13 @@ func (r *Inputer) WriteOutputEvents(roomID string, updates []api.OutputEvent) er Value: sarama.ByteEncoder(value), } } - return r.Producer.SendMessages(messages) + errs := r.Producer.SendMessages(messages) + if errs != nil { + for _, err := range errs.(sarama.ProducerErrors) { + log.WithError(err).WithField("message_bytes", err.Msg.Value.Length()).Error("Write to kafka failed") + } + } + return errs } // InputRoomEvents implements api.RoomserverInternalAPI diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index c055289c9..d62621c24 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -111,11 +111,11 @@ func (r *Inputer) processRoomEvent( // if storing this event results in it being redacted then do so. if !isRejected && redactedEventID == event.EventID() { - r, rerr := eventutil.RedactEvent(redactionEvent, &event) + r, rerr := eventutil.RedactEvent(redactionEvent, event) if rerr != nil { return "", fmt.Errorf("eventutil.RedactEvent: %w", rerr) } - event = *r + event = r } // For outliers we can stop after we've stored the event itself as it @@ -215,7 +215,7 @@ func (r *Inputer) calculateAndSetState( input *api.InputRoomEvent, roomInfo types.RoomInfo, stateAtEvent *types.StateAtEvent, - event gomatrixserverlib.Event, + event *gomatrixserverlib.Event, isRejected bool, ) error { var err error diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go index 2bf6b9f8a..3608ef4b0 100644 --- a/roomserver/internal/input/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -50,7 +50,7 @@ func (r *Inputer) updateLatestEvents( ctx context.Context, roomInfo *types.RoomInfo, stateAtEvent types.StateAtEvent, - event gomatrixserverlib.Event, + event *gomatrixserverlib.Event, sendAsServer string, transactionID *api.TransactionID, rewritesState bool, @@ -92,7 +92,7 @@ type latestEventsUpdater struct { updater *shared.LatestEventsUpdater roomInfo *types.RoomInfo stateAtEvent types.StateAtEvent - event gomatrixserverlib.Event + event *gomatrixserverlib.Event transactionID *api.TransactionID rewritesState bool // Which server to send this event as. @@ -140,7 +140,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { // Work out what the latest events are. This will include the new // event if it is not already referenced. extremitiesChanged, err := u.calculateLatest( - oldLatest, &u.event, + oldLatest, u.event, types.StateAtEventAndReference{ EventReference: u.event.EventReference(), StateAtEvent: u.stateAtEvent, @@ -373,7 +373,7 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) // extraEventsForIDs returns the full events for the event IDs given, but does not include the current event being // updated. -func (u *latestEventsUpdater) extraEventsForIDs(roomVersion gomatrixserverlib.RoomVersion, eventIDs []string) ([]gomatrixserverlib.HeaderedEvent, error) { +func (u *latestEventsUpdater) extraEventsForIDs(roomVersion gomatrixserverlib.RoomVersion, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { var extraEventIDs []string for _, e := range eventIDs { if e == u.event.EventID() { @@ -388,7 +388,7 @@ func (u *latestEventsUpdater) extraEventsForIDs(roomVersion gomatrixserverlib.Ro if err != nil { return nil, err } - var h []gomatrixserverlib.HeaderedEvent + var h []*gomatrixserverlib.HeaderedEvent for _, e := range extraEvents { h = append(h, e.Headered(roomVersion)) } diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go index 8befcd647..692d8147a 100644 --- a/roomserver/internal/input/input_membership.go +++ b/roomserver/internal/input/input_membership.go @@ -62,13 +62,13 @@ func (r *Inputer) updateMemberships( if change.removedEventNID != 0 { ev, _ := helpers.EventMap(events).Lookup(change.removedEventNID) if ev != nil { - re = &ev.Event + re = ev.Event } } if change.addedEventNID != 0 { ev, _ := helpers.EventMap(events).Lookup(change.addedEventNID) if ev != nil { - ae = &ev.Event + ae = ev.Event } } if updates, err = r.updateMembership(updater, targetUserNID, re, ae, updates); err != nil { diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index d90ac8fcc..eb47ac218 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -83,7 +83,7 @@ func (r *Backfiller) PerformBackfill( } // Retrieve events from the list that was filled previously. - var loadedEvents []gomatrixserverlib.Event + var loadedEvents []*gomatrixserverlib.Event loadedEvents, err = helpers.LoadEvents(ctx, r.DB, resultNIDs) if err != nil { return err @@ -211,10 +211,10 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom } } - var newEvents []gomatrixserverlib.HeaderedEvent + var newEvents []*gomatrixserverlib.HeaderedEvent for _, ev := range missingMap { if ev != nil { - newEvents = append(newEvents, *ev) + newEvents = append(newEvents, ev) } } util.GetLogger(ctx).Infof("Persisting %d new events", len(newEvents)) @@ -232,7 +232,7 @@ type backfillRequester struct { // per-request state servers []gomatrixserverlib.ServerName eventIDToBeforeStateIDs map[string][]string - eventIDMap map[string]gomatrixserverlib.Event + eventIDMap map[string]*gomatrixserverlib.Event } func newBackfillRequester( @@ -248,13 +248,13 @@ func newBackfillRequester( fsAPI: fsAPI, thisServer: thisServer, eventIDToBeforeStateIDs: make(map[string][]string), - eventIDMap: make(map[string]gomatrixserverlib.Event), + eventIDMap: make(map[string]*gomatrixserverlib.Event), bwExtrems: bwExtrems, preferServer: preferServer, } } -func (b *backfillRequester) StateIDsBeforeEvent(ctx context.Context, targetEvent gomatrixserverlib.HeaderedEvent) ([]string, error) { +func (b *backfillRequester) StateIDsBeforeEvent(ctx context.Context, targetEvent *gomatrixserverlib.HeaderedEvent) ([]string, error) { b.eventIDMap[targetEvent.EventID()] = targetEvent.Unwrap() if ids, ok := b.eventIDToBeforeStateIDs[targetEvent.EventID()]; ok { return ids, nil @@ -305,7 +305,7 @@ FederationHit: return nil, lastErr } -func (b *backfillRequester) calculateNewStateIDs(targetEvent, prevEvent gomatrixserverlib.Event, prevEventStateIDs []string) []string { +func (b *backfillRequester) calculateNewStateIDs(targetEvent, prevEvent *gomatrixserverlib.Event, prevEventStateIDs []string) []string { newStateIDs := prevEventStateIDs[:] if prevEvent.StateKey() == nil { // state is the same as the previous event @@ -343,7 +343,7 @@ func (b *backfillRequester) calculateNewStateIDs(targetEvent, prevEvent gomatrix } func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatrixserverlib.RoomVersion, - event gomatrixserverlib.HeaderedEvent, eventIDs []string) (map[string]*gomatrixserverlib.Event, error) { + event *gomatrixserverlib.HeaderedEvent, eventIDs []string) (map[string]*gomatrixserverlib.Event, error) { // try to fetch the events from the database first events, err := b.ProvideEvents(roomVer, eventIDs) @@ -355,7 +355,7 @@ func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatr if len(events) == len(eventIDs) { result := make(map[string]*gomatrixserverlib.Event) for i := range events { - result[events[i].EventID()] = &events[i] + result[events[i].EventID()] = events[i] b.eventIDMap[events[i].EventID()] = events[i] } return result, nil @@ -372,7 +372,7 @@ func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatr return nil, err } for eventID, ev := range result { - b.eventIDMap[eventID] = *ev + b.eventIDMap[eventID] = ev } return result, nil } @@ -426,7 +426,7 @@ FindSuccessor: } // possibly return all joined servers depending on history visiblity - memberEventsFromVis, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries) + memberEventsFromVis, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries, b.thisServer) if err != nil { logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules") return nil @@ -476,7 +476,7 @@ func (b *backfillRequester) Backfill(ctx context.Context, server gomatrixserverl return tx, err } -func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion, eventIDs []string) ([]gomatrixserverlib.Event, error) { +func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion, eventIDs []string) ([]*gomatrixserverlib.Event, error) { ctx := context.Background() nidMap, err := b.db.EventNIDs(ctx, eventIDs) if err != nil { @@ -494,18 +494,19 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion, logrus.WithError(err).WithField("event_nids", eventNIDs).Error("Failed to load events") return nil, err } - events := make([]gomatrixserverlib.Event, len(eventsWithNids)) + events := make([]*gomatrixserverlib.Event, len(eventsWithNids)) for i := range eventsWithNids { events[i] = eventsWithNids[i].Event } return events, nil } -// joinEventsFromHistoryVisibility returns all CURRENTLY joined members if the provided state indicated a 'shared' history visibility. +// joinEventsFromHistoryVisibility returns all CURRENTLY joined members if our server can read the room history // TODO: Long term we probably want a history_visibility table which stores eventNID | visibility_enum so we can just // pull all events and then filter by that table. func joinEventsFromHistoryVisibility( - ctx context.Context, db storage.Database, roomID string, stateEntries []types.StateEntry) ([]types.Event, error) { + ctx context.Context, db storage.Database, roomID string, stateEntries []types.StateEntry, + thisServer gomatrixserverlib.ServerName) ([]types.Event, error) { var eventNIDs []types.EventNID for _, entry := range stateEntries { @@ -521,13 +522,15 @@ func joinEventsFromHistoryVisibility( if err != nil { return nil, err } - events := make([]gomatrixserverlib.Event, len(stateEvents)) + events := make([]*gomatrixserverlib.Event, len(stateEvents)) for i := range stateEvents { events[i] = stateEvents[i].Event } - visibility := auth.HistoryVisibilityForRoom(events) - if visibility != "shared" { - logrus.Infof("ServersAtEvent history visibility not shared: %s", visibility) + + // Can we see events in the room? + canSeeEvents := auth.IsServerAllowed(thisServer, true, events) + if !canSeeEvents { + logrus.Infof("ServersAtEvent history not visible to us: %s", auth.HistoryVisibilityForRoom(events)) return nil, nil } // get joined members @@ -542,7 +545,7 @@ func joinEventsFromHistoryVisibility( return db.Events(ctx, joinEventNIDs) } -func persistEvents(ctx context.Context, db storage.Database, events []gomatrixserverlib.HeaderedEvent) (types.RoomNID, map[string]types.Event) { +func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixserverlib.HeaderedEvent) (types.RoomNID, map[string]types.Event) { var roomNID types.RoomNID backfilledEventMap := make(map[string]types.Event) for j, ev := range events { @@ -570,7 +573,7 @@ func persistEvents(ctx context.Context, db storage.Database, events []gomatrixse // redacted, which we don't care about since we aren't returning it in this backfill. if redactedEventID == ev.EventID() { eventToRedact := ev.Unwrap() - redactedEvent, err := eventutil.RedactEvent(redactionEvent, &eventToRedact) + redactedEvent, err := eventutil.RedactEvent(redactionEvent, eventToRedact) if err != nil { logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to redact event") continue diff --git a/roomserver/internal/perform/perform_forget.go b/roomserver/internal/perform/perform_forget.go new file mode 100644 index 000000000..e970d9a88 --- /dev/null +++ b/roomserver/internal/perform/perform_forget.go @@ -0,0 +1,35 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package perform + +import ( + "context" + + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/storage" +) + +type Forgetter struct { + DB storage.Database +} + +// PerformForget implements api.RoomServerQueryAPI +func (f *Forgetter) PerformForget( + ctx context.Context, + request *api.PerformForgetRequest, + response *api.PerformForgetResponse, +) error { + return f.DB.ForgetRoom(ctx, request.UserID, request.RoomID, true) +} diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 734e73d43..0fb6ddd44 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -86,7 +86,7 @@ func (r *Inviter) PerformInvite( var isAlreadyJoined bool if info != nil { - _, isAlreadyJoined, err = r.DB.GetMembership(ctx, info.RoomNID, *event.StateKey()) + _, isAlreadyJoined, _, err = r.DB.GetMembership(ctx, info.RoomNID, *event.StateKey()) if err != nil { return nil, fmt.Errorf("r.DB.GetMembership: %w", err) } @@ -198,7 +198,7 @@ func (r *Inviter) PerformInvite( } unwrapped := event.Unwrap() - outputUpdates, err := helpers.UpdateToInviteMembership(updater, &unwrapped, nil, req.Event.RoomVersion) + outputUpdates, err := helpers.UpdateToInviteMembership(updater, unwrapped, nil, req.Event.RoomVersion) if err != nil { return nil, fmt.Errorf("updateToInviteMembership: %w", err) } @@ -248,11 +248,11 @@ func buildInviteStrippedState( return nil, err } inviteState := []gomatrixserverlib.InviteV2StrippedState{ - gomatrixserverlib.NewInviteV2StrippedState(&input.Event.Event), + gomatrixserverlib.NewInviteV2StrippedState(input.Event.Event), } stateEvents = append(stateEvents, types.Event{Event: input.Event.Unwrap()}) for _, event := range stateEvents { - inviteState = append(inviteState, gomatrixserverlib.NewInviteV2StrippedState(&event.Event)) + inviteState = append(inviteState, gomatrixserverlib.NewInviteV2StrippedState(event.Event)) } return inviteState, nil } diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 56ae6d0b1..f3745a7ff 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -47,7 +47,7 @@ func (r *Joiner) PerformJoin( req *api.PerformJoinRequest, res *api.PerformJoinResponse, ) { - roomID, err := r.performJoin(ctx, req) + roomID, joinedVia, err := r.performJoin(ctx, req) if err != nil { perr, ok := err.(*api.PerformError) if ok { @@ -59,21 +59,22 @@ func (r *Joiner) PerformJoin( } } res.RoomID = roomID + res.JoinedVia = joinedVia } func (r *Joiner) performJoin( ctx context.Context, req *api.PerformJoinRequest, -) (string, error) { +) (string, gomatrixserverlib.ServerName, error) { _, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { - return "", &api.PerformError{ + return "", "", &api.PerformError{ Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID), } } if domain != r.Cfg.Matrix.ServerName { - return "", &api.PerformError{ + return "", "", &api.PerformError{ Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID), } @@ -84,7 +85,7 @@ func (r *Joiner) performJoin( if strings.HasPrefix(req.RoomIDOrAlias, "#") { return r.performJoinRoomByAlias(ctx, req) } - return "", &api.PerformError{ + return "", "", &api.PerformError{ Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("Room ID or alias %q is invalid", req.RoomIDOrAlias), } @@ -93,11 +94,11 @@ func (r *Joiner) performJoin( func (r *Joiner) performJoinRoomByAlias( ctx context.Context, req *api.PerformJoinRequest, -) (string, error) { +) (string, gomatrixserverlib.ServerName, error) { // Get the domain part of the room alias. _, domain, err := gomatrixserverlib.SplitID('#', req.RoomIDOrAlias) if err != nil { - return "", fmt.Errorf("Alias %q is not in the correct format", req.RoomIDOrAlias) + return "", "", fmt.Errorf("Alias %q is not in the correct format", req.RoomIDOrAlias) } req.ServerNames = append(req.ServerNames, domain) @@ -115,7 +116,7 @@ func (r *Joiner) performJoinRoomByAlias( err = r.FSAPI.PerformDirectoryLookup(ctx, &dirReq, &dirRes) if err != nil { logrus.WithError(err).Errorf("error looking up alias %q", req.RoomIDOrAlias) - return "", fmt.Errorf("Looking up alias %q over federation failed: %w", req.RoomIDOrAlias, err) + return "", "", fmt.Errorf("Looking up alias %q over federation failed: %w", req.RoomIDOrAlias, err) } roomID = dirRes.RoomID req.ServerNames = append(req.ServerNames, dirRes.ServerNames...) @@ -123,13 +124,13 @@ func (r *Joiner) performJoinRoomByAlias( // Otherwise, look up if we know this room alias locally. roomID, err = r.DB.GetRoomIDForAlias(ctx, req.RoomIDOrAlias) if err != nil { - return "", fmt.Errorf("Lookup room alias %q failed: %w", req.RoomIDOrAlias, err) + return "", "", fmt.Errorf("Lookup room alias %q failed: %w", req.RoomIDOrAlias, err) } } // If the room ID is empty then we failed to look up the alias. if roomID == "" { - return "", fmt.Errorf("Alias %q not found", req.RoomIDOrAlias) + return "", "", fmt.Errorf("Alias %q not found", req.RoomIDOrAlias) } // If we do, then pluck out the room ID and continue the join. @@ -142,11 +143,11 @@ func (r *Joiner) performJoinRoomByAlias( func (r *Joiner) performJoinRoomByID( ctx context.Context, req *api.PerformJoinRequest, -) (string, error) { +) (string, gomatrixserverlib.ServerName, error) { // Get the domain part of the room ID. _, domain, err := gomatrixserverlib.SplitID('!', req.RoomIDOrAlias) if err != nil { - return "", &api.PerformError{ + return "", "", &api.PerformError{ Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("Room ID %q is invalid: %s", req.RoomIDOrAlias, err), } @@ -169,7 +170,7 @@ func (r *Joiner) performJoinRoomByID( Redacts: "", } if err = eb.SetUnsigned(struct{}{}); err != nil { - return "", fmt.Errorf("eb.SetUnsigned: %w", err) + return "", "", fmt.Errorf("eb.SetUnsigned: %w", err) } // It is possible for the request to include some "content" for the @@ -180,7 +181,7 @@ func (r *Joiner) performJoinRoomByID( } req.Content["membership"] = gomatrixserverlib.Join if err = eb.SetContent(req.Content); err != nil { - return "", fmt.Errorf("eb.SetContent: %w", err) + return "", "", fmt.Errorf("eb.SetContent: %w", err) } // Force a federated join if we aren't in the room and we've been @@ -194,7 +195,7 @@ func (r *Joiner) performJoinRoomByID( if err == nil && isInvitePending { _, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender) if ierr != nil { - return "", fmt.Errorf("gomatrixserverlib.SplitID: %w", err) + return "", "", fmt.Errorf("gomatrixserverlib.SplitID: %w", err) } // If we were invited by someone from another server then we can @@ -206,8 +207,10 @@ func (r *Joiner) performJoinRoomByID( } // If we should do a forced federated join then do that. + var joinedVia gomatrixserverlib.ServerName if forceFederatedJoin { - return req.RoomIDOrAlias, r.performFederatedJoinRoomByID(ctx, req) + joinedVia, err = r.performFederatedJoinRoomByID(ctx, req) + return req.RoomIDOrAlias, joinedVia, err } // Try to construct an actual join event from the template. @@ -249,7 +252,7 @@ func (r *Joiner) performJoinRoomByID( inputRes := api.InputRoomEventsResponse{} r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes) if err = inputRes.Err(); err != nil { - return "", &api.PerformError{ + return "", "", &api.PerformError{ Code: api.PerformErrorNotAllowed, Msg: fmt.Sprintf("InputRoomEvents auth failed: %s", err), } @@ -265,7 +268,7 @@ func (r *Joiner) performJoinRoomByID( // Otherwise we'll try a federated join as normal, since it's quite // possible that the room still exists on other servers. if len(req.ServerNames) == 0 { - return "", &api.PerformError{ + return "", "", &api.PerformError{ Code: api.PerformErrorNoRoom, Msg: fmt.Sprintf("Room ID %q does not exist", req.RoomIDOrAlias), } @@ -273,24 +276,25 @@ func (r *Joiner) performJoinRoomByID( } // Perform a federated room join. - return req.RoomIDOrAlias, r.performFederatedJoinRoomByID(ctx, req) + joinedVia, err = r.performFederatedJoinRoomByID(ctx, req) + return req.RoomIDOrAlias, joinedVia, err default: // Something else went wrong. - return "", fmt.Errorf("Error joining local room: %q", err) + return "", "", fmt.Errorf("Error joining local room: %q", err) } // By this point, if req.RoomIDOrAlias contained an alias, then // it will have been overwritten with a room ID by performJoinRoomByAlias. // We should now include this in the response so that the CS API can // return the right room ID. - return req.RoomIDOrAlias, nil + return req.RoomIDOrAlias, r.Cfg.Matrix.ServerName, nil } func (r *Joiner) performFederatedJoinRoomByID( ctx context.Context, req *api.PerformJoinRequest, -) error { +) (gomatrixserverlib.ServerName, error) { // Try joining by all of the supplied server names. fedReq := fsAPI.PerformJoinRequest{ RoomID: req.RoomIDOrAlias, // the room ID to try and join @@ -301,13 +305,13 @@ func (r *Joiner) performFederatedJoinRoomByID( fedRes := fsAPI.PerformJoinResponse{} r.FSAPI.PerformJoin(ctx, &fedReq, &fedRes) if fedRes.LastError != nil { - return &api.PerformError{ + return "", &api.PerformError{ Code: api.PerformErrRemote, Msg: fedRes.LastError.Message, RemoteCode: fedRes.LastError.Code, } } - return nil + return fedRes.JoinedVia, nil } func buildEvent( diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 8123353ce..aef6fa3f3 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -204,11 +204,13 @@ func (r *Queryer) QueryMembershipForUser( return fmt.Errorf("QueryMembershipForUser: unknown room %s", request.RoomID) } - membershipEventNID, stillInRoom, err := r.DB.GetMembership(ctx, info.RoomNID, request.UserID) + membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.UserID) if err != nil { return err } + response.IsRoomForgotten = isRoomforgotten + if membershipEventNID == 0 { response.HasBeenInRoom = false return nil @@ -241,11 +243,13 @@ func (r *Queryer) QueryMembershipsForRoom( return err } - membershipEventNID, stillInRoom, err := r.DB.GetMembership(ctx, info.RoomNID, request.Sender) + membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.Sender) if err != nil { return err } + response.IsRoomForgotten = isRoomforgotten + if membershipEventNID == 0 { response.HasBeenInRoom = false response.JoinEvents = nil @@ -412,7 +416,7 @@ func (r *Queryer) QueryMissingEvents( return err } - response.Events = make([]gomatrixserverlib.HeaderedEvent, 0, len(loadedEvents)-len(eventsToFilter)) + response.Events = make([]*gomatrixserverlib.HeaderedEvent, 0, len(loadedEvents)-len(eventsToFilter)) for _, event := range loadedEvents { if !eventsToFilter[event.EventID()] { roomVersion, verr := r.roomVersion(event.RoomID()) @@ -483,7 +487,7 @@ func (r *Queryer) QueryStateAndAuthChain( return err } -func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo types.RoomInfo, eventIDs []string) ([]gomatrixserverlib.Event, error) { +func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo types.RoomInfo, eventIDs []string) ([]*gomatrixserverlib.Event, error) { roomState := state.NewStateResolution(r.DB, roomInfo) prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs) if err != nil { @@ -514,13 +518,13 @@ type eventsFromIDs func(context.Context, []string) ([]types.Event, error) // given events. Will *not* error if we don't have all auth events. func GetAuthChain( ctx context.Context, fn eventsFromIDs, authEventIDs []string, -) ([]gomatrixserverlib.Event, error) { +) ([]*gomatrixserverlib.Event, error) { // List of event IDs to fetch. On each pass, these events will be requested // from the database and the `eventsToFetch` will be updated with any new // events that we have learned about and need to find. When `eventsToFetch` // is eventually empty, we should have reached the end of the chain. eventsToFetch := authEventIDs - authEventsMap := make(map[string]gomatrixserverlib.Event) + authEventsMap := make(map[string]*gomatrixserverlib.Event) for len(eventsToFetch) > 0 { // Try to retrieve the events from the database. @@ -551,7 +555,7 @@ func GetAuthChain( // We've now retrieved all of the events we can. Flatten them down into an // array and return them. - var authEvents []gomatrixserverlib.Event + var authEvents []*gomatrixserverlib.Event for _, event := range authEventsMap { authEvents = append(authEvents, event) } diff --git a/roomserver/internal/query/query_test.go b/roomserver/internal/query/query_test.go index b4cb99b85..4e761d8ec 100644 --- a/roomserver/internal/query/query_test.go +++ b/roomserver/internal/query/query_test.go @@ -26,12 +26,12 @@ import ( // used to implement RoomserverInternalAPIEventDB to test getAuthChain type getEventDB struct { - eventMap map[string]gomatrixserverlib.Event + eventMap map[string]*gomatrixserverlib.Event } func createEventDB() *getEventDB { return &getEventDB{ - eventMap: make(map[string]gomatrixserverlib.Event), + eventMap: make(map[string]*gomatrixserverlib.Event), } } diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index 756a10d18..96fe84536 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -32,6 +32,7 @@ const ( RoomserverPerformBackfillPath = "/roomserver/performBackfill" RoomserverPerformPublishPath = "/roomserver/performPublish" RoomserverPerformInboundPeekPath = "/roomserver/performInboundPeek" + RoomserverPerformForgetPath = "/roomserver/performForget" // Query operations RoomserverQueryLatestEventsAndStatePath = "/roomserver/queryLatestEventsAndState" @@ -506,3 +507,12 @@ func (h *httpRoomserverInternalAPI) QueryServerBannedFromRoom( apiURL := h.roomserverURL + RoomserverQueryServerBannedFromRoomPath return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) } + +func (h *httpRoomserverInternalAPI) PerformForget(ctx context.Context, req *api.PerformForgetRequest, res *api.PerformForgetResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformForget") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverPerformForgetPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) + +} diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go index 566e4eb48..1649060c2 100644 --- a/roomserver/inthttp/server.go +++ b/roomserver/inthttp/server.go @@ -264,6 +264,20 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle( + RoomserverPerformForgetPath, + httputil.MakeInternalAPI("PerformForget", func(req *http.Request) util.JSONResponse { + var request api.PerformForgetRequest + var response api.PerformForgetResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.PerformForget(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) internalAPIMux.Handle( RoomserverQueryRoomVersionCapabilitiesPath, httputil.MakeInternalAPI("QueryRoomVersionCapabilities", func(req *http.Request) util.JSONResponse { diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 41cbd2637..7638f3f3e 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -94,7 +94,7 @@ type fledglingEvent struct { RoomID string } -func mustCreateEvents(t *testing.T, roomVer gomatrixserverlib.RoomVersion, events []fledglingEvent) (result []gomatrixserverlib.HeaderedEvent) { +func mustCreateEvents(t *testing.T, roomVer gomatrixserverlib.RoomVersion, events []fledglingEvent) (result []*gomatrixserverlib.HeaderedEvent) { t.Helper() depth := int64(1) seed := make([]byte, ed25519.SeedSize) // zero seed @@ -143,16 +143,15 @@ func mustCreateEvents(t *testing.T, roomVer gomatrixserverlib.RoomVersion, event return } -func mustLoadRawEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []json.RawMessage) []gomatrixserverlib.HeaderedEvent { +func mustLoadRawEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []json.RawMessage) []*gomatrixserverlib.HeaderedEvent { t.Helper() - hs := make([]gomatrixserverlib.HeaderedEvent, len(events)) + hs := make([]*gomatrixserverlib.HeaderedEvent, len(events)) for i := range events { e, err := gomatrixserverlib.NewEventFromTrustedJSON(events[i], false, ver) if err != nil { t.Fatalf("cannot load test data: " + err.Error()) } - h := e.Headered(ver) - hs[i] = h + hs[i] = e.Headered(ver) } return hs } @@ -187,7 +186,7 @@ func mustCreateRoomserverAPI(t *testing.T) (api.RoomserverInternalAPI, *dummyPro ), dp } -func mustSendEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []json.RawMessage) (api.RoomserverInternalAPI, *dummyProducer, []gomatrixserverlib.HeaderedEvent) { +func mustSendEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []json.RawMessage) (api.RoomserverInternalAPI, *dummyProducer, []*gomatrixserverlib.HeaderedEvent) { t.Helper() rsAPI, dp := mustCreateRoomserverAPI(t) hevents := mustLoadRawEvents(t, ver, events) diff --git a/roomserver/state/state.go b/roomserver/state/state.go index d23f14c84..87715af42 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -522,7 +522,7 @@ func init() { // Returns a numeric ID for the snapshot of the state before the event. func (v StateResolution) CalculateAndStoreStateBeforeEvent( ctx context.Context, - event gomatrixserverlib.Event, + event *gomatrixserverlib.Event, isRejected bool, ) (types.StateSnapshotNID, error) { // Load the state at the prev events. @@ -689,17 +689,17 @@ func (v StateResolution) calculateStateAfterManyEvents( // TODO: Some of this can possibly be deduplicated func ResolveConflictsAdhoc( version gomatrixserverlib.RoomVersion, - events []gomatrixserverlib.Event, - authEvents []gomatrixserverlib.Event, -) ([]gomatrixserverlib.Event, error) { + events []*gomatrixserverlib.Event, + authEvents []*gomatrixserverlib.Event, +) ([]*gomatrixserverlib.Event, error) { type stateKeyTuple struct { Type string StateKey string } // Prepare our data structures. - eventMap := make(map[stateKeyTuple][]gomatrixserverlib.Event) - var conflicted, notConflicted, resolved []gomatrixserverlib.Event + eventMap := make(map[stateKeyTuple][]*gomatrixserverlib.Event) + var conflicted, notConflicted, resolved []*gomatrixserverlib.Event // Run through all of the events that we were given and sort them // into a map, sorted by (event_type, state_key) tuple. This means @@ -868,15 +868,15 @@ func (v StateResolution) resolveConflictsV2( // For each conflicted event, we will add a new set of auth events. Auth // events may be duplicated across these sets but that's OK. - authSets := make(map[string][]gomatrixserverlib.Event) - var authEvents []gomatrixserverlib.Event - var authDifference []gomatrixserverlib.Event + authSets := make(map[string][]*gomatrixserverlib.Event) + var authEvents []*gomatrixserverlib.Event + var authDifference []*gomatrixserverlib.Event // For each conflicted event, let's try and get the needed auth events. for _, conflictedEvent := range conflictedEvents { // Work out which auth events we need to load. key := conflictedEvent.EventID() - needed := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{conflictedEvent}) + needed := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{conflictedEvent}) // Find the numeric IDs for the necessary state keys. var neededStateKeys []string @@ -909,7 +909,7 @@ func (v StateResolution) resolveConflictsV2( // This function helps us to work out whether an event exists in one of the // auth sets. - isInAuthList := func(k string, event gomatrixserverlib.Event) bool { + isInAuthList := func(k string, event *gomatrixserverlib.Event) bool { for _, e := range authSets[k] { if e.EventID() == event.EventID() { return true @@ -919,7 +919,7 @@ func (v StateResolution) resolveConflictsV2( } // This function works out if an event exists in all of the auth sets. - isInAllAuthLists := func(event gomatrixserverlib.Event) bool { + isInAllAuthLists := func(event *gomatrixserverlib.Event) bool { found := true for k := range authSets { found = found && isInAuthList(k, event) @@ -1006,7 +1006,7 @@ func (v StateResolution) stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.Ev // Returns an error if there was a problem talking to the database. func (v StateResolution) loadStateEvents( ctx context.Context, entries []types.StateEntry, -) ([]gomatrixserverlib.Event, map[string]types.StateEntry, error) { +) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) { eventNIDs := make([]types.EventNID, len(entries)) for i := range entries { eventNIDs[i] = entries[i].EventNID @@ -1016,7 +1016,7 @@ func (v StateResolution) loadStateEvents( return nil, nil, err } eventIDMap := map[string]types.StateEntry{} - result := make([]gomatrixserverlib.Event, len(entries)) + result := make([]*gomatrixserverlib.Event, len(entries)) for i := range entries { event, ok := eventMap(events).lookup(entries[i].EventNID) if !ok { diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 10a380e85..d2b0e75c9 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -69,7 +69,7 @@ type Database interface { SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) // Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error. StoreEvent( - ctx context.Context, event gomatrixserverlib.Event, txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, + ctx context.Context, event *gomatrixserverlib.Event, txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, isRejected bool, ) (types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) // Look up the state entries for a list of string event IDs @@ -126,7 +126,7 @@ type Database interface { // in this room, along a boolean set to true if the user is still in this room, // false if not. // Returns an error if there was a problem talking to the database. - GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom bool, err error) + GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom, isRoomForgotten bool, err error) // Lookup the membership event numeric IDs for all user that are or have // been members of a given room. Only lookup events of "join" membership if // joinOnly is set to true. @@ -158,4 +158,6 @@ type Database interface { GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) // GetKnownRooms returns a list of all rooms we know about. GetKnownRooms(ctx context.Context) ([]string, error) + // ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room + ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error } diff --git a/roomserver/storage/postgres/deltas/20201028212440_add_forgotten_column.go b/roomserver/storage/postgres/deltas/20201028212440_add_forgotten_column.go new file mode 100644 index 000000000..733f0fa14 --- /dev/null +++ b/roomserver/storage/postgres/deltas/20201028212440_add_forgotten_column.go @@ -0,0 +1,47 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/pressly/goose" +) + +func LoadFromGoose() { + goose.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn) +} + +func LoadAddForgottenColumn(m *sqlutil.Migrations) { + m.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn) +} + +func UpAddForgottenColumn(tx *sql.Tx) error { + _, err := tx.Exec(`ALTER TABLE roomserver_membership ADD COLUMN IF NOT EXISTS forgotten BOOLEAN NOT NULL DEFAULT false;`) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownAddForgottenColumn(tx *sql.Tx) error { + _, err := tx.Exec(`ALTER TABLE roomserver_membership DROP COLUMN IF EXISTS forgotten;`) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index 5164f654f..e392a4fbb 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -60,13 +60,15 @@ CREATE TABLE IF NOT EXISTS roomserver_membership ( -- a federated one. This is an optimisation for resetting state on federated -- room joins. target_local BOOLEAN NOT NULL DEFAULT false, + forgotten BOOLEAN NOT NULL DEFAULT FALSE, UNIQUE (room_nid, target_nid) ); ` var selectJoinedUsersSetForRoomsSQL = "" + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid = ANY($1) AND" + - " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " GROUP BY target_nid" + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + + " GROUP BY target_nid" // Insert a row in to membership table so that it can be locked by the // SELECT FOR UPDATE @@ -76,37 +78,41 @@ const insertMembershipSQL = "" + " ON CONFLICT DO NOTHING" const selectMembershipFromRoomAndTargetSQL = "" + - "SELECT membership_nid, event_nid FROM roomserver_membership" + + "SELECT membership_nid, event_nid, forgotten FROM roomserver_membership" + " WHERE room_nid = $1 AND target_nid = $2" const selectMembershipsFromRoomAndMembershipSQL = "" + "SELECT event_nid FROM roomserver_membership" + - " WHERE room_nid = $1 AND membership_nid = $2" + " WHERE room_nid = $1 AND membership_nid = $2 and forgotten = false" const selectLocalMembershipsFromRoomAndMembershipSQL = "" + "SELECT event_nid FROM roomserver_membership" + " WHERE room_nid = $1 AND membership_nid = $2" + - " AND target_local = true" + " AND target_local = true and forgotten = false" const selectMembershipsFromRoomSQL = "" + "SELECT event_nid FROM roomserver_membership" + - " WHERE room_nid = $1" + " WHERE room_nid = $1 and forgotten = false" const selectLocalMembershipsFromRoomSQL = "" + "SELECT event_nid FROM roomserver_membership" + " WHERE room_nid = $1" + - " AND target_local = true" + " AND target_local = true and forgotten = false" const selectMembershipForUpdateSQL = "" + "SELECT membership_nid FROM roomserver_membership" + " WHERE room_nid = $1 AND target_nid = $2 FOR UPDATE" const updateMembershipSQL = "" + - "UPDATE roomserver_membership SET sender_nid = $3, membership_nid = $4, event_nid = $5" + + "UPDATE roomserver_membership SET sender_nid = $3, membership_nid = $4, event_nid = $5, forgotten = $6" + + " WHERE room_nid = $1 AND target_nid = $2" + +const updateMembershipForgetRoom = "" + + "UPDATE roomserver_membership SET forgotten = $3" + " WHERE room_nid = $1 AND target_nid = $2" const selectRoomsWithMembershipSQL = "" + - "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2" + "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false" // selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is // joined to. Since this information is used to populate the user directory, we will @@ -130,6 +136,7 @@ type membershipStatements struct { selectRoomsWithMembershipStmt *sql.Stmt selectJoinedUsersSetForRoomsStmt *sql.Stmt selectKnownUsersStmt *sql.Stmt + updateMembershipForgetRoomStmt *sql.Stmt } func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) { @@ -151,9 +158,15 @@ func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) { {&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL}, {&s.selectJoinedUsersSetForRoomsStmt, selectJoinedUsersSetForRoomsSQL}, {&s.selectKnownUsersStmt, selectKnownUsersSQL}, + {&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom}, }.Prepare(db) } +func (s *membershipStatements) execSchema(db *sql.DB) error { + _, err := db.Exec(membershipSchema) + return err +} + func (s *membershipStatements) InsertMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, @@ -177,10 +190,10 @@ func (s *membershipStatements) SelectMembershipForUpdate( func (s *membershipStatements) SelectMembershipFromRoomAndTarget( ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, -) (eventNID types.EventNID, membership tables.MembershipState, err error) { +) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) { err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext( ctx, roomNID, targetUserNID, - ).Scan(&membership, &eventNID) + ).Scan(&membership, &eventNID, &forgotten) return } @@ -238,12 +251,11 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( func (s *membershipStatements) UpdateMembership( ctx context.Context, - txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, - senderUserNID types.EventStateKeyNID, membership tables.MembershipState, - eventNID types.EventNID, + txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState, + eventNID types.EventNID, forgotten bool, ) error { _, err := sqlutil.TxStmt(txn, s.updateMembershipStmt).ExecContext( - ctx, roomNID, targetUserNID, senderUserNID, membership, eventNID, + ctx, roomNID, targetUserNID, senderUserNID, membership, eventNID, forgotten, ) return err } @@ -305,3 +317,14 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type } return result, rows.Err() } + +func (s *membershipStatements) UpdateForgetMembership( + ctx context.Context, + txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + forget bool, +) error { + _, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext( + ctx, roomNID, targetUserNID, forget, + ) + return err +} diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 02ff072d7..37aca647c 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -18,12 +18,13 @@ package postgres import ( "database/sql" + // Import the postgres database driver. + _ "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/sqlutil" - - // Import the postgres database driver. - _ "github.com/lib/pq" + "github.com/matrix-org/dendrite/roomserver/storage/postgres/deltas" "github.com/matrix-org/dendrite/roomserver/storage/shared" ) @@ -33,7 +34,6 @@ type Database struct { } // Open a postgres database. -// nolint: gocyclo func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) { var d Database var db *sql.DB @@ -41,61 +41,82 @@ func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) if db, err = sqlutil.Open(dbProperties); err != nil { return nil, err } + + // Create tables before executing migrations so we don't fail if the table is missing, + // and THEN prepare statements so we don't fail due to referencing new columns + ms := membershipStatements{} + if err := ms.execSchema(db); err != nil { + return nil, err + } + m := sqlutil.NewMigrations() + deltas.LoadAddForgottenColumn(m) + if err := m.RunDeltas(db, dbProperties); err != nil { + return nil, err + } + if err := d.prepare(db, cache); err != nil { + return nil, err + } + + return &d, nil +} + +// nolint: gocyclo +func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) (err error) { eventStateKeys, err := NewPostgresEventStateKeysTable(db) if err != nil { - return nil, err + return err } eventTypes, err := NewPostgresEventTypesTable(db) if err != nil { - return nil, err + return err } eventJSON, err := NewPostgresEventJSONTable(db) if err != nil { - return nil, err + return err } events, err := NewPostgresEventsTable(db) if err != nil { - return nil, err + return err } rooms, err := NewPostgresRoomsTable(db) if err != nil { - return nil, err + return err } transactions, err := NewPostgresTransactionsTable(db) if err != nil { - return nil, err + return err } stateBlock, err := NewPostgresStateBlockTable(db) if err != nil { - return nil, err + return err } stateSnapshot, err := NewPostgresStateSnapshotTable(db) if err != nil { - return nil, err + return err } roomAliases, err := NewPostgresRoomAliasesTable(db) if err != nil { - return nil, err + return err } prevEvents, err := NewPostgresPreviousEventsTable(db) if err != nil { - return nil, err + return err } invites, err := NewPostgresInvitesTable(db) if err != nil { - return nil, err + return err } membership, err := NewPostgresMembershipTable(db) if err != nil { - return nil, err + return err } published, err := NewPostgresPublishedTable(db) if err != nil { - return nil, err + return err } redactions, err := NewPostgresRedactionsTable(db) if err != nil { - return nil, err + return err } d.Database = shared.Database{ DB: db, @@ -116,5 +137,5 @@ func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) PublishedTable: published, RedactionsTable: redactions, } - return &d, nil + return nil } diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go index 7abddd018..57f3a520a 100644 --- a/roomserver/storage/shared/membership_updater.go +++ b/roomserver/storage/shared/membership_updater.go @@ -101,9 +101,7 @@ func (u *MembershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, er return fmt.Errorf("u.d.InvitesTable.InsertInviteEvent: %w", err) } if u.membership != tables.MembershipStateInvite { - if err = u.d.MembershipTable.UpdateMembership( - u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0, - ); err != nil { + if err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0, false); err != nil { return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) } } @@ -139,10 +137,7 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd } if u.membership != tables.MembershipStateJoin || isUpdate { - if err = u.d.MembershipTable.UpdateMembership( - u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, - tables.MembershipStateJoin, nIDs[eventID], - ); err != nil { + if err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateJoin, nIDs[eventID], false); err != nil { return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) } } @@ -176,10 +171,7 @@ func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s } if u.membership != tables.MembershipStateLeaveOrBan { - if err = u.d.MembershipTable.UpdateMembership( - u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, - tables.MembershipStateLeaveOrBan, nIDs[eventID], - ); err != nil { + if err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateLeaveOrBan, nIDs[eventID], false); err != nil { return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) } } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index aec15ab22..2548980d4 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -258,30 +258,28 @@ func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { }) } -func (d *Database) GetMembership( - ctx context.Context, roomNID types.RoomNID, requestSenderUserID string, -) (membershipEventNID types.EventNID, stillInRoom bool, err error) { +func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom, isRoomforgotten bool, err error) { var requestSenderUserNID types.EventStateKeyNID err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { requestSenderUserNID, err = d.assignStateKeyNID(ctx, txn, requestSenderUserID) return err }) if err != nil { - return 0, false, fmt.Errorf("d.assignStateKeyNID: %w", err) + return 0, false, false, fmt.Errorf("d.assignStateKeyNID: %w", err) } - senderMembershipEventNID, senderMembership, err := + senderMembershipEventNID, senderMembership, isRoomforgotten, err := d.MembershipTable.SelectMembershipFromRoomAndTarget( ctx, roomNID, requestSenderUserNID, ) if err == sql.ErrNoRows { // The user has never been a member of that room - return 0, false, nil + return 0, false, false, nil } else if err != nil { return } - return senderMembershipEventNID, senderMembership == tables.MembershipStateJoin, nil + return senderMembershipEventNID, senderMembership == tables.MembershipStateJoin, isRoomforgotten, nil } func (d *Database) GetMembershipEventNIDsForRoom( @@ -390,7 +388,7 @@ func (d *Database) GetLatestEventsForUpdate( // nolint:gocyclo func (d *Database) StoreEvent( - ctx context.Context, event gomatrixserverlib.Event, + ctx context.Context, event *gomatrixserverlib.Event, txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, isRejected bool, ) (types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) { var ( @@ -613,7 +611,7 @@ func (d *Database) assignStateKeyNID( return eventStateKeyNID, err } -func extractRoomVersionFromCreateEvent(event gomatrixserverlib.Event) ( +func extractRoomVersionFromCreateEvent(event *gomatrixserverlib.Event) ( gomatrixserverlib.RoomVersion, error, ) { var err error @@ -653,7 +651,7 @@ func extractRoomVersionFromCreateEvent(event gomatrixserverlib.Event) ( // Returns the redaction event and the event ID of the redacted event if this call resulted in a redaction. // nolint:gocyclo func (d *Database) handleRedactions( - ctx context.Context, txn *sql.Tx, eventNID types.EventNID, event gomatrixserverlib.Event, + ctx context.Context, txn *sql.Tx, eventNID types.EventNID, event *gomatrixserverlib.Event, ) (*gomatrixserverlib.Event, string, error) { var err error isRedactionEvent := event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil @@ -705,12 +703,12 @@ func (d *Database) handleRedactions( err = fmt.Errorf("d.RedactionsTable.MarkRedactionValidated: %w", err) } - return &redactionEvent.Event, redactedEvent.EventID(), err + return redactionEvent.Event, redactedEvent.EventID(), err } // loadRedactionPair returns both the redaction event and the redacted event, else nil. func (d *Database) loadRedactionPair( - ctx context.Context, txn *sql.Tx, eventNID types.EventNID, event gomatrixserverlib.Event, + ctx context.Context, txn *sql.Tx, eventNID types.EventNID, event *gomatrixserverlib.Event, ) (*types.Event, *types.Event, bool, error) { var redactionEvent, redactedEvent *types.Event var info *tables.RedactionInfo @@ -816,8 +814,7 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s if err != nil { return nil, err } - h := ev.Headered(roomInfo.RoomVersion) - return &h, nil + return ev.Headered(roomInfo.RoomVersion), nil } } @@ -936,12 +933,11 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu if err != nil { return nil, fmt.Errorf("GetBulkStateContent: failed to load event JSON for event NID %v : %w", events[i].EventNID, err) } - hev := ev.Headered(roomVer) result[i] = tables.StrippedEvent{ EventType: ev.Type(), RoomID: ev.RoomID(), StateKey: *ev.StateKey(), - ContentValue: tables.ExtractContentValue(&hev), + ContentValue: tables.ExtractContentValue(ev.Headered(roomVer)), } } @@ -992,6 +988,25 @@ func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) { return d.RoomsTable.SelectRoomIDs(ctx) } +// ForgetRoom sets a users room to forgotten +func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error { + roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, []string{roomID}) + if err != nil { + return err + } + if len(roomNIDs) > 1 { + return fmt.Errorf("expected one room, got %d", len(roomNIDs)) + } + stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, userID) + if err != nil { + return err + } + + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.MembershipTable.UpdateForgetMembership(ctx, nil, roomNIDs[0], stateKeyNID, forget) + }) +} + // FIXME TODO: Remove all this - horrible dupe with roomserver/state. Can't use the original impl because of circular loops // it should live in this package! diff --git a/roomserver/storage/sqlite3/deltas/20201028212440_add_forgotten_column.go b/roomserver/storage/sqlite3/deltas/20201028212440_add_forgotten_column.go new file mode 100644 index 000000000..33fe9e2a9 --- /dev/null +++ b/roomserver/storage/sqlite3/deltas/20201028212440_add_forgotten_column.go @@ -0,0 +1,82 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/pressly/goose" +) + +func LoadFromGoose() { + goose.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn) +} + +func LoadAddForgottenColumn(m *sqlutil.Migrations) { + m.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn) +} + +func UpAddForgottenColumn(tx *sql.Tx) error { + _, err := tx.Exec(` ALTER TABLE roomserver_membership RENAME TO roomserver_membership_tmp; +CREATE TABLE IF NOT EXISTS roomserver_membership ( + room_nid INTEGER NOT NULL, + target_nid INTEGER NOT NULL, + sender_nid INTEGER NOT NULL DEFAULT 0, + membership_nid INTEGER NOT NULL DEFAULT 1, + event_nid INTEGER NOT NULL DEFAULT 0, + target_local BOOLEAN NOT NULL DEFAULT false, + forgotten BOOLEAN NOT NULL DEFAULT false, + UNIQUE (room_nid, target_nid) + ); +INSERT + INTO roomserver_membership ( + room_nid, target_nid, sender_nid, membership_nid, event_nid, target_local + ) SELECT + room_nid, target_nid, sender_nid, membership_nid, event_nid, target_local + FROM roomserver_membership_tmp +; +DROP TABLE roomserver_membership_tmp;`) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownAddForgottenColumn(tx *sql.Tx) error { + _, err := tx.Exec(` ALTER TABLE roomserver_membership RENAME TO roomserver_membership_tmp; +CREATE TABLE IF NOT EXISTS roomserver_membership ( + room_nid INTEGER NOT NULL, + target_nid INTEGER NOT NULL, + sender_nid INTEGER NOT NULL DEFAULT 0, + membership_nid INTEGER NOT NULL DEFAULT 1, + event_nid INTEGER NOT NULL DEFAULT 0, + target_local BOOLEAN NOT NULL DEFAULT false, + UNIQUE (room_nid, target_nid) + ); +INSERT + INTO roomserver_membership ( + room_nid, target_nid, sender_nid, membership_nid, event_nid, target_local + ) SELECT + room_nid, target_nid, sender_nid, membership_nid, event_nid, target_local + FROM roomserver_membership_tmp +; +DROP TABLE roomserver_membership_tmp;`) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index bb1ab39aa..d716ced04 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -36,13 +36,15 @@ const membershipSchema = ` membership_nid INTEGER NOT NULL DEFAULT 1, event_nid INTEGER NOT NULL DEFAULT 0, target_local BOOLEAN NOT NULL DEFAULT false, + forgotten BOOLEAN NOT NULL DEFAULT false, UNIQUE (room_nid, target_nid) ); ` var selectJoinedUsersSetForRoomsSQL = "" + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid IN ($1) AND" + - " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " GROUP BY target_nid" + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + + " GROUP BY target_nid" // Insert a row in to membership table so that it can be locked by the // SELECT FOR UPDATE @@ -52,37 +54,41 @@ const insertMembershipSQL = "" + " ON CONFLICT DO NOTHING" const selectMembershipFromRoomAndTargetSQL = "" + - "SELECT membership_nid, event_nid FROM roomserver_membership" + + "SELECT membership_nid, event_nid, forgotten FROM roomserver_membership" + " WHERE room_nid = $1 AND target_nid = $2" const selectMembershipsFromRoomAndMembershipSQL = "" + "SELECT event_nid FROM roomserver_membership" + - " WHERE room_nid = $1 AND membership_nid = $2" + " WHERE room_nid = $1 AND membership_nid = $2 and forgotten = false" const selectLocalMembershipsFromRoomAndMembershipSQL = "" + "SELECT event_nid FROM roomserver_membership" + " WHERE room_nid = $1 AND membership_nid = $2" + - " AND target_local = true" + " AND target_local = true and forgotten = false" const selectMembershipsFromRoomSQL = "" + "SELECT event_nid FROM roomserver_membership" + - " WHERE room_nid = $1" + " WHERE room_nid = $1 and forgotten = false" const selectLocalMembershipsFromRoomSQL = "" + "SELECT event_nid FROM roomserver_membership" + " WHERE room_nid = $1" + - " AND target_local = true" + " AND target_local = true and forgotten = false" const selectMembershipForUpdateSQL = "" + "SELECT membership_nid FROM roomserver_membership" + " WHERE room_nid = $1 AND target_nid = $2" const updateMembershipSQL = "" + - "UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3" + - " WHERE room_nid = $4 AND target_nid = $5" + "UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3, forgotten = $4" + + " WHERE room_nid = $5 AND target_nid = $6" + +const updateMembershipForgetRoom = "" + + "UPDATE roomserver_membership SET forgotten = $1" + + " WHERE room_nid = $2 AND target_nid = $3" const selectRoomsWithMembershipSQL = "" + - "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2" + "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false" // selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is // joined to. Since this information is used to populate the user directory, we will @@ -106,16 +112,13 @@ type membershipStatements struct { selectRoomsWithMembershipStmt *sql.Stmt updateMembershipStmt *sql.Stmt selectKnownUsersStmt *sql.Stmt + updateMembershipForgetRoomStmt *sql.Stmt } func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) { s := &membershipStatements{ db: db, } - _, err := db.Exec(membershipSchema) - if err != nil { - return nil, err - } return s, shared.StatementList{ {&s.insertMembershipStmt, insertMembershipSQL}, @@ -128,9 +131,15 @@ func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) { {&s.updateMembershipStmt, updateMembershipSQL}, {&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL}, {&s.selectKnownUsersStmt, selectKnownUsersSQL}, + {&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom}, }.Prepare(db) } +func (s *membershipStatements) execSchema(db *sql.DB) error { + _, err := db.Exec(membershipSchema) + return err +} + func (s *membershipStatements) InsertMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, @@ -155,10 +164,10 @@ func (s *membershipStatements) SelectMembershipForUpdate( func (s *membershipStatements) SelectMembershipFromRoomAndTarget( ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, -) (eventNID types.EventNID, membership tables.MembershipState, err error) { +) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) { err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext( ctx, roomNID, targetUserNID, - ).Scan(&membership, &eventNID) + ).Scan(&membership, &eventNID, &forgotten) return } @@ -216,13 +225,12 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( func (s *membershipStatements) UpdateMembership( ctx context.Context, txn *sql.Tx, - roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, - senderUserNID types.EventStateKeyNID, membership tables.MembershipState, - eventNID types.EventNID, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState, + eventNID types.EventNID, forgotten bool, ) error { stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt) _, err := stmt.ExecContext( - ctx, senderUserNID, membership, eventNID, roomNID, targetUserNID, + ctx, senderUserNID, membership, eventNID, forgotten, roomNID, targetUserNID, ) return err } @@ -285,3 +293,14 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type } return result, rows.Err() } + +func (s *membershipStatements) UpdateForgetMembership( + ctx context.Context, + txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + forget bool, +) error { + _, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext( + ctx, forget, roomNID, targetUserNID, + ) + return err +} diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 6d9b860f5..b36930206 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -19,127 +19,138 @@ import ( "context" "database/sql" + _ "github.com/mattn/go-sqlite3" + "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" - "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" - _ "github.com/mattn/go-sqlite3" ) // A Database is used to store room events and stream offsets. type Database struct { shared.Database - events tables.Events - eventJSON tables.EventJSON - eventTypes tables.EventTypes - eventStateKeys tables.EventStateKeys - rooms tables.Rooms - transactions tables.Transactions - prevEvents tables.PreviousEvents - invites tables.Invites - membership tables.Membership - db *sql.DB - writer sqlutil.Writer } // Open a sqlite database. -// nolint: gocyclo func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) { var d Database + var db *sql.DB var err error - if d.db, err = sqlutil.Open(dbProperties); err != nil { + if db, err = sqlutil.Open(dbProperties); err != nil { return nil, err } - d.writer = sqlutil.NewExclusiveWriter() - //d.db.Exec("PRAGMA journal_mode=WAL;") - //d.db.Exec("PRAGMA read_uncommitted = true;") + + //db.Exec("PRAGMA journal_mode=WAL;") + //db.Exec("PRAGMA read_uncommitted = true;") // FIXME: We are leaking connections somewhere. Setting this to 2 will eventually // cause the roomserver to be unresponsive to new events because something will // acquire the global mutex and never unlock it because it is waiting for a connection // which it will never obtain. - d.db.SetMaxOpenConns(20) + db.SetMaxOpenConns(20) - d.eventStateKeys, err = NewSqliteEventStateKeysTable(d.db) - if err != nil { + // Create tables before executing migrations so we don't fail if the table is missing, + // and THEN prepare statements so we don't fail due to referencing new columns + ms := membershipStatements{} + if err := ms.execSchema(db); err != nil { return nil, err } - d.eventTypes, err = NewSqliteEventTypesTable(d.db) - if err != nil { + m := sqlutil.NewMigrations() + deltas.LoadAddForgottenColumn(m) + if err := m.RunDeltas(db, dbProperties); err != nil { return nil, err } - d.eventJSON, err = NewSqliteEventJSONTable(d.db) - if err != nil { + if err := d.prepare(db, cache); err != nil { return nil, err } - d.events, err = NewSqliteEventsTable(d.db) + + return &d, nil +} + +// nolint: gocyclo +func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error { + var err error + eventStateKeys, err := NewSqliteEventStateKeysTable(db) if err != nil { - return nil, err + return err } - d.rooms, err = NewSqliteRoomsTable(d.db) + eventTypes, err := NewSqliteEventTypesTable(db) if err != nil { - return nil, err + return err } - d.transactions, err = NewSqliteTransactionsTable(d.db) + eventJSON, err := NewSqliteEventJSONTable(db) if err != nil { - return nil, err + return err } - stateBlock, err := NewSqliteStateBlockTable(d.db) + events, err := NewSqliteEventsTable(db) if err != nil { - return nil, err + return err } - stateSnapshot, err := NewSqliteStateSnapshotTable(d.db) + rooms, err := NewSqliteRoomsTable(db) if err != nil { - return nil, err + return err } - d.prevEvents, err = NewSqlitePrevEventsTable(d.db) + transactions, err := NewSqliteTransactionsTable(db) if err != nil { - return nil, err + return err } - roomAliases, err := NewSqliteRoomAliasesTable(d.db) + stateBlock, err := NewSqliteStateBlockTable(db) if err != nil { - return nil, err + return err } - d.invites, err = NewSqliteInvitesTable(d.db) + stateSnapshot, err := NewSqliteStateSnapshotTable(db) if err != nil { - return nil, err + return err } - d.membership, err = NewSqliteMembershipTable(d.db) + prevEvents, err := NewSqlitePrevEventsTable(db) if err != nil { - return nil, err + return err } - published, err := NewSqlitePublishedTable(d.db) + roomAliases, err := NewSqliteRoomAliasesTable(db) if err != nil { - return nil, err + return err } - redactions, err := NewSqliteRedactionsTable(d.db) + invites, err := NewSqliteInvitesTable(db) if err != nil { - return nil, err + return err + } + membership, err := NewSqliteMembershipTable(db) + if err != nil { + return err + } + published, err := NewSqlitePublishedTable(db) + if err != nil { + return err + } + redactions, err := NewSqliteRedactionsTable(db) + if err != nil { + return err } d.Database = shared.Database{ - DB: d.db, + DB: db, Cache: cache, - Writer: d.writer, - EventsTable: d.events, - EventTypesTable: d.eventTypes, - EventStateKeysTable: d.eventStateKeys, - EventJSONTable: d.eventJSON, - RoomsTable: d.rooms, - TransactionsTable: d.transactions, + Writer: sqlutil.NewExclusiveWriter(), + EventsTable: events, + EventTypesTable: eventTypes, + EventStateKeysTable: eventStateKeys, + EventJSONTable: eventJSON, + RoomsTable: rooms, + TransactionsTable: transactions, StateBlockTable: stateBlock, StateSnapshotTable: stateSnapshot, - PrevEventsTable: d.prevEvents, + PrevEventsTable: prevEvents, RoomAliasesTable: roomAliases, - InvitesTable: d.invites, - MembershipTable: d.membership, + InvitesTable: invites, + MembershipTable: membership, PublishedTable: published, RedactionsTable: redactions, GetLatestEventsForUpdateFn: d.GetLatestEventsForUpdate, } - return &d, nil + return nil } func (d *Database) SupportsConcurrentRoomInputs() bool { diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index eba878ba5..d73445846 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -123,15 +123,16 @@ const ( type Membership interface { InsertMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool) error SelectMembershipForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (MembershipState, error) - SelectMembershipFromRoomAndTarget(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, error) + SelectMembershipFromRoomAndTarget(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, bool, error) SelectMembershipsFromRoom(ctx context.Context, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error) SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) - UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID) error + UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error SelectRoomsWithMembership(ctx context.Context, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error) // SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the // counts of how many rooms they are joined. SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) + UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error } type Published interface { diff --git a/roomserver/types/types.go b/roomserver/types/types.go index c0fcef65e..e866f6cbe 100644 --- a/roomserver/types/types.go +++ b/roomserver/types/types.go @@ -126,7 +126,7 @@ type StateAtEventAndReference struct { // It is when performing bulk event lookup in the database. type Event struct { EventNID EventNID - gomatrixserverlib.Event + *gomatrixserverlib.Event } const ( diff --git a/syncapi/consumers/eduserver_receipts.go b/syncapi/consumers/eduserver_receipts.go new file mode 100644 index 000000000..c5d17414a --- /dev/null +++ b/syncapi/consumers/eduserver_receipts.go @@ -0,0 +1,94 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package consumers + +import ( + "context" + "encoding/json" + + "github.com/matrix-org/dendrite/syncapi/types" + + "github.com/Shopify/sarama" + "github.com/matrix-org/dendrite/eduserver/api" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/sync" + log "github.com/sirupsen/logrus" +) + +// OutputReceiptEventConsumer consumes events that originated in the EDU server. +type OutputReceiptEventConsumer struct { + receiptConsumer *internal.ContinualConsumer + db storage.Database + notifier *sync.Notifier +} + +// NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer. +// Call Start() to begin consuming from the EDU server. +func NewOutputReceiptEventConsumer( + cfg *config.SyncAPI, + kafkaConsumer sarama.Consumer, + n *sync.Notifier, + store storage.Database, +) *OutputReceiptEventConsumer { + + consumer := internal.ContinualConsumer{ + ComponentName: "syncapi/eduserver/receipt", + Topic: cfg.Matrix.Kafka.TopicFor(config.TopicOutputReceiptEvent), + Consumer: kafkaConsumer, + PartitionStore: store, + } + + s := &OutputReceiptEventConsumer{ + receiptConsumer: &consumer, + db: store, + notifier: n, + } + + consumer.ProcessMessage = s.onMessage + + return s +} + +// Start consuming from EDU api +func (s *OutputReceiptEventConsumer) Start() error { + return s.receiptConsumer.Start() +} + +func (s *OutputReceiptEventConsumer) onMessage(msg *sarama.ConsumerMessage) error { + var output api.OutputReceiptEvent + if err := json.Unmarshal(msg.Value, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("EDU server output log: message parse failure") + return nil + } + + streamPos, err := s.db.StoreReceipt( + context.TODO(), + output.RoomID, + output.Type, + output.UserID, + output.EventID, + output.Timestamp, + ) + if err != nil { + return err + } + // update stream position + s.notifier.OnNewReceipt(types.NewStreamToken(0, streamPos, nil)) + + return nil +} diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index ac1128c11..05f32a83a 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -118,7 +118,7 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error { func (s *OutputRoomEventConsumer) onRedactEvent( ctx context.Context, msg api.OutputRedactedEvent, ) error { - err := s.db.RedactEvent(ctx, msg.RedactedEventID, &msg.RedactedBecause) + err := s.db.RedactEvent(ctx, msg.RedactedEventID, msg.RedactedBecause) if err != nil { log.WithError(err).Error("RedactEvent error'd") return err @@ -156,7 +156,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( pduPos, err := s.db.WriteEvent( ctx, - &ev, + ev, addsStateEvents, msg.AddsStateEventIDs, msg.RemovesStateEventIDs, @@ -174,12 +174,12 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( return nil } - if pduPos, err = s.notifyJoinedPeeks(ctx, &ev, pduPos); err != nil { + if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil { logrus.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos) return err } - s.notifier.OnNewEvent(&ev, "", nil, types.NewStreamToken(pduPos, 0, nil)) + s.notifier.OnNewEvent(ev, "", nil, types.NewStreamToken(pduPos, 0, nil)) return nil } @@ -197,8 +197,8 @@ func (s *OutputRoomEventConsumer) onOldRoomEvent( // from confusing clients into thinking they've joined/left rooms. pduPos, err := s.db.WriteEvent( ctx, - &ev, - []gomatrixserverlib.HeaderedEvent{}, + ev, + []*gomatrixserverlib.HeaderedEvent{}, []string{}, // adds no state []string{}, // removes no state nil, // no transaction @@ -213,12 +213,12 @@ func (s *OutputRoomEventConsumer) onOldRoomEvent( return nil } - if pduPos, err = s.notifyJoinedPeeks(ctx, &ev, pduPos); err != nil { + if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil { logrus.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos) return err } - s.notifier.OnNewEvent(&ev, "", nil, types.NewStreamToken(pduPos, 0, nil)) + s.notifier.OnNewEvent(ev, "", nil, types.NewStreamToken(pduPos, 0, nil)) return nil } @@ -267,7 +267,7 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent( }).Panicf("roomserver output log: write invite failure") return nil } - s.notifier.OnNewEvent(&msg.Event, "", nil, types.NewStreamToken(pduPos, 0, nil)) + s.notifier.OnNewEvent(msg.Event, "", nil, types.NewStreamToken(pduPos, 0, nil)) return nil } @@ -309,7 +309,7 @@ func (s *OutputRoomEventConsumer) onNewPeek( return nil } -func (s *OutputRoomEventConsumer) updateStateEvent(event gomatrixserverlib.HeaderedEvent) (gomatrixserverlib.HeaderedEvent, error) { +func (s *OutputRoomEventConsumer) updateStateEvent(event *gomatrixserverlib.HeaderedEvent) (*gomatrixserverlib.HeaderedEvent, error) { if event.StateKey() == nil { return event, nil } diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index e5299f200..81a7857fc 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -59,6 +59,7 @@ const defaultMessagesLimit = 10 // OnIncomingMessagesRequest implements the /messages endpoint from the // client-server API. // See: https://matrix.org/docs/spec/client_server/latest.html#get-matrix-client-r0-rooms-roomid-messages +// nolint:gocyclo func OnIncomingMessagesRequest( req *http.Request, db storage.Database, roomID string, device *userapi.Device, federation *gomatrixserverlib.FederationClient, @@ -67,6 +68,19 @@ func OnIncomingMessagesRequest( ) util.JSONResponse { var err error + // check if the user has already forgotten about this room + isForgotten, err := checkIsRoomForgotten(req.Context(), roomID, device.UserID, rsAPI) + if err != nil { + return jsonerror.InternalServerError() + } + + if isForgotten { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("user already forgot about this room"), + } + } + // Extract parameters from the request's URL. // Pagination tokens. var fromStream *types.StreamingToken @@ -182,6 +196,19 @@ func OnIncomingMessagesRequest( } } +func checkIsRoomForgotten(ctx context.Context, roomID, userID string, rsAPI api.RoomserverInternalAPI) (bool, error) { + req := api.QueryMembershipForUserRequest{ + RoomID: roomID, + UserID: userID, + } + resp := api.QueryMembershipForUserResponse{} + if err := rsAPI.QueryMembershipForUser(ctx, &req, &resp); err != nil { + return false, err + } + + return resp.IsRoomForgotten, nil +} + // retrieveEvents retrieves events from the local database for a request on // /messages. If there's not enough events to retrieve, it asks another // homeserver in the room for older events. @@ -208,7 +235,7 @@ func (r *messagesReq) retrieveEvents() ( return } - var events []gomatrixserverlib.HeaderedEvent + var events []*gomatrixserverlib.HeaderedEvent util.GetLogger(r.ctx).WithField("start", start).WithField("end", end).Infof("Fetched %d events locally", len(streamEvents)) // There can be two reasons for streamEvents to be empty: either we've @@ -232,8 +259,8 @@ func (r *messagesReq) retrieveEvents() ( // Sort the events to ensure we send them in the right order. if r.backwardOrdering { // This reverses the array from old->new to new->old - reversed := func(in []gomatrixserverlib.HeaderedEvent) []gomatrixserverlib.HeaderedEvent { - out := make([]gomatrixserverlib.HeaderedEvent, len(in)) + reversed := func(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent { + out := make([]*gomatrixserverlib.HeaderedEvent, len(in)) for i := 0; i < len(in); i++ { out[i] = in[len(in)-i-1] } @@ -260,7 +287,7 @@ func (r *messagesReq) retrieveEvents() ( } // nolint:gocyclo -func (r *messagesReq) filterHistoryVisible(events []gomatrixserverlib.HeaderedEvent) []gomatrixserverlib.HeaderedEvent { +func (r *messagesReq) filterHistoryVisible(events []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent { // TODO FIXME: We don't fully implement history visibility yet. To avoid leaking events which the // user shouldn't see, we check the recent events and remove any prior to the join event of the user // which is equiv to history_visibility: joined @@ -275,8 +302,8 @@ func (r *messagesReq) filterHistoryVisible(events []gomatrixserverlib.HeaderedEv } } - var result []gomatrixserverlib.HeaderedEvent - var eventsToCheck []gomatrixserverlib.HeaderedEvent + var result []*gomatrixserverlib.HeaderedEvent + var eventsToCheck []*gomatrixserverlib.HeaderedEvent if joinEventIndex != -1 { if r.backwardOrdering { result = events[:joinEventIndex+1] @@ -286,7 +313,7 @@ func (r *messagesReq) filterHistoryVisible(events []gomatrixserverlib.HeaderedEv eventsToCheck = append(eventsToCheck, result[len(result)-1]) } } else { - eventsToCheck = []gomatrixserverlib.HeaderedEvent{events[0], events[len(events)-1]} + eventsToCheck = []*gomatrixserverlib.HeaderedEvent{events[0], events[len(events)-1]} result = events } // make sure the user was in the room for both the earliest and latest events, we need this because @@ -310,9 +337,9 @@ func (r *messagesReq) filterHistoryVisible(events []gomatrixserverlib.HeaderedEv for i := range queryRes.StateEvents { switch queryRes.StateEvents[i].Type() { case gomatrixserverlib.MRoomMember: - membershipEvent = &queryRes.StateEvents[i] + membershipEvent = queryRes.StateEvents[i] case gomatrixserverlib.MRoomHistoryVisibility: - hisVisEvent = &queryRes.StateEvents[i] + hisVisEvent = queryRes.StateEvents[i] } } if hisVisEvent == nil { @@ -338,12 +365,12 @@ func (r *messagesReq) filterHistoryVisible(events []gomatrixserverlib.HeaderedEv } if !wasJoined { util.GetLogger(r.ctx).WithField("num_events", len(events)).Warnf("%s was not joined to room during these events, omitting them", r.device.UserID) - return []gomatrixserverlib.HeaderedEvent{} + return []*gomatrixserverlib.HeaderedEvent{} } return result } -func (r *messagesReq) getStartEnd(events []gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) { +func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) { start, err = r.db.EventPositionInTopology( r.ctx, events[0].EventID(), ) @@ -383,7 +410,7 @@ func (r *messagesReq) getStartEnd(events []gomatrixserverlib.HeaderedEvent) (sta // Returns an error if there was an issue talking with the database or // backfilling. func (r *messagesReq) handleEmptyEventsSlice() ( - events []gomatrixserverlib.HeaderedEvent, err error, + events []*gomatrixserverlib.HeaderedEvent, err error, ) { backwardExtremities, err := r.db.BackwardExtremitiesForRoom(r.ctx, r.roomID) @@ -397,7 +424,7 @@ func (r *messagesReq) handleEmptyEventsSlice() ( } else { // If not, it means the slice was empty because we reached the room's // creation, so return an empty slice. - events = []gomatrixserverlib.HeaderedEvent{} + events = []*gomatrixserverlib.HeaderedEvent{} } return @@ -409,7 +436,7 @@ func (r *messagesReq) handleEmptyEventsSlice() ( // through backfilling if needed. // Returns an error if there was an issue while backfilling. func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent) ( - events []gomatrixserverlib.HeaderedEvent, err error, + events []*gomatrixserverlib.HeaderedEvent, err error, ) { // Check if we have enough events. isSetLargeEnough := len(streamEvents) >= r.limit @@ -437,7 +464,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent // Backfill is needed if we've reached a backward extremity and need more // events. It's only needed if the direction is backward. if len(backwardExtremities) > 0 && !isSetLargeEnough && r.backwardOrdering { - var pdus []gomatrixserverlib.HeaderedEvent + var pdus []*gomatrixserverlib.HeaderedEvent // Only ask the remote server for enough events to reach the limit. pdus, err = r.backfill(r.roomID, backwardExtremities, r.limit-len(streamEvents)) if err != nil { @@ -455,7 +482,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent return } -type eventsByDepth []gomatrixserverlib.HeaderedEvent +type eventsByDepth []*gomatrixserverlib.HeaderedEvent func (e eventsByDepth) Len() int { return len(e) @@ -476,7 +503,7 @@ func (e eventsByDepth) Less(i, j int) bool { // event, or if there is no remote homeserver to contact. // Returns an error if there was an issue with retrieving the list of servers in // the room or sending the request. -func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][]string, limit int) ([]gomatrixserverlib.HeaderedEvent, error) { +func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][]string, limit int) ([]*gomatrixserverlib.HeaderedEvent, error) { var res api.PerformBackfillResponse err := r.rsAPI.PerformBackfill(context.Background(), &api.PerformBackfillRequest{ RoomID: roomID, @@ -504,8 +531,8 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][] for i := range res.Events { _, err = r.db.WriteEvent( context.Background(), - &res.Events[i], - []gomatrixserverlib.HeaderedEvent{}, + res.Events[i], + []*gomatrixserverlib.HeaderedEvent{}, []string{}, []string{}, nil, true, diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index e12a1166e..eaa0f64f3 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -18,6 +18,8 @@ import ( "context" "time" + eduAPI "github.com/matrix-org/dendrite/eduserver/api" + "github.com/matrix-org/dendrite/eduserver/cache" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/api" @@ -37,11 +39,11 @@ type Database interface { // If an event is not found in the database then it will be omitted from the list. // Returns an error if there was a problem talking with the database. // Does not include any transaction IDs in the returned events. - Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.HeaderedEvent, error) + Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) // WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races // when generating the sync stream position for this event. Returns the sync stream position for the inserted event. // Returns an error if there was a problem inserting this event. - WriteEvent(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, addStateEvents []gomatrixserverlib.HeaderedEvent, + WriteEvent(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, addStateEvents []*gomatrixserverlib.HeaderedEvent, addStateEventIDs []string, removeStateEventIDs []string, transactionID *api.TransactionID, excludeFromSync bool) (types.StreamPosition, error) // PurgeRoomState completely purges room state from the sync API. This is done when // receiving an output event that completely resets the state. @@ -53,7 +55,7 @@ type Database interface { // GetStateEventsForRoom fetches the state events for a given room. // Returns an empty slice if no state events could be found for this room. // Returns an error if there was an issue with the retrieval. - GetStateEventsForRoom(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter) (stateEvents []gomatrixserverlib.HeaderedEvent, err error) + GetStateEventsForRoom(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter) (stateEvents []*gomatrixserverlib.HeaderedEvent, err error) // SyncPosition returns the latest positions for syncing. SyncPosition(ctx context.Context) (types.StreamingToken, error) // IncrementalSync returns all the data needed in order to create an incremental @@ -82,7 +84,7 @@ type Database interface { // AddInviteEvent stores a new invite event for a user. // If the invite was successfully stored this returns the stream ID it was stored at. // Returns an error if there was a problem communicating with the database. - AddInviteEvent(ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent) (types.StreamPosition, error) + AddInviteEvent(ctx context.Context, inviteEvent *gomatrixserverlib.HeaderedEvent) (types.StreamPosition, error) // RetireInviteEvent removes an old invite event from the database. Returns the new position of the retired invite. // Returns an error if there was a problem communicating with the database. RetireInviteEvent(ctx context.Context, inviteEventID string) (types.StreamPosition, error) @@ -114,7 +116,7 @@ type Database interface { // StreamEventsToEvents converts streamEvent to Event. If device is non-nil and // matches the streamevent.transactionID device then the transaction ID gets // added to the unsigned section of the output event. - StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []gomatrixserverlib.HeaderedEvent + StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent // AddSendToDevice increases the EDU position in the cache and returns the stream position. AddSendToDevice() types.StreamPosition // SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns three lists: @@ -147,4 +149,8 @@ type Database interface { PutFilter(ctx context.Context, localpart string, filter *gomatrixserverlib.Filter) (string, error) // RedactEvent wipes an event in the database and sets the unsigned.redacted_because key to the redaction event RedactEvent(ctx context.Context, redactedEventID string, redactedBecause *gomatrixserverlib.HeaderedEvent) error + // StoreReceipt stores new receipt events + StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) + // GetRoomReceipts gets all receipts for a given roomID + GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error) } diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index 0ca9eed97..123272782 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -195,7 +195,7 @@ func (s *currentRoomStateStatements) SelectRoomIDsWithMembership( func (s *currentRoomStateStatements) SelectCurrentState( ctx context.Context, txn *sql.Tx, roomID string, stateFilter *gomatrixserverlib.StateFilter, -) ([]gomatrixserverlib.HeaderedEvent, error) { +) ([]*gomatrixserverlib.HeaderedEvent, error) { stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt) rows, err := stmt.QueryContext(ctx, roomID, pq.StringArray(stateFilter.Senders), @@ -231,7 +231,7 @@ func (s *currentRoomStateStatements) DeleteRoomStateForRoom( func (s *currentRoomStateStatements) UpsertRoomState( ctx context.Context, txn *sql.Tx, - event gomatrixserverlib.HeaderedEvent, membership *string, addedAt types.StreamPosition, + event *gomatrixserverlib.HeaderedEvent, membership *string, addedAt types.StreamPosition, ) error { // Parse content as JSON and search for an "url" key containsURL := false @@ -275,8 +275,8 @@ func (s *currentRoomStateStatements) SelectEventsWithEventIDs( return rowsToStreamEvents(rows) } -func rowsToEvents(rows *sql.Rows) ([]gomatrixserverlib.HeaderedEvent, error) { - result := []gomatrixserverlib.HeaderedEvent{} +func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) { + result := []*gomatrixserverlib.HeaderedEvent{} for rows.Next() { var eventBytes []byte if err := rows.Scan(&eventBytes); err != nil { @@ -287,7 +287,7 @@ func rowsToEvents(rows *sql.Rows) ([]gomatrixserverlib.HeaderedEvent, error) { if err := json.Unmarshal(eventBytes, &ev); err != nil { return nil, err } - result = append(result, ev) + result = append(result, &ev) } return result, rows.Err() } diff --git a/syncapi/storage/postgres/invites_table.go b/syncapi/storage/postgres/invites_table.go index c0dd42c5a..48ad58c05 100644 --- a/syncapi/storage/postgres/invites_table.go +++ b/syncapi/storage/postgres/invites_table.go @@ -91,7 +91,7 @@ func NewPostgresInvitesTable(db *sql.DB) (tables.Invites, error) { } func (s *inviteEventsStatements) InsertInviteEvent( - ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent, + ctx context.Context, txn *sql.Tx, inviteEvent *gomatrixserverlib.HeaderedEvent, ) (streamPos types.StreamPosition, err error) { var headeredJSON []byte headeredJSON, err = json.Marshal(inviteEvent) @@ -121,15 +121,15 @@ func (s *inviteEventsStatements) DeleteInviteEvent( // active invites for the target user ID in the supplied range. func (s *inviteEventsStatements) SelectInviteEventsInRange( ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range, -) (map[string]gomatrixserverlib.HeaderedEvent, map[string]gomatrixserverlib.HeaderedEvent, error) { +) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, error) { stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt) rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High()) if err != nil { return nil, nil, err } defer internal.CloseAndLogIfError(ctx, rows, "selectInviteEventsInRange: rows.close() failed") - result := map[string]gomatrixserverlib.HeaderedEvent{} - retired := map[string]gomatrixserverlib.HeaderedEvent{} + result := map[string]*gomatrixserverlib.HeaderedEvent{} + retired := map[string]*gomatrixserverlib.HeaderedEvent{} for rows.Next() { var ( roomID string @@ -148,7 +148,7 @@ func (s *inviteEventsStatements) SelectInviteEventsInRange( continue } - var event gomatrixserverlib.HeaderedEvent + var event *gomatrixserverlib.HeaderedEvent if err := json.Unmarshal(eventJSON, &event); err != nil { return nil, nil, err } diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index 4b2101bbc..ce4b63350 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -247,7 +247,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange( stateNeeded[ev.RoomID()] = needSet eventIDToEvent[ev.EventID()] = types.StreamEvent{ - HeaderedEvent: ev, + HeaderedEvent: &ev, StreamPosition: streamPos, ExcludeFromSync: excludeFromSync, } @@ -437,7 +437,7 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) { } result = append(result, types.StreamEvent{ - HeaderedEvent: ev, + HeaderedEvent: &ev, StreamPosition: streamPos, TransactionID: transactionID, ExcludeFromSync: excludeFromSync, diff --git a/syncapi/storage/postgres/receipt_table.go b/syncapi/storage/postgres/receipt_table.go new file mode 100644 index 000000000..c5ec6cbc6 --- /dev/null +++ b/syncapi/storage/postgres/receipt_table.go @@ -0,0 +1,106 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "fmt" + + "github.com/lib/pq" + + "github.com/matrix-org/dendrite/eduserver/api" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +const receiptsSchema = ` +CREATE SEQUENCE IF NOT EXISTS syncapi_stream_id; +-- Stores data about receipts +CREATE TABLE IF NOT EXISTS syncapi_receipts ( + -- The ID + id BIGINT PRIMARY KEY DEFAULT nextval('syncapi_stream_id'), + room_id TEXT NOT NULL, + receipt_type TEXT NOT NULL, + user_id TEXT NOT NULL, + event_id TEXT NOT NULL, + receipt_ts BIGINT NOT NULL, + CONSTRAINT syncapi_receipts_unique UNIQUE (room_id, receipt_type, user_id) +); +CREATE INDEX IF NOT EXISTS syncapi_receipts_room_id ON syncapi_receipts(room_id); +` + +const upsertReceipt = "" + + "INSERT INTO syncapi_receipts" + + " (room_id, receipt_type, user_id, event_id, receipt_ts)" + + " VALUES ($1, $2, $3, $4, $5)" + + " ON CONFLICT (room_id, receipt_type, user_id)" + + " DO UPDATE SET id = nextval('syncapi_stream_id'), event_id = $4, receipt_ts = $5" + + " RETURNING id" + +const selectRoomReceipts = "" + + "SELECT room_id, receipt_type, user_id, event_id, receipt_ts" + + " FROM syncapi_receipts" + + " WHERE room_id = ANY($1) AND id > $2" + +type receiptStatements struct { + db *sql.DB + upsertReceipt *sql.Stmt + selectRoomReceipts *sql.Stmt +} + +func NewPostgresReceiptsTable(db *sql.DB) (tables.Receipts, error) { + _, err := db.Exec(receiptsSchema) + if err != nil { + return nil, err + } + r := &receiptStatements{ + db: db, + } + if r.upsertReceipt, err = db.Prepare(upsertReceipt); err != nil { + return nil, fmt.Errorf("unable to prepare upsertReceipt statement: %w", err) + } + if r.selectRoomReceipts, err = db.Prepare(selectRoomReceipts); err != nil { + return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err) + } + return r, nil +} + +func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) { + stmt := sqlutil.TxStmt(txn, r.upsertReceipt) + err = stmt.QueryRowContext(ctx, roomId, receiptType, userId, eventId, timestamp).Scan(&pos) + return +} + +func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]api.OutputReceiptEvent, error) { + rows, err := r.selectRoomReceipts.QueryContext(ctx, pq.Array(roomIDs), streamPos) + if err != nil { + return nil, fmt.Errorf("unable to query room receipts: %w", err) + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomReceiptsAfter: rows.close() failed") + var res []api.OutputReceiptEvent + for rows.Next() { + r := api.OutputReceiptEvent{} + err = rows.Scan(&r.RoomID, &r.Type, &r.UserID, &r.EventID, &r.Timestamp) + if err != nil { + return res, fmt.Errorf("unable to scan row to api.Receipts: %w", err) + } + res = append(res, r) + } + return res, rows.Err() +} diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go index 7f19722ae..979e19a0b 100644 --- a/syncapi/storage/postgres/syncserver.go +++ b/syncapi/storage/postgres/syncserver.go @@ -82,6 +82,10 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e if err != nil { return nil, err } + receipts, err := NewPostgresReceiptsTable(d.db) + if err != nil { + return nil, err + } d.Database = shared.Database{ DB: d.db, Writer: d.writer, @@ -94,6 +98,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e BackwardExtremities: backwardExtremities, Filter: filter, SendToDevice: sendToDevice, + Receipts: receipts, EDUCache: cache.New(), } return &d, nil diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index a7c07f943..fd8ca0412 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -21,6 +21,7 @@ import ( "fmt" "time" + eduAPI "github.com/matrix-org/dendrite/eduserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/eduserver/cache" @@ -47,6 +48,7 @@ type Database struct { BackwardExtremities tables.BackwardsExtremities SendToDevice tables.SendToDevice Filter tables.Filter + Receipts tables.Receipts EDUCache *cache.EDUCache } @@ -55,7 +57,7 @@ type Database struct { // If an event is not found in the database then it will be omitted from the list. // Returns an error if there was a problem talking with the database. // Does not include any transaction IDs in the returned events. -func (d *Database) Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.HeaderedEvent, error) { +func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs) if err != nil { return nil, err @@ -133,7 +135,7 @@ func (d *Database) GetStateEvent( func (d *Database) GetStateEventsForRoom( ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter, -) (stateEvents []gomatrixserverlib.HeaderedEvent, err error) { +) (stateEvents []*gomatrixserverlib.HeaderedEvent, err error) { stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, nil, roomID, stateFilter) return } @@ -142,7 +144,7 @@ func (d *Database) GetStateEventsForRoom( // If the invite was successfully stored this returns the stream ID it was stored at. // Returns an error if there was a problem communicating with the database. func (d *Database) AddInviteEvent( - ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent, + ctx context.Context, inviteEvent *gomatrixserverlib.HeaderedEvent, ) (sp types.StreamPosition, err error) { _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { sp, err = d.Invites.InsertInviteEvent(ctx, txn, inviteEvent) @@ -221,8 +223,8 @@ func (d *Database) UpsertAccountData( return } -func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []gomatrixserverlib.HeaderedEvent { - out := make([]gomatrixserverlib.HeaderedEvent, len(in)) +func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent { + out := make([]*gomatrixserverlib.HeaderedEvent, len(in)) for i := 0; i < len(in); i++ { out[i] = in[i].HeaderedEvent if device != nil && in[i].TransactionID != nil { @@ -293,7 +295,7 @@ func (d *Database) PurgeRoomState( func (d *Database) WriteEvent( ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, - addStateEvents []gomatrixserverlib.HeaderedEvent, + addStateEvents []*gomatrixserverlib.HeaderedEvent, addStateEventIDs, removeStateEventIDs []string, transactionID *api.TransactionID, excludeFromSync bool, ) (pduPosition types.StreamPosition, returnErr error) { @@ -330,7 +332,7 @@ func (d *Database) WriteEvent( func (d *Database) updateRoomState( ctx context.Context, txn *sql.Tx, removedEventIDs []string, - addedEvents []gomatrixserverlib.HeaderedEvent, + addedEvents []*gomatrixserverlib.HeaderedEvent, pduPosition types.StreamPosition, ) error { // remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add. @@ -527,10 +529,10 @@ func (d *Database) addTypingDeltaToResponse( joinedRoomIDs []string, res *types.Response, ) error { - var jr types.JoinResponse var ok bool var err error for _, roomID := range joinedRoomIDs { + var jr types.JoinResponse if typingUsers, updated := d.EDUCache.GetTypingUsersIfUpdatedAfter( roomID, int64(since.EDUPosition()), ); updated { @@ -554,21 +556,84 @@ func (d *Database) addTypingDeltaToResponse( return nil } +// addReceiptDeltaToResponse adds all receipt information to a sync response +// since the specified position +func (d *Database) addReceiptDeltaToResponse( + since types.StreamingToken, + joinedRoomIDs []string, + res *types.Response, +) error { + receipts, err := d.Receipts.SelectRoomReceiptsAfter(context.TODO(), joinedRoomIDs, since.EDUPosition()) + if err != nil { + return fmt.Errorf("unable to select receipts for rooms: %w", err) + } + + // Group receipts by room, so we can create one ClientEvent for every room + receiptsByRoom := make(map[string][]eduAPI.OutputReceiptEvent) + for _, receipt := range receipts { + receiptsByRoom[receipt.RoomID] = append(receiptsByRoom[receipt.RoomID], receipt) + } + + for roomID, receipts := range receiptsByRoom { + var jr types.JoinResponse + var ok bool + + // Make sure we use an existing JoinResponse if there is one. + // If not, we'll create a new one + if jr, ok = res.Rooms.Join[roomID]; !ok { + jr = types.JoinResponse{} + } + + ev := gomatrixserverlib.ClientEvent{ + Type: gomatrixserverlib.MReceipt, + RoomID: roomID, + } + content := make(map[string]eduAPI.ReceiptMRead) + for _, receipt := range receipts { + var read eduAPI.ReceiptMRead + if read, ok = content[receipt.EventID]; !ok { + read = eduAPI.ReceiptMRead{ + User: make(map[string]eduAPI.ReceiptTS), + } + } + read.User[receipt.UserID] = eduAPI.ReceiptTS{TS: receipt.Timestamp} + content[receipt.EventID] = read + } + ev.Content, err = json.Marshal(content) + if err != nil { + return err + } + + jr.Ephemeral.Events = append(jr.Ephemeral.Events, ev) + res.Rooms.Join[roomID] = jr + } + + return nil +} + // addEDUDeltaToResponse adds updates for EDUs of each type since fromPos if // the positions of that type are not equal in fromPos and toPos. func (d *Database) addEDUDeltaToResponse( fromPos, toPos types.StreamingToken, joinedRoomIDs []string, res *types.Response, -) (err error) { - +) error { if fromPos.EDUPosition() != toPos.EDUPosition() { - err = d.addTypingDeltaToResponse( - fromPos, joinedRoomIDs, res, - ) + // add typing deltas + if err := d.addTypingDeltaToResponse(fromPos, joinedRoomIDs, res); err != nil { + return fmt.Errorf("unable to apply typing delta to response: %w", err) + } } - return + // Check on initial sync and if EDUPositions differ + if (fromPos.EDUPosition() == 0 && toPos.EDUPosition() == 0) || + fromPos.EDUPosition() != toPos.EDUPosition() { + if err := d.addReceiptDeltaToResponse(fromPos, joinedRoomIDs, res); err != nil { + return fmt.Errorf("unable to apply receipts to response: %w", err) + } + } + + return nil } func (d *Database) GetFilter( @@ -644,14 +709,14 @@ func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, reda } eventToRedact := redactedEvents[0].Unwrap() redactionEvent := redactedBecause.Unwrap() - ev, err := eventutil.RedactEvent(&redactionEvent, &eventToRedact) + ev, err := eventutil.RedactEvent(redactionEvent, eventToRedact) if err != nil { return err } newEvent := ev.Headered(redactedBecause.RoomVersion) err = d.Writer.Do(nil, nil, func(txn *sql.Tx) error { - return d.OutputEvents.UpdateEventJSON(ctx, &newEvent) + return d.OutputEvents.UpdateEventJSON(ctx, newEvent) }) return err } @@ -744,7 +809,7 @@ func (d *Database) getJoinResponseForCompleteSync( stateFilter *gomatrixserverlib.StateFilter, numRecentEventsPerRoom int, device userapi.Device, ) (jr *types.JoinResponse, err error) { - var stateEvents []gomatrixserverlib.HeaderedEvent + var stateEvents []*gomatrixserverlib.HeaderedEvent stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, txn, roomID, stateFilter) if err != nil { return @@ -1115,7 +1180,7 @@ func (d *Database) getStateDeltas( // dupe join events will result in the entire room state coming down to the client again. This is added in // the 'state' part of the response though, so is transparent modulo bandwidth concerns as it is not added to // the timeline. - if membership := getMembershipFromEvent(&ev.Event, userID); membership != "" { + if membership := getMembershipFromEvent(ev.Event, userID); membership != "" { if membership == gomatrixserverlib.Join { // send full room state down instead of a delta var s []types.StreamEvent @@ -1199,7 +1264,7 @@ func (d *Database) getStateDeltasForFullStateSync( for roomID, stateStreamEvents := range state { for _, ev := range stateStreamEvents { - if membership := getMembershipFromEvent(&ev.Event, userID); membership != "" { + if membership := getMembershipFromEvent(ev.Event, userID); membership != "" { if membership != gomatrixserverlib.Join { // We've already added full state for all joined rooms above. deltas[roomID] = stateDelta{ membership: membership, @@ -1361,7 +1426,7 @@ func (d *Database) CleanSendToDeviceUpdates( // There may be some overlap where events in stateEvents are already in recentEvents, so filter // them out so we don't include them twice in the /sync response. They should be in recentEvents // only, so clients get to the correct state once they have rolled forward. -func removeDuplicates(stateEvents, recentEvents []gomatrixserverlib.HeaderedEvent) []gomatrixserverlib.HeaderedEvent { +func removeDuplicates(stateEvents, recentEvents []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent { for _, recentEv := range recentEvents { if recentEv.StateKey() == nil { continue // not a state event @@ -1398,9 +1463,22 @@ func getMembershipFromEvent(ev *gomatrixserverlib.Event, userID string) string { type stateDelta struct { roomID string - stateEvents []gomatrixserverlib.HeaderedEvent + stateEvents []*gomatrixserverlib.HeaderedEvent membership string // The PDU stream position of the latest membership event for this user, if applicable. // Can be 0 if there is no membership event in this delta. membershipPos types.StreamPosition } + +// StoreReceipt stores user receipts +func (d *Database) StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + pos, err = d.Receipts.UpsertReceipt(ctx, txn, roomId, receiptType, userId, eventId, timestamp) + return err + }) + return +} + +func (d *Database) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error) { + return d.Receipts.SelectRoomReceiptsAfter(ctx, roomIDs, streamPos) +} diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index 13d23be5f..357d4282e 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -184,7 +184,7 @@ func (s *currentRoomStateStatements) SelectRoomIDsWithMembership( func (s *currentRoomStateStatements) SelectCurrentState( ctx context.Context, txn *sql.Tx, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, -) ([]gomatrixserverlib.HeaderedEvent, error) { +) ([]*gomatrixserverlib.HeaderedEvent, error) { stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt) rows, err := stmt.QueryContext(ctx, roomID, nil, // FIXME: pq.StringArray(stateFilterPart.Senders), @@ -220,7 +220,7 @@ func (s *currentRoomStateStatements) DeleteRoomStateForRoom( func (s *currentRoomStateStatements) UpsertRoomState( ctx context.Context, txn *sql.Tx, - event gomatrixserverlib.HeaderedEvent, membership *string, addedAt types.StreamPosition, + event *gomatrixserverlib.HeaderedEvent, membership *string, addedAt types.StreamPosition, ) error { // Parse content as JSON and search for an "url" key containsURL := false @@ -286,8 +286,8 @@ func (s *currentRoomStateStatements) SelectEventsWithEventIDs( return res, nil } -func rowsToEvents(rows *sql.Rows) ([]gomatrixserverlib.HeaderedEvent, error) { - result := []gomatrixserverlib.HeaderedEvent{} +func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) { + result := []*gomatrixserverlib.HeaderedEvent{} for rows.Next() { var eventBytes []byte if err := rows.Scan(&eventBytes); err != nil { @@ -298,7 +298,7 @@ func rowsToEvents(rows *sql.Rows) ([]gomatrixserverlib.HeaderedEvent, error) { if err := json.Unmarshal(eventBytes, &ev); err != nil { return nil, err } - result = append(result, ev) + result = append(result, &ev) } return result, nil } diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go index 1a36ad40c..f9dcfdbcd 100644 --- a/syncapi/storage/sqlite3/invites_table.go +++ b/syncapi/storage/sqlite3/invites_table.go @@ -91,7 +91,7 @@ func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Inv } func (s *inviteEventsStatements) InsertInviteEvent( - ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent, + ctx context.Context, txn *sql.Tx, inviteEvent *gomatrixserverlib.HeaderedEvent, ) (streamPos types.StreamPosition, err error) { streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn) if err != nil { @@ -132,15 +132,15 @@ func (s *inviteEventsStatements) DeleteInviteEvent( // active invites for the target user ID in the supplied range. func (s *inviteEventsStatements) SelectInviteEventsInRange( ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range, -) (map[string]gomatrixserverlib.HeaderedEvent, map[string]gomatrixserverlib.HeaderedEvent, error) { +) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, error) { stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt) rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High()) if err != nil { return nil, nil, err } defer internal.CloseAndLogIfError(ctx, rows, "selectInviteEventsInRange: rows.close() failed") - result := map[string]gomatrixserverlib.HeaderedEvent{} - retired := map[string]gomatrixserverlib.HeaderedEvent{} + result := map[string]*gomatrixserverlib.HeaderedEvent{} + retired := map[string]*gomatrixserverlib.HeaderedEvent{} for rows.Next() { var ( roomID string @@ -159,7 +159,7 @@ func (s *inviteEventsStatements) SelectInviteEventsInRange( continue } - var event gomatrixserverlib.HeaderedEvent + var event *gomatrixserverlib.HeaderedEvent if err := json.Unmarshal(eventJSON, &event); err != nil { return nil, nil, err } diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index 587a40726..064075824 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -246,7 +246,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange( stateNeeded[ev.RoomID()] = needSet eventIDToEvent[ev.EventID()] = types.StreamEvent{ - HeaderedEvent: ev, + HeaderedEvent: &ev, StreamPosition: streamPos, ExcludeFromSync: excludeFromSync, } @@ -452,7 +452,7 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) { } result = append(result, types.StreamEvent{ - HeaderedEvent: ev, + HeaderedEvent: &ev, StreamPosition: streamPos, TransactionID: transactionID, ExcludeFromSync: excludeFromSync, diff --git a/syncapi/storage/sqlite3/receipt_table.go b/syncapi/storage/sqlite3/receipt_table.go new file mode 100644 index 000000000..b1770e801 --- /dev/null +++ b/syncapi/storage/sqlite3/receipt_table.go @@ -0,0 +1,118 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/matrix-org/dendrite/eduserver/api" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +const receiptsSchema = ` +-- Stores data about receipts +CREATE TABLE IF NOT EXISTS syncapi_receipts ( + -- The ID + id BIGINT, + room_id TEXT NOT NULL, + receipt_type TEXT NOT NULL, + user_id TEXT NOT NULL, + event_id TEXT NOT NULL, + receipt_ts BIGINT NOT NULL, + CONSTRAINT syncapi_receipts_unique UNIQUE (room_id, receipt_type, user_id) +); +CREATE INDEX IF NOT EXISTS syncapi_receipts_room_id_idx ON syncapi_receipts(room_id); +` + +const upsertReceipt = "" + + "INSERT INTO syncapi_receipts" + + " (id, room_id, receipt_type, user_id, event_id, receipt_ts)" + + " VALUES ($1, $2, $3, $4, $5, $6)" + + " ON CONFLICT (room_id, receipt_type, user_id)" + + " DO UPDATE SET id = $7, event_id = $8, receipt_ts = $9" + +const selectRoomReceipts = "" + + "SELECT room_id, receipt_type, user_id, event_id, receipt_ts" + + " FROM syncapi_receipts" + + " WHERE id > $1 and room_id in ($2)" + +type receiptStatements struct { + db *sql.DB + streamIDStatements *streamIDStatements + upsertReceipt *sql.Stmt + selectRoomReceipts *sql.Stmt +} + +func NewSqliteReceiptsTable(db *sql.DB, streamID *streamIDStatements) (tables.Receipts, error) { + _, err := db.Exec(receiptsSchema) + if err != nil { + return nil, err + } + r := &receiptStatements{ + db: db, + streamIDStatements: streamID, + } + if r.upsertReceipt, err = db.Prepare(upsertReceipt); err != nil { + return nil, fmt.Errorf("unable to prepare upsertReceipt statement: %w", err) + } + if r.selectRoomReceipts, err = db.Prepare(selectRoomReceipts); err != nil { + return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err) + } + return r, nil +} + +// UpsertReceipt creates new user receipts +func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) { + pos, err = r.streamIDStatements.nextStreamID(ctx, txn) + if err != nil { + return + } + stmt := sqlutil.TxStmt(txn, r.upsertReceipt) + _, err = stmt.ExecContext(ctx, pos, roomId, receiptType, userId, eventId, timestamp, pos, eventId, timestamp) + return +} + +// SelectRoomReceiptsAfter select all receipts for a given room after a specific timestamp +func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]api.OutputReceiptEvent, error) { + selectSQL := strings.Replace(selectRoomReceipts, "($2)", sqlutil.QueryVariadicOffset(len(roomIDs), 1), 1) + + params := make([]interface{}, len(roomIDs)+1) + params[0] = streamPos + for k, v := range roomIDs { + params[k+1] = v + } + rows, err := r.db.QueryContext(ctx, selectSQL, params...) + if err != nil { + return nil, fmt.Errorf("unable to query room receipts: %w", err) + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomReceiptsAfter: rows.close() failed") + var res []api.OutputReceiptEvent + for rows.Next() { + r := api.OutputReceiptEvent{} + err = rows.Scan(&r.RoomID, &r.Type, &r.UserID, &r.EventID, &r.Timestamp) + if err != nil { + return res, fmt.Errorf("unable to scan row to api.Receipts: %w", err) + } + res = append(res, r) + } + return res, rows.Err() +} diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index 86d83ec98..036e2b2e5 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -95,6 +95,10 @@ func (d *SyncServerDatasource) prepare() (err error) { if err != nil { return err } + receipts, err := NewSqliteReceiptsTable(d.db, &d.streamID) + if err != nil { + return err + } d.Database = shared.Database{ DB: d.db, Writer: d.writer, @@ -107,6 +111,7 @@ func (d *SyncServerDatasource) prepare() (err error) { Topology: topology, Filter: filter, SendToDevice: sendToDevice, + Receipts: receipts, EDUCache: cache.New(), } return nil diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index 2869ac5d2..3930ab6b5 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -37,7 +37,7 @@ var ( }) ) -func MustCreateEvent(t *testing.T, roomID string, prevs []gomatrixserverlib.HeaderedEvent, b *gomatrixserverlib.EventBuilder) gomatrixserverlib.HeaderedEvent { +func MustCreateEvent(t *testing.T, roomID string, prevs []*gomatrixserverlib.HeaderedEvent, b *gomatrixserverlib.EventBuilder) *gomatrixserverlib.HeaderedEvent { b.RoomID = roomID if prevs != nil { prevIDs := make([]string, len(prevs)) @@ -70,8 +70,8 @@ func MustCreateDatabase(t *testing.T) storage.Database { } // Create a list of events which include a create event, join event and some messages. -func SimpleRoom(t *testing.T, roomID, userA, userB string) (msgs []gomatrixserverlib.HeaderedEvent, state []gomatrixserverlib.HeaderedEvent) { - var events []gomatrixserverlib.HeaderedEvent +func SimpleRoom(t *testing.T, roomID, userA, userB string) (msgs []*gomatrixserverlib.HeaderedEvent, state []*gomatrixserverlib.HeaderedEvent) { + var events []*gomatrixserverlib.HeaderedEvent events = append(events, MustCreateEvent(t, roomID, nil, &gomatrixserverlib.EventBuilder{ Content: []byte(fmt.Sprintf(`{"room_version":"4","creator":"%s"}`, userA)), Type: "m.room.create", @@ -80,7 +80,7 @@ func SimpleRoom(t *testing.T, roomID, userA, userB string) (msgs []gomatrixserve Depth: int64(len(events) + 1), })) state = append(state, events[len(events)-1]) - events = append(events, MustCreateEvent(t, roomID, []gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ + events = append(events, MustCreateEvent(t, roomID, []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ Content: []byte(`{"membership":"join"}`), Type: "m.room.member", StateKey: &userA, @@ -89,14 +89,14 @@ func SimpleRoom(t *testing.T, roomID, userA, userB string) (msgs []gomatrixserve })) state = append(state, events[len(events)-1]) for i := 0; i < 10; i++ { - events = append(events, MustCreateEvent(t, roomID, []gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ + events = append(events, MustCreateEvent(t, roomID, []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ Content: []byte(fmt.Sprintf(`{"body":"Message A %d"}`, i+1)), Type: "m.room.message", Sender: userA, Depth: int64(len(events) + 1), })) } - events = append(events, MustCreateEvent(t, roomID, []gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ + events = append(events, MustCreateEvent(t, roomID, []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ Content: []byte(`{"membership":"join"}`), Type: "m.room.member", StateKey: &userB, @@ -105,7 +105,7 @@ func SimpleRoom(t *testing.T, roomID, userA, userB string) (msgs []gomatrixserve })) state = append(state, events[len(events)-1]) for i := 0; i < 10; i++ { - events = append(events, MustCreateEvent(t, roomID, []gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ + events = append(events, MustCreateEvent(t, roomID, []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ Content: []byte(fmt.Sprintf(`{"body":"Message B %d"}`, i+1)), Type: "m.room.message", Sender: userB, @@ -116,16 +116,16 @@ func SimpleRoom(t *testing.T, roomID, userA, userB string) (msgs []gomatrixserve return events, state } -func MustWriteEvents(t *testing.T, db storage.Database, events []gomatrixserverlib.HeaderedEvent) (positions []types.StreamPosition) { +func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserverlib.HeaderedEvent) (positions []types.StreamPosition) { for _, ev := range events { - var addStateEvents []gomatrixserverlib.HeaderedEvent + var addStateEvents []*gomatrixserverlib.HeaderedEvent var addStateEventIDs []string var removeStateEventIDs []string if ev.StateKey() != nil { addStateEvents = append(addStateEvents, ev) addStateEventIDs = append(addStateEventIDs, ev.EventID()) } - pos, err := db.WriteEvent(ctx, &ev, addStateEvents, addStateEventIDs, removeStateEventIDs, nil, false) + pos, err := db.WriteEvent(ctx, ev, addStateEvents, addStateEventIDs, removeStateEventIDs, nil, false) if err != nil { t.Fatalf("WriteEvent failed: %s", err) } @@ -156,8 +156,8 @@ func TestSyncResponse(t *testing.T) { testCases := []struct { Name string DoSync func() (*types.Response, error) - WantTimeline []gomatrixserverlib.HeaderedEvent - WantState []gomatrixserverlib.HeaderedEvent + WantTimeline []*gomatrixserverlib.HeaderedEvent + WantState []*gomatrixserverlib.HeaderedEvent }{ // The purpose of this test is to make sure that incremental syncs are including up to the latest events. // It's a basic sanity test that sync works. It creates a `since` token that is on the penultimate event. @@ -339,7 +339,7 @@ func TestGetEventsInRangeWithEventsSameDepth(t *testing.T) { t.Parallel() db := MustCreateDatabase(t) - var events []gomatrixserverlib.HeaderedEvent + var events []*gomatrixserverlib.HeaderedEvent events = append(events, MustCreateEvent(t, testRoomID, nil, &gomatrixserverlib.EventBuilder{ Content: []byte(fmt.Sprintf(`{"room_version":"4","creator":"%s"}`, testUserIDA)), Type: "m.room.create", @@ -347,7 +347,7 @@ func TestGetEventsInRangeWithEventsSameDepth(t *testing.T) { Sender: testUserIDA, Depth: int64(len(events) + 1), })) - events = append(events, MustCreateEvent(t, testRoomID, []gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ + events = append(events, MustCreateEvent(t, testRoomID, []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ Content: []byte(`{"membership":"join"}`), Type: "m.room.member", StateKey: &testUserIDA, @@ -355,7 +355,7 @@ func TestGetEventsInRangeWithEventsSameDepth(t *testing.T) { Depth: int64(len(events) + 1), })) // fork the dag into three, same prev_events and depth - parent := []gomatrixserverlib.HeaderedEvent{events[len(events)-1]} + parent := []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]} depth := int64(len(events) + 1) for i := 0; i < 3; i++ { events = append(events, MustCreateEvent(t, testRoomID, parent, &gomatrixserverlib.EventBuilder{ @@ -388,7 +388,7 @@ func TestGetEventsInRangeWithEventsSameDepth(t *testing.T) { Name string From types.TopologyToken Limit int - Wants []gomatrixserverlib.HeaderedEvent + Wants []*gomatrixserverlib.HeaderedEvent }{ { Name: "Pagination over the whole fork", @@ -429,7 +429,7 @@ func TestGetEventsInTopologicalRangeMultiRoom(t *testing.T) { t.Parallel() db := MustCreateDatabase(t) - makeEvents := func(roomID string) (events []gomatrixserverlib.HeaderedEvent) { + makeEvents := func(roomID string) (events []*gomatrixserverlib.HeaderedEvent) { events = append(events, MustCreateEvent(t, roomID, nil, &gomatrixserverlib.EventBuilder{ Content: []byte(fmt.Sprintf(`{"room_version":"4","creator":"%s"}`, testUserIDA)), Type: "m.room.create", @@ -437,7 +437,7 @@ func TestGetEventsInTopologicalRangeMultiRoom(t *testing.T) { Sender: testUserIDA, Depth: int64(len(events) + 1), })) - events = append(events, MustCreateEvent(t, roomID, []gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ + events = append(events, MustCreateEvent(t, roomID, []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ Content: []byte(`{"membership":"join"}`), Type: "m.room.member", StateKey: &testUserIDA, @@ -483,14 +483,14 @@ func TestGetEventsInRangeWithEventsInsertedLikeBackfill(t *testing.T) { // "federation" join userC := fmt.Sprintf("@radiance:%s", testOrigin) - joinEvent := MustCreateEvent(t, testRoomID, []gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ + joinEvent := MustCreateEvent(t, testRoomID, []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ Content: []byte(`{"membership":"join"}`), Type: "m.room.member", StateKey: &userC, Sender: userC, Depth: int64(len(events) + 1), }) - MustWriteEvents(t, db, []gomatrixserverlib.HeaderedEvent{joinEvent}) + MustWriteEvents(t, db, []*gomatrixserverlib.HeaderedEvent{joinEvent}) // Sync will return this for the prev_batch from := topologyTokenBefore(t, db, joinEvent.EventID()) @@ -627,7 +627,7 @@ func TestInviteBehaviour(t *testing.T) { StateKey: &testUserIDA, Sender: "@inviteUser2:somewhere", }) - for _, ev := range []gomatrixserverlib.HeaderedEvent{inviteEvent1, inviteEvent2} { + for _, ev := range []*gomatrixserverlib.HeaderedEvent{inviteEvent1, inviteEvent2} { _, err := db.AddInviteEvent(ctx, ev) if err != nil { t.Fatalf("Failed to AddInviteEvent: %s", err) @@ -688,7 +688,7 @@ func assertInvitedToRooms(t *testing.T, res *types.Response, roomIDs []string) { } } -func assertEventsEqual(t *testing.T, msg string, checkRoomID bool, gots []gomatrixserverlib.ClientEvent, wants []gomatrixserverlib.HeaderedEvent) { +func assertEventsEqual(t *testing.T, msg string, checkRoomID bool, gots []gomatrixserverlib.ClientEvent, wants []*gomatrixserverlib.HeaderedEvent) { t.Helper() if len(gots) != len(wants) { t.Fatalf("%s response returned %d events, want %d", msg, len(gots), len(wants)) @@ -738,8 +738,8 @@ func topologyTokenBefore(t *testing.T, db storage.Database, eventID string) *typ return &tok } -func reversed(in []gomatrixserverlib.HeaderedEvent) []gomatrixserverlib.HeaderedEvent { - out := make([]gomatrixserverlib.HeaderedEvent, len(in)) +func reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent { + out := make([]*gomatrixserverlib.HeaderedEvent, len(in)) for i := 0; i < len(in); i++ { out[i] = in[len(in)-i-1] } diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index da095be53..a2d8791b6 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" + eduAPI "github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -31,11 +32,11 @@ type AccountData interface { } type Invites interface { - InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent) (streamPos types.StreamPosition, err error) + InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEvent *gomatrixserverlib.HeaderedEvent) (streamPos types.StreamPosition, err error) DeleteInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string) (types.StreamPosition, error) // SelectInviteEventsInRange returns a map of room ID to invite events. If multiple invite/retired invites exist in the given range, return the latest value // for the room. - SelectInviteEventsInRange(ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range) (invites map[string]gomatrixserverlib.HeaderedEvent, retired map[string]gomatrixserverlib.HeaderedEvent, err error) + SelectInviteEventsInRange(ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range) (invites map[string]*gomatrixserverlib.HeaderedEvent, retired map[string]*gomatrixserverlib.HeaderedEvent, err error) SelectMaxInviteID(ctx context.Context, txn *sql.Tx) (id int64, err error) } @@ -86,11 +87,11 @@ type Topology interface { type CurrentRoomState interface { SelectStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) SelectEventsWithEventIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error) - UpsertRoomState(ctx context.Context, txn *sql.Tx, event gomatrixserverlib.HeaderedEvent, membership *string, addedAt types.StreamPosition) error + UpsertRoomState(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, membership *string, addedAt types.StreamPosition) error DeleteRoomStateByEventID(ctx context.Context, txn *sql.Tx, eventID string) error DeleteRoomStateForRoom(ctx context.Context, txn *sql.Tx, roomID string) error // SelectCurrentState returns all the current state events for the given room. - SelectCurrentState(ctx context.Context, txn *sql.Tx, roomID string, stateFilter *gomatrixserverlib.StateFilter) ([]gomatrixserverlib.HeaderedEvent, error) + SelectCurrentState(ctx context.Context, txn *sql.Tx, roomID string, stateFilter *gomatrixserverlib.StateFilter) ([]*gomatrixserverlib.HeaderedEvent, error) // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. SelectRoomIDsWithMembership(ctx context.Context, txn *sql.Tx, userID string, membership string) ([]string, error) // SelectJoinedUsers returns a map of room ID to a list of joined user IDs. @@ -156,3 +157,8 @@ type Filter interface { SelectFilter(ctx context.Context, localpart string, filterID string) (*gomatrixserverlib.Filter, error) InsertFilter(ctx context.Context, filter *gomatrixserverlib.Filter, localpart string) (filterID string, err error) } + +type Receipts interface { + UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) + SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error) +} diff --git a/syncapi/sync/notifier.go b/syncapi/sync/notifier.go index fcac3f16c..daa3a1d8c 100644 --- a/syncapi/sync/notifier.go +++ b/syncapi/sync/notifier.go @@ -149,6 +149,16 @@ func (n *Notifier) OnNewSendToDevice( n.wakeupUserDevice(userID, deviceIDs, latestPos) } +// OnNewReceipt updates the current position +func (n *Notifier) OnNewReceipt( + posUpdate types.StreamingToken, +) { + n.streamLock.Lock() + defer n.streamLock.Unlock() + latestPos := n.currPos.WithUpdates(posUpdate) + n.currPos = latestPos +} + func (n *Notifier) OnNewKeyChange( posUpdate types.StreamingToken, wakeUserID, keyChangeUserID string, ) { diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 8a79737aa..cb84c5bb6 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -19,10 +19,14 @@ package sync import ( "context" "fmt" + "net" "net/http" + "strings" + "sync" "time" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal/config" keyapi "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/internal" @@ -37,18 +41,62 @@ import ( // RequestPool manages HTTP long-poll connections for /sync type RequestPool struct { db storage.Database + cfg *config.SyncAPI userAPI userapi.UserInternalAPI notifier *Notifier keyAPI keyapi.KeyInternalAPI rsAPI roomserverAPI.RoomserverInternalAPI + lastseen sync.Map } // NewRequestPool makes a new RequestPool func NewRequestPool( - db storage.Database, n *Notifier, userAPI userapi.UserInternalAPI, keyAPI keyapi.KeyInternalAPI, + db storage.Database, cfg *config.SyncAPI, n *Notifier, + userAPI userapi.UserInternalAPI, keyAPI keyapi.KeyInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI, ) *RequestPool { - return &RequestPool{db, userAPI, n, keyAPI, rsAPI} + rp := &RequestPool{db, cfg, userAPI, n, keyAPI, rsAPI, sync.Map{}} + go rp.cleanLastSeen() + return rp +} + +func (rp *RequestPool) cleanLastSeen() { + for { + rp.lastseen.Range(func(key interface{}, _ interface{}) bool { + rp.lastseen.Delete(key) + return true + }) + time.Sleep(time.Minute) + } +} + +func (rp *RequestPool) updateLastSeen(req *http.Request, device *userapi.Device) { + if _, ok := rp.lastseen.LoadOrStore(device.UserID+device.ID, struct{}{}); ok { + return + } + + remoteAddr := req.RemoteAddr + if rp.cfg.RealIPHeader != "" { + if header := req.Header.Get(rp.cfg.RealIPHeader); header != "" { + // TODO: Maybe this isn't great but it will satisfy both X-Real-IP + // and X-Forwarded-For (which can be a list where the real client + // address is the first listed address). Make more intelligent? + addresses := strings.Split(header, ",") + if ip := net.ParseIP(addresses[0]); ip != nil { + remoteAddr = addresses[0] + } + } + } + + lsreq := &userapi.PerformLastSeenUpdateRequest{ + UserID: device.UserID, + DeviceID: device.ID, + RemoteAddr: remoteAddr, + } + lsres := &userapi.PerformLastSeenUpdateResponse{} + go rp.userAPI.PerformLastSeenUpdate(req.Context(), lsreq, lsres) // nolint:errcheck + + rp.lastseen.Store(device.UserID+device.ID, time.Now()) } // OnIncomingSyncRequest is called when a client makes a /sync request. This function MUST be @@ -74,6 +122,8 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. "limit": syncReq.limit, }) + rp.updateLastSeen(req, device) + currPos := rp.notifier.CurrentPosition() if rp.shouldReturnImmediately(syncReq) { @@ -278,7 +328,7 @@ func (rp *RequestPool) appendAccountData( // data keys were set between two message. This isn't a huge issue since the // duplicate data doesn't represent a huge quantity of data, but an optimisation // here would be making sure each data is sent only once to the client. - if req.since == nil { + if req.since == nil || (req.since.PDUPosition() == 0 && req.since.EDUPosition() == 0) { // If this is the initial sync, we don't need to check if a data has // already been sent. Instead, we send the whole batch. dataReq := &userapi.QueryAccountDataRequest{ diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index de0bb434b..7e277ba19 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -61,7 +61,7 @@ func AddPublicRoutes( logrus.WithError(err).Panicf("failed to start notifier") } - requestPool := sync.NewRequestPool(syncDB, notifier, userAPI, keyAPI, rsAPI) + requestPool := sync.NewRequestPool(syncDB, cfg, notifier, userAPI, keyAPI, rsAPI) keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer( cfg.Matrix.ServerName, string(cfg.Matrix.Kafka.TopicFor(config.TopicOutputKeyChangeEvent)), @@ -99,5 +99,12 @@ func AddPublicRoutes( logrus.WithError(err).Panicf("failed to start send-to-device consumer") } + receiptConsumer := consumers.NewOutputReceiptEventConsumer( + cfg, consumer, notifier, syncDB, + ) + if err = receiptConsumer.Start(); err != nil { + logrus.WithError(err).Panicf("failed to start receipts consumer") + } + routing.Setup(router, requestPool, syncDB, userAPI, federation, rsAPI, cfg) } diff --git a/syncapi/types/types.go b/syncapi/types/types.go index 9be83f5fa..91d33d4a6 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -59,7 +59,7 @@ func (p *LogPosition) IsAfter(lp *LogPosition) bool { // StreamEvent is the same as gomatrixserverlib.Event but also has the PDU stream position for this event. type StreamEvent struct { - gomatrixserverlib.HeaderedEvent + *gomatrixserverlib.HeaderedEvent StreamPosition StreamPosition TransactionID *api.TransactionID ExcludeFromSync bool @@ -471,7 +471,7 @@ type InviteResponse struct { } // NewInviteResponse creates an empty response with initialised arrays. -func NewInviteResponse(event gomatrixserverlib.HeaderedEvent) *InviteResponse { +func NewInviteResponse(event *gomatrixserverlib.HeaderedEvent) *InviteResponse { res := InviteResponse{} res.InviteState.Events = []json.RawMessage{} diff --git a/sytest-blacklist b/sytest-blacklist index f493f94fe..f42d5f822 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -60,4 +60,7 @@ Invited user can reject invite for empty room If user leaves room, remote user changes device and rejoins we see update in /sync and /keys/changes # Blacklisted due to flakiness -A prev_batch token from incremental sync can be used in the v1 messages API \ No newline at end of file +A prev_batch token from incremental sync can be used in the v1 messages API + +# Blacklisted due to flakiness +Forgotten room messages cannot be paginated \ No newline at end of file diff --git a/sytest-whitelist b/sytest-whitelist index 1a12b591b..ffcb1785a 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -16,6 +16,13 @@ POST /register rejects registration of usernames with '£' POST /register rejects registration of usernames with 'é' POST /register rejects registration of usernames with '\n' POST /register rejects registration of usernames with ''' +POST /register allows registration of usernames with 'q' +POST /register allows registration of usernames with '3' +POST /register allows registration of usernames with '.' +POST /register allows registration of usernames with '_' +POST /register allows registration of usernames with '=' +POST /register allows registration of usernames with '-' +POST /register allows registration of usernames with '/' GET /login yields a set of flows POST /login can log in as a user POST /login returns the same device_id as that in the request @@ -483,6 +490,16 @@ POST rejects invalid utf-8 in JSON Users cannot kick users who have already left a room Event with an invalid signature in the send_join response should not cause room join to fail Inbound federation rejects typing notifications from wrong remote +POST /rooms/:room_id/receipt can create receipts +Receipts must be m.read +Read receipts appear in initial v2 /sync +New read receipts appear in incremental v2 /sync +Outbound federation sends receipts +Inbound federation rejects receipts from wrong remote Should not be able to take over the room by pretending there is no PL event Can get rooms/{roomId}/state for a departed room (SPEC-216) Users cannot set notifications powerlevel higher than their own +Forgetting room does not show up in v2 /sync +Can forget room you've been kicked from +Can re-join room if re-invited +/whois diff --git a/userapi/api/api.go b/userapi/api/api.go index 6c3f3c69c..809ba0476 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -29,6 +29,7 @@ type UserInternalAPI interface { PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error + PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error @@ -183,6 +184,17 @@ type PerformPasswordUpdateResponse struct { Account *Account } +// PerformLastSeenUpdateRequest is the request for PerformLastSeenUpdate. +type PerformLastSeenUpdateRequest struct { + UserID string + DeviceID string + RemoteAddr string +} + +// PerformLastSeenUpdateResponse is the response for PerformLastSeenUpdate. +type PerformLastSeenUpdateResponse struct { +} + // PerformDeviceCreationRequest is the request for PerformDeviceCreation type PerformDeviceCreationRequest struct { Localpart string diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 81d002414..3b5f4978a 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -172,6 +172,21 @@ func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) er return nil } +func (a *UserInternalAPI) PerformLastSeenUpdate( + ctx context.Context, + req *api.PerformLastSeenUpdateRequest, + res *api.PerformLastSeenUpdateResponse, +) error { + localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID) + if err != nil { + return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) + } + if err := a.DeviceDB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr); err != nil { + return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err) + } + return nil +} + func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error { localpart, _, err := gomatrixserverlib.SplitID('@', req.RequestingUserID) if err != nil { diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 4d9dcc416..680e4cb52 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -32,6 +32,7 @@ const ( PerformAccountCreationPath = "/userapi/performAccountCreation" PerformPasswordUpdatePath = "/userapi/performPasswordUpdate" PerformDeviceDeletionPath = "/userapi/performDeviceDeletion" + PerformLastSeenUpdatePath = "/userapi/performLastSeenUpdate" PerformDeviceUpdatePath = "/userapi/performDeviceUpdate" PerformAccountDeactivationPath = "/userapi/performAccountDeactivation" @@ -119,6 +120,18 @@ func (h *httpUserInternalAPI) PerformDeviceDeletion( return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } +func (h *httpUserInternalAPI) PerformLastSeenUpdate( + ctx context.Context, + req *api.PerformLastSeenUpdateRequest, + res *api.PerformLastSeenUpdateResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLastSeen") + defer span.Finish() + + apiURL := h.apiURL + PerformLastSeenUpdatePath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + func (h *httpUserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error { span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDeviceUpdate") defer span.Finish() diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index 81e936e58..e495e3536 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -65,6 +65,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(PerformLastSeenUpdatePath, + httputil.MakeInternalAPI("performLastSeenUpdate", func(req *http.Request) util.JSONResponse { + request := api.PerformLastSeenUpdateRequest{} + response := api.PerformLastSeenUpdateResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformLastSeenUpdate(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) internalAPIMux.Handle(PerformDeviceUpdatePath, httputil.MakeInternalAPI("performDeviceUpdate", func(req *http.Request) util.JSONResponse { request := api.PerformDeviceUpdateRequest{} diff --git a/userapi/storage/devices/interface.go b/userapi/storage/devices/interface.go index 9953ba062..95fe99f33 100644 --- a/userapi/storage/devices/interface.go +++ b/userapi/storage/devices/interface.go @@ -33,9 +33,9 @@ type Database interface { // Returns the device on success. CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error) UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error + UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error RemoveDevice(ctx context.Context, deviceID, localpart string) error RemoveDevices(ctx context.Context, localpart string, devices []string) error // RemoveAllDevices deleted all devices for this user. Returns the devices deleted. RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error) - UpdateDeviceLastSeen(ctx context.Context, deviceID, ipAddr string) error } diff --git a/userapi/storage/devices/postgres/devices_table.go b/userapi/storage/devices/postgres/devices_table.go index 379fed794..7de9f5f9e 100644 --- a/userapi/storage/devices/postgres/devices_table.go +++ b/userapi/storage/devices/postgres/devices_table.go @@ -77,7 +77,7 @@ const selectDeviceByIDSQL = "" + "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" const selectDevicesByLocalpartSQL = "" + - "SELECT device_id, display_name FROM device_devices WHERE localpart = $1 AND device_id != $2" + "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2" const updateDeviceNameSQL = "" + "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" @@ -95,7 +95,7 @@ const selectDevicesByIDSQL = "" + "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id = ANY($1)" const updateDeviceLastSeen = "" + - "UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE device_id = $3" + "UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4" type devicesStatements struct { insertDeviceStmt *sql.Stmt @@ -281,8 +281,9 @@ func (s *devicesStatements) selectDevicesByLocalpart( for rows.Next() { var dev api.Device - var id, displayname sql.NullString - err = rows.Scan(&id, &displayname) + var lastseents sql.NullInt64 + var id, displayname, ip, useragent sql.NullString + err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent) if err != nil { return devices, err } @@ -292,6 +293,16 @@ func (s *devicesStatements) selectDevicesByLocalpart( if displayname.Valid { dev.DisplayName = displayname.String } + if lastseents.Valid { + dev.LastSeenTS = lastseents.Int64 + } + if ip.Valid { + dev.LastSeenIP = ip.String + } + if useragent.Valid { + dev.UserAgent = useragent.String + } + dev.UserID = userutil.MakeUserID(localpart, s.serverName) devices = append(devices, dev) } @@ -299,9 +310,9 @@ func (s *devicesStatements) selectDevicesByLocalpart( return devices, rows.Err() } -func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, deviceID, ipAddr string) error { +func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error { lastSeenTs := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) - _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, deviceID) + _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID) return err } diff --git a/userapi/storage/devices/postgres/storage.go b/userapi/storage/devices/postgres/storage.go index e318b260b..6dd18b092 100644 --- a/userapi/storage/devices/postgres/storage.go +++ b/userapi/storage/devices/postgres/storage.go @@ -205,8 +205,8 @@ func (d *Database) RemoveAllDevices( } // UpdateDeviceLastSeen updates a the last seen timestamp and the ip address -func (d *Database) UpdateDeviceLastSeen(ctx context.Context, deviceID, ipAddr string) error { +func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error { return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.devices.updateDeviceLastSeen(ctx, txn, deviceID, ipAddr) + return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) }) } diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go index 26c03222a..955d8ac7f 100644 --- a/userapi/storage/devices/sqlite3/devices_table.go +++ b/userapi/storage/devices/sqlite3/devices_table.go @@ -62,7 +62,7 @@ const selectDeviceByIDSQL = "" + "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" const selectDevicesByLocalpartSQL = "" + - "SELECT device_id, display_name FROM device_devices WHERE localpart = $1 AND device_id != $2" + "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2" const updateDeviceNameSQL = "" + "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" @@ -80,7 +80,7 @@ const selectDevicesByIDSQL = "" + "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)" const updateDeviceLastSeen = "" + - "UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE device_id = $3" + "UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4" type devicesStatements struct { db *sql.DB @@ -256,8 +256,9 @@ func (s *devicesStatements) selectDevicesByLocalpart( for rows.Next() { var dev api.Device - var id, displayname sql.NullString - err = rows.Scan(&id, &displayname) + var lastseents sql.NullInt64 + var id, displayname, ip, useragent sql.NullString + err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent) if err != nil { return devices, err } @@ -267,6 +268,16 @@ func (s *devicesStatements) selectDevicesByLocalpart( if displayname.Valid { dev.DisplayName = displayname.String } + if lastseents.Valid { + dev.LastSeenTS = lastseents.Int64 + } + if ip.Valid { + dev.LastSeenIP = ip.String + } + if useragent.Valid { + dev.UserAgent = useragent.String + } + dev.UserID = userutil.MakeUserID(localpart, s.serverName) devices = append(devices, dev) } @@ -303,9 +314,9 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s return devices, rows.Err() } -func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, deviceID, ipAddr string) error { +func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error { lastSeenTs := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) - _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, deviceID) + _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID) return err } diff --git a/userapi/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go index 25888eae4..2eefb3f34 100644 --- a/userapi/storage/devices/sqlite3/storage.go +++ b/userapi/storage/devices/sqlite3/storage.go @@ -207,8 +207,8 @@ func (d *Database) RemoveAllDevices( } // UpdateDeviceLastSeen updates a the last seen timestamp and the ip address -func (d *Database) UpdateDeviceLastSeen(ctx context.Context, deviceID, ipAddr string) error { +func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error { return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.devices.updateDeviceLastSeen(ctx, txn, deviceID, ipAddr) + return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) }) }