Implement pushserver/storage for notifications.

This commit is contained in:
Tommie Gannert 2021-10-27 17:52:12 +02:00
parent 54ece78a12
commit 926252671b
5 changed files with 331 additions and 3 deletions

View file

@ -68,3 +68,12 @@ type QueryPushRulesRequest struct {
type QueryPushRulesResponse struct { type QueryPushRulesResponse struct {
RuleSets *pushrules.AccountRuleSets `json:"rule_sets"` RuleSets *pushrules.AccountRuleSets `json:"rule_sets"`
} }
type Notification struct {
Actions []*pushrules.Action `json:"actions"` // Required.
Event gomatrixserverlib.ClientEvent `json:"event"` // Required.
ProfileTag string `json:"profile_tag"` // Required by Sytest, but actually optional.
Read bool `json:"read"` // Required.
RoomID string `json:"room_id"` // Required.
TS gomatrixserverlib.Timestamp `json:"ts"` // Required.
}

View file

@ -5,6 +5,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/pushserver/api" "github.com/matrix-org/dendrite/pushserver/api"
"github.com/matrix-org/dendrite/pushserver/storage/tables"
) )
type Database interface { type Database interface {
@ -13,4 +14,13 @@ type Database interface {
GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error) GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error)
RemovePusher(ctx context.Context, appId, pushkey, localpart string) error RemovePusher(ctx context.Context, appId, pushkey, localpart string) error
RemovePushers(ctx context.Context, appId, pushkey string) error RemovePushers(ctx context.Context, appId, pushkey string) error
InsertNotification(ctx context.Context, localpart, eventID string, tweaks map[string]interface{}, n *api.Notification) error
DeleteNotificationsUpTo(ctx context.Context, localpart, roomID, upToEventID string) (affected bool, _ error)
SetNotificationsRead(ctx context.Context, localpart, roomID, upToEventID string, b bool) (affected bool, _ error)
GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error)
GetNotificationCount(ctx context.Context, localpart string, filter NotificationFilter) (int64, error)
GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error)
} }
type NotificationFilter = tables.NotificationFilter

View file

@ -0,0 +1,237 @@
// Copyright 2021 Dan Peleg <dan@globekeeper.com>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package shared
import (
"context"
"database/sql"
"encoding/json"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/pushserver/api"
"github.com/matrix-org/dendrite/pushserver/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
type notificationsStatements struct {
insertStmt *sql.Stmt
deleteUpToStmt *sql.Stmt
updateReadStmt *sql.Stmt
selectStmt *sql.Stmt
selectCountStmt *sql.Stmt
selectRoomCountsStmt *sql.Stmt
}
func prepareNotificationsTable(db *sql.DB) (tables.Notifications, error) {
s := &notificationsStatements{}
return s, sqlutil.StatementList{
{&s.insertStmt, insertNotificationSQL},
{&s.deleteUpToStmt, deleteNotificationsUpToSQL},
{&s.updateReadStmt, updateNotificationReadSQL},
{&s.selectStmt, selectNotificationSQL},
{&s.selectCountStmt, selectNotificationCountSQL},
{&s.selectRoomCountsStmt, selectRoomNotificationCountsSQL},
}.Prepare(db)
}
const insertNotificationSQL = "INSERT INTO pushserver_notifications (localpart, room_id, event_id, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6)"
// Insert inserts a notification into the database.
func (s *notificationsStatements) Insert(ctx context.Context, localpart, eventID string, highlight bool, n *api.Notification) error {
roomID, tsMS := n.RoomID, n.TS
nn := *n
// Clears out fields that have their own columns to (1) shrink the
// data and (2) avoid difficult-to-debug inconsistency bugs.
nn.RoomID = ""
nn.TS, nn.Read = 0, false
bs, err := json.Marshal(nn)
if err != nil {
return err
}
_, err = s.insertStmt.ExecContext(ctx, localpart, roomID, eventID, tsMS, highlight, string(bs))
return err
}
const deleteNotificationsUpToSQL = `DELETE FROM pushserver_notifications
WHERE
localpart = $1 AND
room_id = $2 AND
id <= (
SELECT MAX(id)
FROM pushserver_notifications
WHERE
localpart = $1 AND
room_id = $2 AND
event_id = $3
)`
// DeleteUpTo deletes all previous notifications, up to and including the event.
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, localpart, roomID, eventID string) (affected bool, _ error) {
res, err := s.deleteUpToStmt.ExecContext(ctx, localpart, roomID, eventID)
if err != nil {
return false, err
}
nrows, err := res.RowsAffected()
if err != nil {
return true, err
}
log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "event_id": eventID}).Tracef("DeleteUpTo: %d rows affected", nrows)
return nrows > 0, nil
}
const updateNotificationReadSQL = `UPDATE pushserver_notifications
SET read = $1
WHERE
localpart = $2 AND
room_id = $3 AND
id <= (
SELECT MAX(id)
FROM pushserver_notifications
WHERE
localpart = $2 AND
room_id = $3 AND
event_id = $4
) AND
read <> $1`
// UpdateRead updates the "read" value for an event.
func (s *notificationsStatements) UpdateRead(ctx context.Context, localpart, roomID, eventID string, v bool) (affected bool, _ error) {
res, err := s.updateReadStmt.ExecContext(ctx, v, localpart, roomID, eventID)
if err != nil {
return false, err
}
nrows, err := res.RowsAffected()
if err != nil {
return true, err
}
log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "event_id": eventID}).Tracef("UpdateRead: %d rows affected", nrows)
return nrows > 0, nil
}
const selectNotificationSQL = `SELECT id, room_id, ts_ms, read, notification_json
FROM pushserver_notifications
WHERE
localpart = $1 AND
id > $2 AND
(
(($3 & 1) <> 0 AND highlight) OR
(($3 & 2) <> 0 AND NOT highlight)
) AND
NOT read
ORDER BY localpart, id
LIMIT $4`
func (s *notificationsStatements) Select(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
rows, err := s.selectStmt.QueryContext(ctx, localpart, fromID, uint32(filter), limit)
if err != nil {
return nil, 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
var maxID int64 = -1
var notifs []*api.Notification
for rows.Next() {
var id int64
var roomID string
var ts gomatrixserverlib.Timestamp
var read bool
var jsonStr string
err = rows.Scan(
&id,
&roomID,
&ts,
&read,
&jsonStr)
if err != nil {
return nil, 0, err
}
var n api.Notification
err := json.Unmarshal([]byte(jsonStr), &n)
if err != nil {
return nil, 0, err
}
n.RoomID = roomID
n.TS = ts
n.Read = read
notifs = append(notifs, &n)
if maxID < id {
maxID = id
}
}
return notifs, maxID, rows.Err()
}
const selectNotificationCountSQL = `SELECT COUNT(*)
FROM pushserver_notifications
WHERE
localpart = $1 AND
(
(($2 & 1) <> 0 AND highlight) OR
(($2 & 2) <> 0 AND NOT highlight)
) AND
NOT read`
func (s *notificationsStatements) SelectCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) {
rows, err := s.selectCountStmt.QueryContext(ctx, localpart, uint32(filter))
if err != nil {
return 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
for rows.Next() {
var count int64
if err := rows.Scan(&count); err != nil {
return 0, err
}
return count, nil
}
return 0, rows.Err()
}
const selectRoomNotificationCountsSQL = `SELECT
COUNT(*),
COUNT(*) FILTER (WHERE highlight)
FROM pushserver_notifications
WHERE
localpart = $1 AND
room_id = $2 AND
NOT read`
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) {
rows, err := s.selectRoomCountsStmt.QueryContext(ctx, localpart, roomID)
if err != nil {
return 0, 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
for rows.Next() {
var total, highlight int64
if err := rows.Scan(&total, &highlight); err != nil {
return 0, 0, err
}
return total, highlight, nil
}
return 0, 0, rows.Err()
}

View file

@ -5,22 +5,62 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/pushserver/api" "github.com/matrix-org/dendrite/pushserver/api"
"github.com/matrix-org/dendrite/pushserver/storage/tables" "github.com/matrix-org/dendrite/pushserver/storage/tables"
) )
type Database struct { type Database struct {
DB *sql.DB DB *sql.DB
Writer sqlutil.Writer Writer sqlutil.Writer
pushers tables.Pusher notifications tables.Notifications
pushers tables.Pusher
} }
func (d *Database) Prepare() (err error) { func (d *Database) Prepare() (err error) {
d.notifications, err = prepareNotificationsTable(d.DB)
if err != nil {
return
}
d.pushers, err = preparePushersTable(d.DB) d.pushers, err = preparePushersTable(d.DB)
return return
} }
func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, tweaks map[string]interface{}, n *api.Notification) error {
return d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
return d.notifications.Insert(ctx, localpart, eventID, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n)
})
}
func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID, upToEventID string) (affected bool, err error) {
err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
affected, err = d.notifications.DeleteUpTo(ctx, localpart, roomID, upToEventID)
return err
})
return
}
func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID, upToEventID string, b bool) (affected bool, err error) {
err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
affected, err = d.notifications.UpdateRead(ctx, localpart, roomID, upToEventID, b)
return err
})
return
}
func (d *Database) GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
return d.notifications.Select(ctx, localpart, fromID, limit, filter)
}
func (d *Database) GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) {
return d.notifications.SelectCount(ctx, localpart, filter)
}
func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) {
return d.notifications.SelectRoomCounts(ctx, localpart, roomID)
}
func (d *Database) CreatePusher( func (d *Database) CreatePusher(
ctx context.Context, p api.Pusher, localpart string, ctx context.Context, p api.Pusher, localpart string,
) error { ) error {

View file

@ -23,3 +23,35 @@ type Pusher interface {
ctx context.Context, appid, pushkey string, ctx context.Context, appid, pushkey string,
) error ) error
} }
type Notifications interface {
Insert(ctx context.Context, localpart, eventID string, highlight bool, n *api.Notification) error
DeleteUpTo(ctx context.Context, localpart, roomID, eventID string) (affected bool, _ error)
UpdateRead(ctx context.Context, localpart, roomID, eventID string, v bool) (affected bool, _ error)
Select(ctx context.Context, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error)
SelectCount(ctx context.Context, localpart string, filter NotificationFilter) (int64, error)
SelectRoomCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error)
}
type NotificationFilter uint32
const (
// HighlightNotifications returns notifications that had a
// "highlight" tweak assigned to them from evaluating push rules.
HighlightNotifications NotificationFilter = 1 << iota
// NonHighlightNotifications returns notifications that don't
// match HighlightNotifications.
NonHighlightNotifications
// NoNotifications is a filter to exclude all types of
// notifications. It's useful as a zero value, but isn't likely to
// be used in a call to Notifications.Select*.
NoNotifications NotificationFilter = 0
// AllNotifications is a filter to include all types of
// notifications in Notifications.Select*. Note that PostgreSQL
// balks if this doesn't fit in INTEGER, even though we use
// uint32.
AllNotifications NotificationFilter = (1 << 31) - 1
)