Merge branch 'master' into neilalexander/rstxn
This commit is contained in:
commit
e8f58acf03
|
@ -34,7 +34,7 @@ import (
|
|||
type OutputRoomEventConsumer struct {
|
||||
ctx context.Context
|
||||
jetstream nats.JetStreamContext
|
||||
durable nats.SubOpt
|
||||
durable string
|
||||
topic string
|
||||
asDB storage.Database
|
||||
rsAPI api.RoomserverInternalAPI
|
||||
|
@ -66,14 +66,15 @@ func NewOutputRoomEventConsumer(
|
|||
|
||||
// Start consuming from room servers
|
||||
func (s *OutputRoomEventConsumer) Start() error {
|
||||
_, err := s.jetstream.Subscribe(s.topic, s.onMessage, s.durable)
|
||||
return err
|
||||
return jetstream.JetStreamConsumer(
|
||||
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
|
||||
nats.DeliverAll(), nats.ManualAck(),
|
||||
)
|
||||
}
|
||||
|
||||
// onMessage is called when the appservice component receives a new event from
|
||||
// the room server output log.
|
||||
func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) {
|
||||
jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
|
||||
func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
||||
// Parse out the event JSON
|
||||
var output api.OutputEvent
|
||||
if err := json.Unmarshal(msg.Data, &output); err != nil {
|
||||
|
@ -96,7 +97,6 @@ func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) {
|
|||
}
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// filterRoomserverEvents takes in events and decides whether any of them need
|
||||
|
|
|
@ -281,7 +281,7 @@ func (m *DendriteMonolith) Start() {
|
|||
cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk))
|
||||
cfg.Global.PrivateKey = sk
|
||||
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
|
||||
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("file:%s/%s", m.StorageDirectory, prefix))
|
||||
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/%s", m.StorageDirectory, prefix))
|
||||
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-account.db", m.StorageDirectory, prefix))
|
||||
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-device.db", m.StorageDirectory, prefix))
|
||||
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-mediaapi.db", m.CacheDirectory, prefix))
|
||||
|
|
|
@ -86,7 +86,7 @@ func (m *DendriteMonolith) Start() {
|
|||
cfg.Global.ServerName = gomatrixserverlib.ServerName(ygg.DerivedServerName())
|
||||
cfg.Global.PrivateKey = ygg.PrivateKey()
|
||||
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
|
||||
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("file:%s/", m.StorageDirectory))
|
||||
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", m.StorageDirectory))
|
||||
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-account.db", m.StorageDirectory))
|
||||
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-device.db", m.StorageDirectory))
|
||||
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-mediaapi.db", m.StorageDirectory))
|
||||
|
|
|
@ -12,10 +12,14 @@ COPY . .
|
|||
RUN go build ./cmd/dendrite-monolith-server
|
||||
RUN go build ./cmd/generate-keys
|
||||
RUN go build ./cmd/generate-config
|
||||
RUN ./generate-config --ci > dendrite.yaml
|
||||
RUN ./generate-keys --private-key matrix_key.pem --tls-cert server.crt --tls-key server.key
|
||||
RUN ./generate-keys --private-key matrix_key.pem
|
||||
|
||||
ENV SERVER_NAME=localhost
|
||||
EXPOSE 8008 8448
|
||||
|
||||
CMD sed -i "s/server_name: localhost/server_name: ${SERVER_NAME}/g" dendrite.yaml && ./dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml
|
||||
# At runtime, generate TLS cert based on the CA now mounted at /ca
|
||||
# At runtime, replace the SERVER_NAME with what we are told
|
||||
CMD ./generate-keys --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /ca/ca.crt --tls-authority-key /ca/ca.key && \
|
||||
./generate-config -server $SERVER_NAME --ci > dendrite.yaml && \
|
||||
cp /ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \
|
||||
./dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml
|
||||
|
|
|
@ -83,7 +83,7 @@ func main() {
|
|||
if *defaultsForCI {
|
||||
cfg.AppServiceAPI.DisableTLSValidation = true
|
||||
cfg.ClientAPI.RateLimiting.Enabled = false
|
||||
cfg.FederationAPI.DisableTLSValidation = true
|
||||
cfg.FederationAPI.DisableTLSValidation = false
|
||||
// don't hit matrix.org when running tests!!!
|
||||
cfg.FederationAPI.KeyPerspectives = config.KeyPerspectives{}
|
||||
cfg.MSCs.MSCs = []string{"msc2836", "msc2946", "msc2444", "msc2753"}
|
||||
|
|
|
@ -35,6 +35,9 @@ var (
|
|||
tlsCertFile = flag.String("tls-cert", "", "An X509 certificate file to generate for use for TLS")
|
||||
tlsKeyFile = flag.String("tls-key", "", "An RSA private key file to generate for use for TLS")
|
||||
privateKeyFile = flag.String("private-key", "", "An Ed25519 private key to generate for use for object signing")
|
||||
authorityCertFile = flag.String("tls-authority-cert", "", "Optional: Create TLS certificate/keys based on this CA authority. Useful for integration testing.")
|
||||
authorityKeyFile = flag.String("tls-authority-key", "", "Optional: Create TLS certificate/keys based on this CA authority. Useful for integration testing.")
|
||||
serverName = flag.String("server", "", "Optional: Create TLS certificate/keys with this domain name set. Useful for integration testing.")
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
@ -54,9 +57,16 @@ func main() {
|
|||
if *tlsCertFile == "" || *tlsKeyFile == "" {
|
||||
log.Fatal("Zero or both of --tls-key and --tls-cert must be supplied")
|
||||
}
|
||||
if *authorityCertFile == "" && *authorityKeyFile == "" {
|
||||
if err := test.NewTLSKey(*tlsKeyFile, *tlsCertFile); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
} else {
|
||||
// generate the TLS cert/key based on the authority given.
|
||||
if err := test.NewTLSKeyWithAuthority(*serverName, *tlsKeyFile, *tlsCertFile, *authorityKeyFile, *authorityCertFile); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
fmt.Printf("Created TLS cert file: %s\n", *tlsCertFile)
|
||||
fmt.Printf("Created TLS key file: %s\n", *tlsKeyFile)
|
||||
}
|
||||
|
|
|
@ -34,7 +34,7 @@ import (
|
|||
type OutputEDUConsumer struct {
|
||||
ctx context.Context
|
||||
jetstream nats.JetStreamContext
|
||||
durable nats.SubOpt
|
||||
durable string
|
||||
db storage.Database
|
||||
queues *queue.OutgoingQueues
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
|
@ -66,13 +66,22 @@ func NewOutputEDUConsumer(
|
|||
|
||||
// Start consuming from EDU servers
|
||||
func (t *OutputEDUConsumer) Start() error {
|
||||
if _, err := t.jetstream.Subscribe(t.typingTopic, t.onTypingEvent, t.durable); err != nil {
|
||||
if err := jetstream.JetStreamConsumer(
|
||||
t.ctx, t.jetstream, t.typingTopic, t.durable, t.onTypingEvent,
|
||||
nats.DeliverAll(), nats.ManualAck(),
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := t.jetstream.Subscribe(t.sendToDeviceTopic, t.onSendToDeviceEvent, t.durable); err != nil {
|
||||
if err := jetstream.JetStreamConsumer(
|
||||
t.ctx, t.jetstream, t.sendToDeviceTopic, t.durable, t.onSendToDeviceEvent,
|
||||
nats.DeliverAll(), nats.ManualAck(),
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := t.jetstream.Subscribe(t.receiptTopic, t.onReceiptEvent, t.durable); err != nil {
|
||||
if err := jetstream.JetStreamConsumer(
|
||||
t.ctx, t.jetstream, t.receiptTopic, t.durable, t.onReceiptEvent,
|
||||
nats.DeliverAll(), nats.ManualAck(),
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
@ -80,9 +89,8 @@ func (t *OutputEDUConsumer) Start() error {
|
|||
|
||||
// onSendToDeviceEvent is called in response to a message received on the
|
||||
// send-to-device events topic from the EDU server.
|
||||
func (t *OutputEDUConsumer) onSendToDeviceEvent(msg *nats.Msg) {
|
||||
func (t *OutputEDUConsumer) onSendToDeviceEvent(ctx context.Context, msg *nats.Msg) bool {
|
||||
// Extract the send-to-device event from msg.
|
||||
jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
|
||||
var ote api.OutputSendToDeviceEvent
|
||||
if err := json.Unmarshal(msg.Data, &ote); err != nil {
|
||||
log.WithError(err).Errorf("eduserver output log: message parse failed (expected send-to-device)")
|
||||
|
@ -133,13 +141,11 @@ func (t *OutputEDUConsumer) onSendToDeviceEvent(msg *nats.Msg) {
|
|||
}
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// onTypingEvent is called in response to a message received on the typing
|
||||
// events topic from the EDU server.
|
||||
func (t *OutputEDUConsumer) onTypingEvent(msg *nats.Msg) {
|
||||
jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
|
||||
func (t *OutputEDUConsumer) onTypingEvent(ctx context.Context, msg *nats.Msg) bool {
|
||||
// Extract the typing event from msg.
|
||||
var ote api.OutputTypingEvent
|
||||
if err := json.Unmarshal(msg.Data, &ote); err != nil {
|
||||
|
@ -160,7 +166,7 @@ func (t *OutputEDUConsumer) onTypingEvent(msg *nats.Msg) {
|
|||
return true
|
||||
}
|
||||
|
||||
joined, err := t.db.GetJoinedHosts(t.ctx, ote.Event.RoomID)
|
||||
joined, err := t.db.GetJoinedHosts(ctx, ote.Event.RoomID)
|
||||
if err != nil {
|
||||
log.WithError(err).WithField("room_id", ote.Event.RoomID).Error("failed to get joined hosts for room")
|
||||
return false
|
||||
|
@ -187,13 +193,11 @@ func (t *OutputEDUConsumer) onTypingEvent(msg *nats.Msg) {
|
|||
}
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// onReceiptEvent is called in response to a message received on the receipt
|
||||
// events topic from the EDU server.
|
||||
func (t *OutputEDUConsumer) onReceiptEvent(msg *nats.Msg) {
|
||||
jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
|
||||
func (t *OutputEDUConsumer) onReceiptEvent(ctx context.Context, msg *nats.Msg) bool {
|
||||
// Extract the typing event from msg.
|
||||
var receipt api.OutputReceiptEvent
|
||||
if err := json.Unmarshal(msg.Data, &receipt); err != nil {
|
||||
|
@ -212,7 +216,7 @@ func (t *OutputEDUConsumer) onReceiptEvent(msg *nats.Msg) {
|
|||
return true
|
||||
}
|
||||
|
||||
joined, err := t.db.GetJoinedHosts(t.ctx, receipt.RoomID)
|
||||
joined, err := t.db.GetJoinedHosts(ctx, receipt.RoomID)
|
||||
if err != nil {
|
||||
log.WithError(err).WithField("room_id", receipt.RoomID).Error("failed to get joined hosts for room")
|
||||
return false
|
||||
|
@ -250,5 +254,4 @@ func (t *OutputEDUConsumer) onReceiptEvent(msg *nats.Msg) {
|
|||
}
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
|
|
@ -37,7 +37,7 @@ type OutputRoomEventConsumer struct {
|
|||
cfg *config.FederationAPI
|
||||
rsAPI api.RoomserverInternalAPI
|
||||
jetstream nats.JetStreamContext
|
||||
durable nats.SubOpt
|
||||
durable string
|
||||
db storage.Database
|
||||
queues *queue.OutgoingQueues
|
||||
topic string
|
||||
|
@ -66,20 +66,17 @@ func NewOutputRoomEventConsumer(
|
|||
|
||||
// Start consuming from room servers
|
||||
func (s *OutputRoomEventConsumer) Start() error {
|
||||
_, err := s.jetstream.Subscribe(
|
||||
s.topic, s.onMessage, s.durable,
|
||||
nats.DeliverAll(),
|
||||
nats.ManualAck(),
|
||||
return jetstream.JetStreamConsumer(
|
||||
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
|
||||
nats.DeliverAll(), nats.ManualAck(),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// onMessage is called when the federation server receives a new event from the room server output log.
|
||||
// It is unsafe to call this with messages for the same room in multiple gorountines
|
||||
// because updates it will likely fail with a types.EventIDMismatchError when it
|
||||
// realises that it cannot update the room state using the deltas.
|
||||
func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) {
|
||||
jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
|
||||
func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
||||
// Parse out the event JSON
|
||||
var output api.OutputEvent
|
||||
if err := json.Unmarshal(msg.Data, &output); err != nil {
|
||||
|
@ -133,7 +130,6 @@ func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) {
|
|||
}
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// processInboundPeek starts tracking a new federated inbound peek (replacing the existing one if any)
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math/big"
|
||||
|
@ -158,11 +159,10 @@ func NewMatrixKey(matrixKeyPath string) (err error) {
|
|||
|
||||
const certificateDuration = time.Hour * 24 * 365 * 10
|
||||
|
||||
// NewTLSKey generates a new RSA TLS key and certificate and writes it to a file.
|
||||
func NewTLSKey(tlsKeyPath, tlsCertPath string) error {
|
||||
func generateTLSTemplate(dnsNames []string) (*rsa.PrivateKey, *x509.Certificate, error) {
|
||||
priv, err := rsa.GenerateKey(rand.Reader, 4096)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
notBefore := time.Now()
|
||||
|
@ -170,7 +170,7 @@ func NewTLSKey(tlsKeyPath, tlsCertPath string) error {
|
|||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
template := x509.Certificate{
|
||||
|
@ -180,20 +180,21 @@ func NewTLSKey(tlsKeyPath, tlsCertPath string) error {
|
|||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
DNSNames: dnsNames,
|
||||
}
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return priv, &template, nil
|
||||
}
|
||||
|
||||
func writeCertificate(tlsCertPath string, derBytes []byte) error {
|
||||
certOut, err := os.Create(tlsCertPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer certOut.Close() // nolint: errcheck
|
||||
if err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
|
||||
return err
|
||||
}
|
||||
return pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
||||
}
|
||||
|
||||
func writePrivateKey(tlsKeyPath string, priv *rsa.PrivateKey) error {
|
||||
keyOut, err := os.OpenFile(tlsKeyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -205,3 +206,73 @@ func NewTLSKey(tlsKeyPath, tlsCertPath string) error {
|
|||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// NewTLSKey generates a new RSA TLS key and certificate and writes it to a file.
|
||||
func NewTLSKey(tlsKeyPath, tlsCertPath string) error {
|
||||
priv, template, err := generateTLSTemplate(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Self-signed certificate: template == parent
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = writeCertificate(tlsCertPath, derBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
return writePrivateKey(tlsKeyPath, priv)
|
||||
}
|
||||
|
||||
func NewTLSKeyWithAuthority(serverName, tlsKeyPath, tlsCertPath, authorityKeyPath, authorityCertPath string) error {
|
||||
priv, template, err := generateTLSTemplate([]string{serverName})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// load the authority key
|
||||
dat, err := ioutil.ReadFile(authorityKeyPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
block, _ := pem.Decode([]byte(dat))
|
||||
if block == nil || block.Type != "RSA PRIVATE KEY" {
|
||||
return errors.New("authority .key is not a valid pem encoded rsa private key")
|
||||
}
|
||||
authorityPriv, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// load the authority certificate
|
||||
dat, err = ioutil.ReadFile(authorityCertPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
block, _ = pem.Decode([]byte(dat))
|
||||
if block == nil || block.Type != "CERTIFICATE" {
|
||||
return errors.New("authority .crt is not a valid pem encoded x509 cert")
|
||||
}
|
||||
var caCerts []*x509.Certificate
|
||||
caCerts, err = x509.ParseCertificates(block.Bytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(caCerts) != 1 {
|
||||
return errors.New("authority .crt contains none or more than one cert")
|
||||
}
|
||||
authorityCert := caCerts[0]
|
||||
|
||||
// Sign the new certificate using the authority's key/cert
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, template, authorityCert, &priv.PublicKey, authorityPriv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = writeCertificate(tlsCertPath, derBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
return writePrivateKey(tlsKeyPath, priv)
|
||||
}
|
||||
|
|
|
@ -41,7 +41,7 @@ type RoomserverInternalAPI struct {
|
|||
fsAPI fsAPI.FederationInternalAPI
|
||||
asAPI asAPI.AppServiceQueryAPI
|
||||
JetStream nats.JetStreamContext
|
||||
Durable nats.SubOpt
|
||||
Durable string
|
||||
InputRoomEventTopic string // JetStream topic for new input room events
|
||||
OutputRoomEventTopic string // JetStream topic for new output room events
|
||||
PerspectiveServerNames []gomatrixserverlib.ServerName
|
||||
|
@ -87,7 +87,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.FederationInternalA
|
|||
InputRoomEventTopic: r.InputRoomEventTopic,
|
||||
OutputRoomEventTopic: r.OutputRoomEventTopic,
|
||||
JetStream: r.JetStream,
|
||||
Durable: r.Durable,
|
||||
Durable: nats.Durable(r.Durable),
|
||||
ServerName: r.Cfg.Matrix.ServerName,
|
||||
FSAPI: fsAPI,
|
||||
KeyRing: keyRing,
|
||||
|
|
|
@ -27,6 +27,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
|
@ -54,18 +55,23 @@ func (r *Inviter) PerformInvite(
|
|||
return nil, fmt.Errorf("failed to load RoomInfo: %w", err)
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"event_id": event.EventID(),
|
||||
"room_id": roomID,
|
||||
"room_version": req.RoomVersion,
|
||||
"target_user_id": targetUserID,
|
||||
"room_info_exists": info != nil,
|
||||
}).Debug("processing invite event")
|
||||
|
||||
_, domain, _ := gomatrixserverlib.SplitID('@', targetUserID)
|
||||
isTargetLocal := domain == r.Cfg.Matrix.ServerName
|
||||
isOriginLocal := event.Origin() == r.Cfg.Matrix.ServerName
|
||||
|
||||
logger := util.GetLogger(ctx).WithFields(map[string]interface{}{
|
||||
"inviter": event.Sender(),
|
||||
"invitee": *event.StateKey(),
|
||||
"room_id": roomID,
|
||||
"event_id": event.EventID(),
|
||||
})
|
||||
logger.WithFields(log.Fields{
|
||||
"room_version": req.RoomVersion,
|
||||
"room_info_exists": info != nil,
|
||||
"target_local": isTargetLocal,
|
||||
"origin_local": isOriginLocal,
|
||||
}).Debug("processing invite event")
|
||||
|
||||
inviteState := req.InviteRoomState
|
||||
if len(inviteState) == 0 && info != nil {
|
||||
var is []gomatrixserverlib.InviteV2StrippedState
|
||||
|
@ -122,23 +128,49 @@ func (r *Inviter) PerformInvite(
|
|||
Code: api.PerformErrorNotAllowed,
|
||||
Msg: "User is already joined to room",
|
||||
}
|
||||
logger.Debugf("user already joined")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if isOriginLocal {
|
||||
if !isOriginLocal {
|
||||
// The invite originated over federation. Process the membership
|
||||
// update, which will notify the sync API etc about the incoming
|
||||
// invite. We do NOT send an InputRoomEvent for the invite as it
|
||||
// will never pass auth checks due to lacking room state, but we
|
||||
// still need to tell the client about the invite so we can accept
|
||||
// it, hence we return an output event to send to the sync api.
|
||||
updater, err := r.DB.MembershipUpdater(ctx, roomID, targetUserID, isTargetLocal, req.RoomVersion)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err)
|
||||
}
|
||||
|
||||
unwrapped := event.Unwrap()
|
||||
outputUpdates, err := helpers.UpdateToInviteMembership(updater, unwrapped, nil, req.Event.RoomVersion)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("updateToInviteMembership: %w", err)
|
||||
}
|
||||
|
||||
if err = updater.Commit(); err != nil {
|
||||
return nil, fmt.Errorf("updater.Commit: %w", err)
|
||||
}
|
||||
logger.Debugf("updated membership to invite and sending invite OutputEvent")
|
||||
return outputUpdates, nil
|
||||
}
|
||||
|
||||
// The invite originated locally. Therefore we have a responsibility to
|
||||
// try and see if the user is allowed to make this invite. We can't do
|
||||
// this for invites coming in over federation - we have to take those on
|
||||
// trust.
|
||||
_, err = helpers.CheckAuthEvents(ctx, r.DB, event, event.AuthEventIDs())
|
||||
if err != nil {
|
||||
log.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error(
|
||||
logger.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error(
|
||||
"processInviteEvent.checkAuthEvents failed for event",
|
||||
)
|
||||
res.Error = &api.PerformError{
|
||||
Msg: err.Error(),
|
||||
Code: api.PerformErrorNotAllowed,
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// If the invite originated from us and the target isn't local then we
|
||||
|
@ -157,16 +189,18 @@ func (r *Inviter) PerformInvite(
|
|||
Msg: err.Error(),
|
||||
Code: api.PerformErrorNotAllowed,
|
||||
}
|
||||
log.WithError(err).WithField("event_id", event.EventID()).Error("r.FSAPI.PerformInvite failed")
|
||||
logger.WithError(err).WithField("event_id", event.EventID()).Error("r.FSAPI.PerformInvite failed")
|
||||
return nil, nil
|
||||
}
|
||||
event = fsRes.Event
|
||||
logger.Debugf("Federated PerformInvite success with event ID %s", event.EventID())
|
||||
}
|
||||
|
||||
// Send the invite event to the roomserver input stream. This will
|
||||
// notify existing users in the room about the invite, update the
|
||||
// membership table and ensure that the event is ready and available
|
||||
// to use as an auth event when accepting the invite.
|
||||
// It will NOT notify the invitee of this invite.
|
||||
inputReq := &api.InputRoomEventsRequest{
|
||||
InputRoomEvents: []api.InputRoomEvent{
|
||||
{
|
||||
|
@ -184,31 +218,12 @@ func (r *Inviter) PerformInvite(
|
|||
Msg: fmt.Sprintf("r.InputRoomEvents: %s", err.Error()),
|
||||
Code: api.PerformErrorNotAllowed,
|
||||
}
|
||||
log.WithError(err).WithField("event_id", event.EventID()).Error("r.InputRoomEvents failed")
|
||||
logger.WithError(err).WithField("event_id", event.EventID()).Error("r.InputRoomEvents failed")
|
||||
return nil, nil
|
||||
}
|
||||
} else {
|
||||
// The invite originated over federation. Process the membership
|
||||
// update, which will notify the sync API etc about the incoming
|
||||
// invite.
|
||||
updater, err := r.DB.MembershipUpdater(ctx, roomID, targetUserID, isTargetLocal, req.RoomVersion)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err)
|
||||
}
|
||||
|
||||
unwrapped := event.Unwrap()
|
||||
outputUpdates, err := helpers.UpdateToInviteMembership(updater, unwrapped, nil, req.Event.RoomVersion)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("updateToInviteMembership: %w", err)
|
||||
}
|
||||
|
||||
if err = updater.Commit(); err != nil {
|
||||
return nil, fmt.Errorf("updater.Commit: %w", err)
|
||||
}
|
||||
|
||||
return outputUpdates, nil
|
||||
}
|
||||
|
||||
// Don't notify the sync api of this event in the same way as a federated invite so the invitee
|
||||
// gets the invite, as the roomserver will do this when it processes the m.room.member invite.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -2,8 +2,6 @@ package config
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/nats-io/nats.go"
|
||||
)
|
||||
|
||||
type JetStream struct {
|
||||
|
@ -25,8 +23,8 @@ func (c *JetStream) TopicFor(name string) string {
|
|||
return fmt.Sprintf("%s%s", c.TopicPrefix, name)
|
||||
}
|
||||
|
||||
func (c *JetStream) Durable(name string) nats.SubOpt {
|
||||
return nats.Durable(c.TopicFor(name))
|
||||
func (c *JetStream) Durable(name string) string {
|
||||
return c.TopicFor(name)
|
||||
}
|
||||
|
||||
func (c *JetStream) Defaults(generate bool) {
|
||||
|
|
|
@ -1,12 +1,81 @@
|
|||
package jetstream
|
||||
|
||||
import "github.com/nats-io/nats.go"
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
func WithJetStreamMessage(msg *nats.Msg, f func(msg *nats.Msg) bool) {
|
||||
_ = msg.InProgress()
|
||||
if f(msg) {
|
||||
_ = msg.Ack()
|
||||
} else {
|
||||
_ = msg.Nak()
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func JetStreamConsumer(
|
||||
ctx context.Context, js nats.JetStreamContext, subj, durable string,
|
||||
f func(ctx context.Context, msg *nats.Msg) bool,
|
||||
opts ...nats.SubOpt,
|
||||
) error {
|
||||
defer func() {
|
||||
// If there are existing consumers from before they were pull
|
||||
// consumers, we need to clean up the old push consumers. However,
|
||||
// in order to not affect the interest-based policies, we need to
|
||||
// do this *after* creating the new pull consumers, which have
|
||||
// "Pull" suffixed to their name.
|
||||
if _, err := js.ConsumerInfo(subj, durable); err == nil {
|
||||
if err := js.DeleteConsumer(subj, durable); err != nil {
|
||||
logrus.WithContext(ctx).Warnf("Failed to clean up old consumer %q", durable)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
name := durable + "Pull"
|
||||
sub, err := js.PullSubscribe(subj, name, opts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("nats.SubscribeSync: %w", err)
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
// The context behaviour here is surprising — we supply a context
|
||||
// so that we can interrupt the fetch if we want, but NATS will still
|
||||
// enforce its own deadline (roughly 5 seconds by default). Therefore
|
||||
// it is our responsibility to check whether our context expired or
|
||||
// not when a context error is returned. Footguns. Footguns everywhere.
|
||||
msgs, err := sub.Fetch(1, nats.Context(ctx))
|
||||
if err != nil {
|
||||
if err == context.Canceled || err == context.DeadlineExceeded {
|
||||
// Work out whether it was the JetStream context that expired
|
||||
// or whether it was our supplied context.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// The supplied context expired, so we want to stop the
|
||||
// consumer altogether.
|
||||
return
|
||||
default:
|
||||
// The JetStream context expired, so the fetch probably
|
||||
// just timed out and we should try again.
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
// Something else went wrong, so we'll panic.
|
||||
logrus.WithContext(ctx).WithField("subject", subj).Fatal(err)
|
||||
}
|
||||
}
|
||||
if len(msgs) < 1 {
|
||||
continue
|
||||
}
|
||||
msg := msgs[0]
|
||||
if err = msg.InProgress(); err != nil {
|
||||
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.InProgress: %w", err))
|
||||
continue
|
||||
}
|
||||
if f(ctx, msg) {
|
||||
if err = msg.Ack(); err != nil {
|
||||
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.Ack: %w", err))
|
||||
}
|
||||
} else {
|
||||
if err = msg.Nak(); err != nil {
|
||||
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.Nak: %w", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -34,7 +34,7 @@ import (
|
|||
type OutputClientDataConsumer struct {
|
||||
ctx context.Context
|
||||
jetstream nats.JetStreamContext
|
||||
durable nats.SubOpt
|
||||
durable string
|
||||
topic string
|
||||
db storage.Database
|
||||
stream types.StreamProvider
|
||||
|
@ -63,15 +63,16 @@ func NewOutputClientDataConsumer(
|
|||
|
||||
// Start consuming from room servers
|
||||
func (s *OutputClientDataConsumer) Start() error {
|
||||
_, err := s.jetstream.Subscribe(s.topic, s.onMessage, s.durable)
|
||||
return err
|
||||
return jetstream.JetStreamConsumer(
|
||||
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
|
||||
nats.DeliverAll(), nats.ManualAck(),
|
||||
)
|
||||
}
|
||||
|
||||
// onMessage is called when the sync server receives a new event from the client API server output log.
|
||||
// It is not safe for this function to be called from multiple goroutines, or else the
|
||||
// sync stream position may race and be incorrectly calculated.
|
||||
func (s *OutputClientDataConsumer) onMessage(msg *nats.Msg) {
|
||||
jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
|
||||
func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
||||
// Parse out the event JSON
|
||||
userID := msg.Header.Get(jetstream.UserID)
|
||||
var output eventutil.AccountData
|
||||
|
@ -103,5 +104,4 @@ func (s *OutputClientDataConsumer) onMessage(msg *nats.Msg) {
|
|||
s.notifier.OnNewAccountData(userID, types.StreamingToken{AccountDataPosition: streamPos})
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
|
|
@ -34,7 +34,7 @@ import (
|
|||
type OutputReceiptEventConsumer struct {
|
||||
ctx context.Context
|
||||
jetstream nats.JetStreamContext
|
||||
durable nats.SubOpt
|
||||
durable string
|
||||
topic string
|
||||
db storage.Database
|
||||
stream types.StreamProvider
|
||||
|
@ -64,12 +64,13 @@ func NewOutputReceiptEventConsumer(
|
|||
|
||||
// Start consuming from EDU api
|
||||
func (s *OutputReceiptEventConsumer) Start() error {
|
||||
_, err := s.jetstream.Subscribe(s.topic, s.onMessage, s.durable)
|
||||
return err
|
||||
return jetstream.JetStreamConsumer(
|
||||
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
|
||||
nats.DeliverAll(), nats.ManualAck(),
|
||||
)
|
||||
}
|
||||
|
||||
func (s *OutputReceiptEventConsumer) onMessage(msg *nats.Msg) {
|
||||
jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
|
||||
func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
||||
var output api.OutputReceiptEvent
|
||||
if err := json.Unmarshal(msg.Data, &output); err != nil {
|
||||
// If the message was invalid, log it and move on to the next message in the stream
|
||||
|
@ -95,5 +96,4 @@ func (s *OutputReceiptEventConsumer) onMessage(msg *nats.Msg) {
|
|||
s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos})
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
|
|
@ -36,7 +36,7 @@ import (
|
|||
type OutputSendToDeviceEventConsumer struct {
|
||||
ctx context.Context
|
||||
jetstream nats.JetStreamContext
|
||||
durable nats.SubOpt
|
||||
durable string
|
||||
topic string
|
||||
db storage.Database
|
||||
serverName gomatrixserverlib.ServerName // our server name
|
||||
|
@ -68,12 +68,13 @@ func NewOutputSendToDeviceEventConsumer(
|
|||
|
||||
// Start consuming from EDU api
|
||||
func (s *OutputSendToDeviceEventConsumer) Start() error {
|
||||
_, err := s.jetstream.Subscribe(s.topic, s.onMessage, s.durable)
|
||||
return err
|
||||
return jetstream.JetStreamConsumer(
|
||||
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
|
||||
nats.DeliverAll(), nats.ManualAck(),
|
||||
)
|
||||
}
|
||||
|
||||
func (s *OutputSendToDeviceEventConsumer) onMessage(msg *nats.Msg) {
|
||||
jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
|
||||
func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
||||
var output api.OutputSendToDeviceEvent
|
||||
if err := json.Unmarshal(msg.Data, &output); err != nil {
|
||||
// If the message was invalid, log it and move on to the next message in the stream
|
||||
|
@ -115,5 +116,4 @@ func (s *OutputSendToDeviceEventConsumer) onMessage(msg *nats.Msg) {
|
|||
)
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
|
|
@ -35,7 +35,7 @@ import (
|
|||
type OutputTypingEventConsumer struct {
|
||||
ctx context.Context
|
||||
jetstream nats.JetStreamContext
|
||||
durable nats.SubOpt
|
||||
durable string
|
||||
topic string
|
||||
eduCache *cache.EDUCache
|
||||
stream types.StreamProvider
|
||||
|
@ -66,12 +66,13 @@ func NewOutputTypingEventConsumer(
|
|||
|
||||
// Start consuming from EDU api
|
||||
func (s *OutputTypingEventConsumer) Start() error {
|
||||
_, err := s.jetstream.Subscribe(s.topic, s.onMessage, s.durable)
|
||||
return err
|
||||
return jetstream.JetStreamConsumer(
|
||||
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
|
||||
nats.DeliverAll(), nats.ManualAck(),
|
||||
)
|
||||
}
|
||||
|
||||
func (s *OutputTypingEventConsumer) onMessage(msg *nats.Msg) {
|
||||
jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
|
||||
func (s *OutputTypingEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
||||
var output api.OutputTypingEvent
|
||||
if err := json.Unmarshal(msg.Data, &output); err != nil {
|
||||
// If the message was invalid, log it and move on to the next message in the stream
|
||||
|
@ -102,5 +103,4 @@ func (s *OutputTypingEventConsumer) onMessage(msg *nats.Msg) {
|
|||
s.notifier.OnNewTyping(output.Event.RoomID, types.StreamingToken{TypingPosition: typingPos})
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
|
|
@ -38,7 +38,7 @@ type OutputRoomEventConsumer struct {
|
|||
cfg *config.SyncAPI
|
||||
rsAPI api.RoomserverInternalAPI
|
||||
jetstream nats.JetStreamContext
|
||||
durable nats.SubOpt
|
||||
durable string
|
||||
topic string
|
||||
db storage.Database
|
||||
pduStream types.StreamProvider
|
||||
|
@ -73,19 +73,16 @@ func NewOutputRoomEventConsumer(
|
|||
|
||||
// Start consuming from room servers
|
||||
func (s *OutputRoomEventConsumer) Start() error {
|
||||
_, err := s.jetstream.Subscribe(
|
||||
s.topic, s.onMessage, s.durable,
|
||||
nats.DeliverAll(),
|
||||
nats.ManualAck(),
|
||||
return jetstream.JetStreamConsumer(
|
||||
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
|
||||
nats.DeliverAll(), nats.ManualAck(),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// onMessage is called when the sync server receives a new event from the room server output log.
|
||||
// It is not safe for this function to be called from multiple goroutines, or else the
|
||||
// sync stream position may race and be incorrectly calculated.
|
||||
func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) {
|
||||
jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
|
||||
func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
|
||||
// Parse out the event JSON
|
||||
var err error
|
||||
var output api.OutputEvent
|
||||
|
@ -131,7 +128,6 @@ func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) {
|
|||
}
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func (s *OutputRoomEventConsumer) onRedactEvent(
|
||||
|
|
Loading…
Reference in a new issue