From ba0b3adab4de7865afd467b61638437b1af39fce Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 31 Aug 2022 10:41:32 +0100 Subject: [PATCH 1/2] Pinecone standalone refactoring (#2685) This refactors the `dendrite-demo-pinecone` executable so that it: 1. Converts the old `.key` file into a standard `.pem` file 2. Allows passing in the `--config` option to supply a normal Dendrite configuration file, so that you can configure PostgreSQL instead of SQLite, appservices and all the other usual stuff --- cmd/dendrite-demo-pinecone/main.go | 110 +++++++++++++++++++---------- setup/config/config.go | 15 ++-- test/keys.go | 7 +- 3 files changed, 86 insertions(+), 46 deletions(-) diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index 75f29fe27..b16cfec6a 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -24,6 +24,7 @@ import ( "net" "net/http" "os" + "strings" "time" "github.com/gorilla/mux" @@ -42,6 +43,7 @@ import ( "github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/gomatrixserverlib" @@ -70,31 +72,84 @@ func main() { var pk ed25519.PublicKey var sk ed25519.PrivateKey - keyfile := *instanceName + ".key" - if _, err := os.Stat(keyfile); os.IsNotExist(err) { - if pk, sk, err = ed25519.GenerateKey(nil); err != nil { - panic(err) + // iterate through the cli args and check if the config flag was set + configFlagSet := false + for _, arg := range os.Args { + if arg == "--config" || arg == "-config" { + configFlagSet = true + break } - if err = os.WriteFile(keyfile, sk, 0644); err != nil { - panic(err) - } - } else if err == nil { - if sk, err = os.ReadFile(keyfile); err != nil { - panic(err) - } - if len(sk) != ed25519.PrivateKeySize { - panic("the private key is not long enough") - } - pk = sk.Public().(ed25519.PublicKey) } + cfg := &config.Dendrite{} + + // use custom config if config flag is set + if configFlagSet { + cfg = setup.ParseFlags(true) + sk = cfg.Global.PrivateKey + } else { + keyfile := *instanceName + ".pem" + if _, err := os.Stat(keyfile); os.IsNotExist(err) { + oldkeyfile := *instanceName + ".key" + if _, err = os.Stat(oldkeyfile); os.IsNotExist(err) { + if err = test.NewMatrixKey(keyfile); err != nil { + panic("failed to generate a new PEM key: " + err.Error()) + } + if _, sk, err = config.LoadMatrixKey(keyfile, os.ReadFile); err != nil { + panic("failed to load PEM key: " + err.Error()) + } + } else { + if sk, err = os.ReadFile(oldkeyfile); err != nil { + panic("failed to read the old private key: " + err.Error()) + } + if len(sk) != ed25519.PrivateKeySize { + panic("the private key is not long enough") + } + if err := test.SaveMatrixKey(keyfile, sk); err != nil { + panic("failed to convert the private key to PEM format: " + err.Error()) + } + } + } else { + var err error + if _, sk, err = config.LoadMatrixKey(keyfile, os.ReadFile); err != nil { + panic("failed to load PEM key: " + err.Error()) + } + } + cfg.Defaults(true) + cfg.Global.PrivateKey = sk + cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName)) + cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName)) + cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName)) + cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName)) + cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName)) + cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", *instanceName)) + cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", *instanceName)) + cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-appservice.db", *instanceName)) + cfg.MSCs.MSCs = []string{"msc2836", "msc2946"} + cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", *instanceName)) + cfg.ClientAPI.RegistrationDisabled = false + cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true + if err := cfg.Derive(); err != nil { + panic(err) + } + } + + pk = sk.Public().(ed25519.PublicKey) + cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk)) + cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) + + base := base.NewBaseDendrite(cfg, "Monolith") + defer base.Close() // nolint: errcheck + pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk, false) pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"}) pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter) pManager := pineconeConnections.NewConnectionManager(pRouter, nil) pMulticast.Start() if instancePeer != nil && *instancePeer != "" { - pManager.AddPeer(*instancePeer) + for _, peer := range strings.Split(*instancePeer, ",") { + pManager.AddPeer(strings.Trim(peer, " \t\r\n")) + } } go func() { @@ -125,29 +180,6 @@ func main() { } }() - cfg := &config.Dendrite{} - cfg.Defaults(true) - cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk)) - cfg.Global.PrivateKey = sk - cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) - cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName)) - cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName)) - cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName)) - cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName)) - cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName)) - cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", *instanceName)) - cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", *instanceName)) - cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-appservice.db", *instanceName)) - cfg.MSCs.MSCs = []string{"msc2836", "msc2946"} - cfg.ClientAPI.RegistrationDisabled = false - cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true - if err := cfg.Derive(); err != nil { - panic(err) - } - - base := base.NewBaseDendrite(cfg, "Monolith") - defer base.Close() // nolint: errcheck - federation := conn.CreateFederationClient(base, pQUIC) serverKeyAPI := &signing.YggdrasilKeys{} diff --git a/setup/config/config.go b/setup/config/config.go index 924b51f22..cc9c04470 100644 --- a/setup/config/config.go +++ b/setup/config/config.go @@ -224,12 +224,7 @@ func loadConfig( } privateKeyPath := absPath(basePath, c.Global.PrivateKeyPath) - privateKeyData, err := readFile(privateKeyPath) - if err != nil { - return nil, err - } - - if c.Global.KeyID, c.Global.PrivateKey, err = readKeyPEM(privateKeyPath, privateKeyData, true); err != nil { + if c.Global.KeyID, c.Global.PrivateKey, err = LoadMatrixKey(privateKeyPath, readFile); err != nil { return nil, err } @@ -265,6 +260,14 @@ func loadConfig( return &c, nil } +func LoadMatrixKey(privateKeyPath string, readFile func(string) ([]byte, error)) (gomatrixserverlib.KeyID, ed25519.PrivateKey, error) { + privateKeyData, err := readFile(privateKeyPath) + if err != nil { + return "", nil, err + } + return readKeyPEM(privateKeyPath, privateKeyData, true) +} + // Derive generates data that is derived from various values provided in // the config file. func (config *Dendrite) Derive() error { diff --git a/test/keys.go b/test/keys.go index fb156ef27..05f7317cf 100644 --- a/test/keys.go +++ b/test/keys.go @@ -15,6 +15,7 @@ package test import ( + "crypto/ed25519" "crypto/rand" "crypto/rsa" "crypto/x509" @@ -44,6 +45,10 @@ func NewMatrixKey(matrixKeyPath string) (err error) { if err != nil { return err } + return SaveMatrixKey(matrixKeyPath, data[3:]) +} + +func SaveMatrixKey(matrixKeyPath string, data ed25519.PrivateKey) error { keyOut, err := os.OpenFile(matrixKeyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) if err != nil { return err @@ -62,7 +67,7 @@ func NewMatrixKey(matrixKeyPath string) (err error) { Headers: map[string]string{ "Key-ID": fmt.Sprintf("ed25519:%s", keyID[:6]), }, - Bytes: data[3:], + Bytes: data, }) return err } From 175f65407a7f684753334022e66b8209f3db7396 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 31 Aug 2022 12:21:56 +0100 Subject: [PATCH 2/2] Allow batching in `JetStreamConsumer` (#2686) This allows us to receive more than one message from NATS at a time if we want. --- appservice/consumers/roomserver.go | 7 ++++--- federationapi/consumers/keychange.go | 7 ++++--- federationapi/consumers/presence.go | 5 +++-- federationapi/consumers/receipts.go | 5 +++-- federationapi/consumers/roomserver.go | 7 ++++--- federationapi/consumers/sendtodevice.go | 7 ++++--- federationapi/consumers/typing.go | 5 +++-- keyserver/consumers/devicelistupdate.go | 7 ++++--- setup/jetstream/helpers.go | 25 +++++++++++++++++++----- syncapi/consumers/clientapi.go | 7 ++++--- syncapi/consumers/keychange.go | 7 ++++--- syncapi/consumers/presence.go | 5 +++-- syncapi/consumers/receipts.go | 7 ++++--- syncapi/consumers/roomserver.go | 7 ++++--- syncapi/consumers/sendtodevice.go | 7 ++++--- syncapi/consumers/typing.go | 7 ++++--- syncapi/consumers/userapi.go | 7 ++++--- userapi/consumers/syncapi_readupdate.go | 7 ++++--- userapi/consumers/syncapi_streamevent.go | 7 ++++--- 19 files changed, 88 insertions(+), 55 deletions(-) diff --git a/appservice/consumers/roomserver.go b/appservice/consumers/roomserver.go index d567408be..21b52bc3c 100644 --- a/appservice/consumers/roomserver.go +++ b/appservice/consumers/roomserver.go @@ -68,14 +68,15 @@ func NewOutputRoomEventConsumer( // Start consuming from room servers func (s *OutputRoomEventConsumer) Start() error { return jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, - nats.DeliverAll(), nats.ManualAck(), + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), ) } // onMessage is called when the appservice component receives a new event from // the room server output log. -func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called // Parse out the event JSON var output api.OutputEvent if err := json.Unmarshal(msg.Data, &output); err != nil { diff --git a/federationapi/consumers/keychange.go b/federationapi/consumers/keychange.go index 6d3cf0e46..f3314bc98 100644 --- a/federationapi/consumers/keychange.go +++ b/federationapi/consumers/keychange.go @@ -67,14 +67,15 @@ func NewKeyChangeConsumer( // Start consuming from key servers func (t *KeyChangeConsumer) Start() error { return jetstream.JetStreamConsumer( - t.ctx, t.jetstream, t.topic, t.durable, t.onMessage, - nats.DeliverAll(), nats.ManualAck(), + t.ctx, t.jetstream, t.topic, t.durable, 1, + t.onMessage, nats.DeliverAll(), nats.ManualAck(), ) } // onMessage is called in response to a message received on the // key change events topic from the key server. -func (t *KeyChangeConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (t *KeyChangeConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called var m api.DeviceMessage if err := json.Unmarshal(msg.Data, &m); err != nil { logrus.WithError(err).Errorf("failed to read device message from key change topic") diff --git a/federationapi/consumers/presence.go b/federationapi/consumers/presence.go index a65d2aa04..e76103cd3 100644 --- a/federationapi/consumers/presence.go +++ b/federationapi/consumers/presence.go @@ -69,14 +69,15 @@ func (t *OutputPresenceConsumer) Start() error { return nil } return jetstream.JetStreamConsumer( - t.ctx, t.jetstream, t.topic, t.durable, t.onMessage, + t.ctx, t.jetstream, t.topic, t.durable, 1, t.onMessage, nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(), ) } // onMessage is called in response to a message received on the presence // events topic from the client api. -func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called // only send presence events which originated from us userID := msg.Header.Get(jetstream.UserID) _, serverName, err := gomatrixserverlib.SplitID('@', userID) diff --git a/federationapi/consumers/receipts.go b/federationapi/consumers/receipts.go index 2c9d79bcb..366cb264e 100644 --- a/federationapi/consumers/receipts.go +++ b/federationapi/consumers/receipts.go @@ -65,14 +65,15 @@ func NewOutputReceiptConsumer( // Start consuming from the clientapi func (t *OutputReceiptConsumer) Start() error { return jetstream.JetStreamConsumer( - t.ctx, t.jetstream, t.topic, t.durable, t.onMessage, + t.ctx, t.jetstream, t.topic, t.durable, 1, t.onMessage, nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(), ) } // onMessage is called in response to a message received on the receipt // events topic from the client api. -func (t *OutputReceiptConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (t *OutputReceiptConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called receipt := syncTypes.OutputReceiptEvent{ UserID: msg.Header.Get(jetstream.UserID), RoomID: msg.Header.Get(jetstream.RoomID), diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index 2622ecb3f..349b50b05 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -68,8 +68,8 @@ func NewOutputRoomEventConsumer( // Start consuming from room servers func (s *OutputRoomEventConsumer) Start() error { return jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, - nats.DeliverAll(), nats.ManualAck(), + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), ) } @@ -77,7 +77,8 @@ func (s *OutputRoomEventConsumer) Start() error { // It is unsafe to call this with messages for the same room in multiple gorountines // because updates it will likely fail with a types.EventIDMismatchError when it // realises that it cannot update the room state using the deltas. -func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called // Parse out the event JSON var output api.OutputEvent if err := json.Unmarshal(msg.Data, &output); err != nil { diff --git a/federationapi/consumers/sendtodevice.go b/federationapi/consumers/sendtodevice.go index f99a895e0..e44bad723 100644 --- a/federationapi/consumers/sendtodevice.go +++ b/federationapi/consumers/sendtodevice.go @@ -63,14 +63,15 @@ func NewOutputSendToDeviceConsumer( // Start consuming from the client api func (t *OutputSendToDeviceConsumer) Start() error { return jetstream.JetStreamConsumer( - t.ctx, t.jetstream, t.topic, t.durable, t.onMessage, - nats.DeliverAll(), nats.ManualAck(), + t.ctx, t.jetstream, t.topic, t.durable, 1, + t.onMessage, nats.DeliverAll(), nats.ManualAck(), ) } // onMessage is called in response to a message received on the // send-to-device events topic from the client api. -func (t *OutputSendToDeviceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (t *OutputSendToDeviceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called // only send send-to-device events which originated from us sender := msg.Header.Get("sender") _, originServerName, err := gomatrixserverlib.SplitID('@', sender) diff --git a/federationapi/consumers/typing.go b/federationapi/consumers/typing.go index 428e1a867..9c7379136 100644 --- a/federationapi/consumers/typing.go +++ b/federationapi/consumers/typing.go @@ -62,14 +62,15 @@ func NewOutputTypingConsumer( // Start consuming from the clientapi func (t *OutputTypingConsumer) Start() error { return jetstream.JetStreamConsumer( - t.ctx, t.jetstream, t.topic, t.durable, t.onMessage, + t.ctx, t.jetstream, t.topic, t.durable, 1, t.onMessage, nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(), ) } // onMessage is called in response to a message received on the typing // events topic from the client api. -func (t *OutputTypingConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (t *OutputTypingConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called // Extract the typing event from msg. roomID := msg.Header.Get(jetstream.RoomID) userID := msg.Header.Get(jetstream.UserID) diff --git a/keyserver/consumers/devicelistupdate.go b/keyserver/consumers/devicelistupdate.go index f4f246280..d15f94267 100644 --- a/keyserver/consumers/devicelistupdate.go +++ b/keyserver/consumers/devicelistupdate.go @@ -55,14 +55,15 @@ func NewDeviceListUpdateConsumer( // Start consuming from key servers func (t *DeviceListUpdateConsumer) Start() error { return jetstream.JetStreamConsumer( - t.ctx, t.jetstream, t.topic, t.durable, t.onMessage, - nats.DeliverAll(), nats.ManualAck(), + t.ctx, t.jetstream, t.topic, t.durable, 1, + t.onMessage, nats.DeliverAll(), nats.ManualAck(), ) } // onMessage is called in response to a message received on the // key change events topic from the key server. -func (t *DeviceListUpdateConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (t *DeviceListUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called var m gomatrixserverlib.DeviceListUpdateEvent if err := json.Unmarshal(msg.Data, &m); err != nil { logrus.WithError(err).Errorf("Failed to read from device list update input topic") diff --git a/setup/jetstream/helpers.go b/setup/jetstream/helpers.go index 1c07583e9..f47637c69 100644 --- a/setup/jetstream/helpers.go +++ b/setup/jetstream/helpers.go @@ -9,9 +9,16 @@ import ( "github.com/sirupsen/logrus" ) +// JetStreamConsumer starts a durable consumer on the given subject with the +// given durable name. The function will be called when one or more messages +// is available, up to the maximum batch size specified. If the batch is set to +// 1 then messages will be delivered one at a time. If the function is called, +// the messages array is guaranteed to be at least 1 in size. Any provided NATS +// options will be passed through to the pull subscriber creation. The consumer +// will continue to run until the context expires, at which point it will stop. func JetStreamConsumer( - ctx context.Context, js nats.JetStreamContext, subj, durable string, - f func(ctx context.Context, msg *nats.Msg) bool, + ctx context.Context, js nats.JetStreamContext, subj, durable string, batch int, + f func(ctx context.Context, msgs []*nats.Msg) bool, opts ...nats.SubOpt, ) error { defer func() { @@ -27,6 +34,14 @@ func JetStreamConsumer( } }() + // If the batch size is greater than 1, we will want to acknowledge all + // received messages in the batch. Below we will send an acknowledgement + // for the most recent message in the batch and AckAll will ensure that + // all messages that came before it are also acknowledged implicitly. + if batch > 1 { + opts = append(opts, nats.AckAll()) + } + name := durable + "Pull" sub, err := js.PullSubscribe(subj, name, opts...) if err != nil { @@ -50,7 +65,7 @@ func JetStreamConsumer( // enforce its own deadline (roughly 5 seconds by default). Therefore // it is our responsibility to check whether our context expired or // not when a context error is returned. Footguns. Footguns everywhere. - msgs, err := sub.Fetch(1, nats.Context(ctx)) + msgs, err := sub.Fetch(batch, nats.Context(ctx)) if err != nil { if err == context.Canceled || err == context.DeadlineExceeded { // Work out whether it was the JetStream context that expired @@ -74,13 +89,13 @@ func JetStreamConsumer( if len(msgs) < 1 { continue } - msg := msgs[0] + msg := msgs[len(msgs)-1] // most recent message, in case of AckAll if err = msg.InProgress(nats.Context(ctx)); err != nil { logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.InProgress: %w", err)) sentry.CaptureException(err) continue } - if f(ctx, msg) { + if f(ctx, msgs) { if err = msg.AckSync(nats.Context(ctx)); err != nil { logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.AckSync: %w", err)) sentry.CaptureException(err) diff --git a/syncapi/consumers/clientapi.go b/syncapi/consumers/clientapi.go index 02633b567..f0588cab8 100644 --- a/syncapi/consumers/clientapi.go +++ b/syncapi/consumers/clientapi.go @@ -75,15 +75,16 @@ func NewOutputClientDataConsumer( // Start consuming from room servers func (s *OutputClientDataConsumer) Start() error { return jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, - nats.DeliverAll(), nats.ManualAck(), + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), ) } // onMessage is called when the sync server receives a new event from the client API server output log. // It is not safe for this function to be called from multiple goroutines, or else the // sync stream position may race and be incorrectly calculated. -func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called // Parse out the event JSON userID := msg.Header.Get(jetstream.UserID) var output eventutil.AccountData diff --git a/syncapi/consumers/keychange.go b/syncapi/consumers/keychange.go index c8d88ddac..c42e71971 100644 --- a/syncapi/consumers/keychange.go +++ b/syncapi/consumers/keychange.go @@ -75,12 +75,13 @@ func NewOutputKeyChangeEventConsumer( // Start consuming from the key server func (s *OutputKeyChangeEventConsumer) Start() error { return jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, - nats.DeliverAll(), nats.ManualAck(), + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), ) } -func (s *OutputKeyChangeEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (s *OutputKeyChangeEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called var m api.DeviceMessage if err := json.Unmarshal(msg.Data, &m); err != nil { logrus.WithError(err).Errorf("failed to read device message from key change topic") diff --git a/syncapi/consumers/presence.go b/syncapi/consumers/presence.go index db7d67fa6..61bdc13de 100644 --- a/syncapi/consumers/presence.go +++ b/syncapi/consumers/presence.go @@ -128,12 +128,13 @@ func (s *PresenceConsumer) Start() error { return nil } return jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.presenceTopic, s.durable, s.onMessage, + s.ctx, s.jetstream, s.presenceTopic, s.durable, 1, s.onMessage, nats.DeliverAll(), nats.ManualAck(), nats.HeadersOnly(), ) } -func (s *PresenceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (s *PresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called userID := msg.Header.Get(jetstream.UserID) presence := msg.Header.Get("presence") timestamp := msg.Header.Get("last_active_ts") diff --git a/syncapi/consumers/receipts.go b/syncapi/consumers/receipts.go index 83156cf93..a18244c44 100644 --- a/syncapi/consumers/receipts.go +++ b/syncapi/consumers/receipts.go @@ -74,12 +74,13 @@ func NewOutputReceiptEventConsumer( // Start consuming receipts events. func (s *OutputReceiptEventConsumer) Start() error { return jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, - nats.DeliverAll(), nats.ManualAck(), + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), ) } -func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called output := types.OutputReceiptEvent{ UserID: msg.Header.Get(jetstream.UserID), RoomID: msg.Header.Get(jetstream.RoomID), diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index f77b1673b..6979eb484 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -79,15 +79,16 @@ func NewOutputRoomEventConsumer( // Start consuming from room servers func (s *OutputRoomEventConsumer) Start() error { return jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, - nats.DeliverAll(), nats.ManualAck(), + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), ) } // onMessage is called when the sync server receives a new event from the room server output log. // It is not safe for this function to be called from multiple goroutines, or else the // sync stream position may race and be incorrectly calculated. -func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called // Parse out the event JSON var err error var output api.OutputEvent diff --git a/syncapi/consumers/sendtodevice.go b/syncapi/consumers/sendtodevice.go index 0b9153fcd..89b01d7e5 100644 --- a/syncapi/consumers/sendtodevice.go +++ b/syncapi/consumers/sendtodevice.go @@ -68,12 +68,13 @@ func NewOutputSendToDeviceEventConsumer( // Start consuming send-to-device events. func (s *OutputSendToDeviceEventConsumer) Start() error { return jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, - nats.DeliverAll(), nats.ManualAck(), + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), ) } -func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called userID := msg.Header.Get(jetstream.UserID) _, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { diff --git a/syncapi/consumers/typing.go b/syncapi/consumers/typing.go index 48e484ec5..88db80f8c 100644 --- a/syncapi/consumers/typing.go +++ b/syncapi/consumers/typing.go @@ -64,12 +64,13 @@ func NewOutputTypingEventConsumer( // Start consuming typing events. func (s *OutputTypingEventConsumer) Start() error { return jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, - nats.DeliverAll(), nats.ManualAck(), + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), ) } -func (s *OutputTypingEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (s *OutputTypingEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called roomID := msg.Header.Get(jetstream.RoomID) userID := msg.Header.Get(jetstream.UserID) typing, err := strconv.ParseBool(msg.Header.Get("typing")) diff --git a/syncapi/consumers/userapi.go b/syncapi/consumers/userapi.go index 010fa7c8e..227823522 100644 --- a/syncapi/consumers/userapi.go +++ b/syncapi/consumers/userapi.go @@ -67,8 +67,8 @@ func NewOutputNotificationDataConsumer( // Start starts consumption. func (s *OutputNotificationDataConsumer) Start() error { return jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, - nats.DeliverAll(), nats.ManualAck(), + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), ) } @@ -76,7 +76,8 @@ func (s *OutputNotificationDataConsumer) Start() error { // the push server. It is not safe for this function to be called from // multiple goroutines, or else the sync stream position may race and // be incorrectly calculated. -func (s *OutputNotificationDataConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (s *OutputNotificationDataConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called userID := string(msg.Header.Get(jetstream.UserID)) // Parse out the event JSON diff --git a/userapi/consumers/syncapi_readupdate.go b/userapi/consumers/syncapi_readupdate.go index 067f93330..54654f757 100644 --- a/userapi/consumers/syncapi_readupdate.go +++ b/userapi/consumers/syncapi_readupdate.go @@ -56,15 +56,16 @@ func NewOutputReadUpdateConsumer( func (s *OutputReadUpdateConsumer) Start() error { if err := jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, - nats.DeliverAll(), nats.ManualAck(), + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), ); err != nil { return err } return nil } -func (s *OutputReadUpdateConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (s *OutputReadUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called var read types.ReadUpdate if err := json.Unmarshal(msg.Data, &read); err != nil { log.WithError(err).Error("userapi clientapi consumer: message parse failure") diff --git a/userapi/consumers/syncapi_streamevent.go b/userapi/consumers/syncapi_streamevent.go index ec351ef7e..3ac6f58d0 100644 --- a/userapi/consumers/syncapi_streamevent.go +++ b/userapi/consumers/syncapi_streamevent.go @@ -65,15 +65,16 @@ func NewOutputStreamEventConsumer( func (s *OutputStreamEventConsumer) Start() error { if err := jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, - nats.DeliverAll(), nats.ManualAck(), + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), ); err != nil { return err } return nil } -func (s *OutputStreamEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { +func (s *OutputStreamEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called var output types.StreamedEvent output.Event = &gomatrixserverlib.HeaderedEvent{} if err := json.Unmarshal(msg.Data, &output); err != nil {