diff --git a/appservice/appservice.go b/appservice/appservice.go index 4ff42360b..0967797e5 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -58,7 +58,7 @@ func NewInternalAPI( }, }, } - consumer, _ := jetstream.SetupConsumerProducer(&base.Cfg.Global.JetStream) + _, consumer, _ := jetstream.SetupConsumerProducer(&base.Cfg.Global.JetStream) // Create a connection to the appservice postgres DB appserviceDB, err := storage.NewDatabase(&base.Cfg.AppServiceAPI.Database) diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index ffab1337d..01d5bd90b 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -49,7 +49,7 @@ func AddPublicRoutes( extRoomsProvider api.ExtraPublicRoomsProvider, mscCfg *config.MSCs, ) { - _, producer := jetstream.SetupConsumerProducer(&cfg.Matrix.JetStream) + _, _, producer := jetstream.SetupConsumerProducer(&cfg.Matrix.JetStream) syncProducer := &producers.SyncAPIProducer{ Producer: producer, diff --git a/eduserver/eduserver.go b/eduserver/eduserver.go index 9c3f7ddf6..e57e8bd7a 100644 --- a/eduserver/eduserver.go +++ b/eduserver/eduserver.go @@ -42,7 +42,7 @@ func NewInternalAPI( ) api.EDUServerInputAPI { cfg := &base.Cfg.EDUServer - _, producer := jetstream.SetupConsumerProducer(&cfg.Matrix.JetStream) + _, _, producer := jetstream.SetupConsumerProducer(&cfg.Matrix.JetStream) return &input.EDUServerInputAPI{ Cache: eduCache, diff --git a/federationsender/federationsender.go b/federationsender/federationsender.go index ee89d7822..dbc42346a 100644 --- a/federationsender/federationsender.go +++ b/federationsender/federationsender.go @@ -61,7 +61,7 @@ func NewInternalAPI( FailuresUntilBlacklist: cfg.FederationMaxRetries, } - consumer, _ := jetstream.SetupConsumerProducer(&cfg.Matrix.JetStream) + _, consumer, _ := jetstream.SetupConsumerProducer(&cfg.Matrix.JetStream) queues := queue.NewOutgoingQueues( federationSenderDB, base.ProcessContext, diff --git a/go.mod b/go.mod index 37cedbfba..8fc91cbef 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ replace github.com/nats-io/nats.go => github.com/neilalexander/nats.go v1.11.1-0 require ( github.com/Arceliar/ironwood v0.0.0-20210619124114-6ad55cae5031 + github.com/Arceliar/phony v0.0.0-20210209235338-dde1a8dca979 github.com/DATA-DOG/go-sqlmock v1.5.0 github.com/HdrHistogram/hdrhistogram-go v1.0.1 // indirect github.com/Masterminds/semver/v3 v3.1.1 diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index dc7f79019..1a9ff3f16 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -40,7 +40,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) { func NewInternalAPI( base *setup.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.FederationClient, ) api.KeyInternalAPI { - consumer, producer := jetstream.SetupConsumerProducer(&cfg.Matrix.JetStream) + _, consumer, producer := jetstream.SetupConsumerProducer(&cfg.Matrix.JetStream) db, err := storage.NewDatabase(&cfg.Database) if err != nil { diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index f39b26eaf..63b783328 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -16,6 +16,8 @@ import ( "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" ) // RoomserverInternalAPI is an implementation of api.RoomserverInternalAPI @@ -39,13 +41,15 @@ type RoomserverInternalAPI struct { KeyRing gomatrixserverlib.JSONVerifier fsAPI fsAPI.FederationSenderInternalAPI asAPI asAPI.AppServiceQueryAPI - OutputRoomEventTopic string // Kafka topic for new output room events + InputRoomEventTopic string // JetStream topic for new input room events + OutputRoomEventTopic string // JetStream topic for new output room events PerspectiveServerNames []gomatrixserverlib.ServerName } func NewRoomserverAPI( - cfg *config.RoomServer, roomserverDB storage.Database, producer sarama.SyncProducer, - outputRoomEventTopic string, caches caching.RoomServerCaches, + cfg *config.RoomServer, roomserverDB storage.Database, + consumer nats.JetStreamContext, producer sarama.SyncProducer, + inputRoomEventTopic, outputRoomEventTopic string, caches caching.RoomServerCaches, keyRing gomatrixserverlib.JSONVerifier, perspectiveServerNames []gomatrixserverlib.ServerName, ) *RoomserverInternalAPI { serverACLs := acls.NewServerACLs(roomserverDB) @@ -64,13 +68,18 @@ func NewRoomserverAPI( }, Inputer: &input.Inputer{ DB: roomserverDB, + InputRoomEventTopic: inputRoomEventTopic, OutputRoomEventTopic: outputRoomEventTopic, + Consumer: consumer, Producer: producer, ServerName: cfg.Matrix.ServerName, ACLs: serverACLs, }, // perform-er structs get initialised when we have a federation sender to use } + if err := a.Inputer.Start(); err != nil { + logrus.WithError(err).Panic("failed to start roomserver input API") + } return a } diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index de40e133d..80d6d2c54 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -19,8 +19,8 @@ import ( "context" "encoding/json" "sync" - "time" + "github.com/Arceliar/phony" "github.com/Shopify/sarama" "github.com/getsentry/sentry-go" "github.com/matrix-org/dendrite/internal/hooks" @@ -28,10 +28,10 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" - "go.uber.org/atomic" ) var keyContentFields = map[string]string{ @@ -42,48 +42,62 @@ var keyContentFields = map[string]string{ type Inputer struct { DB storage.Database + Consumer nats.JetStreamContext Producer sarama.SyncProducer ServerName gomatrixserverlib.ServerName ACLs *acls.ServerACLs + InputRoomEventTopic string OutputRoomEventTopic string - workers sync.Map // room ID -> *inputWorker + workers sync.Map // room ID -> *phony.Inbox } -type inputTask struct { - ctx context.Context - event *api.InputRoomEvent - wg *sync.WaitGroup - err error // written back by worker, only safe to read when all tasks are done -} - -type inputWorker struct { - r *Inputer - running atomic.Bool - input *fifoQueue -} - -// Guarded by a CAS on w.running -func (w *inputWorker) start() { - defer w.running.Store(false) - for { - select { - case <-w.input.wait(): - task, ok := w.input.pop() - if !ok { - continue +// onMessage is called when a new event arrives in the roomserver input stream. +func (r *Inputer) Start() error { + _, err := r.Consumer.Subscribe( + r.InputRoomEventTopic, + func(msg *nats.Msg) { + _ = msg.InProgress() + var inputRoomEvent api.InputRoomEvent + if err := json.Unmarshal(msg.Data, &inputRoomEvent); err != nil { + _ = msg.Nak() + return } - roomserverInputBackpressure.With(prometheus.Labels{ - "room_id": task.event.Event.RoomID(), - }).Dec() - hooks.Run(hooks.KindNewEventReceived, task.event.Event) - _, task.err = w.r.processRoomEvent(task.ctx, task.event) - if task.err == nil { - hooks.Run(hooks.KindNewEventPersisted, task.event.Event) - } else { - sentry.CaptureException(task.err) - } - task.wg.Done() - case <-time.After(time.Second * 5): + inbox, _ := r.workers.LoadOrStore(msg.Header.Get("room_id"), &phony.Inbox{}) + inbox.(*phony.Inbox).Act(nil, func() { + if _, err := r.processRoomEvent(context.TODO(), &inputRoomEvent); err != nil { + sentry.CaptureException(err) + _ = msg.Nak() + } else { + hooks.Run(hooks.KindNewEventPersisted, inputRoomEvent.Event) + _ = msg.Ack() + } + }) + }, + nats.ManualAck(), + ) + return err +} + +// InputRoomEvents implements api.RoomserverInternalAPI +func (r *Inputer) InputRoomEvents( + _ context.Context, + request *api.InputRoomEventsRequest, + response *api.InputRoomEventsResponse, +) { + var err error + for _, e := range request.InputRoomEvents { + msg := &nats.Msg{ + Subject: r.InputRoomEventTopic, + Header: nats.Header{}, + } + msg.Header.Set("room_id", e.Event.RoomID()) + msg.Data, err = json.Marshal(e) + if err != nil { + response.ErrMsg = err.Error() + return + } + if _, err = r.Consumer.PublishMsg(msg); err != nil { + response.ErrMsg = err.Error() return } } @@ -156,67 +170,3 @@ var roomserverInputBackpressure = prometheus.NewGaugeVec( }, []string{"room_id"}, ) - -// InputRoomEvents implements api.RoomserverInternalAPI -func (r *Inputer) InputRoomEvents( - _ context.Context, - request *api.InputRoomEventsRequest, - response *api.InputRoomEventsResponse, -) { - // Create a wait group. Each task that we dispatch will call Done on - // this wait group so that we know when all of our events have been - // processed. - wg := &sync.WaitGroup{} - wg.Add(len(request.InputRoomEvents)) - tasks := make([]*inputTask, len(request.InputRoomEvents)) - - for i, e := range request.InputRoomEvents { - // Work out if we are running per-room workers or if we're just doing - // it on a global basis (e.g. SQLite). - roomID := "global" - if r.DB.SupportsConcurrentRoomInputs() { - roomID = e.Event.RoomID() - } - - // Look up the worker, or create it if it doesn't exist. This channel - // is buffered to reduce the chance that we'll be blocked by another - // room - the channel will be quite small as it's just pointer types. - w, _ := r.workers.LoadOrStore(roomID, &inputWorker{ - r: r, - input: newFIFOQueue(), - }) - worker := w.(*inputWorker) - - // Create a task. This contains the input event and a reference to - // the wait group, so that the worker can notify us when this specific - // task has been finished. - tasks[i] = &inputTask{ - ctx: context.Background(), - event: &request.InputRoomEvents[i], - wg: wg, - } - - // Send the task to the worker. - if worker.running.CAS(false, true) { - go worker.start() - } - worker.input.push(tasks[i]) - roomserverInputBackpressure.With(prometheus.Labels{ - "room_id": roomID, - }).Inc() - } - - // Wait for all of the workers to return results about our tasks. - wg.Wait() - - // If any of the tasks returned an error, we should probably report - // that back to the caller. - for _, task := range tasks { - if task.err != nil { - response.ErrMsg = task.err.Error() - _, rejected := task.err.(*gomatrixserverlib.NotAllowed) - response.NotAllowed = rejected - return - } - } -} diff --git a/roomserver/internal/input/input_fifo.go b/roomserver/internal/input/input_fifo.go deleted file mode 100644 index 694b17245..000000000 --- a/roomserver/internal/input/input_fifo.go +++ /dev/null @@ -1,64 +0,0 @@ -package input - -import ( - "sync" -) - -type fifoQueue struct { - tasks []*inputTask - count int - mutex sync.Mutex - notifs chan struct{} -} - -func newFIFOQueue() *fifoQueue { - q := &fifoQueue{ - notifs: make(chan struct{}, 1), - } - return q -} - -func (q *fifoQueue) push(frame *inputTask) { - q.mutex.Lock() - defer q.mutex.Unlock() - q.tasks = append(q.tasks, frame) - q.count++ - select { - case q.notifs <- struct{}{}: - default: - } -} - -// pop returns the first item of the queue, if there is one. -// The second return value will indicate if a task was returned. -// You must check this value, even after calling wait(). -func (q *fifoQueue) pop() (*inputTask, bool) { - q.mutex.Lock() - defer q.mutex.Unlock() - if q.count == 0 { - return nil, false - } - frame := q.tasks[0] - q.tasks[0] = nil - q.tasks = q.tasks[1:] - q.count-- - if q.count == 0 { - // Force a GC of the underlying array, since it might have - // grown significantly if the queue was hammered for some reason - q.tasks = nil - } - return frame, true -} - -// wait returns a channel which can be used to detect when an -// item is waiting in the queue. -func (q *fifoQueue) wait() <-chan struct{} { - q.mutex.Lock() - defer q.mutex.Unlock() - if q.count > 0 && len(q.notifs) == 0 { - ch := make(chan struct{}) - close(ch) - return ch - } - return q.notifs -} diff --git a/roomserver/roomserver.go b/roomserver/roomserver.go index 192a056b8..fe55b1d3d 100644 --- a/roomserver/roomserver.go +++ b/roomserver/roomserver.go @@ -41,7 +41,7 @@ func NewInternalAPI( ) api.RoomserverInternalAPI { cfg := &base.Cfg.RoomServer - _, producer := jetstream.SetupConsumerProducer(&cfg.Matrix.JetStream) + js, _, producer := jetstream.SetupConsumerProducer(&cfg.Matrix.JetStream) var perspectiveServerNames []gomatrixserverlib.ServerName for _, kp := range base.Cfg.SigningKeyServer.KeyPerspectives { @@ -54,7 +54,9 @@ func NewInternalAPI( } return internal.NewRoomserverAPI( - cfg, roomserverDB, producer, string(cfg.Matrix.JetStream.TopicFor(jetstream.OutputRoomEvent)), + cfg, roomserverDB, js, producer, + cfg.Matrix.JetStream.TopicFor(jetstream.InputRoomEvent), + cfg.Matrix.JetStream.TopicFor(jetstream.OutputRoomEvent), base.Caches, keyRing, perspectiveServerNames, ) } diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 335d2a4ca..373702199 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -181,7 +181,9 @@ func mustCreateRoomserverAPI(t *testing.T) (api.RoomserverInternalAPI, *dummyPro logrus.WithError(err).Panicf("failed to connect to room server db") } return internal.NewRoomserverAPI( - &cfg.RoomServer, roomserverDB, dp, string(cfg.Global.JetStream.TopicFor(jetstream.OutputRoomEvent)), + &cfg.RoomServer, roomserverDB, dp, + cfg.Global.JetStream.TopicFor(jetstream.InputRoomEvent), + cfg.Global.JetStream.TopicFor(jetstream.OutputRoomEvent), base.Caches, &test.NopJSONVerifier{}, nil, ), dp } diff --git a/setup/jetstream/nats.go b/setup/jetstream/nats.go index f074f5f20..0d309629a 100644 --- a/setup/jetstream/nats.go +++ b/setup/jetstream/nats.go @@ -18,7 +18,7 @@ import ( var natsServer *natsserver.Server var natsServerMutex sync.Mutex -func SetupConsumerProducer(cfg *config.JetStream) (sarama.Consumer, sarama.SyncProducer) { +func SetupConsumerProducer(cfg *config.JetStream) (nats.JetStreamContext, sarama.Consumer, sarama.SyncProducer) { // check if we need an in-process NATS Server if len(cfg.Addresses) != 0 { return setupNATS(cfg, nil) @@ -51,20 +51,20 @@ func SetupConsumerProducer(cfg *config.JetStream) (sarama.Consumer, sarama.SyncP return setupNATS(cfg, nc) } -func setupNATS(cfg *config.JetStream, nc *natsclient.Conn) (sarama.Consumer, sarama.SyncProducer) { +func setupNATS(cfg *config.JetStream, nc *natsclient.Conn) (nats.JetStreamContext, sarama.Consumer, sarama.SyncProducer) { if nc == nil { var err error nc, err = nats.Connect(strings.Join(cfg.Addresses, ",")) if err != nil { logrus.WithError(err).Panic("Unable to connect to NATS") - return nil, nil + return nil, nil, nil } } s, err := nc.JetStream() if err != nil { logrus.WithError(err).Panic("Unable to get JetStream context") - return nil, nil + return nil, nil, nil } for _, stream := range streams { @@ -89,5 +89,5 @@ func setupNATS(cfg *config.JetStream, nc *natsclient.Conn) (sarama.Consumer, sar consumer := saramajs.NewJetStreamConsumer(nc, s, "") producer := saramajs.NewJetStreamProducer(nc, s, "") - return consumer, producer + return s, consumer, producer } diff --git a/setup/jetstream/streams.go b/setup/jetstream/streams.go index 326e62a93..b43776f25 100644 --- a/setup/jetstream/streams.go +++ b/setup/jetstream/streams.go @@ -7,6 +7,7 @@ import ( ) var ( + InputRoomEvent = "InputRoomEvent" OutputRoomEvent = "OutputRoomEvent" OutputSendToDeviceEvent = "OutputSendToDeviceEvent" OutputKeyChangeEvent = "OutputKeyChangeEvent" @@ -16,6 +17,11 @@ var ( ) var streams = []*nats.StreamConfig{ + { + Name: InputRoomEvent, + Retention: nats.InterestPolicy, + Storage: nats.FileStorage, + }, { Name: OutputRoomEvent, Retention: nats.InterestPolicy, diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index 16e222cb0..32fe033e6 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -48,7 +48,7 @@ func AddPublicRoutes( federation *gomatrixserverlib.FederationClient, cfg *config.SyncAPI, ) { - consumer, _ := jetstream.SetupConsumerProducer(&cfg.Matrix.JetStream) + _, consumer, _ := jetstream.SetupConsumerProducer(&cfg.Matrix.JetStream) syncDB, err := storage.NewSyncServerDatasource(&cfg.Database) if err != nil {