mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-16 02:23:10 -06:00
Maybe fix send-to-device tests
This commit is contained in:
parent
54dd0b30e7
commit
4cd07b4222
|
|
@ -60,6 +60,17 @@ func TestWriteEvents(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithSnapshot(t *testing.T, db storage.Database, f func(snapshot storage.DatabaseSnapshot)) {
|
||||||
|
snapshot, err := db.NewDatabaseSnapshot(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
f(snapshot)
|
||||||
|
if err := snapshot.Rollback(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// These tests assert basic functionality of RecentEvents for PDUs
|
// These tests assert basic functionality of RecentEvents for PDUs
|
||||||
func TestRecentEventsPDU(t *testing.T) {
|
func TestRecentEventsPDU(t *testing.T) {
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
|
@ -79,16 +90,13 @@ func TestRecentEventsPDU(t *testing.T) {
|
||||||
// dummy room to make sure SQL queries are filtering on room ID
|
// dummy room to make sure SQL queries are filtering on room ID
|
||||||
MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
|
MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
|
||||||
|
|
||||||
snapshot, err := db.NewDatabaseSnapshot(ctx)
|
var latest types.StreamPosition
|
||||||
if err != nil {
|
WithSnapshot(t, db, func(snapshot storage.DatabaseSnapshot) {
|
||||||
t.Fatal(err)
|
var err error
|
||||||
}
|
if latest, err = snapshot.MaxStreamPositionForPDUs(ctx); err != nil {
|
||||||
defer snapshot.Rollback() // nolint:errcheck
|
t.Fatal("failed to get MaxStreamPositionForPDUs: %w", err)
|
||||||
|
}
|
||||||
latest, err := snapshot.MaxStreamPositionForPDUs(ctx)
|
})
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to get MaxStreamPositionForPDUs: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
Name string
|
Name string
|
||||||
|
|
@ -146,14 +154,19 @@ func TestRecentEventsPDU(t *testing.T) {
|
||||||
tc := testCases[i]
|
tc := testCases[i]
|
||||||
t.Run(tc.Name, func(st *testing.T) {
|
t.Run(tc.Name, func(st *testing.T) {
|
||||||
var filter gomatrixserverlib.RoomEventFilter
|
var filter gomatrixserverlib.RoomEventFilter
|
||||||
|
var gotEvents []types.StreamEvent
|
||||||
|
var limited bool
|
||||||
filter.Limit = tc.Limit
|
filter.Limit = tc.Limit
|
||||||
gotEvents, limited, err := snapshot.RecentEvents(ctx, r.ID, types.Range{
|
WithSnapshot(t, db, func(snapshot storage.DatabaseSnapshot) {
|
||||||
From: tc.From,
|
var err error
|
||||||
To: tc.To,
|
gotEvents, limited, err = snapshot.RecentEvents(ctx, r.ID, types.Range{
|
||||||
}, &filter, !tc.ReverseOrder, true)
|
From: tc.From,
|
||||||
if err != nil {
|
To: tc.To,
|
||||||
st.Fatalf("failed to do sync: %s", err)
|
}, &filter, !tc.ReverseOrder, true)
|
||||||
}
|
if err != nil {
|
||||||
|
st.Fatalf("failed to do sync: %s", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
if limited != tc.WantLimited {
|
if limited != tc.WantLimited {
|
||||||
st.Errorf("got limited=%v want %v", limited, tc.WantLimited)
|
st.Errorf("got limited=%v want %v", limited, tc.WantLimited)
|
||||||
}
|
}
|
||||||
|
|
@ -184,28 +197,24 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
|
||||||
events := r.Events()
|
events := r.Events()
|
||||||
_ = MustWriteEvents(t, db, events)
|
_ = MustWriteEvents(t, db, events)
|
||||||
|
|
||||||
snapshot, err := db.NewDatabaseSnapshot(ctx)
|
WithSnapshot(t, db, func(snapshot storage.DatabaseSnapshot) {
|
||||||
if err != nil {
|
from, err := snapshot.MaxTopologicalPosition(ctx, r.ID)
|
||||||
t.Fatal(err)
|
if err != nil {
|
||||||
}
|
t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
|
||||||
defer snapshot.Rollback() // nolint:errcheck
|
}
|
||||||
|
t.Logf("max topo pos = %+v", from)
|
||||||
|
// head towards the beginning of time
|
||||||
|
to := types.TopologyToken{}
|
||||||
|
|
||||||
from, err := snapshot.MaxTopologicalPosition(ctx, r.ID)
|
// backpaginate 5 messages starting at the latest position.
|
||||||
if err != nil {
|
filter := &gomatrixserverlib.RoomEventFilter{Limit: 5}
|
||||||
t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
|
paginatedEvents, err := snapshot.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true)
|
||||||
}
|
if err != nil {
|
||||||
t.Logf("max topo pos = %+v", from)
|
t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err)
|
||||||
// head towards the beginning of time
|
}
|
||||||
to := types.TopologyToken{}
|
gots := snapshot.StreamEventsToEvents(nil, paginatedEvents)
|
||||||
|
test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:]))
|
||||||
// backpaginate 5 messages starting at the latest position.
|
})
|
||||||
filter := &gomatrixserverlib.RoomEventFilter{Limit: 5}
|
|
||||||
paginatedEvents, err := snapshot.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err)
|
|
||||||
}
|
|
||||||
gots := snapshot.StreamEventsToEvents(nil, paginatedEvents)
|
|
||||||
test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:]))
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -426,18 +435,16 @@ func TestSendToDeviceBehaviour(t *testing.T) {
|
||||||
defer closeBase()
|
defer closeBase()
|
||||||
// At this point there should be no messages. We haven't sent anything
|
// At this point there should be no messages. We haven't sent anything
|
||||||
// yet.
|
// yet.
|
||||||
snapshot, err := db.NewDatabaseSnapshot(ctx)
|
|
||||||
if err != nil {
|
WithSnapshot(t, db, func(snapshot storage.DatabaseSnapshot) {
|
||||||
t.Fatal(err)
|
_, events, err := snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, 100)
|
||||||
}
|
if err != nil {
|
||||||
defer snapshot.Rollback() // nolint:errcheck
|
t.Fatal(err)
|
||||||
_, events, err := snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, 100)
|
}
|
||||||
if err != nil {
|
if len(events) != 0 {
|
||||||
t.Fatal(err)
|
t.Fatal("first call should have no updates")
|
||||||
}
|
}
|
||||||
if len(events) != 0 {
|
})
|
||||||
t.Fatal("first call should have no updates")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try sending a message.
|
// Try sending a message.
|
||||||
streamPos, err := db.StoreNewSendForDeviceMessage(ctx, alice.ID, deviceID, gomatrixserverlib.SendToDeviceEvent{
|
streamPos, err := db.StoreNewSendForDeviceMessage(ctx, alice.ID, deviceID, gomatrixserverlib.SendToDeviceEvent{
|
||||||
|
|
@ -449,51 +456,58 @@ func TestSendToDeviceBehaviour(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// At this point we should get exactly one message. We're sending the sync position
|
WithSnapshot(t, db, func(snapshot storage.DatabaseSnapshot) {
|
||||||
// that we were given from the update and the send-to-device update will be updated
|
// At this point we should get exactly one message. We're sending the sync position
|
||||||
// in the database to reflect that this was the sync position we sent the message at.
|
// that we were given from the update and the send-to-device update will be updated
|
||||||
streamPos, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos)
|
// in the database to reflect that this was the sync position we sent the message at.
|
||||||
if err != nil {
|
var events []types.SendToDeviceEvent
|
||||||
t.Fatal(err)
|
streamPos, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos)
|
||||||
}
|
if err != nil {
|
||||||
if count := len(events); count != 1 {
|
t.Fatal(err)
|
||||||
t.Fatalf("second call should have one update, got %d", count)
|
}
|
||||||
}
|
if count := len(events); count != 1 {
|
||||||
|
t.Fatalf("second call should have one update, got %d", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
// At this point we should still have one message because we haven't progressed the
|
||||||
|
// sync position yet. This is equivalent to the client failing to /sync and retrying
|
||||||
|
// with the same position.
|
||||||
|
streamPos, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(events) != 1 {
|
||||||
|
t.Fatal("third call should have one update still")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
// At this point we should still have one message because we haven't progressed the
|
|
||||||
// sync position yet. This is equivalent to the client failing to /sync and retrying
|
|
||||||
// with the same position.
|
|
||||||
streamPos, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if len(events) != 1 {
|
|
||||||
t.Fatal("third call should have one update still")
|
|
||||||
}
|
|
||||||
err = db.CleanSendToDeviceUpdates(context.Background(), alice.ID, deviceID, streamPos)
|
err = db.CleanSendToDeviceUpdates(context.Background(), alice.ID, deviceID, streamPos)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// At this point we should now have no updates, because we've progressed the sync
|
WithSnapshot(t, db, func(snapshot storage.DatabaseSnapshot) {
|
||||||
// position. Therefore the update from before will not be sent again.
|
// At this point we should now have no updates, because we've progressed the sync
|
||||||
_, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10)
|
// position. Therefore the update from before will not be sent again.
|
||||||
if err != nil {
|
var events []types.SendToDeviceEvent
|
||||||
t.Fatal(err)
|
_, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10)
|
||||||
}
|
if err != nil {
|
||||||
if len(events) != 0 {
|
t.Fatal(err)
|
||||||
t.Fatal("fourth call should have no updates")
|
}
|
||||||
}
|
if len(events) != 0 {
|
||||||
|
t.Fatal("fourth call should have no updates")
|
||||||
|
}
|
||||||
|
|
||||||
// At this point we should still have no updates, because no new updates have been
|
// At this point we should still have no updates, because no new updates have been
|
||||||
// sent.
|
// sent.
|
||||||
_, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10)
|
_, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if len(events) != 0 {
|
if len(events) != 0 {
|
||||||
t.Fatal("fifth call should have no updates")
|
t.Fatal("fifth call should have no updates")
|
||||||
}
|
}
|
||||||
|
})
|
||||||
|
|
||||||
// Send some more messages and verify the ordering is correct ("in order of arrival")
|
// Send some more messages and verify the ordering is correct ("in order of arrival")
|
||||||
var lastPos types.StreamPosition = 0
|
var lastPos types.StreamPosition = 0
|
||||||
|
|
@ -509,18 +523,20 @@ func TestSendToDeviceBehaviour(t *testing.T) {
|
||||||
lastPos = streamPos
|
lastPos = streamPos
|
||||||
}
|
}
|
||||||
|
|
||||||
_, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, lastPos)
|
WithSnapshot(t, db, func(snapshot storage.DatabaseSnapshot) {
|
||||||
if err != nil {
|
_, events, err := snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, lastPos)
|
||||||
t.Fatalf("unable to get events: %v", err)
|
if err != nil {
|
||||||
}
|
t.Fatalf("unable to get events: %v", err)
|
||||||
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
want := json.RawMessage(fmt.Sprintf(`{"count":%d}`, i))
|
|
||||||
got := events[i].Content
|
|
||||||
if !bytes.Equal(got, want) {
|
|
||||||
t.Fatalf("messages are out of order\nwant: %s\ngot: %s", string(want), string(got))
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
want := json.RawMessage(fmt.Sprintf(`{"count":%d}`, i))
|
||||||
|
got := events[i].Content
|
||||||
|
if !bytes.Equal(got, want) {
|
||||||
|
t.Fatalf("messages are out of order\nwant: %s\ngot: %s", string(want), string(got))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue