From e293d8215195fa5cc3a427a97ae9fc448ea7043f Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Wed, 7 Dec 2022 08:06:45 +0100 Subject: [PATCH] Verify request for notifications --- userapi/util/notify_test.go | 39 +++++++++++++++++++++++++++++----- userapi/util/phonehomestats.go | 4 +--- 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/userapi/util/notify_test.go b/userapi/util/notify_test.go index 8a1755b85..f1d20259c 100644 --- a/userapi/util/notify_test.go +++ b/userapi/util/notify_test.go @@ -2,6 +2,7 @@ package util_test import ( "context" + "encoding/json" "net/http" "net/http/httptest" "testing" @@ -32,15 +33,44 @@ func TestNotifyUserCountsAsync(t *testing.T) { room := test.NewRoom(t, alice) dummyEvent := room.Events()[len(room.Events())-1] + appID := util.RandomString(8) + pushKey := util.RandomString(8) + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { receivedRequest := make(chan bool, 1) // create a test server which responds to our /notify call srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var data pushgateway.NotifyRequest + if err := json.NewDecoder(r.Body).Decode(&data); err != nil { + t.Error(err) + } + notification := data.Notification + // Validate the request + if notification.Counts == nil { + t.Fatal("no unread notification counts in request") + } + if unread := notification.Counts.Unread; unread != 1 { + t.Errorf("expected one unread notification, got %d", unread) + } + + if len(notification.Devices) == 0 { + t.Fatal("expected devices in request") + } + + // We only created one push device, so access it directly + device := notification.Devices[0] + if device.AppID != appID { + t.Errorf("unexpected app_id: %s, want %s", device.AppID, appID) + } + if device.PushKey != pushKey { + t.Errorf("unexpected push_key: %s, want %s", device.PushKey, pushKey) + } + // Return empty result, otherwise the call is handled as failed if _, err := w.Write([]byte("{}")); err != nil { t.Error(err) } - receivedRequest <- true + close(receivedRequest) })) defer srv.Close() @@ -59,11 +89,10 @@ func TestNotifyUserCountsAsync(t *testing.T) { // Prepare pusher with our test server URL if err := db.UpsertPusher(ctx, api.Pusher{ Kind: api.HTTPKind, - AppID: util.RandomString(8), - PushKey: util.RandomString(8), + AppID: appID, + PushKey: pushKey, Data: map[string]interface{}{ - "url": srv.URL, - "event_id": dummyEvent.EventID(), + "url": srv.URL, }, }, aliceLocalpart, serverName); err != nil { t.Error(err) diff --git a/userapi/util/phonehomestats.go b/userapi/util/phonehomestats.go index 2e20f15ac..42c8f5d7c 100644 --- a/userapi/util/phonehomestats.go +++ b/userapi/util/phonehomestats.go @@ -63,9 +63,7 @@ func StartPhoneHomeCollector(startTime time.Time, cfg *config.Dendrite, statsDB } // start initial run after 5min - time.AfterFunc(time.Minute*5, func() { - p.collect() - }) + time.AfterFunc(time.Minute*5, p.collect) // run every 3 hours ticker := time.NewTicker(time.Hour * 3)