diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index c6e354611..86440b2d9 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -57,99 +57,127 @@ type Inputer struct { ACLs *acls.ServerACLs InputRoomEventTopic string OutputRoomEventTopic string - workers sync.Map // room ID -> *phony.Inbox + workers sync.Map // room ID -> *worker Queryer *query.Queryer } -func (r *Inputer) workerForRoom(roomID string) *phony.Inbox { - inbox, _ := r.workers.LoadOrStore(roomID, &phony.Inbox{}) - return inbox.(*phony.Inbox) +type worker struct { + phony.Inbox + r *Inputer + roomID string + subscription *nats.Subscription } -// eventsInProgress is an in-memory map to keep a track of which events we have -// queued up for processing. If we get a redelivery from NATS and we still have -// the queued up item then we won't do anything with the redelivered message. If -// we've restarted Dendrite and now this map is empty then it means that we will -// reload pending work from NATS. -var eventsInProgress sync.Map +func (r *Inputer) workerForRoom(roomID string) *worker { + v, loaded := r.workers.LoadOrStore(roomID, &worker{ + r: r, + roomID: roomID, + }) + w := v.(*worker) + if !loaded { + sub, err := w.r.JetStream.PullSubscribe( + jetstream.InputRoomEventSubj(w.roomID), + "DendriteRoomInputConsumerPull"+jetstream.Tokenise(w.roomID), + nats.ManualAck(), + nats.DeliverAll(), + nats.MaxAckPending(-1), + nats.AckWait(MaximumMissingProcessingTime+(time.Second*10)), + nats.BindStream(r.InputRoomEventTopic), + ) + if err != nil { + logrus.WithError(err).Errorf("Failed to subscribe to stream for room %q", w.roomID) + return nil + } + logrus.Infof("Subscribed to stream for room %q", w.roomID) + w.subscription = sub + w.Act(nil, w.next) + } + return w +} // onMessage is called when a new event arrives in the roomserver input stream. func (r *Inputer) Start() error { _, err := r.JetStream.Subscribe( - r.InputRoomEventTopic, - // We specifically don't use jetstream.WithJetStreamMessage here because we - // queue the task off to a room-specific queue and the ACK needs to be sent - // later, possibly with an error response to the inputter if synchronous. - func(msg *nats.Msg) { - roomID := msg.Header.Get("room_id") - var inputRoomEvent api.InputRoomEvent - if err := json.Unmarshal(msg.Data, &inputRoomEvent); err != nil { - _ = msg.Term() - return - } - - _ = msg.InProgress() - index := roomID + "\000" + inputRoomEvent.Event.EventID() - if _, ok := eventsInProgress.LoadOrStore(index, struct{}{}); ok { - // We're already waiting to deal with this event, so there's no - // point in queuing it up again. We've notified NATS that we're - // working on the message still, so that will have deferred the - // redelivery by a bit. - return - } - - roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Inc() - r.workerForRoom(roomID).Act(nil, func() { - _ = msg.InProgress() // resets the acknowledgement wait timer - defer eventsInProgress.Delete(index) - defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec() - var errString string - if err := r.processRoomEvent(r.ProcessContext.Context(), &inputRoomEvent); err != nil { - if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { - sentry.CaptureException(err) - } - logrus.WithError(err).WithFields(logrus.Fields{ - "room_id": roomID, - "event_id": inputRoomEvent.Event.EventID(), - "type": inputRoomEvent.Event.Type(), - }).Warn("Roomserver failed to process async event") - _ = msg.Term() - errString = err.Error() - } else { - _ = msg.Ack() - } - if replyTo := msg.Header.Get("sync"); replyTo != "" { - if err := r.NATSClient.Publish(replyTo, []byte(errString)); err != nil { - logrus.WithError(err).WithFields(logrus.Fields{ - "room_id": roomID, - "event_id": inputRoomEvent.Event.EventID(), - "type": inputRoomEvent.Event.Type(), - }).Warn("Roomserver failed to respond for sync event") - } - } - }) + "", // don't supply a subject because we're using BindStream + func(m *nats.Msg) { + roomID := m.Header.Get(jetstream.RoomID) + r.workerForRoom(roomID) // start the room worker }, - // NATS wants to acknowledge automatically by default when the message is - // read from the stream, but we want to override that behaviour by making - // sure that we only acknowledge when we're happy we've done everything we - // can. This ensures we retry things when it makes sense to do so. - nats.ManualAck(), - // Use a durable named consumer. - r.Durable, - // If we've missed things in the stream, e.g. we restarted, then replay - // all of the queued messages that were waiting for us. nats.DeliverAll(), - // Ensure that NATS doesn't try to resend us something that wasn't done - // within the period of time that we might still be processing it. - nats.AckWait(MaximumMissingProcessingTime+(time.Second*10)), - // It is recommended to disable this for pull consumers as per the docs: - // https://docs.nats.io/nats-concepts/jetstream/consumers#note-about-push-and-pull-consumers - nats.MaxAckPending(-1), + nats.AckNone(), + nats.BindStream(r.InputRoomEventTopic), ) return err } +func (w *worker) next() { + ctx, cancel := context.WithTimeout(w.r.ProcessContext.Context(), time.Second*30) + defer cancel() + msgs, err := w.subscription.Fetch(1, nats.Context(ctx)) + switch err { + case nil: + if len(msgs) != 1 { + return + } + case context.DeadlineExceeded: + logrus.Infof("Stream for room %q idle, shutting down", w.roomID) + if err = w.subscription.Unsubscribe(); err != nil { + logrus.WithError(err).Errorf("Failed to unsubscribe to stream for room %q", w.roomID) + } + w.r.workers.Delete(w.roomID) + return + case nats.ErrTimeout: + w.Act(nil, w.next) + return + default: + logrus.WithError(err).Errorf("Failed to get next stream message for room %q", w.roomID) + if err = w.subscription.Unsubscribe(); err != nil { + logrus.WithError(err).Errorf("Failed to unsubscribe to stream for room %q", w.roomID) + } + w.r.workers.Delete(w.roomID) + return + } + + msg := msgs[0] + + var inputRoomEvent api.InputRoomEvent + if err = json.Unmarshal(msg.Data, &inputRoomEvent); err != nil { + _ = msg.Term() + return + } + + roomserverInputBackpressure.With(prometheus.Labels{"room_id": w.roomID}).Inc() + defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": w.roomID}).Dec() + + var errString string + if err = w.r.processRoomEvent(w.r.ProcessContext.Context(), &inputRoomEvent); err != nil { + if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { + sentry.CaptureException(err) + } + logrus.WithError(err).WithFields(logrus.Fields{ + "room_id": w.roomID, + "event_id": inputRoomEvent.Event.EventID(), + "type": inputRoomEvent.Event.Type(), + }).Warn("Roomserver failed to process async event") + _ = msg.Term() + errString = err.Error() + } else { + _ = msg.Ack() + } + if replyTo := msg.Header.Get("sync"); replyTo != "" { + if err = w.r.NATSClient.Publish(replyTo, []byte(errString)); err != nil { + logrus.WithError(err).WithFields(logrus.Fields{ + "room_id": w.roomID, + "event_id": inputRoomEvent.Event.EventID(), + "type": inputRoomEvent.Event.Type(), + }).Warn("Roomserver failed to respond for sync event") + } + } + + w.Act(nil, w.next) +} + // InputRoomEvents implements api.RoomserverInternalAPI func (r *Inputer) InputRoomEvents( ctx context.Context, @@ -170,12 +198,13 @@ func (r *Inputer) InputRoomEvents( var err error for _, e := range request.InputRoomEvents { + roomID := e.Event.RoomID() + subj := jetstream.InputRoomEventSubj(roomID) msg := &nats.Msg{ - Subject: r.InputRoomEventTopic, + Subject: subj, Header: nats.Header{}, Reply: replyTo, } - roomID := e.Event.RoomID() msg.Header.Set("room_id", roomID) if replyTo != "" { msg.Header.Set("sync", replyTo) @@ -189,6 +218,7 @@ func (r *Inputer) InputRoomEvents( logrus.WithError(err).WithFields(logrus.Fields{ "room_id": roomID, "event_id": e.Event.EventID(), + "subj": subj, }).Error("Roomserver failed to queue async event") return } diff --git a/setup/jetstream/nats.go b/setup/jetstream/nats.go index 43cc0331d..f42055ec7 100644 --- a/setup/jetstream/nats.go +++ b/setup/jetstream/nats.go @@ -25,8 +25,8 @@ func Prepare(process *process.ProcessContext, cfg *config.JetStream) (natsclient if natsServer == nil { var err error natsServer, err = natsserver.NewServer(&natsserver.Options{ - ServerName: "monolith", - DontListen: true, + ServerName: "monolith", + //DontListen: true, JetStream: true, StoreDir: string(cfg.StoragePath), NoSystemAccount: true, @@ -81,7 +81,9 @@ func setupNATS(cfg *config.JetStream, nc *natsclient.Conn) (natsclient.JetStream logrus.WithError(err).Fatal("Unable to get stream info") } if info == nil { - stream.Subjects = []string{name} + if len(stream.Subjects) == 0 { + stream.Subjects = []string{name} + } // If we're trying to keep everything in memory (e.g. unit tests) // then overwrite the storage policy. @@ -95,6 +97,8 @@ func setupNATS(cfg *config.JetStream, nc *natsclient.Conn) (natsclient.JetStream namespaced.Name = name if _, err = s.AddStream(&namespaced); err != nil { logrus.WithError(err).WithField("stream", name).Fatal("Unable to add stream") + } else { + logrus.WithField("stream", name).Infof("Added stream") } } } diff --git a/setup/jetstream/streams.go b/setup/jetstream/streams.go index aa3e95cb8..fa1ac2127 100644 --- a/setup/jetstream/streams.go +++ b/setup/jetstream/streams.go @@ -1,6 +1,8 @@ package jetstream import ( + "fmt" + "regexp" "time" "github.com/nats-io/nats.go" @@ -24,11 +26,22 @@ var ( OutputReadUpdate = "OutputReadUpdate" ) +var safeCharacters = regexp.MustCompile("[^A-Za-z0-9$]+") + +func Tokenise(str string) string { + return safeCharacters.ReplaceAllString(str, "_") +} + +func InputRoomEventSubj(roomID string) string { + return fmt.Sprintf("%s.%s", InputRoomEvent, Tokenise(roomID)) +} + var streams = []*nats.StreamConfig{ { Name: InputRoomEvent, - Retention: nats.WorkQueuePolicy, + Retention: nats.InterestPolicy, Storage: nats.FileStorage, + Subjects: []string{InputRoomEvent + ".>"}, // room ID }, { Name: OutputRoomEvent,