From 926252671b205bc7c7234b748d190be04bae42c0 Mon Sep 17 00:00:00 2001 From: Tommie Gannert Date: Wed, 27 Oct 2021 17:52:12 +0200 Subject: [PATCH] Implement pushserver/storage for notifications. --- pushserver/api/api.go | 9 + pushserver/storage/interface.go | 10 + .../storage/shared/notification_table.go | 237 ++++++++++++++++++ pushserver/storage/shared/storage.go | 46 +++- pushserver/storage/tables/interface.go | 32 +++ 5 files changed, 331 insertions(+), 3 deletions(-) create mode 100644 pushserver/storage/shared/notification_table.go diff --git a/pushserver/api/api.go b/pushserver/api/api.go index 67fa582c0..b3dee52f8 100644 --- a/pushserver/api/api.go +++ b/pushserver/api/api.go @@ -68,3 +68,12 @@ type QueryPushRulesRequest struct { type QueryPushRulesResponse struct { RuleSets *pushrules.AccountRuleSets `json:"rule_sets"` } + +type Notification struct { + Actions []*pushrules.Action `json:"actions"` // Required. + Event gomatrixserverlib.ClientEvent `json:"event"` // Required. + ProfileTag string `json:"profile_tag"` // Required by Sytest, but actually optional. + Read bool `json:"read"` // Required. + RoomID string `json:"room_id"` // Required. + TS gomatrixserverlib.Timestamp `json:"ts"` // Required. +} diff --git a/pushserver/storage/interface.go b/pushserver/storage/interface.go index dffebf4f1..204d621ea 100644 --- a/pushserver/storage/interface.go +++ b/pushserver/storage/interface.go @@ -5,6 +5,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/pushserver/api" + "github.com/matrix-org/dendrite/pushserver/storage/tables" ) type Database interface { @@ -13,4 +14,13 @@ type Database interface { GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error) RemovePusher(ctx context.Context, appId, pushkey, localpart string) error RemovePushers(ctx context.Context, appId, pushkey string) error + + InsertNotification(ctx context.Context, localpart, eventID string, tweaks map[string]interface{}, n *api.Notification) error + DeleteNotificationsUpTo(ctx context.Context, localpart, roomID, upToEventID string) (affected bool, _ error) + SetNotificationsRead(ctx context.Context, localpart, roomID, upToEventID string, b bool) (affected bool, _ error) + GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error) + GetNotificationCount(ctx context.Context, localpart string, filter NotificationFilter) (int64, error) + GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) } + +type NotificationFilter = tables.NotificationFilter diff --git a/pushserver/storage/shared/notification_table.go b/pushserver/storage/shared/notification_table.go new file mode 100644 index 000000000..b42650c5a --- /dev/null +++ b/pushserver/storage/shared/notification_table.go @@ -0,0 +1,237 @@ +// Copyright 2021 Dan Peleg +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shared + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/pushserver/api" + "github.com/matrix-org/dendrite/pushserver/storage/tables" + "github.com/matrix-org/gomatrixserverlib" + log "github.com/sirupsen/logrus" +) + +type notificationsStatements struct { + insertStmt *sql.Stmt + deleteUpToStmt *sql.Stmt + updateReadStmt *sql.Stmt + selectStmt *sql.Stmt + selectCountStmt *sql.Stmt + selectRoomCountsStmt *sql.Stmt +} + +func prepareNotificationsTable(db *sql.DB) (tables.Notifications, error) { + s := ¬ificationsStatements{} + + return s, sqlutil.StatementList{ + {&s.insertStmt, insertNotificationSQL}, + {&s.deleteUpToStmt, deleteNotificationsUpToSQL}, + {&s.updateReadStmt, updateNotificationReadSQL}, + {&s.selectStmt, selectNotificationSQL}, + {&s.selectCountStmt, selectNotificationCountSQL}, + {&s.selectRoomCountsStmt, selectRoomNotificationCountsSQL}, + }.Prepare(db) +} + +const insertNotificationSQL = "INSERT INTO pushserver_notifications (localpart, room_id, event_id, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6)" + +// Insert inserts a notification into the database. +func (s *notificationsStatements) Insert(ctx context.Context, localpart, eventID string, highlight bool, n *api.Notification) error { + roomID, tsMS := n.RoomID, n.TS + nn := *n + // Clears out fields that have their own columns to (1) shrink the + // data and (2) avoid difficult-to-debug inconsistency bugs. + nn.RoomID = "" + nn.TS, nn.Read = 0, false + bs, err := json.Marshal(nn) + if err != nil { + return err + } + _, err = s.insertStmt.ExecContext(ctx, localpart, roomID, eventID, tsMS, highlight, string(bs)) + return err +} + +const deleteNotificationsUpToSQL = `DELETE FROM pushserver_notifications +WHERE + localpart = $1 AND + room_id = $2 AND + id <= ( + SELECT MAX(id) + FROM pushserver_notifications + WHERE + localpart = $1 AND + room_id = $2 AND + event_id = $3 + )` + +// DeleteUpTo deletes all previous notifications, up to and including the event. +func (s *notificationsStatements) DeleteUpTo(ctx context.Context, localpart, roomID, eventID string) (affected bool, _ error) { + res, err := s.deleteUpToStmt.ExecContext(ctx, localpart, roomID, eventID) + if err != nil { + return false, err + } + nrows, err := res.RowsAffected() + if err != nil { + return true, err + } + log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "event_id": eventID}).Tracef("DeleteUpTo: %d rows affected", nrows) + return nrows > 0, nil +} + +const updateNotificationReadSQL = `UPDATE pushserver_notifications +SET read = $1 +WHERE + localpart = $2 AND + room_id = $3 AND + id <= ( + SELECT MAX(id) + FROM pushserver_notifications + WHERE + localpart = $2 AND + room_id = $3 AND + event_id = $4 + ) AND + read <> $1` + +// UpdateRead updates the "read" value for an event. +func (s *notificationsStatements) UpdateRead(ctx context.Context, localpart, roomID, eventID string, v bool) (affected bool, _ error) { + res, err := s.updateReadStmt.ExecContext(ctx, v, localpart, roomID, eventID) + if err != nil { + return false, err + } + nrows, err := res.RowsAffected() + if err != nil { + return true, err + } + log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "event_id": eventID}).Tracef("UpdateRead: %d rows affected", nrows) + return nrows > 0, nil +} + +const selectNotificationSQL = `SELECT id, room_id, ts_ms, read, notification_json +FROM pushserver_notifications +WHERE + localpart = $1 AND + id > $2 AND + ( + (($3 & 1) <> 0 AND highlight) OR + (($3 & 2) <> 0 AND NOT highlight) + ) AND + NOT read +ORDER BY localpart, id +LIMIT $4` + +func (s *notificationsStatements) Select(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { + rows, err := s.selectStmt.QueryContext(ctx, localpart, fromID, uint32(filter), limit) + + if err != nil { + return nil, 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") + + var maxID int64 = -1 + var notifs []*api.Notification + for rows.Next() { + var id int64 + var roomID string + var ts gomatrixserverlib.Timestamp + var read bool + var jsonStr string + err = rows.Scan( + &id, + &roomID, + &ts, + &read, + &jsonStr) + if err != nil { + return nil, 0, err + } + + var n api.Notification + err := json.Unmarshal([]byte(jsonStr), &n) + if err != nil { + return nil, 0, err + } + n.RoomID = roomID + n.TS = ts + n.Read = read + notifs = append(notifs, &n) + + if maxID < id { + maxID = id + } + } + return notifs, maxID, rows.Err() +} + +const selectNotificationCountSQL = `SELECT COUNT(*) +FROM pushserver_notifications +WHERE + localpart = $1 AND + ( + (($2 & 1) <> 0 AND highlight) OR + (($2 & 2) <> 0 AND NOT highlight) + ) AND + NOT read` + +func (s *notificationsStatements) SelectCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) { + rows, err := s.selectCountStmt.QueryContext(ctx, localpart, uint32(filter)) + + if err != nil { + return 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") + + for rows.Next() { + var count int64 + if err := rows.Scan(&count); err != nil { + return 0, err + } + + return count, nil + } + return 0, rows.Err() +} + +const selectRoomNotificationCountsSQL = `SELECT + COUNT(*), + COUNT(*) FILTER (WHERE highlight) +FROM pushserver_notifications +WHERE + localpart = $1 AND + room_id = $2 AND + NOT read` + +func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) { + rows, err := s.selectRoomCountsStmt.QueryContext(ctx, localpart, roomID) + + if err != nil { + return 0, 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") + + for rows.Next() { + var total, highlight int64 + if err := rows.Scan(&total, &highlight); err != nil { + return 0, 0, err + } + + return total, highlight, nil + } + return 0, 0, rows.Err() +} diff --git a/pushserver/storage/shared/storage.go b/pushserver/storage/shared/storage.go index 3d11ed641..5ffcb7f2a 100644 --- a/pushserver/storage/shared/storage.go +++ b/pushserver/storage/shared/storage.go @@ -5,22 +5,62 @@ import ( "database/sql" "encoding/json" + "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/pushserver/api" "github.com/matrix-org/dendrite/pushserver/storage/tables" ) type Database struct { - DB *sql.DB - Writer sqlutil.Writer - pushers tables.Pusher + DB *sql.DB + Writer sqlutil.Writer + notifications tables.Notifications + pushers tables.Pusher } func (d *Database) Prepare() (err error) { + d.notifications, err = prepareNotificationsTable(d.DB) + if err != nil { + return + } d.pushers, err = preparePushersTable(d.DB) return } +func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, tweaks map[string]interface{}, n *api.Notification) error { + return d.Writer.Do(nil, nil, func(_ *sql.Tx) error { + return d.notifications.Insert(ctx, localpart, eventID, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n) + }) +} + +func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID, upToEventID string) (affected bool, err error) { + err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error { + affected, err = d.notifications.DeleteUpTo(ctx, localpart, roomID, upToEventID) + return err + }) + return +} + +func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID, upToEventID string, b bool) (affected bool, err error) { + err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error { + affected, err = d.notifications.UpdateRead(ctx, localpart, roomID, upToEventID, b) + return err + }) + return +} + +func (d *Database) GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { + return d.notifications.Select(ctx, localpart, fromID, limit, filter) +} + +func (d *Database) GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) { + return d.notifications.SelectCount(ctx, localpart, filter) +} + +func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) { + return d.notifications.SelectRoomCounts(ctx, localpart, roomID) +} + func (d *Database) CreatePusher( ctx context.Context, p api.Pusher, localpart string, ) error { diff --git a/pushserver/storage/tables/interface.go b/pushserver/storage/tables/interface.go index 89656abbf..de222102e 100644 --- a/pushserver/storage/tables/interface.go +++ b/pushserver/storage/tables/interface.go @@ -23,3 +23,35 @@ type Pusher interface { ctx context.Context, appid, pushkey string, ) error } + +type Notifications interface { + Insert(ctx context.Context, localpart, eventID string, highlight bool, n *api.Notification) error + DeleteUpTo(ctx context.Context, localpart, roomID, eventID string) (affected bool, _ error) + UpdateRead(ctx context.Context, localpart, roomID, eventID string, v bool) (affected bool, _ error) + Select(ctx context.Context, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error) + SelectCount(ctx context.Context, localpart string, filter NotificationFilter) (int64, error) + SelectRoomCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) +} + +type NotificationFilter uint32 + +const ( + // HighlightNotifications returns notifications that had a + // "highlight" tweak assigned to them from evaluating push rules. + HighlightNotifications NotificationFilter = 1 << iota + + // NonHighlightNotifications returns notifications that don't + // match HighlightNotifications. + NonHighlightNotifications + + // NoNotifications is a filter to exclude all types of + // notifications. It's useful as a zero value, but isn't likely to + // be used in a call to Notifications.Select*. + NoNotifications NotificationFilter = 0 + + // AllNotifications is a filter to include all types of + // notifications in Notifications.Select*. Note that PostgreSQL + // balks if this doesn't fit in INTEGER, even though we use + // uint32. + AllNotifications NotificationFilter = (1 << 31) - 1 +)