diff --git a/README.md b/README.md index 6c84cffba..dc87d1b47 100644 --- a/README.md +++ b/README.md @@ -80,7 +80,6 @@ test rig with around 900 tests. The script works out how many of these tests are updates with CI. As of October 2020 we're at around 57% CS API coverage and 81% Federation coverage, though check CI for the latest numbers. In practice, this means you can communicate locally and via federation with Synapse servers such as matrix.org reasonably well. There's a long list of features that are not implemented, notably: - - Receipts - Push - Search and Context - User Directory @@ -100,6 +99,7 @@ This means Dendrite supports amongst others: - Redaction - Tagging - E2E keys and device lists + - Receipts ## Contributing diff --git a/build-dendritejs.sh b/build-dendritejs.sh index cd42a6bee..83ec3699c 100755 --- a/build-dendritejs.sh +++ b/build-dendritejs.sh @@ -1,4 +1,4 @@ -#!/bin/bash -eu +#!/bin/sh -eu export GIT_COMMIT=$(git rev-list -1 HEAD) && \ -GOOS=js GOARCH=wasm go build -ldflags "-X main.GitCommit=$GIT_COMMIT" -o main.wasm ./cmd/dendritejs \ No newline at end of file +GOOS=js GOARCH=wasm go build -ldflags "-X main.GitCommit=$GIT_COMMIT" -o main.wasm ./cmd/dendritejs diff --git a/build.sh b/build.sh index 31e0519f5..494d97eda 100755 --- a/build.sh +++ b/build.sh @@ -1,4 +1,4 @@ -#!/bin/bash -eu +#!/bin/sh -eu # Put installed packages into ./bin export GOBIN=$PWD/`dirname $0`/bin @@ -7,7 +7,7 @@ if [ -d ".git" ] then export BUILD=`git rev-parse --short HEAD || ""` export BRANCH=`(git symbolic-ref --short HEAD | tr -d \/ ) || ""` - if [[ $BRANCH == "master" ]] + if [ "$BRANCH" = master ] then export BRANCH="" fi diff --git a/clientapi/routing/account_data.go b/clientapi/routing/account_data.go index 48303c97f..22e635139 100644 --- a/clientapi/routing/account_data.go +++ b/clientapi/routing/account_data.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" + eduserverAPI "github.com/matrix-org/dendrite/eduserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/userapi/api" @@ -148,7 +149,8 @@ type fullyReadEvent struct { // SaveReadMarker implements POST /rooms/{roomId}/read_markers func SaveReadMarker( - req *http.Request, userAPI api.UserInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI, + req *http.Request, + userAPI api.UserInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI, eduAPI eduserverAPI.EDUServerInputAPI, syncProducer *producers.SyncAPIProducer, device *api.Device, roomID string, ) util.JSONResponse { // Verify that the user is a member of this room @@ -192,8 +194,10 @@ func SaveReadMarker( return jsonerror.InternalServerError() } - // TODO handle the read receipt that may be included in the read marker - // See https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-rooms-roomid-read-markers + // Handle the read receipt that may be included in the read marker + if r.Read != "" { + return SetReceipt(req, eduAPI, device, roomID, "m.read", r.Read) + } return util.JSONResponse{ Code: http.StatusOK, diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index fe0795577..e471e2128 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -407,3 +407,47 @@ func checkMemberInRoom(ctx context.Context, rsAPI api.RoomserverInternalAPI, use } return nil } + +func SendForget( + req *http.Request, device *userapi.Device, + roomID string, rsAPI roomserverAPI.RoomserverInternalAPI, +) util.JSONResponse { + ctx := req.Context() + logger := util.GetLogger(ctx).WithField("roomID", roomID).WithField("userID", device.UserID) + var membershipRes api.QueryMembershipForUserResponse + membershipReq := api.QueryMembershipForUserRequest{ + RoomID: roomID, + UserID: device.UserID, + } + err := rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes) + if err != nil { + logger.WithError(err).Error("QueryMembershipForUser: could not query membership for user") + return jsonerror.InternalServerError() + } + if membershipRes.IsInRoom { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.Forbidden("user is still a member of the room"), + } + } + if !membershipRes.HasBeenInRoom { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.Forbidden("user did not belong to room"), + } + } + + request := api.PerformForgetRequest{ + RoomID: roomID, + UserID: device.UserID, + } + response := api.PerformForgetResponse{} + if err := rsAPI.PerformForget(ctx, &request, &response); err != nil { + logger.WithError(err).Error("PerformForget: unable to forget room") + return jsonerror.InternalServerError() + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } +} diff --git a/clientapi/routing/receipt.go b/clientapi/routing/receipt.go new file mode 100644 index 000000000..fe8fe765d --- /dev/null +++ b/clientapi/routing/receipt.go @@ -0,0 +1,54 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "fmt" + "net/http" + "time" + + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/eduserver/api" + + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" +) + +func SetReceipt(req *http.Request, eduAPI api.EDUServerInputAPI, device *userapi.Device, roomId, receiptType, eventId string) util.JSONResponse { + timestamp := gomatrixserverlib.AsTimestamp(time.Now()) + logrus.WithFields(logrus.Fields{ + "roomId": roomId, + "receiptType": receiptType, + "eventId": eventId, + "userId": device.UserID, + "timestamp": timestamp, + }).Debug("Setting receipt") + + // currently only m.read is accepted + if receiptType != "m.read" { + return util.MessageResponse(400, fmt.Sprintf("receipt type must be m.read not '%s'", receiptType)) + } + + if err := api.SendReceipt(req.Context(), eduAPI, device.UserID, roomId, eventId, receiptType, timestamp); err != nil { + return util.ErrorResponse(err) + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } +} diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 756eafe2f..90e9eed38 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -113,7 +113,7 @@ var ( // TODO: Remove old sessions. Need to do so on a session-specific timeout. // sessions stores the completed flow stages for all sessions. Referenced using their sessionID. sessions = newSessionsDict() - validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-./]+$`) + validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`) ) // registerRequest represents the submitted registration request. diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 4f99237f5..99d1bd099 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -705,7 +705,20 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return SaveReadMarker(req, userAPI, rsAPI, syncProducer, device, vars["roomID"]) + return SaveReadMarker(req, userAPI, rsAPI, eduAPI, syncProducer, device, vars["roomID"]) + }), + ).Methods(http.MethodPost, http.MethodOptions) + + r0mux.Handle("/rooms/{roomID}/forget", + httputil.MakeAuthAPI("rooms_forget", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return SendForget(req, device, vars["roomID"], rsAPI) }), ).Methods(http.MethodPost, http.MethodOptions) @@ -830,4 +843,17 @@ func Setup( return ClaimKeys(req, keyAPI) }), ).Methods(http.MethodPost, http.MethodOptions) + r0mux.Handle("/rooms/{roomId}/receipt/{receiptType}/{eventId}", + httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + + return SetReceipt(req, eduAPI, device, vars["roomId"], vars["receiptType"], vars["eventId"]) + }), + ).Methods(http.MethodPost, http.MethodOptions) } diff --git a/eduserver/api/input.go b/eduserver/api/input.go index 0d0d21f33..f8599e1cc 100644 --- a/eduserver/api/input.go +++ b/eduserver/api/input.go @@ -59,6 +59,22 @@ type InputSendToDeviceEventRequest struct { // InputSendToDeviceEventResponse is a response to InputSendToDeviceEventRequest type InputSendToDeviceEventResponse struct{} +type InputReceiptEvent struct { + UserID string `json:"user_id"` + RoomID string `json:"room_id"` + EventID string `json:"event_id"` + Type string `json:"type"` + Timestamp gomatrixserverlib.Timestamp `json:"timestamp"` +} + +// InputReceiptEventRequest is a request to EDUServerInputAPI +type InputReceiptEventRequest struct { + InputReceiptEvent InputReceiptEvent `json:"input_receipt_event"` +} + +// InputReceiptEventResponse is a response to InputReceiptEventRequest +type InputReceiptEventResponse struct{} + // EDUServerInputAPI is used to write events to the typing server. type EDUServerInputAPI interface { InputTypingEvent( @@ -72,4 +88,10 @@ type EDUServerInputAPI interface { request *InputSendToDeviceEventRequest, response *InputSendToDeviceEventResponse, ) error + + InputReceiptEvent( + ctx context.Context, + request *InputReceiptEventRequest, + response *InputReceiptEventResponse, + ) error } diff --git a/eduserver/api/output.go b/eduserver/api/output.go index e6ded8413..650458a29 100644 --- a/eduserver/api/output.go +++ b/eduserver/api/output.go @@ -49,3 +49,39 @@ type OutputSendToDeviceEvent struct { DeviceID string `json:"device_id"` gomatrixserverlib.SendToDeviceEvent } + +type ReceiptEvent struct { + UserID string `json:"user_id"` + RoomID string `json:"room_id"` + EventID string `json:"event_id"` + Type string `json:"type"` + Timestamp gomatrixserverlib.Timestamp `json:"timestamp"` +} + +// OutputReceiptEvent is an entry in the receipt output kafka log +type OutputReceiptEvent struct { + UserID string `json:"user_id"` + RoomID string `json:"room_id"` + EventID string `json:"event_id"` + Type string `json:"type"` + Timestamp gomatrixserverlib.Timestamp `json:"timestamp"` +} + +// Helper structs for receipts json creation +type ReceiptMRead struct { + User map[string]ReceiptTS `json:"m.read"` +} + +type ReceiptTS struct { + TS gomatrixserverlib.Timestamp `json:"ts"` +} + +// FederationSender output +type FederationReceiptMRead struct { + User map[string]FederationReceiptData `json:"m.read"` +} + +type FederationReceiptData struct { + Data ReceiptTS `json:"data"` + EventIDs []string `json:"event_ids"` +} diff --git a/eduserver/api/wrapper.go b/eduserver/api/wrapper.go index c2c4596de..7907f4d39 100644 --- a/eduserver/api/wrapper.go +++ b/eduserver/api/wrapper.go @@ -67,3 +67,22 @@ func SendToDevice( response := InputSendToDeviceEventResponse{} return eduAPI.InputSendToDeviceEvent(ctx, &request, &response) } + +// SendReceipt sends a receipt event to EDU Server +func SendReceipt( + ctx context.Context, + eduAPI EDUServerInputAPI, userID, roomID, eventID, receiptType string, + timestamp gomatrixserverlib.Timestamp, +) error { + request := InputReceiptEventRequest{ + InputReceiptEvent: InputReceiptEvent{ + UserID: userID, + RoomID: roomID, + EventID: eventID, + Type: receiptType, + Timestamp: timestamp, + }, + } + response := InputReceiptEventResponse{} + return eduAPI.InputReceiptEvent(ctx, &request, &response) +} diff --git a/eduserver/eduserver.go b/eduserver/eduserver.go index 098ac0248..d5ab36818 100644 --- a/eduserver/eduserver.go +++ b/eduserver/eduserver.go @@ -49,8 +49,9 @@ func NewInternalAPI( Cache: eduCache, UserAPI: userAPI, Producer: producer, - OutputTypingEventTopic: string(cfg.Matrix.Kafka.TopicFor(config.TopicOutputTypingEvent)), - OutputSendToDeviceEventTopic: string(cfg.Matrix.Kafka.TopicFor(config.TopicOutputSendToDeviceEvent)), + OutputTypingEventTopic: cfg.Matrix.Kafka.TopicFor(config.TopicOutputTypingEvent), + OutputSendToDeviceEventTopic: cfg.Matrix.Kafka.TopicFor(config.TopicOutputSendToDeviceEvent), + OutputReceiptEventTopic: cfg.Matrix.Kafka.TopicFor(config.TopicOutputReceiptEvent), ServerName: cfg.Matrix.ServerName, } } diff --git a/eduserver/input/input.go b/eduserver/input/input.go index e3d2c55e3..c54fb9de8 100644 --- a/eduserver/input/input.go +++ b/eduserver/input/input.go @@ -37,6 +37,8 @@ type EDUServerInputAPI struct { OutputTypingEventTopic string // The kafka topic to output new send to device events to. OutputSendToDeviceEventTopic string + // The kafka topic to output new receipt events to + OutputReceiptEventTopic string // kafka producer Producer sarama.SyncProducer // Internal user query API @@ -173,3 +175,31 @@ func (t *EDUServerInputAPI) sendToDeviceEvent(ise *api.InputSendToDeviceEvent) e return nil } + +// InputReceiptEvent implements api.EDUServerInputAPI +// TODO: Intelligently batch requests sent by the same user (e.g wait a few milliseconds before emitting output events) +func (t *EDUServerInputAPI) InputReceiptEvent( + ctx context.Context, + request *api.InputReceiptEventRequest, + response *api.InputReceiptEventResponse, +) error { + logrus.WithFields(logrus.Fields{}).Infof("Producing to topic '%s'", t.OutputReceiptEventTopic) + output := &api.OutputReceiptEvent{ + UserID: request.InputReceiptEvent.UserID, + RoomID: request.InputReceiptEvent.RoomID, + EventID: request.InputReceiptEvent.EventID, + Type: request.InputReceiptEvent.Type, + Timestamp: request.InputReceiptEvent.Timestamp, + } + js, err := json.Marshal(output) + if err != nil { + return err + } + m := &sarama.ProducerMessage{ + Topic: t.OutputReceiptEventTopic, + Key: sarama.StringEncoder(request.InputReceiptEvent.RoomID + ":" + request.InputReceiptEvent.UserID), + Value: sarama.ByteEncoder(js), + } + _, _, err = t.Producer.SendMessage(m) + return err +} diff --git a/eduserver/inthttp/client.go b/eduserver/inthttp/client.go index 7d0bc1603..0690ed827 100644 --- a/eduserver/inthttp/client.go +++ b/eduserver/inthttp/client.go @@ -14,6 +14,7 @@ import ( const ( EDUServerInputTypingEventPath = "/eduserver/input" EDUServerInputSendToDeviceEventPath = "/eduserver/sendToDevice" + EDUServerInputReceiptEventPath = "/eduserver/receipt" ) // NewEDUServerClient creates a EDUServerInputAPI implemented by talking to a HTTP POST API. @@ -54,3 +55,16 @@ func (h *httpEDUServerInputAPI) InputSendToDeviceEvent( apiURL := h.eduServerURL + EDUServerInputSendToDeviceEventPath return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } + +// InputSendToDeviceEvent implements EDUServerInputAPI +func (h *httpEDUServerInputAPI) InputReceiptEvent( + ctx context.Context, + request *api.InputReceiptEventRequest, + response *api.InputReceiptEventResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "InputReceiptEventPath") + defer span.Finish() + + apiURL := h.eduServerURL + EDUServerInputReceiptEventPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} diff --git a/eduserver/inthttp/server.go b/eduserver/inthttp/server.go index e374513a3..a34943750 100644 --- a/eduserver/inthttp/server.go +++ b/eduserver/inthttp/server.go @@ -38,4 +38,17 @@ func AddRoutes(t api.EDUServerInputAPI, internalAPIMux *mux.Router) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(EDUServerInputReceiptEventPath, + httputil.MakeInternalAPI("inputReceiptEvent", func(req *http.Request) util.JSONResponse { + var request api.InputReceiptEventRequest + var response api.InputReceiptEventResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := t.InputReceiptEvent(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) } diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 76dc3a2ee..79fbcb3d4 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -322,12 +322,69 @@ func (t *txnReq) processEDUs(ctx context.Context) { } case gomatrixserverlib.MDeviceListUpdate: t.processDeviceListUpdate(ctx, e) + case gomatrixserverlib.MReceipt: + // https://matrix.org/docs/spec/server_server/r0.1.4#receipts + payload := map[string]eduserverAPI.FederationReceiptMRead{} + + if err := json.Unmarshal(e.Content, &payload); err != nil { + util.GetLogger(ctx).WithError(err).Error("Failed to unmarshal receipt event") + continue + } + + for roomID, receipt := range payload { + for userID, mread := range receipt.User { + _, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("Failed to split domain from receipt event sender") + continue + } + if t.Origin != domain { + util.GetLogger(ctx).Warnf("Dropping receipt event where sender domain (%q) doesn't match origin (%q)", domain, t.Origin) + continue + } + if err := t.processReceiptEvent(ctx, userID, roomID, "m.read", mread.Data.TS, mread.EventIDs); err != nil { + util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{ + "sender": t.Origin, + "user_id": userID, + "room_id": roomID, + "events": mread.EventIDs, + }).Error("Failed to send receipt event to edu server") + continue + } + } + } default: util.GetLogger(ctx).WithField("type", e.Type).Debug("Unhandled EDU") } } } +// processReceiptEvent sends receipt events to the edu server +func (t *txnReq) processReceiptEvent(ctx context.Context, + userID, roomID, receiptType string, + timestamp gomatrixserverlib.Timestamp, + eventIDs []string, +) error { + // store every event + for _, eventID := range eventIDs { + req := eduserverAPI.InputReceiptEventRequest{ + InputReceiptEvent: eduserverAPI.InputReceiptEvent{ + UserID: userID, + RoomID: roomID, + EventID: eventID, + Type: receiptType, + Timestamp: timestamp, + }, + } + resp := eduserverAPI.InputReceiptEventResponse{} + if err := t.eduAPI.InputReceiptEvent(ctx, &req, &resp); err != nil { + return fmt.Errorf("unable to set receipt event: %w", err) + } + } + + return nil +} + func (t *txnReq) processDeviceListUpdate(ctx context.Context, e gomatrixserverlib.EDU) { var payload gomatrixserverlib.DeviceListUpdateEvent if err := json.Unmarshal(e.Content, &payload); err != nil { diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index 0a462433c..9398fef70 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -76,6 +76,14 @@ func (p *testEDUProducer) InputSendToDeviceEvent( return nil } +func (o *testEDUProducer) InputReceiptEvent( + ctx context.Context, + request *eduAPI.InputReceiptEventRequest, + response *eduAPI.InputReceiptEventResponse, +) error { + return nil +} + type testRoomserverAPI struct { inputRoomEvents []api.InputRoomEvent queryMissingAuthPrevEvents func(*api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse @@ -84,6 +92,10 @@ type testRoomserverAPI struct { queryLatestEventsAndState func(*api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse } +func (t *testRoomserverAPI) PerformForget(ctx context.Context, req *api.PerformForgetRequest, resp *api.PerformForgetResponse) error { + return nil +} + func (t *testRoomserverAPI) SetFederationSenderAPI(fsAPI fsAPI.FederationSenderInternalAPI) {} func (t *testRoomserverAPI) InputRoomEvents( diff --git a/federationsender/consumers/eduserver.go b/federationsender/consumers/eduserver.go index d9ac41b3b..9d7574e68 100644 --- a/federationsender/consumers/eduserver.go +++ b/federationsender/consumers/eduserver.go @@ -34,6 +34,7 @@ import ( type OutputEDUConsumer struct { typingConsumer *internal.ContinualConsumer sendToDeviceConsumer *internal.ContinualConsumer + receiptConsumer *internal.ContinualConsumer db storage.Database queues *queue.OutgoingQueues ServerName gomatrixserverlib.ServerName @@ -51,24 +52,31 @@ func NewOutputEDUConsumer( c := &OutputEDUConsumer{ typingConsumer: &internal.ContinualConsumer{ ComponentName: "eduserver/typing", - Topic: string(cfg.Matrix.Kafka.TopicFor(config.TopicOutputTypingEvent)), + Topic: cfg.Matrix.Kafka.TopicFor(config.TopicOutputTypingEvent), Consumer: kafkaConsumer, PartitionStore: store, }, sendToDeviceConsumer: &internal.ContinualConsumer{ ComponentName: "eduserver/sendtodevice", - Topic: string(cfg.Matrix.Kafka.TopicFor(config.TopicOutputSendToDeviceEvent)), + Topic: cfg.Matrix.Kafka.TopicFor(config.TopicOutputSendToDeviceEvent), + Consumer: kafkaConsumer, + PartitionStore: store, + }, + receiptConsumer: &internal.ContinualConsumer{ + ComponentName: "eduserver/receipt", + Topic: cfg.Matrix.Kafka.TopicFor(config.TopicOutputReceiptEvent), Consumer: kafkaConsumer, PartitionStore: store, }, queues: queues, db: store, ServerName: cfg.Matrix.ServerName, - TypingTopic: string(cfg.Matrix.Kafka.TopicFor(config.TopicOutputTypingEvent)), - SendToDeviceTopic: string(cfg.Matrix.Kafka.TopicFor(config.TopicOutputSendToDeviceEvent)), + TypingTopic: cfg.Matrix.Kafka.TopicFor(config.TopicOutputTypingEvent), + SendToDeviceTopic: cfg.Matrix.Kafka.TopicFor(config.TopicOutputSendToDeviceEvent), } c.typingConsumer.ProcessMessage = c.onTypingEvent c.sendToDeviceConsumer.ProcessMessage = c.onSendToDeviceEvent + c.receiptConsumer.ProcessMessage = c.onReceiptEvent return c } @@ -81,6 +89,9 @@ func (t *OutputEDUConsumer) Start() error { if err := t.sendToDeviceConsumer.Start(); err != nil { return fmt.Errorf("t.sendToDeviceConsumer.Start: %w", err) } + if err := t.receiptConsumer.Start(); err != nil { + return fmt.Errorf("t.receiptConsumer.Start: %w", err) + } return nil } @@ -177,3 +188,58 @@ func (t *OutputEDUConsumer) onTypingEvent(msg *sarama.ConsumerMessage) error { return t.queues.SendEDU(edu, t.ServerName, names) } + +// onReceiptEvent is called in response to a message received on the receipt +// events topic from the EDU server. +func (t *OutputEDUConsumer) onReceiptEvent(msg *sarama.ConsumerMessage) error { + // Extract the typing event from msg. + var receipt api.OutputReceiptEvent + if err := json.Unmarshal(msg.Value, &receipt); err != nil { + // Skip this msg but continue processing messages. + log.WithError(err).Errorf("eduserver output log: message parse failed (expected receipt)") + return nil + } + + // only send receipt events which originated from us + _, receiptServerName, err := gomatrixserverlib.SplitID('@', receipt.UserID) + if err != nil { + log.WithError(err).WithField("user_id", receipt.UserID).Error("Failed to extract domain from receipt sender") + return nil + } + if receiptServerName != t.ServerName { + log.WithField("other_server", receiptServerName).Info("Suppressing receipt notif: originated elsewhere") + return nil + } + + joined, err := t.db.GetJoinedHosts(context.TODO(), receipt.RoomID) + if err != nil { + return err + } + + names := make([]gomatrixserverlib.ServerName, len(joined)) + for i := range joined { + names[i] = joined[i].ServerName + } + + content := map[string]api.FederationReceiptMRead{} + content[receipt.RoomID] = api.FederationReceiptMRead{ + User: map[string]api.FederationReceiptData{ + receipt.UserID: { + Data: api.ReceiptTS{ + TS: receipt.Timestamp, + }, + EventIDs: []string{receipt.EventID}, + }, + }, + } + + edu := &gomatrixserverlib.EDU{ + Type: gomatrixserverlib.MReceipt, + Origin: string(t.ServerName), + } + if edu.Content, err = json.Marshal(content); err != nil { + return err + } + + return t.queues.SendEDU(edu, t.ServerName, names) +} diff --git a/go.mod b/go.mod index f785dd391..d2f0ae260 100644 --- a/go.mod +++ b/go.mod @@ -33,7 +33,7 @@ require ( github.com/pressly/goose v2.7.0-rc5+incompatible github.com/prometheus/client_golang v1.7.1 github.com/sirupsen/logrus v1.6.0 - github.com/tidwall/gjson v1.6.1 + github.com/tidwall/gjson v1.6.3 github.com/tidwall/sjson v1.1.1 github.com/uber/jaeger-client-go v2.25.0+incompatible github.com/uber/jaeger-lib v2.2.0+incompatible diff --git a/go.sum b/go.sum index 7c24516d2..0300e368e 100644 --- a/go.sum +++ b/go.sum @@ -812,8 +812,8 @@ github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpP github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= github.com/tidwall/gjson v1.6.0 h1:9VEQWz6LLMUsUl6PueE49ir4Ka6CzLymOAZDxpFsTDc= github.com/tidwall/gjson v1.6.0/go.mod h1:P256ACg0Mn+j1RXIDXoss50DeIABTYK1PULOJHhxOls= -github.com/tidwall/gjson v1.6.1 h1:LRbvNuNuvAiISWg6gxLEFuCe72UKy5hDqhxW/8183ws= -github.com/tidwall/gjson v1.6.1/go.mod h1:BaHyNc5bjzYkPqgLq7mdVzeiRtULKULXLgZFKsxEHI0= +github.com/tidwall/gjson v1.6.3 h1:aHoiiem0dr7GHkW001T1SMTJ7X5PvyekH5WX0whWGnI= +github.com/tidwall/gjson v1.6.3/go.mod h1:BaHyNc5bjzYkPqgLq7mdVzeiRtULKULXLgZFKsxEHI0= github.com/tidwall/match v1.0.1 h1:PnKP62LPNxHKTwvHHZZzdOAOCtsJTjo6dZLCwpKm5xc= github.com/tidwall/match v1.0.1/go.mod h1:LujAq0jyVjBy028G1WhWfIzbpQfMO8bBZ6Tyb0+pL9E= github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= diff --git a/internal/caching/cache_roomservernids.go b/internal/caching/cache_roomservernids.go index 7cb312c95..cac595494 100644 --- a/internal/caching/cache_roomservernids.go +++ b/internal/caching/cache_roomservernids.go @@ -1,6 +1,8 @@ package caching import ( + "strconv" + "github.com/matrix-org/dendrite/roomserver/types" ) @@ -83,11 +85,11 @@ func (c Caches) GetRoomServerRoomNID(roomID string) (types.RoomNID, bool) { func (c Caches) StoreRoomServerRoomNID(roomID string, roomNID types.RoomNID) { c.RoomServerRoomNIDs.Set(roomID, roomNID) - c.RoomServerRoomIDs.Set(string(roomNID), roomID) + c.RoomServerRoomIDs.Set(strconv.Itoa(int(roomNID)), roomID) } func (c Caches) GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) { - val, found := c.RoomServerRoomIDs.Get(string(roomNID)) + val, found := c.RoomServerRoomIDs.Get(strconv.Itoa(int(roomNID))) if found && val != nil { if roomID, ok := val.(string); ok { return roomID, true diff --git a/internal/config/config_kafka.go b/internal/config/config_kafka.go index 707c92a71..aa91e5589 100644 --- a/internal/config/config_kafka.go +++ b/internal/config/config_kafka.go @@ -9,6 +9,7 @@ const ( TopicOutputKeyChangeEvent = "OutputKeyChangeEvent" TopicOutputRoomEvent = "OutputRoomEvent" TopicOutputClientData = "OutputClientData" + TopicOutputReceiptEvent = "OutputReceiptEvent" ) type Kafka struct { diff --git a/internal/setup/flags.go b/internal/setup/flags.go index e4fc58d60..c6ecb5cd1 100644 --- a/internal/setup/flags.go +++ b/internal/setup/flags.go @@ -16,18 +16,28 @@ package setup import ( "flag" + "fmt" + "os" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/config" - "github.com/sirupsen/logrus" ) -var configPath = flag.String("config", "dendrite.yaml", "The path to the config file. For more information, see the config file in this repository.") +var ( + configPath = flag.String("config", "dendrite.yaml", "The path to the config file. For more information, see the config file in this repository.") + version = flag.Bool("version", false, "Shows the current version and exits immediately.") +) // ParseFlags parses the commandline flags and uses them to create a config. func ParseFlags(monolith bool) *config.Dendrite { flag.Parse() + if *version { + fmt.Println(internal.VersionString()) + os.Exit(0) + } + if *configPath == "" { logrus.Fatal("--config must be supplied") } diff --git a/internal/transactions/transactions_test.go b/internal/transactions/transactions_test.go index f565e4846..aa837f76c 100644 --- a/internal/transactions/transactions_test.go +++ b/internal/transactions/transactions_test.go @@ -14,6 +14,7 @@ package transactions import ( "net/http" + "strconv" "testing" "github.com/matrix-org/util" @@ -44,8 +45,8 @@ func TestCache(t *testing.T) { for i := 1; i <= 100; i++ { fakeTxnCache.AddTransaction( fakeAccessToken, - fakeTxnID+string(i), - &util.JSONResponse{Code: http.StatusOK, JSON: fakeType{ID: string(i)}}, + fakeTxnID+strconv.Itoa(i), + &util.JSONResponse{Code: http.StatusOK, JSON: fakeType{ID: strconv.Itoa(i)}}, ) } diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 043f72221..2683918a8 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -147,6 +147,9 @@ type RoomserverInternalAPI interface { response *PerformBackfillResponse, ) error + // PerformForget forgets a rooms history for a specific user + PerformForget(ctx context.Context, req *PerformForgetRequest, resp *PerformForgetResponse) error + // Asks for the default room version as preferred by the server. QueryRoomVersionCapabilities( ctx context.Context, diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go index f4eaddc1e..e625fb04a 100644 --- a/roomserver/api/api_trace.go +++ b/roomserver/api/api_trace.go @@ -194,6 +194,16 @@ func (t *RoomserverInternalAPITrace) PerformBackfill( return err } +func (t *RoomserverInternalAPITrace) PerformForget( + ctx context.Context, + req *PerformForgetRequest, + res *PerformForgetResponse, +) error { + err := t.Impl.PerformForget(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("PerformForget req=%+v res=%+v", js(req), js(res)) + return err +} + func (t *RoomserverInternalAPITrace) QueryRoomVersionCapabilities( ctx context.Context, req *QueryRoomVersionCapabilitiesRequest, diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go index 0c2d96a7d..eda53c3e4 100644 --- a/roomserver/api/perform.go +++ b/roomserver/api/perform.go @@ -159,3 +159,11 @@ type PerformPublishResponse struct { // If non-nil, the publish request failed. Contains more information why it failed. Error *PerformError } + +// PerformForgetRequest is a request to PerformForget +type PerformForgetRequest struct { + RoomID string `json:"room_id"` + UserID string `json:"user_id"` +} + +type PerformForgetResponse struct{} diff --git a/roomserver/api/query.go b/roomserver/api/query.go index 3afca7e81..bdfbf6fbc 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -140,7 +140,9 @@ type QueryMembershipForUserResponse struct { // True if the user is in room. IsInRoom bool `json:"is_in_room"` // The current membership - Membership string + Membership string `json:"membership"` + // True if the user asked to forget this room. + IsRoomForgotten bool `json:"is_room_forgotten"` } // QueryMembershipsForRoomRequest is a request to QueryMembershipsForRoom @@ -160,6 +162,8 @@ type QueryMembershipsForRoomResponse struct { // True if the user has been in room before and has either stayed in it or // left it. HasBeenInRoom bool `json:"has_been_in_room"` + // True if the user asked to forget this room. + IsRoomForgotten bool `json:"is_room_forgotten"` } // QueryServerJoinedToRoomRequest is a request to QueryServerJoinedToRoom diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index ee4e4ec96..443cc6b38 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -26,6 +26,7 @@ type RoomserverInternalAPI struct { *perform.Leaver *perform.Publisher *perform.Backfiller + *perform.Forgetter DB storage.Database Cfg *config.RoomServer Producer sarama.SyncProducer @@ -112,6 +113,9 @@ func (r *RoomserverInternalAPI) SetFederationSenderAPI(fsAPI fsAPI.FederationSen // than trying random servers PreferServers: r.PerspectiveServerNames, } + r.Forgetter = &perform.Forgetter{ + DB: r.DB, + } } func (r *RoomserverInternalAPI) PerformInvite( @@ -143,3 +147,11 @@ func (r *RoomserverInternalAPI) PerformLeave( } return r.WriteOutputEvents(req.RoomID, outputEvents) } + +func (r *RoomserverInternalAPI) PerformForget( + ctx context.Context, + req *api.PerformForgetRequest, + resp *api.PerformForgetResponse, +) error { + return r.Forgetter.PerformForget(ctx, req, resp) +} diff --git a/roomserver/internal/perform/perform_forget.go b/roomserver/internal/perform/perform_forget.go new file mode 100644 index 000000000..e970d9a88 --- /dev/null +++ b/roomserver/internal/perform/perform_forget.go @@ -0,0 +1,35 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package perform + +import ( + "context" + + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/storage" +) + +type Forgetter struct { + DB storage.Database +} + +// PerformForget implements api.RoomServerQueryAPI +func (f *Forgetter) PerformForget( + ctx context.Context, + request *api.PerformForgetRequest, + response *api.PerformForgetResponse, +) error { + return f.DB.ForgetRoom(ctx, request.UserID, request.RoomID, true) +} diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 734e73d43..0630ed455 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -86,7 +86,7 @@ func (r *Inviter) PerformInvite( var isAlreadyJoined bool if info != nil { - _, isAlreadyJoined, err = r.DB.GetMembership(ctx, info.RoomNID, *event.StateKey()) + _, isAlreadyJoined, _, err = r.DB.GetMembership(ctx, info.RoomNID, *event.StateKey()) if err != nil { return nil, fmt.Errorf("r.DB.GetMembership: %w", err) } diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index ecfb580f2..64ece4eb8 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -204,11 +204,13 @@ func (r *Queryer) QueryMembershipForUser( return fmt.Errorf("QueryMembershipForUser: unknown room %s", request.RoomID) } - membershipEventNID, stillInRoom, err := r.DB.GetMembership(ctx, info.RoomNID, request.UserID) + membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.UserID) if err != nil { return err } + response.IsRoomForgotten = isRoomforgotten + if membershipEventNID == 0 { response.HasBeenInRoom = false return nil @@ -241,11 +243,13 @@ func (r *Queryer) QueryMembershipsForRoom( return err } - membershipEventNID, stillInRoom, err := r.DB.GetMembership(ctx, info.RoomNID, request.Sender) + membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.Sender) if err != nil { return err } + response.IsRoomForgotten = isRoomforgotten + if membershipEventNID == 0 { response.HasBeenInRoom = false response.JoinEvents = nil diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index 24a82adf8..f5b66ca6a 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -31,6 +31,7 @@ const ( RoomserverPerformLeavePath = "/roomserver/performLeave" RoomserverPerformBackfillPath = "/roomserver/performBackfill" RoomserverPerformPublishPath = "/roomserver/performPublish" + RoomserverPerformForgetPath = "/roomserver/performForget" // Query operations RoomserverQueryLatestEventsAndStatePath = "/roomserver/queryLatestEventsAndState" @@ -492,3 +493,12 @@ func (h *httpRoomserverInternalAPI) QueryServerBannedFromRoom( apiURL := h.roomserverURL + RoomserverQueryServerBannedFromRoomPath return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) } + +func (h *httpRoomserverInternalAPI) PerformForget(ctx context.Context, req *api.PerformForgetRequest, res *api.PerformForgetResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformForget") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverPerformForgetPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) + +} diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go index 9c9d4d4ae..2bc8f82df 100644 --- a/roomserver/inthttp/server.go +++ b/roomserver/inthttp/server.go @@ -251,6 +251,20 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle( + RoomserverPerformForgetPath, + httputil.MakeInternalAPI("PerformForget", func(req *http.Request) util.JSONResponse { + var request api.PerformForgetRequest + var response api.PerformForgetResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.PerformForget(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) internalAPIMux.Handle( RoomserverQueryRoomVersionCapabilitiesPath, httputil.MakeInternalAPI("QueryRoomVersionCapabilities", func(req *http.Request) util.JSONResponse { diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 10a380e85..c6f5c8082 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -126,7 +126,7 @@ type Database interface { // in this room, along a boolean set to true if the user is still in this room, // false if not. // Returns an error if there was a problem talking to the database. - GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom bool, err error) + GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom, isRoomForgotten bool, err error) // Lookup the membership event numeric IDs for all user that are or have // been members of a given room. Only lookup events of "join" membership if // joinOnly is set to true. @@ -158,4 +158,6 @@ type Database interface { GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) // GetKnownRooms returns a list of all rooms we know about. GetKnownRooms(ctx context.Context) ([]string, error) + // ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room + ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error } diff --git a/roomserver/storage/postgres/deltas/20201028212440_add_forgotten_column.go b/roomserver/storage/postgres/deltas/20201028212440_add_forgotten_column.go new file mode 100644 index 000000000..733f0fa14 --- /dev/null +++ b/roomserver/storage/postgres/deltas/20201028212440_add_forgotten_column.go @@ -0,0 +1,47 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/pressly/goose" +) + +func LoadFromGoose() { + goose.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn) +} + +func LoadAddForgottenColumn(m *sqlutil.Migrations) { + m.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn) +} + +func UpAddForgottenColumn(tx *sql.Tx) error { + _, err := tx.Exec(`ALTER TABLE roomserver_membership ADD COLUMN IF NOT EXISTS forgotten BOOLEAN NOT NULL DEFAULT false;`) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownAddForgottenColumn(tx *sql.Tx) error { + _, err := tx.Exec(`ALTER TABLE roomserver_membership DROP COLUMN IF EXISTS forgotten;`) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index 5164f654f..e392a4fbb 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -60,13 +60,15 @@ CREATE TABLE IF NOT EXISTS roomserver_membership ( -- a federated one. This is an optimisation for resetting state on federated -- room joins. target_local BOOLEAN NOT NULL DEFAULT false, + forgotten BOOLEAN NOT NULL DEFAULT FALSE, UNIQUE (room_nid, target_nid) ); ` var selectJoinedUsersSetForRoomsSQL = "" + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid = ANY($1) AND" + - " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " GROUP BY target_nid" + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + + " GROUP BY target_nid" // Insert a row in to membership table so that it can be locked by the // SELECT FOR UPDATE @@ -76,37 +78,41 @@ const insertMembershipSQL = "" + " ON CONFLICT DO NOTHING" const selectMembershipFromRoomAndTargetSQL = "" + - "SELECT membership_nid, event_nid FROM roomserver_membership" + + "SELECT membership_nid, event_nid, forgotten FROM roomserver_membership" + " WHERE room_nid = $1 AND target_nid = $2" const selectMembershipsFromRoomAndMembershipSQL = "" + "SELECT event_nid FROM roomserver_membership" + - " WHERE room_nid = $1 AND membership_nid = $2" + " WHERE room_nid = $1 AND membership_nid = $2 and forgotten = false" const selectLocalMembershipsFromRoomAndMembershipSQL = "" + "SELECT event_nid FROM roomserver_membership" + " WHERE room_nid = $1 AND membership_nid = $2" + - " AND target_local = true" + " AND target_local = true and forgotten = false" const selectMembershipsFromRoomSQL = "" + "SELECT event_nid FROM roomserver_membership" + - " WHERE room_nid = $1" + " WHERE room_nid = $1 and forgotten = false" const selectLocalMembershipsFromRoomSQL = "" + "SELECT event_nid FROM roomserver_membership" + " WHERE room_nid = $1" + - " AND target_local = true" + " AND target_local = true and forgotten = false" const selectMembershipForUpdateSQL = "" + "SELECT membership_nid FROM roomserver_membership" + " WHERE room_nid = $1 AND target_nid = $2 FOR UPDATE" const updateMembershipSQL = "" + - "UPDATE roomserver_membership SET sender_nid = $3, membership_nid = $4, event_nid = $5" + + "UPDATE roomserver_membership SET sender_nid = $3, membership_nid = $4, event_nid = $5, forgotten = $6" + + " WHERE room_nid = $1 AND target_nid = $2" + +const updateMembershipForgetRoom = "" + + "UPDATE roomserver_membership SET forgotten = $3" + " WHERE room_nid = $1 AND target_nid = $2" const selectRoomsWithMembershipSQL = "" + - "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2" + "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false" // selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is // joined to. Since this information is used to populate the user directory, we will @@ -130,6 +136,7 @@ type membershipStatements struct { selectRoomsWithMembershipStmt *sql.Stmt selectJoinedUsersSetForRoomsStmt *sql.Stmt selectKnownUsersStmt *sql.Stmt + updateMembershipForgetRoomStmt *sql.Stmt } func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) { @@ -151,9 +158,15 @@ func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) { {&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL}, {&s.selectJoinedUsersSetForRoomsStmt, selectJoinedUsersSetForRoomsSQL}, {&s.selectKnownUsersStmt, selectKnownUsersSQL}, + {&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom}, }.Prepare(db) } +func (s *membershipStatements) execSchema(db *sql.DB) error { + _, err := db.Exec(membershipSchema) + return err +} + func (s *membershipStatements) InsertMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, @@ -177,10 +190,10 @@ func (s *membershipStatements) SelectMembershipForUpdate( func (s *membershipStatements) SelectMembershipFromRoomAndTarget( ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, -) (eventNID types.EventNID, membership tables.MembershipState, err error) { +) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) { err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext( ctx, roomNID, targetUserNID, - ).Scan(&membership, &eventNID) + ).Scan(&membership, &eventNID, &forgotten) return } @@ -238,12 +251,11 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( func (s *membershipStatements) UpdateMembership( ctx context.Context, - txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, - senderUserNID types.EventStateKeyNID, membership tables.MembershipState, - eventNID types.EventNID, + txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState, + eventNID types.EventNID, forgotten bool, ) error { _, err := sqlutil.TxStmt(txn, s.updateMembershipStmt).ExecContext( - ctx, roomNID, targetUserNID, senderUserNID, membership, eventNID, + ctx, roomNID, targetUserNID, senderUserNID, membership, eventNID, forgotten, ) return err } @@ -305,3 +317,14 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type } return result, rows.Err() } + +func (s *membershipStatements) UpdateForgetMembership( + ctx context.Context, + txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + forget bool, +) error { + _, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext( + ctx, roomNID, targetUserNID, forget, + ) + return err +} diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 02ff072d7..37aca647c 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -18,12 +18,13 @@ package postgres import ( "database/sql" + // Import the postgres database driver. + _ "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/sqlutil" - - // Import the postgres database driver. - _ "github.com/lib/pq" + "github.com/matrix-org/dendrite/roomserver/storage/postgres/deltas" "github.com/matrix-org/dendrite/roomserver/storage/shared" ) @@ -33,7 +34,6 @@ type Database struct { } // Open a postgres database. -// nolint: gocyclo func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) { var d Database var db *sql.DB @@ -41,61 +41,82 @@ func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) if db, err = sqlutil.Open(dbProperties); err != nil { return nil, err } + + // Create tables before executing migrations so we don't fail if the table is missing, + // and THEN prepare statements so we don't fail due to referencing new columns + ms := membershipStatements{} + if err := ms.execSchema(db); err != nil { + return nil, err + } + m := sqlutil.NewMigrations() + deltas.LoadAddForgottenColumn(m) + if err := m.RunDeltas(db, dbProperties); err != nil { + return nil, err + } + if err := d.prepare(db, cache); err != nil { + return nil, err + } + + return &d, nil +} + +// nolint: gocyclo +func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) (err error) { eventStateKeys, err := NewPostgresEventStateKeysTable(db) if err != nil { - return nil, err + return err } eventTypes, err := NewPostgresEventTypesTable(db) if err != nil { - return nil, err + return err } eventJSON, err := NewPostgresEventJSONTable(db) if err != nil { - return nil, err + return err } events, err := NewPostgresEventsTable(db) if err != nil { - return nil, err + return err } rooms, err := NewPostgresRoomsTable(db) if err != nil { - return nil, err + return err } transactions, err := NewPostgresTransactionsTable(db) if err != nil { - return nil, err + return err } stateBlock, err := NewPostgresStateBlockTable(db) if err != nil { - return nil, err + return err } stateSnapshot, err := NewPostgresStateSnapshotTable(db) if err != nil { - return nil, err + return err } roomAliases, err := NewPostgresRoomAliasesTable(db) if err != nil { - return nil, err + return err } prevEvents, err := NewPostgresPreviousEventsTable(db) if err != nil { - return nil, err + return err } invites, err := NewPostgresInvitesTable(db) if err != nil { - return nil, err + return err } membership, err := NewPostgresMembershipTable(db) if err != nil { - return nil, err + return err } published, err := NewPostgresPublishedTable(db) if err != nil { - return nil, err + return err } redactions, err := NewPostgresRedactionsTable(db) if err != nil { - return nil, err + return err } d.Database = shared.Database{ DB: db, @@ -116,5 +137,5 @@ func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) PublishedTable: published, RedactionsTable: redactions, } - return &d, nil + return nil } diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go index 7abddd018..57f3a520a 100644 --- a/roomserver/storage/shared/membership_updater.go +++ b/roomserver/storage/shared/membership_updater.go @@ -101,9 +101,7 @@ func (u *MembershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, er return fmt.Errorf("u.d.InvitesTable.InsertInviteEvent: %w", err) } if u.membership != tables.MembershipStateInvite { - if err = u.d.MembershipTable.UpdateMembership( - u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0, - ); err != nil { + if err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0, false); err != nil { return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) } } @@ -139,10 +137,7 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd } if u.membership != tables.MembershipStateJoin || isUpdate { - if err = u.d.MembershipTable.UpdateMembership( - u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, - tables.MembershipStateJoin, nIDs[eventID], - ); err != nil { + if err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateJoin, nIDs[eventID], false); err != nil { return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) } } @@ -176,10 +171,7 @@ func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s } if u.membership != tables.MembershipStateLeaveOrBan { - if err = u.d.MembershipTable.UpdateMembership( - u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, - tables.MembershipStateLeaveOrBan, nIDs[eventID], - ); err != nil { + if err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateLeaveOrBan, nIDs[eventID], false); err != nil { return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) } } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index aec15ab22..5361bd213 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -258,30 +258,28 @@ func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { }) } -func (d *Database) GetMembership( - ctx context.Context, roomNID types.RoomNID, requestSenderUserID string, -) (membershipEventNID types.EventNID, stillInRoom bool, err error) { +func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom, isRoomforgotten bool, err error) { var requestSenderUserNID types.EventStateKeyNID err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { requestSenderUserNID, err = d.assignStateKeyNID(ctx, txn, requestSenderUserID) return err }) if err != nil { - return 0, false, fmt.Errorf("d.assignStateKeyNID: %w", err) + return 0, false, false, fmt.Errorf("d.assignStateKeyNID: %w", err) } - senderMembershipEventNID, senderMembership, err := + senderMembershipEventNID, senderMembership, isRoomforgotten, err := d.MembershipTable.SelectMembershipFromRoomAndTarget( ctx, roomNID, requestSenderUserNID, ) if err == sql.ErrNoRows { // The user has never been a member of that room - return 0, false, nil + return 0, false, false, nil } else if err != nil { return } - return senderMembershipEventNID, senderMembership == tables.MembershipStateJoin, nil + return senderMembershipEventNID, senderMembership == tables.MembershipStateJoin, isRoomforgotten, nil } func (d *Database) GetMembershipEventNIDsForRoom( @@ -992,6 +990,25 @@ func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) { return d.RoomsTable.SelectRoomIDs(ctx) } +// ForgetRoom sets a users room to forgotten +func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error { + roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, []string{roomID}) + if err != nil { + return err + } + if len(roomNIDs) > 1 { + return fmt.Errorf("expected one room, got %d", len(roomNIDs)) + } + stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, userID) + if err != nil { + return err + } + + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.MembershipTable.UpdateForgetMembership(ctx, nil, roomNIDs[0], stateKeyNID, forget) + }) +} + // FIXME TODO: Remove all this - horrible dupe with roomserver/state. Can't use the original impl because of circular loops // it should live in this package! diff --git a/roomserver/storage/sqlite3/deltas/20201028212440_add_forgotten_column.go b/roomserver/storage/sqlite3/deltas/20201028212440_add_forgotten_column.go new file mode 100644 index 000000000..33fe9e2a9 --- /dev/null +++ b/roomserver/storage/sqlite3/deltas/20201028212440_add_forgotten_column.go @@ -0,0 +1,82 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/pressly/goose" +) + +func LoadFromGoose() { + goose.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn) +} + +func LoadAddForgottenColumn(m *sqlutil.Migrations) { + m.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn) +} + +func UpAddForgottenColumn(tx *sql.Tx) error { + _, err := tx.Exec(` ALTER TABLE roomserver_membership RENAME TO roomserver_membership_tmp; +CREATE TABLE IF NOT EXISTS roomserver_membership ( + room_nid INTEGER NOT NULL, + target_nid INTEGER NOT NULL, + sender_nid INTEGER NOT NULL DEFAULT 0, + membership_nid INTEGER NOT NULL DEFAULT 1, + event_nid INTEGER NOT NULL DEFAULT 0, + target_local BOOLEAN NOT NULL DEFAULT false, + forgotten BOOLEAN NOT NULL DEFAULT false, + UNIQUE (room_nid, target_nid) + ); +INSERT + INTO roomserver_membership ( + room_nid, target_nid, sender_nid, membership_nid, event_nid, target_local + ) SELECT + room_nid, target_nid, sender_nid, membership_nid, event_nid, target_local + FROM roomserver_membership_tmp +; +DROP TABLE roomserver_membership_tmp;`) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownAddForgottenColumn(tx *sql.Tx) error { + _, err := tx.Exec(` ALTER TABLE roomserver_membership RENAME TO roomserver_membership_tmp; +CREATE TABLE IF NOT EXISTS roomserver_membership ( + room_nid INTEGER NOT NULL, + target_nid INTEGER NOT NULL, + sender_nid INTEGER NOT NULL DEFAULT 0, + membership_nid INTEGER NOT NULL DEFAULT 1, + event_nid INTEGER NOT NULL DEFAULT 0, + target_local BOOLEAN NOT NULL DEFAULT false, + UNIQUE (room_nid, target_nid) + ); +INSERT + INTO roomserver_membership ( + room_nid, target_nid, sender_nid, membership_nid, event_nid, target_local + ) SELECT + room_nid, target_nid, sender_nid, membership_nid, event_nid, target_local + FROM roomserver_membership_tmp +; +DROP TABLE roomserver_membership_tmp;`) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index bb1ab39aa..d716ced04 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -36,13 +36,15 @@ const membershipSchema = ` membership_nid INTEGER NOT NULL DEFAULT 1, event_nid INTEGER NOT NULL DEFAULT 0, target_local BOOLEAN NOT NULL DEFAULT false, + forgotten BOOLEAN NOT NULL DEFAULT false, UNIQUE (room_nid, target_nid) ); ` var selectJoinedUsersSetForRoomsSQL = "" + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid IN ($1) AND" + - " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " GROUP BY target_nid" + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" + + " GROUP BY target_nid" // Insert a row in to membership table so that it can be locked by the // SELECT FOR UPDATE @@ -52,37 +54,41 @@ const insertMembershipSQL = "" + " ON CONFLICT DO NOTHING" const selectMembershipFromRoomAndTargetSQL = "" + - "SELECT membership_nid, event_nid FROM roomserver_membership" + + "SELECT membership_nid, event_nid, forgotten FROM roomserver_membership" + " WHERE room_nid = $1 AND target_nid = $2" const selectMembershipsFromRoomAndMembershipSQL = "" + "SELECT event_nid FROM roomserver_membership" + - " WHERE room_nid = $1 AND membership_nid = $2" + " WHERE room_nid = $1 AND membership_nid = $2 and forgotten = false" const selectLocalMembershipsFromRoomAndMembershipSQL = "" + "SELECT event_nid FROM roomserver_membership" + " WHERE room_nid = $1 AND membership_nid = $2" + - " AND target_local = true" + " AND target_local = true and forgotten = false" const selectMembershipsFromRoomSQL = "" + "SELECT event_nid FROM roomserver_membership" + - " WHERE room_nid = $1" + " WHERE room_nid = $1 and forgotten = false" const selectLocalMembershipsFromRoomSQL = "" + "SELECT event_nid FROM roomserver_membership" + " WHERE room_nid = $1" + - " AND target_local = true" + " AND target_local = true and forgotten = false" const selectMembershipForUpdateSQL = "" + "SELECT membership_nid FROM roomserver_membership" + " WHERE room_nid = $1 AND target_nid = $2" const updateMembershipSQL = "" + - "UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3" + - " WHERE room_nid = $4 AND target_nid = $5" + "UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3, forgotten = $4" + + " WHERE room_nid = $5 AND target_nid = $6" + +const updateMembershipForgetRoom = "" + + "UPDATE roomserver_membership SET forgotten = $1" + + " WHERE room_nid = $2 AND target_nid = $3" const selectRoomsWithMembershipSQL = "" + - "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2" + "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false" // selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is // joined to. Since this information is used to populate the user directory, we will @@ -106,16 +112,13 @@ type membershipStatements struct { selectRoomsWithMembershipStmt *sql.Stmt updateMembershipStmt *sql.Stmt selectKnownUsersStmt *sql.Stmt + updateMembershipForgetRoomStmt *sql.Stmt } func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) { s := &membershipStatements{ db: db, } - _, err := db.Exec(membershipSchema) - if err != nil { - return nil, err - } return s, shared.StatementList{ {&s.insertMembershipStmt, insertMembershipSQL}, @@ -128,9 +131,15 @@ func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) { {&s.updateMembershipStmt, updateMembershipSQL}, {&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL}, {&s.selectKnownUsersStmt, selectKnownUsersSQL}, + {&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom}, }.Prepare(db) } +func (s *membershipStatements) execSchema(db *sql.DB) error { + _, err := db.Exec(membershipSchema) + return err +} + func (s *membershipStatements) InsertMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, @@ -155,10 +164,10 @@ func (s *membershipStatements) SelectMembershipForUpdate( func (s *membershipStatements) SelectMembershipFromRoomAndTarget( ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, -) (eventNID types.EventNID, membership tables.MembershipState, err error) { +) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) { err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext( ctx, roomNID, targetUserNID, - ).Scan(&membership, &eventNID) + ).Scan(&membership, &eventNID, &forgotten) return } @@ -216,13 +225,12 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( func (s *membershipStatements) UpdateMembership( ctx context.Context, txn *sql.Tx, - roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, - senderUserNID types.EventStateKeyNID, membership tables.MembershipState, - eventNID types.EventNID, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState, + eventNID types.EventNID, forgotten bool, ) error { stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt) _, err := stmt.ExecContext( - ctx, senderUserNID, membership, eventNID, roomNID, targetUserNID, + ctx, senderUserNID, membership, eventNID, forgotten, roomNID, targetUserNID, ) return err } @@ -285,3 +293,14 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type } return result, rows.Err() } + +func (s *membershipStatements) UpdateForgetMembership( + ctx context.Context, + txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + forget bool, +) error { + _, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext( + ctx, forget, roomNID, targetUserNID, + ) + return err +} diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 6d9b860f5..b36930206 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -19,127 +19,138 @@ import ( "context" "database/sql" + _ "github.com/mattn/go-sqlite3" + "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" - "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" - _ "github.com/mattn/go-sqlite3" ) // A Database is used to store room events and stream offsets. type Database struct { shared.Database - events tables.Events - eventJSON tables.EventJSON - eventTypes tables.EventTypes - eventStateKeys tables.EventStateKeys - rooms tables.Rooms - transactions tables.Transactions - prevEvents tables.PreviousEvents - invites tables.Invites - membership tables.Membership - db *sql.DB - writer sqlutil.Writer } // Open a sqlite database. -// nolint: gocyclo func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) { var d Database + var db *sql.DB var err error - if d.db, err = sqlutil.Open(dbProperties); err != nil { + if db, err = sqlutil.Open(dbProperties); err != nil { return nil, err } - d.writer = sqlutil.NewExclusiveWriter() - //d.db.Exec("PRAGMA journal_mode=WAL;") - //d.db.Exec("PRAGMA read_uncommitted = true;") + + //db.Exec("PRAGMA journal_mode=WAL;") + //db.Exec("PRAGMA read_uncommitted = true;") // FIXME: We are leaking connections somewhere. Setting this to 2 will eventually // cause the roomserver to be unresponsive to new events because something will // acquire the global mutex and never unlock it because it is waiting for a connection // which it will never obtain. - d.db.SetMaxOpenConns(20) + db.SetMaxOpenConns(20) - d.eventStateKeys, err = NewSqliteEventStateKeysTable(d.db) - if err != nil { + // Create tables before executing migrations so we don't fail if the table is missing, + // and THEN prepare statements so we don't fail due to referencing new columns + ms := membershipStatements{} + if err := ms.execSchema(db); err != nil { return nil, err } - d.eventTypes, err = NewSqliteEventTypesTable(d.db) - if err != nil { + m := sqlutil.NewMigrations() + deltas.LoadAddForgottenColumn(m) + if err := m.RunDeltas(db, dbProperties); err != nil { return nil, err } - d.eventJSON, err = NewSqliteEventJSONTable(d.db) - if err != nil { + if err := d.prepare(db, cache); err != nil { return nil, err } - d.events, err = NewSqliteEventsTable(d.db) + + return &d, nil +} + +// nolint: gocyclo +func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error { + var err error + eventStateKeys, err := NewSqliteEventStateKeysTable(db) if err != nil { - return nil, err + return err } - d.rooms, err = NewSqliteRoomsTable(d.db) + eventTypes, err := NewSqliteEventTypesTable(db) if err != nil { - return nil, err + return err } - d.transactions, err = NewSqliteTransactionsTable(d.db) + eventJSON, err := NewSqliteEventJSONTable(db) if err != nil { - return nil, err + return err } - stateBlock, err := NewSqliteStateBlockTable(d.db) + events, err := NewSqliteEventsTable(db) if err != nil { - return nil, err + return err } - stateSnapshot, err := NewSqliteStateSnapshotTable(d.db) + rooms, err := NewSqliteRoomsTable(db) if err != nil { - return nil, err + return err } - d.prevEvents, err = NewSqlitePrevEventsTable(d.db) + transactions, err := NewSqliteTransactionsTable(db) if err != nil { - return nil, err + return err } - roomAliases, err := NewSqliteRoomAliasesTable(d.db) + stateBlock, err := NewSqliteStateBlockTable(db) if err != nil { - return nil, err + return err } - d.invites, err = NewSqliteInvitesTable(d.db) + stateSnapshot, err := NewSqliteStateSnapshotTable(db) if err != nil { - return nil, err + return err } - d.membership, err = NewSqliteMembershipTable(d.db) + prevEvents, err := NewSqlitePrevEventsTable(db) if err != nil { - return nil, err + return err } - published, err := NewSqlitePublishedTable(d.db) + roomAliases, err := NewSqliteRoomAliasesTable(db) if err != nil { - return nil, err + return err } - redactions, err := NewSqliteRedactionsTable(d.db) + invites, err := NewSqliteInvitesTable(db) if err != nil { - return nil, err + return err + } + membership, err := NewSqliteMembershipTable(db) + if err != nil { + return err + } + published, err := NewSqlitePublishedTable(db) + if err != nil { + return err + } + redactions, err := NewSqliteRedactionsTable(db) + if err != nil { + return err } d.Database = shared.Database{ - DB: d.db, + DB: db, Cache: cache, - Writer: d.writer, - EventsTable: d.events, - EventTypesTable: d.eventTypes, - EventStateKeysTable: d.eventStateKeys, - EventJSONTable: d.eventJSON, - RoomsTable: d.rooms, - TransactionsTable: d.transactions, + Writer: sqlutil.NewExclusiveWriter(), + EventsTable: events, + EventTypesTable: eventTypes, + EventStateKeysTable: eventStateKeys, + EventJSONTable: eventJSON, + RoomsTable: rooms, + TransactionsTable: transactions, StateBlockTable: stateBlock, StateSnapshotTable: stateSnapshot, - PrevEventsTable: d.prevEvents, + PrevEventsTable: prevEvents, RoomAliasesTable: roomAliases, - InvitesTable: d.invites, - MembershipTable: d.membership, + InvitesTable: invites, + MembershipTable: membership, PublishedTable: published, RedactionsTable: redactions, GetLatestEventsForUpdateFn: d.GetLatestEventsForUpdate, } - return &d, nil + return nil } func (d *Database) SupportsConcurrentRoomInputs() bool { diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index eba878ba5..d73445846 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -123,15 +123,16 @@ const ( type Membership interface { InsertMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool) error SelectMembershipForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (MembershipState, error) - SelectMembershipFromRoomAndTarget(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, error) + SelectMembershipFromRoomAndTarget(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, bool, error) SelectMembershipsFromRoom(ctx context.Context, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error) SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) - UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID) error + UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error SelectRoomsWithMembership(ctx context.Context, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error) // SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the // counts of how many rooms they are joined. SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) + UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error } type Published interface { diff --git a/syncapi/consumers/eduserver_receipts.go b/syncapi/consumers/eduserver_receipts.go new file mode 100644 index 000000000..c5d17414a --- /dev/null +++ b/syncapi/consumers/eduserver_receipts.go @@ -0,0 +1,94 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package consumers + +import ( + "context" + "encoding/json" + + "github.com/matrix-org/dendrite/syncapi/types" + + "github.com/Shopify/sarama" + "github.com/matrix-org/dendrite/eduserver/api" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/sync" + log "github.com/sirupsen/logrus" +) + +// OutputReceiptEventConsumer consumes events that originated in the EDU server. +type OutputReceiptEventConsumer struct { + receiptConsumer *internal.ContinualConsumer + db storage.Database + notifier *sync.Notifier +} + +// NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer. +// Call Start() to begin consuming from the EDU server. +func NewOutputReceiptEventConsumer( + cfg *config.SyncAPI, + kafkaConsumer sarama.Consumer, + n *sync.Notifier, + store storage.Database, +) *OutputReceiptEventConsumer { + + consumer := internal.ContinualConsumer{ + ComponentName: "syncapi/eduserver/receipt", + Topic: cfg.Matrix.Kafka.TopicFor(config.TopicOutputReceiptEvent), + Consumer: kafkaConsumer, + PartitionStore: store, + } + + s := &OutputReceiptEventConsumer{ + receiptConsumer: &consumer, + db: store, + notifier: n, + } + + consumer.ProcessMessage = s.onMessage + + return s +} + +// Start consuming from EDU api +func (s *OutputReceiptEventConsumer) Start() error { + return s.receiptConsumer.Start() +} + +func (s *OutputReceiptEventConsumer) onMessage(msg *sarama.ConsumerMessage) error { + var output api.OutputReceiptEvent + if err := json.Unmarshal(msg.Value, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("EDU server output log: message parse failure") + return nil + } + + streamPos, err := s.db.StoreReceipt( + context.TODO(), + output.RoomID, + output.Type, + output.UserID, + output.EventID, + output.Timestamp, + ) + if err != nil { + return err + } + // update stream position + s.notifier.OnNewReceipt(types.NewStreamToken(0, streamPos, nil)) + + return nil +} diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index e5299f200..2f79ed5cc 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -59,6 +59,7 @@ const defaultMessagesLimit = 10 // OnIncomingMessagesRequest implements the /messages endpoint from the // client-server API. // See: https://matrix.org/docs/spec/client_server/latest.html#get-matrix-client-r0-rooms-roomid-messages +// nolint:gocyclo func OnIncomingMessagesRequest( req *http.Request, db storage.Database, roomID string, device *userapi.Device, federation *gomatrixserverlib.FederationClient, @@ -67,6 +68,19 @@ func OnIncomingMessagesRequest( ) util.JSONResponse { var err error + // check if the user has already forgotten about this room + isForgotten, err := checkIsRoomForgotten(req.Context(), roomID, device.UserID, rsAPI) + if err != nil { + return jsonerror.InternalServerError() + } + + if isForgotten { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("user already forgot about this room"), + } + } + // Extract parameters from the request's URL. // Pagination tokens. var fromStream *types.StreamingToken @@ -182,6 +196,19 @@ func OnIncomingMessagesRequest( } } +func checkIsRoomForgotten(ctx context.Context, roomID, userID string, rsAPI api.RoomserverInternalAPI) (bool, error) { + req := api.QueryMembershipForUserRequest{ + RoomID: roomID, + UserID: userID, + } + resp := api.QueryMembershipForUserResponse{} + if err := rsAPI.QueryMembershipForUser(ctx, &req, &resp); err != nil { + return false, err + } + + return resp.IsRoomForgotten, nil +} + // retrieveEvents retrieves events from the local database for a request on // /messages. If there's not enough events to retrieve, it asks another // homeserver in the room for older events. diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index e12a1166e..727cc0484 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -18,6 +18,8 @@ import ( "context" "time" + eduAPI "github.com/matrix-org/dendrite/eduserver/api" + "github.com/matrix-org/dendrite/eduserver/cache" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/api" @@ -147,4 +149,8 @@ type Database interface { PutFilter(ctx context.Context, localpart string, filter *gomatrixserverlib.Filter) (string, error) // RedactEvent wipes an event in the database and sets the unsigned.redacted_because key to the redaction event RedactEvent(ctx context.Context, redactedEventID string, redactedBecause *gomatrixserverlib.HeaderedEvent) error + // StoreReceipt stores new receipt events + StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) + // GetRoomReceipts gets all receipts for a given roomID + GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error) } diff --git a/syncapi/storage/postgres/receipt_table.go b/syncapi/storage/postgres/receipt_table.go new file mode 100644 index 000000000..c5ec6cbc6 --- /dev/null +++ b/syncapi/storage/postgres/receipt_table.go @@ -0,0 +1,106 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "fmt" + + "github.com/lib/pq" + + "github.com/matrix-org/dendrite/eduserver/api" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +const receiptsSchema = ` +CREATE SEQUENCE IF NOT EXISTS syncapi_stream_id; +-- Stores data about receipts +CREATE TABLE IF NOT EXISTS syncapi_receipts ( + -- The ID + id BIGINT PRIMARY KEY DEFAULT nextval('syncapi_stream_id'), + room_id TEXT NOT NULL, + receipt_type TEXT NOT NULL, + user_id TEXT NOT NULL, + event_id TEXT NOT NULL, + receipt_ts BIGINT NOT NULL, + CONSTRAINT syncapi_receipts_unique UNIQUE (room_id, receipt_type, user_id) +); +CREATE INDEX IF NOT EXISTS syncapi_receipts_room_id ON syncapi_receipts(room_id); +` + +const upsertReceipt = "" + + "INSERT INTO syncapi_receipts" + + " (room_id, receipt_type, user_id, event_id, receipt_ts)" + + " VALUES ($1, $2, $3, $4, $5)" + + " ON CONFLICT (room_id, receipt_type, user_id)" + + " DO UPDATE SET id = nextval('syncapi_stream_id'), event_id = $4, receipt_ts = $5" + + " RETURNING id" + +const selectRoomReceipts = "" + + "SELECT room_id, receipt_type, user_id, event_id, receipt_ts" + + " FROM syncapi_receipts" + + " WHERE room_id = ANY($1) AND id > $2" + +type receiptStatements struct { + db *sql.DB + upsertReceipt *sql.Stmt + selectRoomReceipts *sql.Stmt +} + +func NewPostgresReceiptsTable(db *sql.DB) (tables.Receipts, error) { + _, err := db.Exec(receiptsSchema) + if err != nil { + return nil, err + } + r := &receiptStatements{ + db: db, + } + if r.upsertReceipt, err = db.Prepare(upsertReceipt); err != nil { + return nil, fmt.Errorf("unable to prepare upsertReceipt statement: %w", err) + } + if r.selectRoomReceipts, err = db.Prepare(selectRoomReceipts); err != nil { + return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err) + } + return r, nil +} + +func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) { + stmt := sqlutil.TxStmt(txn, r.upsertReceipt) + err = stmt.QueryRowContext(ctx, roomId, receiptType, userId, eventId, timestamp).Scan(&pos) + return +} + +func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]api.OutputReceiptEvent, error) { + rows, err := r.selectRoomReceipts.QueryContext(ctx, pq.Array(roomIDs), streamPos) + if err != nil { + return nil, fmt.Errorf("unable to query room receipts: %w", err) + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomReceiptsAfter: rows.close() failed") + var res []api.OutputReceiptEvent + for rows.Next() { + r := api.OutputReceiptEvent{} + err = rows.Scan(&r.RoomID, &r.Type, &r.UserID, &r.EventID, &r.Timestamp) + if err != nil { + return res, fmt.Errorf("unable to scan row to api.Receipts: %w", err) + } + res = append(res, r) + } + return res, rows.Err() +} diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go index 7f19722ae..979e19a0b 100644 --- a/syncapi/storage/postgres/syncserver.go +++ b/syncapi/storage/postgres/syncserver.go @@ -82,6 +82,10 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e if err != nil { return nil, err } + receipts, err := NewPostgresReceiptsTable(d.db) + if err != nil { + return nil, err + } d.Database = shared.Database{ DB: d.db, Writer: d.writer, @@ -94,6 +98,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e BackwardExtremities: backwardExtremities, Filter: filter, SendToDevice: sendToDevice, + Receipts: receipts, EDUCache: cache.New(), } return &d, nil diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index a7c07f943..2b82ee33c 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -21,6 +21,7 @@ import ( "fmt" "time" + eduAPI "github.com/matrix-org/dendrite/eduserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/eduserver/cache" @@ -47,6 +48,7 @@ type Database struct { BackwardExtremities tables.BackwardsExtremities SendToDevice tables.SendToDevice Filter tables.Filter + Receipts tables.Receipts EDUCache *cache.EDUCache } @@ -527,10 +529,10 @@ func (d *Database) addTypingDeltaToResponse( joinedRoomIDs []string, res *types.Response, ) error { - var jr types.JoinResponse var ok bool var err error for _, roomID := range joinedRoomIDs { + var jr types.JoinResponse if typingUsers, updated := d.EDUCache.GetTypingUsersIfUpdatedAfter( roomID, int64(since.EDUPosition()), ); updated { @@ -554,21 +556,84 @@ func (d *Database) addTypingDeltaToResponse( return nil } +// addReceiptDeltaToResponse adds all receipt information to a sync response +// since the specified position +func (d *Database) addReceiptDeltaToResponse( + since types.StreamingToken, + joinedRoomIDs []string, + res *types.Response, +) error { + receipts, err := d.Receipts.SelectRoomReceiptsAfter(context.TODO(), joinedRoomIDs, since.EDUPosition()) + if err != nil { + return fmt.Errorf("unable to select receipts for rooms: %w", err) + } + + // Group receipts by room, so we can create one ClientEvent for every room + receiptsByRoom := make(map[string][]eduAPI.OutputReceiptEvent) + for _, receipt := range receipts { + receiptsByRoom[receipt.RoomID] = append(receiptsByRoom[receipt.RoomID], receipt) + } + + for roomID, receipts := range receiptsByRoom { + var jr types.JoinResponse + var ok bool + + // Make sure we use an existing JoinResponse if there is one. + // If not, we'll create a new one + if jr, ok = res.Rooms.Join[roomID]; !ok { + jr = types.JoinResponse{} + } + + ev := gomatrixserverlib.ClientEvent{ + Type: gomatrixserverlib.MReceipt, + RoomID: roomID, + } + content := make(map[string]eduAPI.ReceiptMRead) + for _, receipt := range receipts { + var read eduAPI.ReceiptMRead + if read, ok = content[receipt.EventID]; !ok { + read = eduAPI.ReceiptMRead{ + User: make(map[string]eduAPI.ReceiptTS), + } + } + read.User[receipt.UserID] = eduAPI.ReceiptTS{TS: receipt.Timestamp} + content[receipt.EventID] = read + } + ev.Content, err = json.Marshal(content) + if err != nil { + return err + } + + jr.Ephemeral.Events = append(jr.Ephemeral.Events, ev) + res.Rooms.Join[roomID] = jr + } + + return nil +} + // addEDUDeltaToResponse adds updates for EDUs of each type since fromPos if // the positions of that type are not equal in fromPos and toPos. func (d *Database) addEDUDeltaToResponse( fromPos, toPos types.StreamingToken, joinedRoomIDs []string, res *types.Response, -) (err error) { - +) error { if fromPos.EDUPosition() != toPos.EDUPosition() { - err = d.addTypingDeltaToResponse( - fromPos, joinedRoomIDs, res, - ) + // add typing deltas + if err := d.addTypingDeltaToResponse(fromPos, joinedRoomIDs, res); err != nil { + return fmt.Errorf("unable to apply typing delta to response: %w", err) + } } - return + // Check on initial sync and if EDUPositions differ + if (fromPos.EDUPosition() == 0 && toPos.EDUPosition() == 0) || + fromPos.EDUPosition() != toPos.EDUPosition() { + if err := d.addReceiptDeltaToResponse(fromPos, joinedRoomIDs, res); err != nil { + return fmt.Errorf("unable to apply receipts to response: %w", err) + } + } + + return nil } func (d *Database) GetFilter( @@ -1404,3 +1469,16 @@ type stateDelta struct { // Can be 0 if there is no membership event in this delta. membershipPos types.StreamPosition } + +// StoreReceipt stores user receipts +func (d *Database) StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + pos, err = d.Receipts.UpsertReceipt(ctx, txn, roomId, receiptType, userId, eventId, timestamp) + return err + }) + return +} + +func (d *Database) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error) { + return d.Receipts.SelectRoomReceiptsAfter(ctx, roomIDs, streamPos) +} diff --git a/syncapi/storage/sqlite3/receipt_table.go b/syncapi/storage/sqlite3/receipt_table.go new file mode 100644 index 000000000..b1770e801 --- /dev/null +++ b/syncapi/storage/sqlite3/receipt_table.go @@ -0,0 +1,118 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/matrix-org/dendrite/eduserver/api" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +const receiptsSchema = ` +-- Stores data about receipts +CREATE TABLE IF NOT EXISTS syncapi_receipts ( + -- The ID + id BIGINT, + room_id TEXT NOT NULL, + receipt_type TEXT NOT NULL, + user_id TEXT NOT NULL, + event_id TEXT NOT NULL, + receipt_ts BIGINT NOT NULL, + CONSTRAINT syncapi_receipts_unique UNIQUE (room_id, receipt_type, user_id) +); +CREATE INDEX IF NOT EXISTS syncapi_receipts_room_id_idx ON syncapi_receipts(room_id); +` + +const upsertReceipt = "" + + "INSERT INTO syncapi_receipts" + + " (id, room_id, receipt_type, user_id, event_id, receipt_ts)" + + " VALUES ($1, $2, $3, $4, $5, $6)" + + " ON CONFLICT (room_id, receipt_type, user_id)" + + " DO UPDATE SET id = $7, event_id = $8, receipt_ts = $9" + +const selectRoomReceipts = "" + + "SELECT room_id, receipt_type, user_id, event_id, receipt_ts" + + " FROM syncapi_receipts" + + " WHERE id > $1 and room_id in ($2)" + +type receiptStatements struct { + db *sql.DB + streamIDStatements *streamIDStatements + upsertReceipt *sql.Stmt + selectRoomReceipts *sql.Stmt +} + +func NewSqliteReceiptsTable(db *sql.DB, streamID *streamIDStatements) (tables.Receipts, error) { + _, err := db.Exec(receiptsSchema) + if err != nil { + return nil, err + } + r := &receiptStatements{ + db: db, + streamIDStatements: streamID, + } + if r.upsertReceipt, err = db.Prepare(upsertReceipt); err != nil { + return nil, fmt.Errorf("unable to prepare upsertReceipt statement: %w", err) + } + if r.selectRoomReceipts, err = db.Prepare(selectRoomReceipts); err != nil { + return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err) + } + return r, nil +} + +// UpsertReceipt creates new user receipts +func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) { + pos, err = r.streamIDStatements.nextStreamID(ctx, txn) + if err != nil { + return + } + stmt := sqlutil.TxStmt(txn, r.upsertReceipt) + _, err = stmt.ExecContext(ctx, pos, roomId, receiptType, userId, eventId, timestamp, pos, eventId, timestamp) + return +} + +// SelectRoomReceiptsAfter select all receipts for a given room after a specific timestamp +func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]api.OutputReceiptEvent, error) { + selectSQL := strings.Replace(selectRoomReceipts, "($2)", sqlutil.QueryVariadicOffset(len(roomIDs), 1), 1) + + params := make([]interface{}, len(roomIDs)+1) + params[0] = streamPos + for k, v := range roomIDs { + params[k+1] = v + } + rows, err := r.db.QueryContext(ctx, selectSQL, params...) + if err != nil { + return nil, fmt.Errorf("unable to query room receipts: %w", err) + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomReceiptsAfter: rows.close() failed") + var res []api.OutputReceiptEvent + for rows.Next() { + r := api.OutputReceiptEvent{} + err = rows.Scan(&r.RoomID, &r.Type, &r.UserID, &r.EventID, &r.Timestamp) + if err != nil { + return res, fmt.Errorf("unable to scan row to api.Receipts: %w", err) + } + res = append(res, r) + } + return res, rows.Err() +} diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index 86d83ec98..036e2b2e5 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -95,6 +95,10 @@ func (d *SyncServerDatasource) prepare() (err error) { if err != nil { return err } + receipts, err := NewSqliteReceiptsTable(d.db, &d.streamID) + if err != nil { + return err + } d.Database = shared.Database{ DB: d.db, Writer: d.writer, @@ -107,6 +111,7 @@ func (d *SyncServerDatasource) prepare() (err error) { Topology: topology, Filter: filter, SendToDevice: sendToDevice, + Receipts: receipts, EDUCache: cache.New(), } return nil diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index da095be53..f8e7a224a 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" + eduAPI "github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -156,3 +157,8 @@ type Filter interface { SelectFilter(ctx context.Context, localpart string, filterID string) (*gomatrixserverlib.Filter, error) InsertFilter(ctx context.Context, filter *gomatrixserverlib.Filter, localpart string) (filterID string, err error) } + +type Receipts interface { + UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) + SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error) +} diff --git a/syncapi/sync/notifier.go b/syncapi/sync/notifier.go index fcac3f16c..daa3a1d8c 100644 --- a/syncapi/sync/notifier.go +++ b/syncapi/sync/notifier.go @@ -149,6 +149,16 @@ func (n *Notifier) OnNewSendToDevice( n.wakeupUserDevice(userID, deviceIDs, latestPos) } +// OnNewReceipt updates the current position +func (n *Notifier) OnNewReceipt( + posUpdate types.StreamingToken, +) { + n.streamLock.Lock() + defer n.streamLock.Unlock() + latestPos := n.currPos.WithUpdates(posUpdate) + n.currPos = latestPos +} + func (n *Notifier) OnNewKeyChange( posUpdate types.StreamingToken, wakeUserID, keyChangeUserID string, ) { diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index de0bb434b..393a7aa55 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -99,5 +99,12 @@ func AddPublicRoutes( logrus.WithError(err).Panicf("failed to start send-to-device consumer") } + receiptConsumer := consumers.NewOutputReceiptEventConsumer( + cfg, consumer, notifier, syncDB, + ) + if err = receiptConsumer.Start(); err != nil { + logrus.WithError(err).Panicf("failed to start receipts consumer") + } + routing.Setup(router, requestPool, syncDB, userAPI, federation, rsAPI, cfg) } diff --git a/sytest-whitelist b/sytest-whitelist index 1a12b591b..49011b5a3 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -483,6 +483,16 @@ POST rejects invalid utf-8 in JSON Users cannot kick users who have already left a room Event with an invalid signature in the send_join response should not cause room join to fail Inbound federation rejects typing notifications from wrong remote +POST /rooms/:room_id/receipt can create receipts +Receipts must be m.read +Read receipts appear in initial v2 /sync +New read receipts appear in incremental v2 /sync +Outbound federation sends receipts +Inbound federation rejects receipts from wrong remote Should not be able to take over the room by pretending there is no PL event Can get rooms/{roomId}/state for a departed room (SPEC-216) Users cannot set notifications powerlevel higher than their own +Forgotten room messages cannot be paginated +Forgetting room does not show up in v2 /sync +Can forget room you've been kicked from +Can re-join room if re-invited