diff --git a/internal/transactionrequest_test.go b/internal/transactionrequest_test.go index 9d108ac32..dbb5b6ead 100644 --- a/internal/transactionrequest_test.go +++ b/internal/transactionrequest_test.go @@ -17,16 +17,23 @@ package internal import ( "context" "encoding/json" + "strconv" "testing" + "time" "github.com/matrix-org/dendrite/federationapi/producers" + keyAPI "github.com/matrix-org/dendrite/keyserver/api" rsAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" "github.com/stretchr/testify/assert" + "go.uber.org/atomic" + "gotest.tools/v3/poll" ) var ( @@ -84,7 +91,7 @@ func (r *FakeRsAPI) InputRoomEvents( } func TestEmptyTransactionRequest(t *testing.T) { - txn := NewTxnReq(&FakeRsAPI{}, nil, "", nil, nil, nil, false, []json.RawMessage{}, []gomatrixserverlib.EDU{}, "", "", "") + txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", nil, nil, nil, false, []json.RawMessage{}, []gomatrixserverlib.EDU{}, "", "", "") txnRes, jsonRes := txn.ProcessTransaction(context.Background()) assert.Nil(t, jsonRes) @@ -92,7 +99,7 @@ func TestEmptyTransactionRequest(t *testing.T) { } func TestProcessTransactionRequestPDU(t *testing.T) { - txn := NewTxnReq(&FakeRsAPI{}, nil, "", nil, nil, nil, false, []json.RawMessage{testData[0]}, []gomatrixserverlib.EDU{}, "", "", "") + txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", nil, nil, nil, false, []json.RawMessage{testData[0]}, []gomatrixserverlib.EDU{}, "", "", "") txnRes, jsonRes := txn.ProcessTransaction(context.Background()) assert.Nil(t, jsonRes) @@ -101,43 +108,388 @@ func TestProcessTransactionRequestPDU(t *testing.T) { func TestProcessTransactionRequestPDUs(t *testing.T) { keyRing := &test.NopJSONVerifier{} - txn := NewTxnReq(&FakeRsAPI{}, nil, "", keyRing, nil, nil, false, testData, []gomatrixserverlib.EDU{}, "", "", "") + txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", keyRing, nil, nil, false, testData, []gomatrixserverlib.EDU{}, "", "", "") txnRes, jsonRes := txn.ProcessTransaction(context.Background()) assert.Nil(t, jsonRes) assert.Equal(t, 1, len(txnRes.PDUs)) } -func TestProcessTransactionRequestEDU(t *testing.T) { - var err error - edu := gomatrixserverlib.EDU{Type: "m.typing"} - if edu.Content, err = json.Marshal(map[string]interface{}{ - "room_id": "!roomid:kaer.morhen", - "user_id": "@userid:kaer.morhen", - "typing": true, - }); err != nil { - t.Errorf("failed to marshal EDU JSON") - } - +func createTransactionWithEDU(edu gomatrixserverlib.EDU) (TxnReq, nats.JetStreamContext, *config.Dendrite) { cfg := &config.Dendrite{} cfg.Defaults(config.DefaultOpts{ Generate: true, Monolithic: true, }) - nats := &jetstream.NATSInstance{} - js, _ := nats.Prepare(process.NewProcessContext(), &cfg.Global.JetStream) + cfg.Global.JetStream.InMemory = true + natsInstance := &jetstream.NATSInstance{} + pc := process.NewProcessContext() + js, _ := natsInstance.Prepare(pc, &cfg.Global.JetStream) producer := &producers.SyncAPIProducer{ JetStream: js, - TopicReceiptEvent: "OutputReceiptEvent", - TopicSendToDeviceEvent: "OutputSendToDeviceEvent", - TopicTypingEvent: "OutputTypingEvent", - TopicPresenceEvent: "OutputPresenceEvent", - TopicDeviceListUpdate: "InputDeviceListUpdate", - TopicSigningKeyUpdate: "InputSigningKeyUpdate", + TopicReceiptEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputReceiptEvent), + TopicSendToDeviceEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), + TopicTypingEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputTypingEvent), + TopicPresenceEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputPresenceEvent), + TopicDeviceListUpdate: cfg.Global.JetStream.Prefixed(jetstream.InputDeviceListUpdate), + TopicSigningKeyUpdate: cfg.Global.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), Config: &cfg.FederationAPI, UserAPI: nil, } - txn := NewTxnReq(&FakeRsAPI{}, nil, "", nil, nil, producer, false, []json.RawMessage{}, []gomatrixserverlib.EDU{edu}, "", "", "") + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", keyRing, nil, producer, true, []json.RawMessage{}, []gomatrixserverlib.EDU{edu}, "kaer.morhen", "", "ourserver") + return txn, js, cfg +} + +func TestProcessTransactionRequestEDUTyping(t *testing.T) { + var err error + roomID := "!roomid:kaer.morhen" + userID := "@userid:kaer.morhen" + typing := true + edu := gomatrixserverlib.EDU{Type: "m.typing"} + if edu.Content, err = json.Marshal(map[string]interface{}{ + "room_id": roomID, + "user_id": userID, + "typing": typing, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + + txn, js, cfg := createTransactionWithEDU(edu) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + room := msg.Header.Get(jetstream.RoomID) + assert.Equal(t, roomID, room) + user := msg.Header.Get(jetstream.UserID) + assert.Equal(t, userID, user) + typ, parseErr := strconv.ParseBool(msg.Header.Get("typing")) + if parseErr != nil { + return true + } + assert.Equal(t, typing, typ) + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx, js, cfg.Global.JetStream.Prefixed(jetstream.OutputTypingEvent), + cfg.Global.JetStream.Durable("TestTypingConsumer"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUToDevice(t *testing.T) { + var err error + sender := "@userid:kaer.morhen" + messageID := "$x4MKEPRSF6OGlo0qpnsP3BfSmYX5HhVlykOsQH3ECyg" + msgType := "m.dendrite.test" + edu := gomatrixserverlib.EDU{Type: "m.direct_to_device"} + if edu.Content, err = json.Marshal(map[string]interface{}{ + "sender": sender, + "type": msgType, + "message_id": messageID, + "messages": map[string]interface{}{ + "@alice:example.org": map[string]interface{}{ + "IWHQUZUIAH": map[string]interface{}{ + "algorithm": "m.megolm.v1.aes-sha2", + "room_id": "!Cuyf34gef24t:localhost", + "session_id": "X3lUlvLELLYxeTx4yOVu6UDpasGEVO0Jbu+QFnm0cKQ", + "session_key": "AgAAAADxKHa9uFxcXzwYoNueL5Xqi69IkD4sni8LlfJL7qNBEY...", + }, + }, + }, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + + txn, js, cfg := createTransactionWithEDU(edu) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + + var output types.OutputSendToDeviceEvent + if err = json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + println(err.Error()) + return true + } + assert.Equal(t, sender, output.Sender) + assert.Equal(t, msgType, output.Type) + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx, js, cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), + cfg.Global.JetStream.Durable("TestToDevice"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUDeviceListUpdate(t *testing.T) { + var err error + deviceID := "QBUAZIFURK" + userID := "@john:example.com" + edu := gomatrixserverlib.EDU{Type: "m.device_list_update"} + if edu.Content, err = json.Marshal(map[string]interface{}{ + "device_display_name": "Mobile", + "device_id": deviceID, + "key": "value", + "keys": map[string]interface{}{ + "algorithms": []string{ + "m.olm.v1.curve25519-aes-sha2", + "m.megolm.v1.aes-sha2", + }, + "device_id": "JLAFKJWSCS", + "keys": map[string]interface{}{ + "curve25519:JLAFKJWSCS": "3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI", + "ed25519:JLAFKJWSCS": "lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI", + }, + "signatures": map[string]interface{}{ + "@alice:example.com": map[string]interface{}{ + "ed25519:JLAFKJWSCS": "dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA", + }, + }, + "user_id": "@alice:example.com", + }, + "prev_id": []int{ + 5, + }, + "stream_id": 6, + "user_id": userID, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + + txn, js, cfg := createTransactionWithEDU(edu) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + + var output gomatrixserverlib.DeviceListUpdateEvent + if err = json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + println(err.Error()) + return true + } + assert.Equal(t, userID, output.UserID) + assert.Equal(t, deviceID, output.DeviceID) + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx, js, cfg.Global.JetStream.Prefixed(jetstream.InputDeviceListUpdate), + cfg.Global.JetStream.Durable("TestDeviceListUpdate"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUReceipt(t *testing.T) { + var err error + roomID := "!some_room:example.org" + edu := gomatrixserverlib.EDU{Type: "m.receipt"} + if edu.Content, err = json.Marshal(map[string]interface{}{ + roomID: map[string]interface{}{ + "m.read": map[string]interface{}{ + "@john:kaer.morhen": map[string]interface{}{ + "data": map[string]interface{}{ + "ts": 1533358089009, + }, + "event_ids": []string{ + "$read_this_event:matrix.org", + }, + }, + }, + }, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + + txn, js, cfg := createTransactionWithEDU(edu) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + + var output types.OutputReceiptEvent + output.RoomID = msg.Header.Get(jetstream.RoomID) + assert.Equal(t, roomID, output.RoomID) + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx, js, cfg.Global.JetStream.Prefixed(jetstream.OutputReceiptEvent), + cfg.Global.JetStream.Durable("TestReceipt"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUSigningKeyUpdate(t *testing.T) { + var err error + edu := gomatrixserverlib.EDU{Type: "m.signing_key_update"} + if edu.Content, err = json.Marshal(map[string]interface{}{}); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + + txn, js, cfg := createTransactionWithEDU(edu) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + + var output keyAPI.CrossSigningKeyUpdate + if err = json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + println(err.Error()) + return true + } + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx, js, cfg.Global.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), + cfg.Global.JetStream.Durable("TestSigningKeyUpdate"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUPresence(t *testing.T) { + var err error + userID := "@john:kaer.morhen" + presence := "online" + edu := gomatrixserverlib.EDU{Type: "m.presence"} + if edu.Content, err = json.Marshal(map[string]interface{}{ + "push": []map[string]interface{}{{ + "currently_active": true, + "last_active_ago": 5000, + "presence": presence, + "status_msg": "Making cupcakes", + "user_id": userID, + }}, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + + txn, js, cfg := createTransactionWithEDU(edu) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + + userIDRes := msg.Header.Get(jetstream.UserID) + presenceRes := msg.Header.Get("presence") + assert.Equal(t, userID, userIDRes) + assert.Equal(t, presence, presenceRes) + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx, js, cfg.Global.JetStream.Prefixed(jetstream.OutputPresenceEvent), + cfg.Global.JetStream.Durable("TestPresence"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUUnhandled(t *testing.T) { + var err error + edu := gomatrixserverlib.EDU{Type: "m.unhandled"} + if edu.Content, err = json.Marshal(map[string]interface{}{}); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + + txn, _, _ := createTransactionWithEDU(edu) txnRes, jsonRes := txn.ProcessTransaction(context.Background()) assert.Nil(t, jsonRes)