Notifications
This commit is contained in:
parent
6ea580d9c9
commit
4998e0757d
|
@ -40,16 +40,17 @@ func GetNotifications(
|
||||||
}
|
}
|
||||||
|
|
||||||
var queryRes userapi.QueryNotificationsResponse
|
var queryRes userapi.QueryNotificationsResponse
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
|
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
err = userAPI.QueryNotifications(req.Context(), &userapi.QueryNotificationsRequest{
|
err = userAPI.QueryNotifications(req.Context(), &userapi.QueryNotificationsRequest{
|
||||||
Localpart: localpart,
|
Localpart: localpart,
|
||||||
From: req.URL.Query().Get("from"),
|
ServerName: domain,
|
||||||
Limit: int(limit),
|
From: req.URL.Query().Get("from"),
|
||||||
Only: req.URL.Query().Get("only"),
|
Limit: int(limit),
|
||||||
|
Only: req.URL.Query().Get("only"),
|
||||||
}, &queryRes)
|
}, &queryRes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("QueryNotifications failed")
|
util.GetLogger(req.Context()).WithError(err).Error("QueryNotifications failed")
|
||||||
|
|
|
@ -575,10 +575,11 @@ type QueryPushRulesResponse struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueryNotificationsRequest struct {
|
type QueryNotificationsRequest struct {
|
||||||
Localpart string `json:"localpart"` // Required.
|
Localpart string `json:"localpart"` // Required.
|
||||||
From string `json:"from,omitempty"`
|
ServerName gomatrixserverlib.ServerName `json:"server_name"` // Required.
|
||||||
Limit int `json:"limit,omitempty"`
|
From string `json:"from,omitempty"`
|
||||||
Only string `json:"only,omitempty"`
|
Limit int `json:"limit,omitempty"`
|
||||||
|
Only string `json:"only,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueryNotificationsResponse struct {
|
type QueryNotificationsResponse struct {
|
||||||
|
|
|
@ -104,7 +104,7 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp)), true)
|
updated, err := s.db.SetNotificationsRead(ctx, localpart, domain, roomID, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp)), true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Error("userapi EDU consumer")
|
log.WithError(err).Error("userapi EDU consumer")
|
||||||
return false
|
return false
|
||||||
|
|
|
@ -527,7 +527,7 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr
|
||||||
RoomID: event.RoomID(),
|
RoomID: event.RoomID(),
|
||||||
TS: gomatrixserverlib.AsTimestamp(time.Now()),
|
TS: gomatrixserverlib.AsTimestamp(time.Now()),
|
||||||
}
|
}
|
||||||
if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), streamPos, tweaks, n); err != nil {
|
if err = s.db.InsertNotification(ctx, mem.Localpart, mem.Domain, event.EventID(), streamPos, tweaks, n); err != nil {
|
||||||
return fmt.Errorf("s.db.InsertNotification: %w", err)
|
return fmt.Errorf("s.db.InsertNotification: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -536,7 +536,7 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr
|
||||||
}
|
}
|
||||||
|
|
||||||
// We do this after InsertNotification. Thus, this should always return >=1.
|
// We do this after InsertNotification. Thus, this should always return >=1.
|
||||||
userNumUnreadNotifs, err := s.db.GetNotificationCount(ctx, mem.Localpart, tables.AllNotifications)
|
userNumUnreadNotifs, err := s.db.GetNotificationCount(ctx, mem.Localpart, mem.Domain, tables.AllNotifications)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("s.db.GetNotificationCount: %w", err)
|
return fmt.Errorf("s.db.GetNotificationCount: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -108,7 +108,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now())))
|
deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, domain, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now())))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Errorf("UserInternalAPI.setFullyRead: DeleteNotificationsUpTo failed")
|
logrus.WithError(err).Errorf("UserInternalAPI.setFullyRead: DeleteNotificationsUpTo failed")
|
||||||
return err
|
return err
|
||||||
|
@ -789,7 +789,7 @@ func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.Query
|
||||||
if req.Only == "highlight" {
|
if req.Only == "highlight" {
|
||||||
filter = tables.HighlightNotifications
|
filter = tables.HighlightNotifications
|
||||||
}
|
}
|
||||||
notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, fromID, req.Limit, filter)
|
notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, req.ServerName, fromID, req.Limit, filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -61,12 +61,12 @@ func (p *SyncAPI) SendAccountData(userID string, data eventutil.AccountData) err
|
||||||
// GetAndSendNotificationData reads the database and sends data about unread
|
// GetAndSendNotificationData reads the database and sends data about unread
|
||||||
// notifications to the Sync API server.
|
// notifications to the Sync API server.
|
||||||
func (p *SyncAPI) GetAndSendNotificationData(ctx context.Context, userID, roomID string) error {
|
func (p *SyncAPI) GetAndSendNotificationData(ctx context.Context, userID, roomID string) error {
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
|
localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
ntotal, nhighlight, err := p.db.GetRoomNotificationCounts(ctx, localpart, roomID)
|
ntotal, nhighlight, err := p.db.GetRoomNotificationCounts(ctx, localpart, domain, roomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -121,12 +121,12 @@ type ThreePID interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Notification interface {
|
type Notification interface {
|
||||||
InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error
|
InsertNotification(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error
|
||||||
DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error)
|
DeleteNotificationsUpTo(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, err error)
|
||||||
SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, read bool) (affected bool, err error)
|
SetNotificationsRead(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, read bool) (affected bool, err error)
|
||||||
GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error)
|
GetNotifications(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error)
|
||||||
GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error)
|
GetNotificationCount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (int64, error)
|
||||||
GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error)
|
GetRoomNotificationCounts(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error)
|
||||||
DeleteOldNotifications(ctx context.Context) error
|
DeleteOldNotifications(ctx context.Context) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -53,33 +53,33 @@ CREATE TABLE IF NOT EXISTS userapi_notifications (
|
||||||
read BOOLEAN NOT NULL DEFAULT FALSE
|
read BOOLEAN NOT NULL DEFAULT FALSE
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id);
|
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, server_name, room_id, event_id);
|
||||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id);
|
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, server_name, room_id, id);
|
||||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id);
|
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, server_name, id);
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertNotificationSQL = "" +
|
const insertNotificationSQL = "" +
|
||||||
"INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)"
|
"INSERT INTO userapi_notifications (localpart, server_name, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
|
||||||
|
|
||||||
const deleteNotificationsUpToSQL = "" +
|
const deleteNotificationsUpToSQL = "" +
|
||||||
"DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3"
|
"DELETE FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND stream_pos <= $4"
|
||||||
|
|
||||||
const updateNotificationReadSQL = "" +
|
const updateNotificationReadSQL = "" +
|
||||||
"UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1"
|
"UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND server_name = $3 AND room_id = $4 AND stream_pos <= $5 AND read <> $1"
|
||||||
|
|
||||||
const selectNotificationSQL = "" +
|
const selectNotificationSQL = "" +
|
||||||
"SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" +
|
"SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND id > $3 AND (" +
|
||||||
"(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
|
"(($4 & 1) <> 0 AND highlight) OR (($4 & 2) <> 0 AND NOT highlight)" +
|
||||||
") AND NOT read ORDER BY localpart, id LIMIT $4"
|
") AND NOT read ORDER BY localpart, id LIMIT $5"
|
||||||
|
|
||||||
const selectNotificationCountSQL = "" +
|
const selectNotificationCountSQL = "" +
|
||||||
"SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" +
|
"SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND (" +
|
||||||
"(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" +
|
"(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
|
||||||
") AND NOT read"
|
") AND NOT read"
|
||||||
|
|
||||||
const selectRoomNotificationCountsSQL = "" +
|
const selectRoomNotificationCountsSQL = "" +
|
||||||
"SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " +
|
"SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " +
|
||||||
"WHERE localpart = $1 AND room_id = $2 AND NOT read"
|
"WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND NOT read"
|
||||||
|
|
||||||
const cleanNotificationsSQL = "" +
|
const cleanNotificationsSQL = "" +
|
||||||
"DELETE FROM userapi_notifications WHERE" +
|
"DELETE FROM userapi_notifications WHERE" +
|
||||||
|
@ -112,7 +112,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert inserts a notification into the database.
|
// Insert inserts a notification into the database.
|
||||||
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error {
|
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error {
|
||||||
roomID, tsMS := n.RoomID, n.TS
|
roomID, tsMS := n.RoomID, n.TS
|
||||||
nn := *n
|
nn := *n
|
||||||
// Clears out fields that have their own columns to (1) shrink the
|
// Clears out fields that have their own columns to (1) shrink the
|
||||||
|
@ -123,13 +123,13 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs))
|
_, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, serverName, roomID, eventID, pos, tsMS, highlight, string(bs))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteUpTo deletes all previous notifications, up to and including the event.
|
// DeleteUpTo deletes all previous notifications, up to and including the event.
|
||||||
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) {
|
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error) {
|
||||||
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
|
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, serverName, roomID, pos)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
@ -142,8 +142,8 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateRead updates the "read" value for an event.
|
// UpdateRead updates the "read" value for an event.
|
||||||
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) {
|
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) {
|
||||||
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
|
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, serverName, roomID, pos)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
@ -155,8 +155,8 @@ func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, l
|
||||||
return nrows > 0, nil
|
return nrows > 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
|
func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
|
||||||
rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit)
|
rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, serverName, fromID, uint32(filter), limit)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
|
@ -198,12 +198,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local
|
||||||
return notifs, maxID, rows.Err()
|
return notifs, maxID, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) {
|
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (count int64, err error) {
|
||||||
err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count)
|
err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, serverName, uint32(filter)).Scan(&count)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) {
|
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, err error) {
|
||||||
err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight)
|
err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, serverName, roomID).Scan(&total, &highlight)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -727,38 +727,38 @@ func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (
|
||||||
return d.LoginTokens.SelectLoginToken(ctx, token)
|
return d.LoginTokens.SelectLoginToken(ctx, token)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error {
|
func (d *Database) InsertNotification(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error {
|
||||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
return d.Notifications.Insert(ctx, txn, localpart, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n)
|
return d.Notifications.Insert(ctx, txn, localpart, serverName, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error) {
|
func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, err error) {
|
||||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, pos)
|
affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, serverName, roomID, pos)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, b bool) (affected bool, err error) {
|
func (d *Database) SetNotificationsRead(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, b bool) (affected bool, err error) {
|
||||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, pos, b)
|
affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, serverName, roomID, pos, b)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
|
func (d *Database) GetNotifications(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
|
||||||
return d.Notifications.Select(ctx, nil, localpart, fromID, limit, filter)
|
return d.Notifications.Select(ctx, nil, localpart, serverName, fromID, limit, filter)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) {
|
func (d *Database) GetNotificationCount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (int64, error) {
|
||||||
return d.Notifications.SelectCount(ctx, nil, localpart, filter)
|
return d.Notifications.SelectCount(ctx, nil, localpart, serverName, filter)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) {
|
func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error) {
|
||||||
return d.Notifications.SelectRoomCounts(ctx, nil, localpart, roomID)
|
return d.Notifications.SelectRoomCounts(ctx, nil, localpart, serverName, roomID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) DeleteOldNotifications(ctx context.Context) error {
|
func (d *Database) DeleteOldNotifications(ctx context.Context) error {
|
||||||
|
|
|
@ -53,33 +53,33 @@ CREATE TABLE IF NOT EXISTS userapi_notifications (
|
||||||
read BOOLEAN NOT NULL DEFAULT FALSE
|
read BOOLEAN NOT NULL DEFAULT FALSE
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id);
|
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, server_name, room_id, event_id);
|
||||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id);
|
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, server_name, room_id, id);
|
||||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id);
|
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, server_name, id);
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertNotificationSQL = "" +
|
const insertNotificationSQL = "" +
|
||||||
"INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)"
|
"INSERT INTO userapi_notifications (localpart, server_name, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
|
||||||
|
|
||||||
const deleteNotificationsUpToSQL = "" +
|
const deleteNotificationsUpToSQL = "" +
|
||||||
"DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3"
|
"DELETE FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND stream_pos <= $4"
|
||||||
|
|
||||||
const updateNotificationReadSQL = "" +
|
const updateNotificationReadSQL = "" +
|
||||||
"UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1"
|
"UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND server_name = $3 AND room_id = $4 AND stream_pos <= $5 AND read <> $1"
|
||||||
|
|
||||||
const selectNotificationSQL = "" +
|
const selectNotificationSQL = "" +
|
||||||
"SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" +
|
"SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND id > $3 AND (" +
|
||||||
"(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
|
"(($4 & 1) <> 0 AND highlight) OR (($4 & 2) <> 0 AND NOT highlight)" +
|
||||||
") AND NOT read ORDER BY localpart, id LIMIT $4"
|
") AND NOT read ORDER BY localpart, id LIMIT $5"
|
||||||
|
|
||||||
const selectNotificationCountSQL = "" +
|
const selectNotificationCountSQL = "" +
|
||||||
"SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" +
|
"SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND (" +
|
||||||
"(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" +
|
"(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
|
||||||
") AND NOT read"
|
") AND NOT read"
|
||||||
|
|
||||||
const selectRoomNotificationCountsSQL = "" +
|
const selectRoomNotificationCountsSQL = "" +
|
||||||
"SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " +
|
"SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " +
|
||||||
"WHERE localpart = $1 AND room_id = $2 AND NOT read"
|
"WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND NOT read"
|
||||||
|
|
||||||
const cleanNotificationsSQL = "" +
|
const cleanNotificationsSQL = "" +
|
||||||
"DELETE FROM userapi_notifications WHERE" +
|
"DELETE FROM userapi_notifications WHERE" +
|
||||||
|
@ -112,7 +112,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert inserts a notification into the database.
|
// Insert inserts a notification into the database.
|
||||||
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error {
|
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error {
|
||||||
roomID, tsMS := n.RoomID, n.TS
|
roomID, tsMS := n.RoomID, n.TS
|
||||||
nn := *n
|
nn := *n
|
||||||
// Clears out fields that have their own columns to (1) shrink the
|
// Clears out fields that have their own columns to (1) shrink the
|
||||||
|
@ -123,13 +123,13 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs))
|
_, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, serverName, roomID, eventID, pos, tsMS, highlight, string(bs))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteUpTo deletes all previous notifications, up to and including the event.
|
// DeleteUpTo deletes all previous notifications, up to and including the event.
|
||||||
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) {
|
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error) {
|
||||||
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
|
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, serverName, roomID, pos)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
@ -142,8 +142,8 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateRead updates the "read" value for an event.
|
// UpdateRead updates the "read" value for an event.
|
||||||
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) {
|
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) {
|
||||||
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
|
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, serverName, roomID, pos)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
@ -155,8 +155,8 @@ func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, l
|
||||||
return nrows > 0, nil
|
return nrows > 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
|
func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
|
||||||
rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit)
|
rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, serverName, fromID, uint32(filter), limit)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
|
@ -198,12 +198,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local
|
||||||
return notifs, maxID, rows.Err()
|
return notifs, maxID, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) {
|
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (count int64, err error) {
|
||||||
err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count)
|
err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, serverName, uint32(filter)).Scan(&count)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) {
|
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, err error) {
|
||||||
err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight)
|
err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, serverName, roomID).Scan(&total, &highlight)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -498,7 +498,7 @@ func Test_ThreePID(t *testing.T) {
|
||||||
|
|
||||||
func Test_Notification(t *testing.T) {
|
func Test_Notification(t *testing.T) {
|
||||||
alice := test.NewUser(t)
|
alice := test.NewUser(t)
|
||||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
room := test.NewRoom(t, alice)
|
room := test.NewRoom(t, alice)
|
||||||
room2 := test.NewRoom(t, alice)
|
room2 := test.NewRoom(t, alice)
|
||||||
|
@ -526,34 +526,34 @@ func Test_Notification(t *testing.T) {
|
||||||
RoomID: roomID,
|
RoomID: roomID,
|
||||||
TS: gomatrixserverlib.AsTimestamp(ts),
|
TS: gomatrixserverlib.AsTimestamp(ts),
|
||||||
}
|
}
|
||||||
err = db.InsertNotification(ctx, aliceLocalpart, eventID, uint64(i+1), nil, notification)
|
err = db.InsertNotification(ctx, aliceLocalpart, aliceDomain, eventID, uint64(i+1), nil, notification)
|
||||||
assert.NoError(t, err, "unable to insert notification")
|
assert.NoError(t, err, "unable to insert notification")
|
||||||
}
|
}
|
||||||
|
|
||||||
// get notifications
|
// get notifications
|
||||||
count, err := db.GetNotificationCount(ctx, aliceLocalpart, tables.AllNotifications)
|
count, err := db.GetNotificationCount(ctx, aliceLocalpart, aliceDomain, tables.AllNotifications)
|
||||||
assert.NoError(t, err, "unable to get notification count")
|
assert.NoError(t, err, "unable to get notification count")
|
||||||
assert.Equal(t, int64(10), count)
|
assert.Equal(t, int64(10), count)
|
||||||
notifs, count, err := db.GetNotifications(ctx, aliceLocalpart, 0, 15, tables.AllNotifications)
|
notifs, count, err := db.GetNotifications(ctx, aliceLocalpart, aliceDomain, 0, 15, tables.AllNotifications)
|
||||||
assert.NoError(t, err, "unable to get notifications")
|
assert.NoError(t, err, "unable to get notifications")
|
||||||
assert.Equal(t, int64(10), count)
|
assert.Equal(t, int64(10), count)
|
||||||
assert.Equal(t, 10, len(notifs))
|
assert.Equal(t, 10, len(notifs))
|
||||||
// ... for a specific room
|
// ... for a specific room
|
||||||
total, _, err := db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID)
|
total, _, err := db.GetRoomNotificationCounts(ctx, aliceLocalpart, aliceDomain, room2.ID)
|
||||||
assert.NoError(t, err, "unable to get notifications for room")
|
assert.NoError(t, err, "unable to get notifications for room")
|
||||||
assert.Equal(t, int64(4), total)
|
assert.Equal(t, int64(4), total)
|
||||||
|
|
||||||
// mark notification as read
|
// mark notification as read
|
||||||
affected, err := db.SetNotificationsRead(ctx, aliceLocalpart, room2.ID, 7, true)
|
affected, err := db.SetNotificationsRead(ctx, aliceLocalpart, aliceDomain, room2.ID, 7, true)
|
||||||
assert.NoError(t, err, "unable to set notifications read")
|
assert.NoError(t, err, "unable to set notifications read")
|
||||||
assert.True(t, affected)
|
assert.True(t, affected)
|
||||||
|
|
||||||
// this should delete 2 notifications
|
// this should delete 2 notifications
|
||||||
affected, err = db.DeleteNotificationsUpTo(ctx, aliceLocalpart, room2.ID, 8)
|
affected, err = db.DeleteNotificationsUpTo(ctx, aliceLocalpart, aliceDomain, room2.ID, 8)
|
||||||
assert.NoError(t, err, "unable to set notifications read")
|
assert.NoError(t, err, "unable to set notifications read")
|
||||||
assert.True(t, affected)
|
assert.True(t, affected)
|
||||||
|
|
||||||
total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID)
|
total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, aliceDomain, room2.ID)
|
||||||
assert.NoError(t, err, "unable to get notifications for room")
|
assert.NoError(t, err, "unable to get notifications for room")
|
||||||
assert.Equal(t, int64(2), total)
|
assert.Equal(t, int64(2), total)
|
||||||
|
|
||||||
|
@ -562,7 +562,7 @@ func Test_Notification(t *testing.T) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// this should now return 0 notifications
|
// this should now return 0 notifications
|
||||||
total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID)
|
total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, aliceDomain, room2.ID)
|
||||||
assert.NoError(t, err, "unable to get notifications for room")
|
assert.NoError(t, err, "unable to get notifications for room")
|
||||||
assert.Equal(t, int64(0), total)
|
assert.Equal(t, int64(0), total)
|
||||||
})
|
})
|
||||||
|
|
|
@ -107,12 +107,12 @@ type PusherTable interface {
|
||||||
|
|
||||||
type NotificationTable interface {
|
type NotificationTable interface {
|
||||||
Clean(ctx context.Context, txn *sql.Tx) error
|
Clean(ctx context.Context, txn *sql.Tx) error
|
||||||
Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error
|
Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error
|
||||||
DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error)
|
DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error)
|
||||||
UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error)
|
UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error)
|
||||||
Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error)
|
Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error)
|
||||||
SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter NotificationFilter) (int64, error)
|
SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter NotificationFilter) (int64, error)
|
||||||
SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error)
|
SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type StatsTable interface {
|
type StatsTable interface {
|
||||||
|
|
|
@ -27,7 +27,7 @@ func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, loc
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
userNumUnreadNotifs, err := db.GetNotificationCount(ctx, localpart, tables.AllNotifications)
|
userNumUnreadNotifs, err := db.GetNotificationCount(ctx, localpart, serverName, tables.AllNotifications)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue