Merge branch 'master' into remove-sarama-dep

This commit is contained in:
Neil Alexander 2022-02-02 13:45:50 +00:00
commit c8e2f65299
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
16 changed files with 688 additions and 611 deletions

View file

@ -34,7 +34,7 @@ import (
type OutputRoomEventConsumer struct { type OutputRoomEventConsumer struct {
ctx context.Context ctx context.Context
jetstream nats.JetStreamContext jetstream nats.JetStreamContext
durable nats.SubOpt durable string
topic string topic string
asDB storage.Database asDB storage.Database
rsAPI api.RoomserverInternalAPI rsAPI api.RoomserverInternalAPI
@ -66,14 +66,15 @@ func NewOutputRoomEventConsumer(
// Start consuming from room servers // Start consuming from room servers
func (s *OutputRoomEventConsumer) Start() error { func (s *OutputRoomEventConsumer) Start() error {
_, err := s.jetstream.Subscribe(s.topic, s.onMessage, s.durable) return jetstream.JetStreamConsumer(
return err s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
nats.DeliverAll(), nats.ManualAck(),
)
} }
// onMessage is called when the appservice component receives a new event from // onMessage is called when the appservice component receives a new event from
// the room server output log. // the room server output log.
func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) { func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
// Parse out the event JSON // Parse out the event JSON
var output api.OutputEvent var output api.OutputEvent
if err := json.Unmarshal(msg.Data, &output); err != nil { if err := json.Unmarshal(msg.Data, &output); err != nil {
@ -96,7 +97,6 @@ func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) {
} }
return true return true
})
} }
// filterRoomserverEvents takes in events and decides whether any of them need // filterRoomserverEvents takes in events and decides whether any of them need

View file

@ -34,7 +34,7 @@ import (
type OutputEDUConsumer struct { type OutputEDUConsumer struct {
ctx context.Context ctx context.Context
jetstream nats.JetStreamContext jetstream nats.JetStreamContext
durable nats.SubOpt durable string
db storage.Database db storage.Database
queues *queue.OutgoingQueues queues *queue.OutgoingQueues
ServerName gomatrixserverlib.ServerName ServerName gomatrixserverlib.ServerName
@ -66,13 +66,22 @@ func NewOutputEDUConsumer(
// Start consuming from EDU servers // Start consuming from EDU servers
func (t *OutputEDUConsumer) Start() error { func (t *OutputEDUConsumer) Start() error {
if _, err := t.jetstream.Subscribe(t.typingTopic, t.onTypingEvent, t.durable); err != nil { if err := jetstream.JetStreamConsumer(
t.ctx, t.jetstream, t.typingTopic, t.durable, t.onTypingEvent,
nats.DeliverAll(), nats.ManualAck(),
); err != nil {
return err return err
} }
if _, err := t.jetstream.Subscribe(t.sendToDeviceTopic, t.onSendToDeviceEvent, t.durable); err != nil { if err := jetstream.JetStreamConsumer(
t.ctx, t.jetstream, t.sendToDeviceTopic, t.durable, t.onSendToDeviceEvent,
nats.DeliverAll(), nats.ManualAck(),
); err != nil {
return err return err
} }
if _, err := t.jetstream.Subscribe(t.receiptTopic, t.onReceiptEvent, t.durable); err != nil { if err := jetstream.JetStreamConsumer(
t.ctx, t.jetstream, t.receiptTopic, t.durable, t.onReceiptEvent,
nats.DeliverAll(), nats.ManualAck(),
); err != nil {
return err return err
} }
return nil return nil
@ -80,9 +89,8 @@ func (t *OutputEDUConsumer) Start() error {
// onSendToDeviceEvent is called in response to a message received on the // onSendToDeviceEvent is called in response to a message received on the
// send-to-device events topic from the EDU server. // send-to-device events topic from the EDU server.
func (t *OutputEDUConsumer) onSendToDeviceEvent(msg *nats.Msg) { func (t *OutputEDUConsumer) onSendToDeviceEvent(ctx context.Context, msg *nats.Msg) bool {
// Extract the send-to-device event from msg. // Extract the send-to-device event from msg.
jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
var ote api.OutputSendToDeviceEvent var ote api.OutputSendToDeviceEvent
if err := json.Unmarshal(msg.Data, &ote); err != nil { if err := json.Unmarshal(msg.Data, &ote); err != nil {
log.WithError(err).Errorf("eduserver output log: message parse failed (expected send-to-device)") log.WithError(err).Errorf("eduserver output log: message parse failed (expected send-to-device)")
@ -133,13 +141,11 @@ func (t *OutputEDUConsumer) onSendToDeviceEvent(msg *nats.Msg) {
} }
return true return true
})
} }
// onTypingEvent is called in response to a message received on the typing // onTypingEvent is called in response to a message received on the typing
// events topic from the EDU server. // events topic from the EDU server.
func (t *OutputEDUConsumer) onTypingEvent(msg *nats.Msg) { func (t *OutputEDUConsumer) onTypingEvent(ctx context.Context, msg *nats.Msg) bool {
jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
// Extract the typing event from msg. // Extract the typing event from msg.
var ote api.OutputTypingEvent var ote api.OutputTypingEvent
if err := json.Unmarshal(msg.Data, &ote); err != nil { if err := json.Unmarshal(msg.Data, &ote); err != nil {
@ -160,7 +166,7 @@ func (t *OutputEDUConsumer) onTypingEvent(msg *nats.Msg) {
return true return true
} }
joined, err := t.db.GetJoinedHosts(t.ctx, ote.Event.RoomID) joined, err := t.db.GetJoinedHosts(ctx, ote.Event.RoomID)
if err != nil { if err != nil {
log.WithError(err).WithField("room_id", ote.Event.RoomID).Error("failed to get joined hosts for room") log.WithError(err).WithField("room_id", ote.Event.RoomID).Error("failed to get joined hosts for room")
return false return false
@ -187,13 +193,11 @@ func (t *OutputEDUConsumer) onTypingEvent(msg *nats.Msg) {
} }
return true return true
})
} }
// onReceiptEvent is called in response to a message received on the receipt // onReceiptEvent is called in response to a message received on the receipt
// events topic from the EDU server. // events topic from the EDU server.
func (t *OutputEDUConsumer) onReceiptEvent(msg *nats.Msg) { func (t *OutputEDUConsumer) onReceiptEvent(ctx context.Context, msg *nats.Msg) bool {
jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
// Extract the typing event from msg. // Extract the typing event from msg.
var receipt api.OutputReceiptEvent var receipt api.OutputReceiptEvent
if err := json.Unmarshal(msg.Data, &receipt); err != nil { if err := json.Unmarshal(msg.Data, &receipt); err != nil {
@ -212,7 +216,7 @@ func (t *OutputEDUConsumer) onReceiptEvent(msg *nats.Msg) {
return true return true
} }
joined, err := t.db.GetJoinedHosts(t.ctx, receipt.RoomID) joined, err := t.db.GetJoinedHosts(ctx, receipt.RoomID)
if err != nil { if err != nil {
log.WithError(err).WithField("room_id", receipt.RoomID).Error("failed to get joined hosts for room") log.WithError(err).WithField("room_id", receipt.RoomID).Error("failed to get joined hosts for room")
return false return false
@ -250,5 +254,4 @@ func (t *OutputEDUConsumer) onReceiptEvent(msg *nats.Msg) {
} }
return true return true
})
} }

View file

@ -35,6 +35,7 @@ import (
type KeyChangeConsumer struct { type KeyChangeConsumer struct {
ctx context.Context ctx context.Context
jetstream nats.JetStreamContext jetstream nats.JetStreamContext
durable string
db storage.Database db storage.Database
queues *queue.OutgoingQueues queues *queue.OutgoingQueues
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
@ -54,6 +55,7 @@ func NewKeyChangeConsumer(
return &KeyChangeConsumer{ return &KeyChangeConsumer{
ctx: process.Context(), ctx: process.Context(),
jetstream: js, jetstream: js,
durable: cfg.Matrix.JetStream.TopicFor("FederationAPIKeyChangeConsumer"),
topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent), topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent),
queues: queues, queues: queues,
db: store, db: store,
@ -64,17 +66,15 @@ func NewKeyChangeConsumer(
// Start consuming from key servers // Start consuming from key servers
func (t *KeyChangeConsumer) Start() error { func (t *KeyChangeConsumer) Start() error {
_, err := t.jetstream.Subscribe( return jetstream.JetStreamConsumer(
t.topic, t.onMessage, t.ctx, t.jetstream, t.topic, t.durable, t.onMessage,
nats.DeliverAll(), nats.DeliverAll(), nats.ManualAck(),
) )
return err
} }
// onMessage is called in response to a message received on the // onMessage is called in response to a message received on the
// key change events topic from the key server. // key change events topic from the key server.
func (t *KeyChangeConsumer) onMessage(msg *nats.Msg) { func (t *KeyChangeConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
var m api.DeviceMessage var m api.DeviceMessage
if err := json.Unmarshal(msg.Data, &m); err != nil { if err := json.Unmarshal(msg.Data, &m); err != nil {
logrus.WithError(err).Errorf("failed to read device message from key change topic") logrus.WithError(err).Errorf("failed to read device message from key change topic")
@ -93,8 +93,6 @@ func (t *KeyChangeConsumer) onMessage(msg *nats.Msg) {
default: default:
return t.onDeviceKeyMessage(m) return t.onDeviceKeyMessage(m)
} }
})
} }
func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool { func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool {

View file

@ -37,7 +37,7 @@ type OutputRoomEventConsumer struct {
cfg *config.FederationAPI cfg *config.FederationAPI
rsAPI api.RoomserverInternalAPI rsAPI api.RoomserverInternalAPI
jetstream nats.JetStreamContext jetstream nats.JetStreamContext
durable nats.SubOpt durable string
db storage.Database db storage.Database
queues *queue.OutgoingQueues queues *queue.OutgoingQueues
topic string topic string
@ -66,20 +66,17 @@ func NewOutputRoomEventConsumer(
// Start consuming from room servers // Start consuming from room servers
func (s *OutputRoomEventConsumer) Start() error { func (s *OutputRoomEventConsumer) Start() error {
_, err := s.jetstream.Subscribe( return jetstream.JetStreamConsumer(
s.topic, s.onMessage, s.durable, s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
nats.DeliverAll(), nats.DeliverAll(), nats.ManualAck(),
nats.ManualAck(),
) )
return err
} }
// onMessage is called when the federation server receives a new event from the room server output log. // onMessage is called when the federation server receives a new event from the room server output log.
// It is unsafe to call this with messages for the same room in multiple gorountines // It is unsafe to call this with messages for the same room in multiple gorountines
// because updates it will likely fail with a types.EventIDMismatchError when it // because updates it will likely fail with a types.EventIDMismatchError when it
// realises that it cannot update the room state using the deltas. // realises that it cannot update the room state using the deltas.
func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) { func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
// Parse out the event JSON // Parse out the event JSON
var output api.OutputEvent var output api.OutputEvent
if err := json.Unmarshal(msg.Data, &output); err != nil { if err := json.Unmarshal(msg.Data, &output); err != nil {
@ -117,6 +114,11 @@ func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) {
} }
} }
case api.OutputTypeNewInviteEvent:
log.WithField("type", output.Type).Debug(
"received new invite, send device keys",
)
case api.OutputTypeNewInboundPeek: case api.OutputTypeNewInboundPeek:
if err := s.processInboundPeek(*output.NewInboundPeek); err != nil { if err := s.processInboundPeek(*output.NewInboundPeek); err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
@ -125,10 +127,7 @@ func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) {
}).Panicf("roomserver output log: remote peek event failure") }).Panicf("roomserver output log: remote peek event failure")
return false return false
} }
case api.OutputTypeNewInviteEvent:
log.WithField("type", output.Type).Debug(
"received new invite, send device keys",
)
default: default:
log.WithField("type", output.Type).Debug( log.WithField("type", output.Type).Debug(
"roomserver output log: ignoring unknown output type", "roomserver output log: ignoring unknown output type",
@ -136,7 +135,6 @@ func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) {
} }
return true return true
})
} }
// processInboundPeek starts tracking a new federated inbound peek (replacing the existing one if any) // processInboundPeek starts tracking a new federated inbound peek (replacing the existing one if any)

View file

@ -34,6 +34,7 @@ type OutputCrossSigningKeyUpdateConsumer struct {
keyAPI api.KeyInternalAPI keyAPI api.KeyInternalAPI
serverName string serverName string
jetstream nats.JetStreamContext jetstream nats.JetStreamContext
durable string
topic string topic string
} }
@ -52,6 +53,7 @@ func NewOutputCrossSigningKeyUpdateConsumer(
ctx: process.Context(), ctx: process.Context(),
keyDB: keyDB, keyDB: keyDB,
jetstream: js, jetstream: js,
durable: cfg.Global.JetStream.Durable("KeyServerCrossSigningConsumer"),
topic: cfg.Global.JetStream.TopicFor(jetstream.OutputKeyChangeEvent), topic: cfg.Global.JetStream.TopicFor(jetstream.OutputKeyChangeEvent),
keyAPI: keyAPI, keyAPI: keyAPI,
serverName: string(cfg.Global.ServerName), serverName: string(cfg.Global.ServerName),
@ -61,16 +63,15 @@ func NewOutputCrossSigningKeyUpdateConsumer(
} }
func (s *OutputCrossSigningKeyUpdateConsumer) Start() error { func (s *OutputCrossSigningKeyUpdateConsumer) Start() error {
_, err := s.jetstream.Subscribe( return jetstream.JetStreamConsumer(
s.topic, s.onMessage, s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
nats.DeliverAll(), nats.ManualAck(),
) )
return err
} }
// onMessage is called in response to a message received on the // onMessage is called in response to a message received on the
// key change events topic from the key server. // key change events topic from the key server.
func (t *OutputCrossSigningKeyUpdateConsumer) onMessage(msg *nats.Msg) { func (t *OutputCrossSigningKeyUpdateConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
var m api.DeviceMessage var m api.DeviceMessage
if err := json.Unmarshal(msg.Data, &m); err != nil { if err := json.Unmarshal(msg.Data, &m); err != nil {
logrus.WithError(err).Errorf("failed to read device message from key change topic") logrus.WithError(err).Errorf("failed to read device message from key change topic")
@ -87,7 +88,6 @@ func (t *OutputCrossSigningKeyUpdateConsumer) onMessage(msg *nats.Msg) {
default: default:
return true return true
} }
})
} }
func (s *OutputCrossSigningKeyUpdateConsumer) onCrossSigningMessage(m api.DeviceMessage) bool { func (s *OutputCrossSigningKeyUpdateConsumer) onCrossSigningMessage(m api.DeviceMessage) bool {

View file

@ -41,7 +41,7 @@ type RoomserverInternalAPI struct {
fsAPI fsAPI.FederationInternalAPI fsAPI fsAPI.FederationInternalAPI
asAPI asAPI.AppServiceQueryAPI asAPI asAPI.AppServiceQueryAPI
JetStream nats.JetStreamContext JetStream nats.JetStreamContext
Durable nats.SubOpt Durable string
InputRoomEventTopic string // JetStream topic for new input room events InputRoomEventTopic string // JetStream topic for new input room events
OutputRoomEventTopic string // JetStream topic for new output room events OutputRoomEventTopic string // JetStream topic for new output room events
PerspectiveServerNames []gomatrixserverlib.ServerName PerspectiveServerNames []gomatrixserverlib.ServerName
@ -87,7 +87,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.FederationInternalA
InputRoomEventTopic: r.InputRoomEventTopic, InputRoomEventTopic: r.InputRoomEventTopic,
OutputRoomEventTopic: r.OutputRoomEventTopic, OutputRoomEventTopic: r.OutputRoomEventTopic,
JetStream: r.JetStream, JetStream: r.JetStream,
Durable: r.Durable, Durable: nats.Durable(r.Durable),
ServerName: r.Cfg.Matrix.ServerName, ServerName: r.Cfg.Matrix.ServerName,
FSAPI: fsAPI, FSAPI: fsAPI,
KeyRing: keyRing, KeyRing: keyRing,

View file

@ -27,6 +27,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -54,18 +55,23 @@ func (r *Inviter) PerformInvite(
return nil, fmt.Errorf("failed to load RoomInfo: %w", err) return nil, fmt.Errorf("failed to load RoomInfo: %w", err)
} }
log.WithFields(log.Fields{
"event_id": event.EventID(),
"room_id": roomID,
"room_version": req.RoomVersion,
"target_user_id": targetUserID,
"room_info_exists": info != nil,
}).Debug("processing invite event")
_, domain, _ := gomatrixserverlib.SplitID('@', targetUserID) _, domain, _ := gomatrixserverlib.SplitID('@', targetUserID)
isTargetLocal := domain == r.Cfg.Matrix.ServerName isTargetLocal := domain == r.Cfg.Matrix.ServerName
isOriginLocal := event.Origin() == r.Cfg.Matrix.ServerName isOriginLocal := event.Origin() == r.Cfg.Matrix.ServerName
logger := util.GetLogger(ctx).WithFields(map[string]interface{}{
"inviter": event.Sender(),
"invitee": *event.StateKey(),
"room_id": roomID,
"event_id": event.EventID(),
})
logger.WithFields(log.Fields{
"room_version": req.RoomVersion,
"room_info_exists": info != nil,
"target_local": isTargetLocal,
"origin_local": isOriginLocal,
}).Debug("processing invite event")
inviteState := req.InviteRoomState inviteState := req.InviteRoomState
if len(inviteState) == 0 && info != nil { if len(inviteState) == 0 && info != nil {
var is []gomatrixserverlib.InviteV2StrippedState var is []gomatrixserverlib.InviteV2StrippedState
@ -122,23 +128,49 @@ func (r *Inviter) PerformInvite(
Code: api.PerformErrorNotAllowed, Code: api.PerformErrorNotAllowed,
Msg: "User is already joined to room", Msg: "User is already joined to room",
} }
logger.Debugf("user already joined")
return nil, nil return nil, nil
} }
if isOriginLocal { if !isOriginLocal {
// The invite originated over federation. Process the membership
// update, which will notify the sync API etc about the incoming
// invite. We do NOT send an InputRoomEvent for the invite as it
// will never pass auth checks due to lacking room state, but we
// still need to tell the client about the invite so we can accept
// it, hence we return an output event to send to the sync api.
updater, err := r.DB.MembershipUpdater(ctx, roomID, targetUserID, isTargetLocal, req.RoomVersion)
if err != nil {
return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err)
}
unwrapped := event.Unwrap()
outputUpdates, err := helpers.UpdateToInviteMembership(updater, unwrapped, nil, req.Event.RoomVersion)
if err != nil {
return nil, fmt.Errorf("updateToInviteMembership: %w", err)
}
if err = updater.Commit(); err != nil {
return nil, fmt.Errorf("updater.Commit: %w", err)
}
logger.Debugf("updated membership to invite and sending invite OutputEvent")
return outputUpdates, nil
}
// The invite originated locally. Therefore we have a responsibility to // The invite originated locally. Therefore we have a responsibility to
// try and see if the user is allowed to make this invite. We can't do // try and see if the user is allowed to make this invite. We can't do
// this for invites coming in over federation - we have to take those on // this for invites coming in over federation - we have to take those on
// trust. // trust.
_, err = helpers.CheckAuthEvents(ctx, r.DB, event, event.AuthEventIDs()) _, err = helpers.CheckAuthEvents(ctx, r.DB, event, event.AuthEventIDs())
if err != nil { if err != nil {
log.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error( logger.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error(
"processInviteEvent.checkAuthEvents failed for event", "processInviteEvent.checkAuthEvents failed for event",
) )
res.Error = &api.PerformError{ res.Error = &api.PerformError{
Msg: err.Error(), Msg: err.Error(),
Code: api.PerformErrorNotAllowed, Code: api.PerformErrorNotAllowed,
} }
return nil, nil
} }
// If the invite originated from us and the target isn't local then we // If the invite originated from us and the target isn't local then we
@ -157,16 +189,18 @@ func (r *Inviter) PerformInvite(
Msg: err.Error(), Msg: err.Error(),
Code: api.PerformErrorNotAllowed, Code: api.PerformErrorNotAllowed,
} }
log.WithError(err).WithField("event_id", event.EventID()).Error("r.FSAPI.PerformInvite failed") logger.WithError(err).WithField("event_id", event.EventID()).Error("r.FSAPI.PerformInvite failed")
return nil, nil return nil, nil
} }
event = fsRes.Event event = fsRes.Event
logger.Debugf("Federated PerformInvite success with event ID %s", event.EventID())
} }
// Send the invite event to the roomserver input stream. This will // Send the invite event to the roomserver input stream. This will
// notify existing users in the room about the invite, update the // notify existing users in the room about the invite, update the
// membership table and ensure that the event is ready and available // membership table and ensure that the event is ready and available
// to use as an auth event when accepting the invite. // to use as an auth event when accepting the invite.
// It will NOT notify the invitee of this invite.
inputReq := &api.InputRoomEventsRequest{ inputReq := &api.InputRoomEventsRequest{
InputRoomEvents: []api.InputRoomEvent{ InputRoomEvents: []api.InputRoomEvent{
{ {
@ -184,31 +218,12 @@ func (r *Inviter) PerformInvite(
Msg: fmt.Sprintf("r.InputRoomEvents: %s", err.Error()), Msg: fmt.Sprintf("r.InputRoomEvents: %s", err.Error()),
Code: api.PerformErrorNotAllowed, Code: api.PerformErrorNotAllowed,
} }
log.WithError(err).WithField("event_id", event.EventID()).Error("r.InputRoomEvents failed") logger.WithError(err).WithField("event_id", event.EventID()).Error("r.InputRoomEvents failed")
return nil, nil return nil, nil
} }
} else {
// The invite originated over federation. Process the membership
// update, which will notify the sync API etc about the incoming
// invite.
updater, err := r.DB.MembershipUpdater(ctx, roomID, targetUserID, isTargetLocal, req.RoomVersion)
if err != nil {
return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err)
}
unwrapped := event.Unwrap()
outputUpdates, err := helpers.UpdateToInviteMembership(updater, unwrapped, nil, req.Event.RoomVersion)
if err != nil {
return nil, fmt.Errorf("updateToInviteMembership: %w", err)
}
if err = updater.Commit(); err != nil {
return nil, fmt.Errorf("updater.Commit: %w", err)
}
return outputUpdates, nil
}
// Don't notify the sync api of this event in the same way as a federated invite so the invitee
// gets the invite, as the roomserver will do this when it processes the m.room.member invite.
return nil, nil return nil, nil
} }

View file

@ -2,8 +2,6 @@ package config
import ( import (
"fmt" "fmt"
"github.com/nats-io/nats.go"
) )
type JetStream struct { type JetStream struct {
@ -25,8 +23,8 @@ func (c *JetStream) TopicFor(name string) string {
return fmt.Sprintf("%s%s", c.TopicPrefix, name) return fmt.Sprintf("%s%s", c.TopicPrefix, name)
} }
func (c *JetStream) Durable(name string) nats.SubOpt { func (c *JetStream) Durable(name string) string {
return nats.Durable(c.TopicFor(name)) return c.TopicFor(name)
} }
func (c *JetStream) Defaults(generate bool) { func (c *JetStream) Defaults(generate bool) {

View file

@ -1,12 +1,81 @@
package jetstream package jetstream
import "github.com/nats-io/nats.go" import (
"context"
"fmt"
func WithJetStreamMessage(msg *nats.Msg, f func(msg *nats.Msg) bool) { "github.com/nats-io/nats.go"
_ = msg.InProgress() "github.com/sirupsen/logrus"
if f(msg) { )
_ = msg.Ack()
} else { func JetStreamConsumer(
_ = msg.Nak() ctx context.Context, js nats.JetStreamContext, subj, durable string,
f func(ctx context.Context, msg *nats.Msg) bool,
opts ...nats.SubOpt,
) error {
defer func() {
// If there are existing consumers from before they were pull
// consumers, we need to clean up the old push consumers. However,
// in order to not affect the interest-based policies, we need to
// do this *after* creating the new pull consumers, which have
// "Pull" suffixed to their name.
if _, err := js.ConsumerInfo(subj, durable); err == nil {
if err := js.DeleteConsumer(subj, durable); err != nil {
logrus.WithContext(ctx).Warnf("Failed to clean up old consumer %q", durable)
} }
}
}()
name := durable + "Pull"
sub, err := js.PullSubscribe(subj, name, opts...)
if err != nil {
return fmt.Errorf("nats.SubscribeSync: %w", err)
}
go func() {
for {
// The context behaviour here is surprising — we supply a context
// so that we can interrupt the fetch if we want, but NATS will still
// enforce its own deadline (roughly 5 seconds by default). Therefore
// it is our responsibility to check whether our context expired or
// not when a context error is returned. Footguns. Footguns everywhere.
msgs, err := sub.Fetch(1, nats.Context(ctx))
if err != nil {
if err == context.Canceled || err == context.DeadlineExceeded {
// Work out whether it was the JetStream context that expired
// or whether it was our supplied context.
select {
case <-ctx.Done():
// The supplied context expired, so we want to stop the
// consumer altogether.
return
default:
// The JetStream context expired, so the fetch probably
// just timed out and we should try again.
continue
}
} else {
// Something else went wrong, so we'll panic.
logrus.WithContext(ctx).WithField("subject", subj).Fatal(err)
}
}
if len(msgs) < 1 {
continue
}
msg := msgs[0]
if err = msg.InProgress(); err != nil {
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.InProgress: %w", err))
continue
}
if f(ctx, msg) {
if err = msg.Ack(); err != nil {
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.Ack: %w", err))
}
} else {
if err = msg.Nak(); err != nil {
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.Nak: %w", err))
}
}
}
}()
return nil
} }

View file

@ -34,7 +34,7 @@ import (
type OutputClientDataConsumer struct { type OutputClientDataConsumer struct {
ctx context.Context ctx context.Context
jetstream nats.JetStreamContext jetstream nats.JetStreamContext
durable nats.SubOpt durable string
topic string topic string
db storage.Database db storage.Database
stream types.StreamProvider stream types.StreamProvider
@ -63,15 +63,16 @@ func NewOutputClientDataConsumer(
// Start consuming from room servers // Start consuming from room servers
func (s *OutputClientDataConsumer) Start() error { func (s *OutputClientDataConsumer) Start() error {
_, err := s.jetstream.Subscribe(s.topic, s.onMessage, s.durable) return jetstream.JetStreamConsumer(
return err s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
nats.DeliverAll(), nats.ManualAck(),
)
} }
// onMessage is called when the sync server receives a new event from the client API server output log. // onMessage is called when the sync server receives a new event from the client API server output log.
// It is not safe for this function to be called from multiple goroutines, or else the // It is not safe for this function to be called from multiple goroutines, or else the
// sync stream position may race and be incorrectly calculated. // sync stream position may race and be incorrectly calculated.
func (s *OutputClientDataConsumer) onMessage(msg *nats.Msg) { func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
// Parse out the event JSON // Parse out the event JSON
userID := msg.Header.Get(jetstream.UserID) userID := msg.Header.Get(jetstream.UserID)
var output eventutil.AccountData var output eventutil.AccountData
@ -103,5 +104,4 @@ func (s *OutputClientDataConsumer) onMessage(msg *nats.Msg) {
s.notifier.OnNewAccountData(userID, types.StreamingToken{AccountDataPosition: streamPos}) s.notifier.OnNewAccountData(userID, types.StreamingToken{AccountDataPosition: streamPos})
return true return true
})
} }

View file

@ -34,7 +34,7 @@ import (
type OutputReceiptEventConsumer struct { type OutputReceiptEventConsumer struct {
ctx context.Context ctx context.Context
jetstream nats.JetStreamContext jetstream nats.JetStreamContext
durable nats.SubOpt durable string
topic string topic string
db storage.Database db storage.Database
stream types.StreamProvider stream types.StreamProvider
@ -64,12 +64,13 @@ func NewOutputReceiptEventConsumer(
// Start consuming from EDU api // Start consuming from EDU api
func (s *OutputReceiptEventConsumer) Start() error { func (s *OutputReceiptEventConsumer) Start() error {
_, err := s.jetstream.Subscribe(s.topic, s.onMessage, s.durable) return jetstream.JetStreamConsumer(
return err s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
nats.DeliverAll(), nats.ManualAck(),
)
} }
func (s *OutputReceiptEventConsumer) onMessage(msg *nats.Msg) { func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
var output api.OutputReceiptEvent var output api.OutputReceiptEvent
if err := json.Unmarshal(msg.Data, &output); err != nil { if err := json.Unmarshal(msg.Data, &output); err != nil {
// If the message was invalid, log it and move on to the next message in the stream // If the message was invalid, log it and move on to the next message in the stream
@ -95,5 +96,4 @@ func (s *OutputReceiptEventConsumer) onMessage(msg *nats.Msg) {
s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos}) s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos})
return true return true
})
} }

View file

@ -36,7 +36,7 @@ import (
type OutputSendToDeviceEventConsumer struct { type OutputSendToDeviceEventConsumer struct {
ctx context.Context ctx context.Context
jetstream nats.JetStreamContext jetstream nats.JetStreamContext
durable nats.SubOpt durable string
topic string topic string
db storage.Database db storage.Database
serverName gomatrixserverlib.ServerName // our server name serverName gomatrixserverlib.ServerName // our server name
@ -68,12 +68,13 @@ func NewOutputSendToDeviceEventConsumer(
// Start consuming from EDU api // Start consuming from EDU api
func (s *OutputSendToDeviceEventConsumer) Start() error { func (s *OutputSendToDeviceEventConsumer) Start() error {
_, err := s.jetstream.Subscribe(s.topic, s.onMessage, s.durable) return jetstream.JetStreamConsumer(
return err s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
nats.DeliverAll(), nats.ManualAck(),
)
} }
func (s *OutputSendToDeviceEventConsumer) onMessage(msg *nats.Msg) { func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
var output api.OutputSendToDeviceEvent var output api.OutputSendToDeviceEvent
if err := json.Unmarshal(msg.Data, &output); err != nil { if err := json.Unmarshal(msg.Data, &output); err != nil {
// If the message was invalid, log it and move on to the next message in the stream // If the message was invalid, log it and move on to the next message in the stream
@ -115,5 +116,4 @@ func (s *OutputSendToDeviceEventConsumer) onMessage(msg *nats.Msg) {
) )
return true return true
})
} }

View file

@ -35,7 +35,7 @@ import (
type OutputTypingEventConsumer struct { type OutputTypingEventConsumer struct {
ctx context.Context ctx context.Context
jetstream nats.JetStreamContext jetstream nats.JetStreamContext
durable nats.SubOpt durable string
topic string topic string
eduCache *cache.EDUCache eduCache *cache.EDUCache
stream types.StreamProvider stream types.StreamProvider
@ -66,12 +66,13 @@ func NewOutputTypingEventConsumer(
// Start consuming from EDU api // Start consuming from EDU api
func (s *OutputTypingEventConsumer) Start() error { func (s *OutputTypingEventConsumer) Start() error {
_, err := s.jetstream.Subscribe(s.topic, s.onMessage, s.durable) return jetstream.JetStreamConsumer(
return err s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
nats.DeliverAll(), nats.ManualAck(),
)
} }
func (s *OutputTypingEventConsumer) onMessage(msg *nats.Msg) { func (s *OutputTypingEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
var output api.OutputTypingEvent var output api.OutputTypingEvent
if err := json.Unmarshal(msg.Data, &output); err != nil { if err := json.Unmarshal(msg.Data, &output); err != nil {
// If the message was invalid, log it and move on to the next message in the stream // If the message was invalid, log it and move on to the next message in the stream
@ -102,5 +103,4 @@ func (s *OutputTypingEventConsumer) onMessage(msg *nats.Msg) {
s.notifier.OnNewTyping(output.Event.RoomID, types.StreamingToken{TypingPosition: typingPos}) s.notifier.OnNewTyping(output.Event.RoomID, types.StreamingToken{TypingPosition: typingPos})
return true return true
})
} }

View file

@ -21,6 +21,7 @@ import (
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/notifier"
@ -35,6 +36,7 @@ import (
type OutputKeyChangeEventConsumer struct { type OutputKeyChangeEventConsumer struct {
ctx context.Context ctx context.Context
jetstream nats.JetStreamContext jetstream nats.JetStreamContext
durable string
topic string topic string
db storage.Database db storage.Database
notifier *notifier.Notifier notifier *notifier.Notifier
@ -48,7 +50,7 @@ type OutputKeyChangeEventConsumer struct {
// Call Start() to begin consuming from the key server. // Call Start() to begin consuming from the key server.
func NewOutputKeyChangeEventConsumer( func NewOutputKeyChangeEventConsumer(
process *process.ProcessContext, process *process.ProcessContext,
serverName gomatrixserverlib.ServerName, cfg *config.SyncAPI,
topic string, topic string,
js nats.JetStreamContext, js nats.JetStreamContext,
keyAPI api.KeyInternalAPI, keyAPI api.KeyInternalAPI,
@ -60,9 +62,10 @@ func NewOutputKeyChangeEventConsumer(
s := &OutputKeyChangeEventConsumer{ s := &OutputKeyChangeEventConsumer{
ctx: process.Context(), ctx: process.Context(),
jetstream: js, jetstream: js,
durable: cfg.Matrix.JetStream.Durable("SyncAPIKeyChangeConsumer"),
topic: topic, topic: topic,
db: store, db: store,
serverName: serverName, serverName: cfg.Matrix.ServerName,
keyAPI: keyAPI, keyAPI: keyAPI,
rsAPI: rsAPI, rsAPI: rsAPI,
notifier: notifier, notifier: notifier,
@ -74,14 +77,13 @@ func NewOutputKeyChangeEventConsumer(
// Start consuming from the key server // Start consuming from the key server
func (s *OutputKeyChangeEventConsumer) Start() error { func (s *OutputKeyChangeEventConsumer) Start() error {
_, err := s.jetstream.Subscribe( return jetstream.JetStreamConsumer(
s.topic, s.onMessage, s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
nats.DeliverAll(), nats.ManualAck(),
) )
return err
} }
func (s *OutputKeyChangeEventConsumer) onMessage(msg *nats.Msg) { func (s *OutputKeyChangeEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
var m api.DeviceMessage var m api.DeviceMessage
if err := json.Unmarshal(msg.Data, &m); err != nil { if err := json.Unmarshal(msg.Data, &m); err != nil {
logrus.WithError(err).Errorf("failed to read device message from key change topic") logrus.WithError(err).Errorf("failed to read device message from key change topic")
@ -100,8 +102,6 @@ func (s *OutputKeyChangeEventConsumer) onMessage(msg *nats.Msg) {
default: default:
return s.onDeviceKeyMessage(m, m.DeviceChangeID) return s.onDeviceKeyMessage(m, m.DeviceChangeID)
} }
})
} }
func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, deviceChangeID int64) bool { func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, deviceChangeID int64) bool {

View file

@ -38,7 +38,7 @@ type OutputRoomEventConsumer struct {
cfg *config.SyncAPI cfg *config.SyncAPI
rsAPI api.RoomserverInternalAPI rsAPI api.RoomserverInternalAPI
jetstream nats.JetStreamContext jetstream nats.JetStreamContext
durable nats.SubOpt durable string
topic string topic string
db storage.Database db storage.Database
pduStream types.StreamProvider pduStream types.StreamProvider
@ -73,19 +73,16 @@ func NewOutputRoomEventConsumer(
// Start consuming from room servers // Start consuming from room servers
func (s *OutputRoomEventConsumer) Start() error { func (s *OutputRoomEventConsumer) Start() error {
_, err := s.jetstream.Subscribe( return jetstream.JetStreamConsumer(
s.topic, s.onMessage, s.durable, s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
nats.DeliverAll(), nats.DeliverAll(), nats.ManualAck(),
nats.ManualAck(),
) )
return err
} }
// onMessage is called when the sync server receives a new event from the room server output log. // onMessage is called when the sync server receives a new event from the room server output log.
// It is not safe for this function to be called from multiple goroutines, or else the // It is not safe for this function to be called from multiple goroutines, or else the
// sync stream position may race and be incorrectly calculated. // sync stream position may race and be incorrectly calculated.
func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) { func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
// Parse out the event JSON // Parse out the event JSON
var err error var err error
var output api.OutputEvent var output api.OutputEvent
@ -131,7 +128,6 @@ func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) {
} }
return true return true
})
} }
func (s *OutputRoomEventConsumer) onRedactEvent( func (s *OutputRoomEventConsumer) onRedactEvent(

View file

@ -65,7 +65,7 @@ func AddPublicRoutes(
requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, keyAPI, rsAPI, streams, notifier) requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, keyAPI, rsAPI, streams, notifier)
keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer( keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer(
process, cfg.Matrix.ServerName, cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent), process, cfg, cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent),
js, keyAPI, rsAPI, syncDB, notifier, js, keyAPI, rsAPI, syncDB, notifier,
streams.DeviceListStreamProvider, streams.DeviceListStreamProvider,
) )