Allow batching in JetStreamConsumer (#2686)

This allows us to receive more than one message from NATS at a time if we want.
This commit is contained in:
Neil Alexander 2022-08-31 12:21:56 +01:00 committed by GitHub
parent ba0b3adab4
commit 175f65407a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 88 additions and 55 deletions

View file

@ -68,14 +68,15 @@ func NewOutputRoomEventConsumer(
// Start consuming from room servers
func (s *OutputRoomEventConsumer) Start() error {
return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
nats.DeliverAll(), nats.ManualAck(),
s.ctx, s.jetstream, s.topic, s.durable, 1,
s.onMessage, nats.DeliverAll(), nats.ManualAck(),
)
}
// onMessage is called when the appservice component receives a new event from
// the room server output log.
func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
// Parse out the event JSON
var output api.OutputEvent
if err := json.Unmarshal(msg.Data, &output); err != nil {

View file

@ -67,14 +67,15 @@ func NewKeyChangeConsumer(
// Start consuming from key servers
func (t *KeyChangeConsumer) Start() error {
return jetstream.JetStreamConsumer(
t.ctx, t.jetstream, t.topic, t.durable, t.onMessage,
nats.DeliverAll(), nats.ManualAck(),
t.ctx, t.jetstream, t.topic, t.durable, 1,
t.onMessage, nats.DeliverAll(), nats.ManualAck(),
)
}
// onMessage is called in response to a message received on the
// key change events topic from the key server.
func (t *KeyChangeConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
func (t *KeyChangeConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
var m api.DeviceMessage
if err := json.Unmarshal(msg.Data, &m); err != nil {
logrus.WithError(err).Errorf("failed to read device message from key change topic")

View file

@ -69,14 +69,15 @@ func (t *OutputPresenceConsumer) Start() error {
return nil
}
return jetstream.JetStreamConsumer(
t.ctx, t.jetstream, t.topic, t.durable, t.onMessage,
t.ctx, t.jetstream, t.topic, t.durable, 1, t.onMessage,
nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(),
)
}
// onMessage is called in response to a message received on the presence
// events topic from the client api.
func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
// only send presence events which originated from us
userID := msg.Header.Get(jetstream.UserID)
_, serverName, err := gomatrixserverlib.SplitID('@', userID)

View file

@ -65,14 +65,15 @@ func NewOutputReceiptConsumer(
// Start consuming from the clientapi
func (t *OutputReceiptConsumer) Start() error {
return jetstream.JetStreamConsumer(
t.ctx, t.jetstream, t.topic, t.durable, t.onMessage,
t.ctx, t.jetstream, t.topic, t.durable, 1, t.onMessage,
nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(),
)
}
// onMessage is called in response to a message received on the receipt
// events topic from the client api.
func (t *OutputReceiptConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
func (t *OutputReceiptConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
receipt := syncTypes.OutputReceiptEvent{
UserID: msg.Header.Get(jetstream.UserID),
RoomID: msg.Header.Get(jetstream.RoomID),

View file

@ -68,8 +68,8 @@ func NewOutputRoomEventConsumer(
// Start consuming from room servers
func (s *OutputRoomEventConsumer) Start() error {
return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
nats.DeliverAll(), nats.ManualAck(),
s.ctx, s.jetstream, s.topic, s.durable, 1,
s.onMessage, nats.DeliverAll(), nats.ManualAck(),
)
}
@ -77,7 +77,8 @@ func (s *OutputRoomEventConsumer) Start() error {
// It is unsafe to call this with messages for the same room in multiple gorountines
// because updates it will likely fail with a types.EventIDMismatchError when it
// realises that it cannot update the room state using the deltas.
func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
// Parse out the event JSON
var output api.OutputEvent
if err := json.Unmarshal(msg.Data, &output); err != nil {

View file

@ -63,14 +63,15 @@ func NewOutputSendToDeviceConsumer(
// Start consuming from the client api
func (t *OutputSendToDeviceConsumer) Start() error {
return jetstream.JetStreamConsumer(
t.ctx, t.jetstream, t.topic, t.durable, t.onMessage,
nats.DeliverAll(), nats.ManualAck(),
t.ctx, t.jetstream, t.topic, t.durable, 1,
t.onMessage, nats.DeliverAll(), nats.ManualAck(),
)
}
// onMessage is called in response to a message received on the
// send-to-device events topic from the client api.
func (t *OutputSendToDeviceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
func (t *OutputSendToDeviceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
// only send send-to-device events which originated from us
sender := msg.Header.Get("sender")
_, originServerName, err := gomatrixserverlib.SplitID('@', sender)

View file

@ -62,14 +62,15 @@ func NewOutputTypingConsumer(
// Start consuming from the clientapi
func (t *OutputTypingConsumer) Start() error {
return jetstream.JetStreamConsumer(
t.ctx, t.jetstream, t.topic, t.durable, t.onMessage,
t.ctx, t.jetstream, t.topic, t.durable, 1, t.onMessage,
nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(),
)
}
// onMessage is called in response to a message received on the typing
// events topic from the client api.
func (t *OutputTypingConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
func (t *OutputTypingConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
// Extract the typing event from msg.
roomID := msg.Header.Get(jetstream.RoomID)
userID := msg.Header.Get(jetstream.UserID)

View file

@ -55,14 +55,15 @@ func NewDeviceListUpdateConsumer(
// Start consuming from key servers
func (t *DeviceListUpdateConsumer) Start() error {
return jetstream.JetStreamConsumer(
t.ctx, t.jetstream, t.topic, t.durable, t.onMessage,
nats.DeliverAll(), nats.ManualAck(),
t.ctx, t.jetstream, t.topic, t.durable, 1,
t.onMessage, nats.DeliverAll(), nats.ManualAck(),
)
}
// onMessage is called in response to a message received on the
// key change events topic from the key server.
func (t *DeviceListUpdateConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
func (t *DeviceListUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
var m gomatrixserverlib.DeviceListUpdateEvent
if err := json.Unmarshal(msg.Data, &m); err != nil {
logrus.WithError(err).Errorf("Failed to read from device list update input topic")

View file

@ -9,9 +9,16 @@ import (
"github.com/sirupsen/logrus"
)
// JetStreamConsumer starts a durable consumer on the given subject with the
// given durable name. The function will be called when one or more messages
// is available, up to the maximum batch size specified. If the batch is set to
// 1 then messages will be delivered one at a time. If the function is called,
// the messages array is guaranteed to be at least 1 in size. Any provided NATS
// options will be passed through to the pull subscriber creation. The consumer
// will continue to run until the context expires, at which point it will stop.
func JetStreamConsumer(
ctx context.Context, js nats.JetStreamContext, subj, durable string,
f func(ctx context.Context, msg *nats.Msg) bool,
ctx context.Context, js nats.JetStreamContext, subj, durable string, batch int,
f func(ctx context.Context, msgs []*nats.Msg) bool,
opts ...nats.SubOpt,
) error {
defer func() {
@ -27,6 +34,14 @@ func JetStreamConsumer(
}
}()
// If the batch size is greater than 1, we will want to acknowledge all
// received messages in the batch. Below we will send an acknowledgement
// for the most recent message in the batch and AckAll will ensure that
// all messages that came before it are also acknowledged implicitly.
if batch > 1 {
opts = append(opts, nats.AckAll())
}
name := durable + "Pull"
sub, err := js.PullSubscribe(subj, name, opts...)
if err != nil {
@ -50,7 +65,7 @@ func JetStreamConsumer(
// enforce its own deadline (roughly 5 seconds by default). Therefore
// it is our responsibility to check whether our context expired or
// not when a context error is returned. Footguns. Footguns everywhere.
msgs, err := sub.Fetch(1, nats.Context(ctx))
msgs, err := sub.Fetch(batch, nats.Context(ctx))
if err != nil {
if err == context.Canceled || err == context.DeadlineExceeded {
// Work out whether it was the JetStream context that expired
@ -74,13 +89,13 @@ func JetStreamConsumer(
if len(msgs) < 1 {
continue
}
msg := msgs[0]
msg := msgs[len(msgs)-1] // most recent message, in case of AckAll
if err = msg.InProgress(nats.Context(ctx)); err != nil {
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.InProgress: %w", err))
sentry.CaptureException(err)
continue
}
if f(ctx, msg) {
if f(ctx, msgs) {
if err = msg.AckSync(nats.Context(ctx)); err != nil {
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.AckSync: %w", err))
sentry.CaptureException(err)

View file

@ -75,15 +75,16 @@ func NewOutputClientDataConsumer(
// Start consuming from room servers
func (s *OutputClientDataConsumer) Start() error {
return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
nats.DeliverAll(), nats.ManualAck(),
s.ctx, s.jetstream, s.topic, s.durable, 1,
s.onMessage, nats.DeliverAll(), nats.ManualAck(),
)
}
// onMessage is called when the sync server receives a new event from the client API server output log.
// It is not safe for this function to be called from multiple goroutines, or else the
// sync stream position may race and be incorrectly calculated.
func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
// Parse out the event JSON
userID := msg.Header.Get(jetstream.UserID)
var output eventutil.AccountData

View file

@ -75,12 +75,13 @@ func NewOutputKeyChangeEventConsumer(
// Start consuming from the key server
func (s *OutputKeyChangeEventConsumer) Start() error {
return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
nats.DeliverAll(), nats.ManualAck(),
s.ctx, s.jetstream, s.topic, s.durable, 1,
s.onMessage, nats.DeliverAll(), nats.ManualAck(),
)
}
func (s *OutputKeyChangeEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
func (s *OutputKeyChangeEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
var m api.DeviceMessage
if err := json.Unmarshal(msg.Data, &m); err != nil {
logrus.WithError(err).Errorf("failed to read device message from key change topic")

View file

@ -128,12 +128,13 @@ func (s *PresenceConsumer) Start() error {
return nil
}
return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.presenceTopic, s.durable, s.onMessage,
s.ctx, s.jetstream, s.presenceTopic, s.durable, 1, s.onMessage,
nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(),
)
}
func (s *PresenceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
func (s *PresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
userID := msg.Header.Get(jetstream.UserID)
presence := msg.Header.Get("presence")
timestamp := msg.Header.Get("last_active_ts")

View file

@ -74,12 +74,13 @@ func NewOutputReceiptEventConsumer(
// Start consuming receipts events.
func (s *OutputReceiptEventConsumer) Start() error {
return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
nats.DeliverAll(), nats.ManualAck(),
s.ctx, s.jetstream, s.topic, s.durable, 1,
s.onMessage, nats.DeliverAll(), nats.ManualAck(),
)
}
func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
output := types.OutputReceiptEvent{
UserID: msg.Header.Get(jetstream.UserID),
RoomID: msg.Header.Get(jetstream.RoomID),

View file

@ -79,15 +79,16 @@ func NewOutputRoomEventConsumer(
// Start consuming from room servers
func (s *OutputRoomEventConsumer) Start() error {
return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
nats.DeliverAll(), nats.ManualAck(),
s.ctx, s.jetstream, s.topic, s.durable, 1,
s.onMessage, nats.DeliverAll(), nats.ManualAck(),
)
}
// onMessage is called when the sync server receives a new event from the room server output log.
// It is not safe for this function to be called from multiple goroutines, or else the
// sync stream position may race and be incorrectly calculated.
func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
// Parse out the event JSON
var err error
var output api.OutputEvent

View file

@ -68,12 +68,13 @@ func NewOutputSendToDeviceEventConsumer(
// Start consuming send-to-device events.
func (s *OutputSendToDeviceEventConsumer) Start() error {
return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
nats.DeliverAll(), nats.ManualAck(),
s.ctx, s.jetstream, s.topic, s.durable, 1,
s.onMessage, nats.DeliverAll(), nats.ManualAck(),
)
}
func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
userID := msg.Header.Get(jetstream.UserID)
_, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {

View file

@ -64,12 +64,13 @@ func NewOutputTypingEventConsumer(
// Start consuming typing events.
func (s *OutputTypingEventConsumer) Start() error {
return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
nats.DeliverAll(), nats.ManualAck(),
s.ctx, s.jetstream, s.topic, s.durable, 1,
s.onMessage, nats.DeliverAll(), nats.ManualAck(),
)
}
func (s *OutputTypingEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
func (s *OutputTypingEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
roomID := msg.Header.Get(jetstream.RoomID)
userID := msg.Header.Get(jetstream.UserID)
typing, err := strconv.ParseBool(msg.Header.Get("typing"))

View file

@ -67,8 +67,8 @@ func NewOutputNotificationDataConsumer(
// Start starts consumption.
func (s *OutputNotificationDataConsumer) Start() error {
return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
nats.DeliverAll(), nats.ManualAck(),
s.ctx, s.jetstream, s.topic, s.durable, 1,
s.onMessage, nats.DeliverAll(), nats.ManualAck(),
)
}
@ -76,7 +76,8 @@ func (s *OutputNotificationDataConsumer) Start() error {
// the push server. It is not safe for this function to be called from
// multiple goroutines, or else the sync stream position may race and
// be incorrectly calculated.
func (s *OutputNotificationDataConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
func (s *OutputNotificationDataConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
userID := string(msg.Header.Get(jetstream.UserID))
// Parse out the event JSON

View file

@ -56,15 +56,16 @@ func NewOutputReadUpdateConsumer(
func (s *OutputReadUpdateConsumer) Start() error {
if err := jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
nats.DeliverAll(), nats.ManualAck(),
s.ctx, s.jetstream, s.topic, s.durable, 1,
s.onMessage, nats.DeliverAll(), nats.ManualAck(),
); err != nil {
return err
}
return nil
}
func (s *OutputReadUpdateConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
func (s *OutputReadUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
var read types.ReadUpdate
if err := json.Unmarshal(msg.Data, &read); err != nil {
log.WithError(err).Error("userapi clientapi consumer: message parse failure")

View file

@ -65,15 +65,16 @@ func NewOutputStreamEventConsumer(
func (s *OutputStreamEventConsumer) Start() error {
if err := jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
nats.DeliverAll(), nats.ManualAck(),
s.ctx, s.jetstream, s.topic, s.durable, 1,
s.onMessage, nats.DeliverAll(), nats.ManualAck(),
); err != nil {
return err
}
return nil
}
func (s *OutputStreamEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
func (s *OutputStreamEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
var output types.StreamedEvent
output.Event = &gomatrixserverlib.HeaderedEvent{}
if err := json.Unmarshal(msg.Data, &output); err != nil {