diff --git a/pushserver/consumers/roomserver_test.go b/pushserver/consumers/roomserver_test.go index a7a565421..904f7b9aa 100644 --- a/pushserver/consumers/roomserver_test.go +++ b/pushserver/consumers/roomserver_test.go @@ -33,7 +33,7 @@ func TestOutputRoomEventConsumer(t *testing.T) { if err != nil { t.Fatalf("NewDatabase failed: %v", err) } - err = db.CreatePusher(ctx, + err = db.UpsertPusher(ctx, api.Pusher{ PushKey: "apushkey", Kind: api.HTTPKind, @@ -45,7 +45,7 @@ func TestOutputRoomEventConsumer(t *testing.T) { }, "alice") if err != nil { - t.Fatalf("CreatePusher failed: %v", err) + t.Fatalf("UpsertPusher failed: %v", err) } var rsAPI fakeRoomServerInternalAPI diff --git a/pushserver/internal/api.go b/pushserver/internal/api.go index 2c595c33e..f56d9ec7e 100644 --- a/pushserver/internal/api.go +++ b/pushserver/internal/api.go @@ -89,7 +89,7 @@ func (a *PushserverInternalAPI) PerformPusherSet(ctx context.Context, req *api.P if req.Pusher.PushKeyTS == 0 { req.Pusher.PushKeyTS = gomatrixserverlib.AsTimestamp(time.Now()) } - return a.DB.CreatePusher(ctx, req.Pusher, req.Localpart) + return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart) } func (a *PushserverInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error { diff --git a/pushserver/storage/interface.go b/pushserver/storage/interface.go index 204d621ea..6263db435 100644 --- a/pushserver/storage/interface.go +++ b/pushserver/storage/interface.go @@ -10,7 +10,7 @@ import ( type Database interface { internal.PartitionStorer - CreatePusher(ctx context.Context, pusher api.Pusher, localpart string) error + UpsertPusher(ctx context.Context, pusher api.Pusher, localpart string) error GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error) RemovePusher(ctx context.Context, appId, pushkey, localpart string) error RemovePushers(ctx context.Context, appId, pushkey string) error diff --git a/pushserver/storage/shared/pusher_table.go b/pushserver/storage/shared/pusher_table.go index 747e03bac..825cdbd18 100644 --- a/pushserver/storage/shared/pusher_table.go +++ b/pushserver/storage/shared/pusher_table.go @@ -56,7 +56,9 @@ CREATE UNIQUE INDEX IF NOT EXISTS pusher_app_id_pushkey_localpart_idx ON pushser ` const insertPusherSQL = "" + - "INSERT INTO pushserver_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" + "INSERT INTO pushserver_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" + + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" + + "ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11" const selectPushersSQL = "" + "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM pushserver_pushers WHERE localpart = $1" diff --git a/pushserver/storage/shared/storage.go b/pushserver/storage/shared/storage.go index 5ffcb7f2a..e436027e7 100644 --- a/pushserver/storage/shared/storage.go +++ b/pushserver/storage/shared/storage.go @@ -61,7 +61,7 @@ func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart, roo return d.notifications.SelectRoomCounts(ctx, localpart, roomID) } -func (d *Database) CreatePusher( +func (d *Database) UpsertPusher( ctx context.Context, p api.Pusher, localpart string, ) error { data, err := json.Marshal(p.Data) diff --git a/pushserver/storage/storage_test.go b/pushserver/storage/storage_test.go index cf3a1b441..c42ba6ba5 100644 --- a/pushserver/storage/storage_test.go +++ b/pushserver/storage/storage_test.go @@ -14,72 +14,95 @@ import ( var testCtx = context.Background() -var testPushers = []api.Pusher{ - { - SessionID: 42984798792, - PushKey: "dc_GxbDa8El0pWKkDIM-rQ:APA91bHflmL6ycJMbLKX8VYLD-Ebft3t-SLQwIap-pDWP-evu1AWxsXxzyl1pgSZxDMn6OeznZsjXhTU0m5xz05dyJ4syX86S89uwxBwtbK-k0PHQt9wF8CgOcibm-OYZodpY5TtmknZ", - Kind: "http", - AppID: "com.example.app.ios", - AppDisplayName: "Mat Rix", - DeviceDisplayName: "iPhone 9", - ProfileTag: "xxyyzz", - Language: "pl", - Data: map[string]interface{}{ - "format": "event_id_only", - "url": "https://push-gateway.location.there/_matrix/push/v1/notify", +var ( + testPushers = []api.Pusher{ + { + SessionID: 42984798792, + PushKey: "dc_GxbDa8El0pWKkDIM-rQ:APA91bHflmL6ycJMbLKX8VYLD-Ebft3t-SLQwIap-pDWP-evu1AWxsXxzyl1pgSZxDMn6OeznZsjXhTU0m5xz05dyJ4syX86S89uwxBwtbK-k0PHQt9wF8CgOcibm-OYZodpY5TtmknZ", + Kind: "http", + AppID: "com.example.app.ios", + AppDisplayName: "Mat Rix", + DeviceDisplayName: "iPhone 9", + ProfileTag: "xxyyzz", + Language: "pl", + Data: map[string]interface{}{ + "format": "event_id_only", + "url": "https://push-gateway.location.there/_matrix/push/v1/notify", + }, }, - }, - { - SessionID: 4298479873432, - PushKey: "dnjekDa8El0pWKkDIM-rQ:APA91bHflmL6ycJMbLKX8VYLD-Ebft3t-SLQwIap-pDWP-evu1AWxsXxzyl1pgSZxDMn6OeznZsjXhTU0m5xz05dyJ4syX86S89uwxBwtbK-k0PHQt9wF8CgOcibm-OYZodpY5TtmknZ", - Kind: "http", + { + SessionID: 4298479873432, + PushKey: "dnjekDa8El0pWKkDIM-rQ:APA91bHflmL6ycJMbLKX8VYLD-Ebft3t-SLQwIap-pDWP-evu1AWxsXxzyl1pgSZxDMn6OeznZsjXhTU0m5xz05dyJ4syX86S89uwxBwtbK-k0PHQt9wF8CgOcibm-OYZodpY5TtmknZ", + Kind: "http", + AppID: "com.example.app.ios", + AppDisplayName: "Riot", + DeviceDisplayName: "Android 11", + ProfileTag: "aabbcc", + Language: "en", + Data: map[string]interface{}{ + "format": "event_id_only", + "url": "https://push-gateway.location.there/_matrix/push/v1/notify", + }, + }, + { + SessionID: 4298479873432, + PushKey: "dc_GxbDa8El0pWKkDIM-rQ:APA91bHflmL6ycJMbLKX8VYLD-Ebft3t-SLQwIap-pDWP-evu1AWxsXxzyl1pgSZxDMn6OeznZsjXhTU0m5xz05dyJ4syX86S89uwxBwtbK-k0PHQt9wF8CgOcibm-OYZodpY5TtmknZ", + Kind: "http", + AppID: "com.example.app.ios", + AppDisplayName: "Riot", + DeviceDisplayName: "Android 11", + ProfileTag: "aabbcc", + Language: "en", + Data: map[string]interface{}{ + "format": "event_id_only", + "url": "https://push-gateway.location.there/_matrix/push/v1/notify", + }, + }, + } + + updatePusher = api.Pusher{ AppID: "com.example.app.ios", - AppDisplayName: "Riot", - DeviceDisplayName: "Android 11", - ProfileTag: "aabbcc", + PushKey: "dc_GxbDa8El0pWKkDIM-rQ:APA91bHflmL6ycJMbLKX8VYLD-Ebft3t-SLQwIap-pDWP-evu1AWxsXxzyl1pgSZxDMn6OeznZsjXhTU0m5xz05dyJ4syX86S89uwxBwtbK-k0PHQt9wF8CgOcibm-OYZodpY5TtmknZ", + SessionID: 429847987, + Kind: "http", + AppDisplayName: "Mat Rix 2", + DeviceDisplayName: "iPhone 9a", + ProfileTag: "xxyyzzaa", Language: "en", Data: map[string]interface{}{ "format": "event_id_only", - "url": "https://push-gateway.location.there/_matrix/push/v1/notify", + "url": "https://push-gateway.location.here/_matrix/push/v1/notify", }, - }, - { - SessionID: 4298479873432, - PushKey: "dc_GxbDa8El0pWKkDIM-rQ:APA91bHflmL6ycJMbLKX8VYLD-Ebft3t-SLQwIap-pDWP-evu1AWxsXxzyl1pgSZxDMn6OeznZsjXhTU0m5xz05dyJ4syX86S89uwxBwtbK-k0PHQt9wF8CgOcibm-OYZodpY5TtmknZ", - Kind: "http", - AppID: "com.example.app.ios", - AppDisplayName: "Riot", - DeviceDisplayName: "Android 11", - ProfileTag: "aabbcc", - Language: "en", - Data: map[string]interface{}{ - "format": "event_id_only", - "url": "https://push-gateway.location.there/_matrix/push/v1/notify", - }, - }, -} + } +) var testUsers = []string{ "admin", "admin", "admin0", + "admin", } func mustNewDatabaseWithTestPushers(is *is.I) Database { + dut := mustNewDatabase(is) + for i, testPusher := range testPushers { + err := dut.UpsertPusher(testCtx, testPusher, testUsers[i]) + is.NoErr(err) + } + return dut +} + +func mustNewDatabase(is *is.I) Database { randPostfix := strconv.Itoa(rand.Int()) dbPath := os.TempDir() + "/dendrite-" + randPostfix dut, err := Open(&config.DatabaseOptions{ ConnectionString: config.DataSource("file:" + dbPath), }) is.NoErr(err) - for i, testPusher := range testPushers { - err = dut.CreatePusher(testCtx, testPusher, testUsers[i]) - is.NoErr(err) - } return dut } -func TestCreatePusher(t *testing.T) { +func TestInsertPusher(t *testing.T) { is := is.New(t) mustNewDatabaseWithTestPushers(is) } @@ -131,3 +154,18 @@ func TestDeletePushers(t *testing.T) { is.NoErr(err) is.Equal(len(pushers), 0) } + +func TestUpdatePusher(t *testing.T) { + is := is.New(t) + dut := mustNewDatabase(is) + err := dut.UpsertPusher(testCtx, testPushers[0], "admin") + is.NoErr(err) + err = dut.UpsertPusher(testCtx, updatePusher, "admin") + is.NoErr(err) + pushers, err := dut.GetPushers(testCtx, "admin") + is.NoErr(err) + is.Equal(len(pushers), 1) + t.Log(pushers[0]) + t.Log(updatePusher) + is.Equal(pushers[0], updatePusher) +}