Maybe fix send-to-device tests

This commit is contained in:
Neil Alexander 2022-09-28 14:54:33 +01:00
parent 54dd0b30e7
commit 4cd07b4222
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944

View file

@ -60,6 +60,17 @@ func TestWriteEvents(t *testing.T) {
}) })
} }
func WithSnapshot(t *testing.T, db storage.Database, f func(snapshot storage.DatabaseSnapshot)) {
snapshot, err := db.NewDatabaseSnapshot(ctx)
if err != nil {
t.Fatal(err)
}
f(snapshot)
if err := snapshot.Rollback(); err != nil {
t.Fatal(err)
}
}
// These tests assert basic functionality of RecentEvents for PDUs // These tests assert basic functionality of RecentEvents for PDUs
func TestRecentEventsPDU(t *testing.T) { func TestRecentEventsPDU(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
@ -79,16 +90,13 @@ func TestRecentEventsPDU(t *testing.T) {
// dummy room to make sure SQL queries are filtering on room ID // dummy room to make sure SQL queries are filtering on room ID
MustWriteEvents(t, db, test.NewRoom(t, alice).Events()) MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
snapshot, err := db.NewDatabaseSnapshot(ctx) var latest types.StreamPosition
if err != nil { WithSnapshot(t, db, func(snapshot storage.DatabaseSnapshot) {
t.Fatal(err) var err error
} if latest, err = snapshot.MaxStreamPositionForPDUs(ctx); err != nil {
defer snapshot.Rollback() // nolint:errcheck t.Fatal("failed to get MaxStreamPositionForPDUs: %w", err)
}
latest, err := snapshot.MaxStreamPositionForPDUs(ctx) })
if err != nil {
t.Fatalf("failed to get MaxStreamPositionForPDUs: %s", err)
}
testCases := []struct { testCases := []struct {
Name string Name string
@ -146,14 +154,19 @@ func TestRecentEventsPDU(t *testing.T) {
tc := testCases[i] tc := testCases[i]
t.Run(tc.Name, func(st *testing.T) { t.Run(tc.Name, func(st *testing.T) {
var filter gomatrixserverlib.RoomEventFilter var filter gomatrixserverlib.RoomEventFilter
var gotEvents []types.StreamEvent
var limited bool
filter.Limit = tc.Limit filter.Limit = tc.Limit
gotEvents, limited, err := snapshot.RecentEvents(ctx, r.ID, types.Range{ WithSnapshot(t, db, func(snapshot storage.DatabaseSnapshot) {
From: tc.From, var err error
To: tc.To, gotEvents, limited, err = snapshot.RecentEvents(ctx, r.ID, types.Range{
}, &filter, !tc.ReverseOrder, true) From: tc.From,
if err != nil { To: tc.To,
st.Fatalf("failed to do sync: %s", err) }, &filter, !tc.ReverseOrder, true)
} if err != nil {
st.Fatalf("failed to do sync: %s", err)
}
})
if limited != tc.WantLimited { if limited != tc.WantLimited {
st.Errorf("got limited=%v want %v", limited, tc.WantLimited) st.Errorf("got limited=%v want %v", limited, tc.WantLimited)
} }
@ -184,28 +197,24 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
events := r.Events() events := r.Events()
_ = MustWriteEvents(t, db, events) _ = MustWriteEvents(t, db, events)
snapshot, err := db.NewDatabaseSnapshot(ctx) WithSnapshot(t, db, func(snapshot storage.DatabaseSnapshot) {
if err != nil { from, err := snapshot.MaxTopologicalPosition(ctx, r.ID)
t.Fatal(err) if err != nil {
} t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
defer snapshot.Rollback() // nolint:errcheck }
t.Logf("max topo pos = %+v", from)
// head towards the beginning of time
to := types.TopologyToken{}
from, err := snapshot.MaxTopologicalPosition(ctx, r.ID) // backpaginate 5 messages starting at the latest position.
if err != nil { filter := &gomatrixserverlib.RoomEventFilter{Limit: 5}
t.Fatalf("failed to get MaxTopologicalPosition: %s", err) paginatedEvents, err := snapshot.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true)
} if err != nil {
t.Logf("max topo pos = %+v", from) t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err)
// head towards the beginning of time }
to := types.TopologyToken{} gots := snapshot.StreamEventsToEvents(nil, paginatedEvents)
test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:]))
// backpaginate 5 messages starting at the latest position. })
filter := &gomatrixserverlib.RoomEventFilter{Limit: 5}
paginatedEvents, err := snapshot.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true)
if err != nil {
t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err)
}
gots := snapshot.StreamEventsToEvents(nil, paginatedEvents)
test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:]))
}) })
} }
@ -426,18 +435,16 @@ func TestSendToDeviceBehaviour(t *testing.T) {
defer closeBase() defer closeBase()
// At this point there should be no messages. We haven't sent anything // At this point there should be no messages. We haven't sent anything
// yet. // yet.
snapshot, err := db.NewDatabaseSnapshot(ctx)
if err != nil { WithSnapshot(t, db, func(snapshot storage.DatabaseSnapshot) {
t.Fatal(err) _, events, err := snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, 100)
} if err != nil {
defer snapshot.Rollback() // nolint:errcheck t.Fatal(err)
_, events, err := snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, 100) }
if err != nil { if len(events) != 0 {
t.Fatal(err) t.Fatal("first call should have no updates")
} }
if len(events) != 0 { })
t.Fatal("first call should have no updates")
}
// Try sending a message. // Try sending a message.
streamPos, err := db.StoreNewSendForDeviceMessage(ctx, alice.ID, deviceID, gomatrixserverlib.SendToDeviceEvent{ streamPos, err := db.StoreNewSendForDeviceMessage(ctx, alice.ID, deviceID, gomatrixserverlib.SendToDeviceEvent{
@ -449,51 +456,58 @@ func TestSendToDeviceBehaviour(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
// At this point we should get exactly one message. We're sending the sync position WithSnapshot(t, db, func(snapshot storage.DatabaseSnapshot) {
// that we were given from the update and the send-to-device update will be updated // At this point we should get exactly one message. We're sending the sync position
// in the database to reflect that this was the sync position we sent the message at. // that we were given from the update and the send-to-device update will be updated
streamPos, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos) // in the database to reflect that this was the sync position we sent the message at.
if err != nil { var events []types.SendToDeviceEvent
t.Fatal(err) streamPos, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos)
} if err != nil {
if count := len(events); count != 1 { t.Fatal(err)
t.Fatalf("second call should have one update, got %d", count) }
} if count := len(events); count != 1 {
t.Fatalf("second call should have one update, got %d", count)
}
// At this point we should still have one message because we haven't progressed the
// sync position yet. This is equivalent to the client failing to /sync and retrying
// with the same position.
streamPos, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos)
if err != nil {
t.Fatal(err)
}
if len(events) != 1 {
t.Fatal("third call should have one update still")
}
})
// At this point we should still have one message because we haven't progressed the
// sync position yet. This is equivalent to the client failing to /sync and retrying
// with the same position.
streamPos, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos)
if err != nil {
t.Fatal(err)
}
if len(events) != 1 {
t.Fatal("third call should have one update still")
}
err = db.CleanSendToDeviceUpdates(context.Background(), alice.ID, deviceID, streamPos) err = db.CleanSendToDeviceUpdates(context.Background(), alice.ID, deviceID, streamPos)
if err != nil { if err != nil {
return return
} }
// At this point we should now have no updates, because we've progressed the sync WithSnapshot(t, db, func(snapshot storage.DatabaseSnapshot) {
// position. Therefore the update from before will not be sent again. // At this point we should now have no updates, because we've progressed the sync
_, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10) // position. Therefore the update from before will not be sent again.
if err != nil { var events []types.SendToDeviceEvent
t.Fatal(err) _, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10)
} if err != nil {
if len(events) != 0 { t.Fatal(err)
t.Fatal("fourth call should have no updates") }
} if len(events) != 0 {
t.Fatal("fourth call should have no updates")
}
// At this point we should still have no updates, because no new updates have been // At this point we should still have no updates, because no new updates have been
// sent. // sent.
_, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10) _, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(events) != 0 { if len(events) != 0 {
t.Fatal("fifth call should have no updates") t.Fatal("fifth call should have no updates")
} }
})
// Send some more messages and verify the ordering is correct ("in order of arrival") // Send some more messages and verify the ordering is correct ("in order of arrival")
var lastPos types.StreamPosition = 0 var lastPos types.StreamPosition = 0
@ -509,18 +523,20 @@ func TestSendToDeviceBehaviour(t *testing.T) {
lastPos = streamPos lastPos = streamPos
} }
_, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, lastPos) WithSnapshot(t, db, func(snapshot storage.DatabaseSnapshot) {
if err != nil { _, events, err := snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, lastPos)
t.Fatalf("unable to get events: %v", err) if err != nil {
} t.Fatalf("unable to get events: %v", err)
for i := 0; i < 10; i++ {
want := json.RawMessage(fmt.Sprintf(`{"count":%d}`, i))
got := events[i].Content
if !bytes.Equal(got, want) {
t.Fatalf("messages are out of order\nwant: %s\ngot: %s", string(want), string(got))
} }
}
for i := 0; i < 10; i++ {
want := json.RawMessage(fmt.Sprintf(`{"count":%d}`, i))
got := events[i].Content
if !bytes.Equal(got, want) {
t.Fatalf("messages are out of order\nwant: %s\ngot: %s", string(want), string(got))
}
}
})
}) })
} }