diff --git a/syncapi/storage/postgres/send_to_device_table.go b/syncapi/storage/postgres/send_to_device_table.go index 2734fef3e..fd0c1c56b 100644 --- a/syncapi/storage/postgres/send_to_device_table.go +++ b/syncapi/storage/postgres/send_to_device_table.go @@ -58,7 +58,7 @@ const selectSendToDeviceMessagesSQL = ` const deleteSendToDeviceMessagesSQL = ` DELETE FROM syncapi_send_to_device - WHERE user_id = $1 AND device_id = $2 AND id < $3 + WHERE user_id = $1 AND device_id = $2 AND id <= $3 ` const selectMaxSendToDeviceIDSQL = "" + diff --git a/syncapi/storage/sqlite3/send_to_device_table.go b/syncapi/storage/sqlite3/send_to_device_table.go index d05d3fe72..e3aa1b7a1 100644 --- a/syncapi/storage/sqlite3/send_to_device_table.go +++ b/syncapi/storage/sqlite3/send_to_device_table.go @@ -55,7 +55,7 @@ const selectSendToDeviceMessagesSQL = ` const deleteSendToDeviceMessagesSQL = ` DELETE FROM syncapi_send_to_device - WHERE user_id = $1 AND device_id = $2 AND id < $3 + WHERE user_id = $1 AND device_id = $2 AND id <= $3 ` const selectMaxSendToDeviceIDSQL = "" + diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index df03a33c2..eda5ef3e6 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -416,11 +416,6 @@ func TestSendToDeviceBehaviour(t *testing.T) { t.Fatal("first call should have no updates") } - err = db.CleanSendToDeviceUpdates(context.Background(), alice.ID, deviceID, 100) - if err != nil { - return - } - // Try sending a message. streamPos, err := db.StoreNewSendForDeviceMessage(ctx, alice.ID, deviceID, gomatrixserverlib.SendToDeviceEvent{ Sender: bob.ID, @@ -441,43 +436,35 @@ func TestSendToDeviceBehaviour(t *testing.T) { if count := len(events); count != 1 { t.Fatalf("second call should have one update, got %d", count) } - err = db.CleanSendToDeviceUpdates(context.Background(), alice.ID, deviceID, streamPos) - if err != nil { - return - } // At this point we should still have one message because we haven't progressed the // sync position yet. This is equivalent to the client failing to /sync and retrying // with the same position. - streamPos, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, 100) + streamPos, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos) if err != nil { t.Fatal(err) } if len(events) != 1 { t.Fatal("third call should have one update still") } - err = db.CleanSendToDeviceUpdates(context.Background(), alice.ID, deviceID, streamPos+1) + err = db.CleanSendToDeviceUpdates(context.Background(), alice.ID, deviceID, streamPos) if err != nil { return } // At this point we should now have no updates, because we've progressed the sync // position. Therefore the update from before will not be sent again. - _, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos+1, streamPos+2) + _, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10) if err != nil { t.Fatal(err) } if len(events) != 0 { t.Fatal("fourth call should have no updates") } - err = db.CleanSendToDeviceUpdates(context.Background(), alice.ID, deviceID, streamPos+1) - if err != nil { - return - } // At this point we should still have no updates, because no new updates have been // sent. - _, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+2) + _, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10) if err != nil { t.Fatal(err) } @@ -491,7 +478,7 @@ func TestSendToDeviceBehaviour(t *testing.T) { streamPos, err = db.StoreNewSendForDeviceMessage(ctx, alice.ID, deviceID, gomatrixserverlib.SendToDeviceEvent{ Sender: bob.ID, Type: "m.type", - Content: json.RawMessage(fmt.Sprintf(`{ "count": %d }`, i)), + Content: json.RawMessage(fmt.Sprintf(`{"count":%d}`, i)), }) if err != nil { t.Fatal(err) diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index 3ce7c64b7..b10864ff5 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -3,11 +3,14 @@ package syncapi import ( "context" "encoding/json" + "fmt" "net/http" "net/http/httptest" + "reflect" "testing" "time" + "github.com/matrix-org/dendrite/clientapi/producers" keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api" @@ -311,6 +314,139 @@ func testSyncAPIUpdatePresenceImmediately(t *testing.T, dbType test.DBType) { } +func TestSendToDevice(t *testing.T) { + test.WithAllDatabases(t, testSendToDevice) +} + +func testSendToDevice(t *testing.T, dbType test.DBType) { + user := test.NewUser(t) + alice := userapi.Device{ + ID: "ALICEID", + UserID: user.ID, + AccessToken: "ALICE_BEARER_TOKEN", + DisplayName: "Alice", + AccountType: userapi.AccountTypeUser, + } + + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) + defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) + + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, &syncKeyAPI{}) + + producer := producers.SyncAPIProducer{ + TopicSendToDeviceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), + JetStream: jsctx, + } + + msgCounter := 0 + + testCases := []struct { + name string + since string + want []string + sendMessagesCount int + }{ + { + name: "initial sync, no messages", + want: []string{}, + }, + { + name: "initial sync, one new message", + sendMessagesCount: 1, + want: []string{ + "message 1", + }, + }, + { + name: "initial sync, two new messages", // we didn't advance the since token, so we'll receive two messages + sendMessagesCount: 1, + want: []string{ + "message 1", + "message 2", + }, + }, + { + name: "incremental sync, one message", // this deletes message 1, as we advanced the since token + since: types.StreamingToken{SendToDevicePosition: 1}.String(), + want: []string{ + "message 2", + }, + }, + { + name: "failed incremental sync, one message", // didn't advance since, so still the same message + since: types.StreamingToken{SendToDevicePosition: 1}.String(), + want: []string{ + "message 2", + }, + }, + { + name: "incremental sync, no message", // this should delete message 2 + since: types.StreamingToken{SendToDevicePosition: 2}.String(), // next_batch from previous sync + want: []string{}, + }, + { + name: "incremental sync, three new messages", + since: types.StreamingToken{SendToDevicePosition: 2}.String(), + sendMessagesCount: 3, + want: []string{ + "message 3", // message 2 was deleted in the previous test + "message 4", + "message 5", + }, + }, + { + name: "initial sync, three messages", // we expect three messages, as we didn't go beyond "2" + want: []string{ + "message 3", + "message 4", + "message 5", + }, + }, + { + name: "incremental sync, no messages", // advance the sync token, no new messages + since: types.StreamingToken{SendToDevicePosition: 5}.String(), + want: []string{}, + }, + } + + ctx := context.Background() + for _, tc := range testCases { + // Send to-device messages of type "m.dendrite.test" with content `{"dummy":"message $counter"}` + for i := 0; i < tc.sendMessagesCount; i++ { + msgCounter++ + msg := map[string]string{ + "dummy": fmt.Sprintf("message %d", msgCounter), + } + if err := producer.SendToDevice(ctx, user.ID, user.ID, alice.ID, "m.dendrite.test", msg); err != nil { + t.Fatalf("unable to send to device message: %v", err) + } + } + time.Sleep((time.Millisecond * 15) * time.Duration(tc.sendMessagesCount)) // wait a bit, so the messages can be processed + // Execute a /sync request, recording the response + w := httptest.NewRecorder() + base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{ + "access_token": alice.AccessToken, + "since": tc.since, + }))) + + // Extract the to_device.events, # gets all values of an array, in this case a string slice with "message $counter" entries + events := gjson.Get(w.Body.String(), "to_device.events.#.content.dummy").Array() + got := make([]string, len(events)) + for i := range events { + got[i] = events[i].String() + } + + // Ensure the messages we received are as we expect them to be + if !reflect.DeepEqual(got, tc.want) { + t.Logf("[%s|since=%s]: Sync: %s", tc.name, tc.since, w.Body.String()) + t.Fatalf("[%s|since=%s]: got: %+v, want: %+v", tc.name, tc.since, got, tc.want) + } + } +} + func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input []*gomatrixserverlib.HeaderedEvent) []*nats.Msg { result := make([]*nats.Msg, len(input)) for i, ev := range input { diff --git a/test/testrig/base.go b/test/testrig/base.go index facb49f3e..d13c43129 100644 --- a/test/testrig/base.go +++ b/test/testrig/base.go @@ -32,11 +32,11 @@ func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, f var cfg config.Dendrite cfg.Defaults(false) cfg.Global.JetStream.InMemory = true - switch dbType { case test.DBTypePostgres: cfg.Global.Defaults(true) // autogen a signing key cfg.MediaAPI.Defaults(true) // autogen a media path + cfg.Global.ServerName = "test" // use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use // the file system event with InMemory=true :( cfg.Global.JetStream.TopicPrefix = fmt.Sprintf("Test_%d_", dbType) @@ -50,6 +50,7 @@ func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, f return base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics), close case test.DBTypeSQLite: cfg.Defaults(true) // sets a sqlite db per component + cfg.Global.ServerName = "test" // use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use // the file system event with InMemory=true :( cfg.Global.JetStream.TopicPrefix = fmt.Sprintf("Test_%d_", dbType)