mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-11 08:03:09 -06:00
Always include notification data for rooms in our reponse
This is to avoid issues where we return notification data in one sync, but not the other. Which could result in clients simply displaying a "dot" as the count. (Observerd in Element Android, where the count would disappear on subsequent syncs where no notification data would be present, e.g. we only returned ephemeral data). It's also something Synapse is doing.
This commit is contained in:
parent
db07e9b365
commit
472887fc32
|
|
@ -29,6 +29,7 @@ import (
|
|||
type Database interface {
|
||||
Presence
|
||||
SharedUsers
|
||||
Notifications
|
||||
|
||||
MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error)
|
||||
MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error)
|
||||
|
|
@ -148,12 +149,6 @@ type Database interface {
|
|||
// GetRoomReceipts gets all receipts for a given roomID
|
||||
GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error)
|
||||
|
||||
// UpsertRoomUnreadNotificationCounts updates the notification statistics about a (user, room) key.
|
||||
UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
|
||||
|
||||
// GetUserUnreadNotificationCounts returns statistics per room a user is interested in.
|
||||
GetUserUnreadNotificationCounts(ctx context.Context, userID string, from, to types.StreamPosition) (map[string]*eventutil.NotificationData, error)
|
||||
|
||||
SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error)
|
||||
SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error)
|
||||
SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error)
|
||||
|
|
@ -179,3 +174,11 @@ type SharedUsers interface {
|
|||
// SharedUsers returns a subset of otherUserIDs that share a room with userID.
|
||||
SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error)
|
||||
}
|
||||
|
||||
type Notifications interface {
|
||||
// UpsertRoomUnreadNotificationCounts updates the notification statistics about a (user, room) key.
|
||||
UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
|
||||
|
||||
// getUserUnreadNotificationCountsForRooms returns the unread notifications for the given rooms
|
||||
GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, roomIDs map[string]string) (map[string]*eventutil.NotificationData, error)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ import (
|
|||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/lib/pq"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
|
|
@ -33,15 +35,15 @@ func NewPostgresNotificationDataTable(db *sql.DB) (tables.NotificationData, erro
|
|||
r := ¬ificationDataStatements{}
|
||||
return r, sqlutil.StatementList{
|
||||
{&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL},
|
||||
{&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL},
|
||||
{&r.selectUserUnreadCountsForRooms, selectUserUnreadNotificationsForRooms},
|
||||
{&r.selectMaxID, selectMaxNotificationIDSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
type notificationDataStatements struct {
|
||||
upsertRoomUnreadCounts *sql.Stmt
|
||||
selectUserUnreadCounts *sql.Stmt
|
||||
selectMaxID *sql.Stmt
|
||||
upsertRoomUnreadCounts *sql.Stmt
|
||||
selectUserUnreadCountsForRooms *sql.Stmt
|
||||
selectMaxID *sql.Stmt
|
||||
}
|
||||
|
||||
const notificationDataSchema = `
|
||||
|
|
@ -61,12 +63,10 @@ const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_
|
|||
DO UPDATE SET id = nextval('syncapi_notification_data_id_seq'), notification_count = $3, highlight_count = $4
|
||||
RETURNING id`
|
||||
|
||||
const selectUserUnreadNotificationCountsSQL = `SELECT
|
||||
id, room_id, notification_count, highlight_count
|
||||
FROM syncapi_notification_data
|
||||
WHERE
|
||||
user_id = $1 AND
|
||||
id BETWEEN $2 + 1 AND $3`
|
||||
const selectUserUnreadNotificationsForRooms = `SELECT room_id, notification_count, highlight_count
|
||||
FROM syncapi_notification_data
|
||||
WHERE user_id = $1 AND
|
||||
room_id = ANY($2)`
|
||||
|
||||
const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data`
|
||||
|
||||
|
|
@ -75,20 +75,20 @@ func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context,
|
|||
return
|
||||
}
|
||||
|
||||
func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, txn *sql.Tx, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) {
|
||||
rows, err := sqlutil.TxStmt(txn, r.selectUserUnreadCounts).QueryContext(ctx, userID, fromExcl, toIncl)
|
||||
func (r *notificationDataStatements) SelectUnserUnreadCountsForRooms(
|
||||
ctx context.Context, txn *sql.Tx, userID string, roomIDs []string,
|
||||
) (map[string]*eventutil.NotificationData, error) {
|
||||
rows, err := sqlutil.TxStmt(txn, r.selectUserUnreadCountsForRooms).QueryContext(ctx, userID, pq.Array(roomIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCounts: rows.close() failed")
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "SelectUnserUnreadCountsForRooms: rows.close() failed")
|
||||
|
||||
roomCounts := map[string]*eventutil.NotificationData{}
|
||||
var roomID string
|
||||
var notificationCount, highlightCount int
|
||||
for rows.Next() {
|
||||
var id types.StreamPosition
|
||||
var roomID string
|
||||
var notificationCount, highlightCount int
|
||||
|
||||
if err = rows.Scan(&id, &roomID, ¬ificationCount, &highlightCount); err != nil {
|
||||
if err = rows.Scan(&roomID, ¬ificationCount, &highlightCount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1036,8 +1036,15 @@ func (d *Database) UpsertRoomUnreadNotificationCounts(ctx context.Context, userI
|
|||
return
|
||||
}
|
||||
|
||||
func (d *Database) GetUserUnreadNotificationCounts(ctx context.Context, userID string, from, to types.StreamPosition) (map[string]*eventutil.NotificationData, error) {
|
||||
return d.NotificationData.SelectUserUnreadCounts(ctx, nil, userID, from, to)
|
||||
func (d *Database) GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, rooms map[string]string) (map[string]*eventutil.NotificationData, error) {
|
||||
roomIDs := make([]string, 0, len(rooms))
|
||||
for roomID, membership := range rooms {
|
||||
if membership != gomatrixserverlib.Join {
|
||||
continue
|
||||
}
|
||||
roomIDs = append(roomIDs, roomID)
|
||||
}
|
||||
return d.NotificationData.SelectUnserUnreadCountsForRooms(ctx, nil, userID, roomIDs)
|
||||
}
|
||||
|
||||
func (d *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) {
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ package sqlite3
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||
|
|
@ -32,19 +33,21 @@ func NewSqliteNotificationDataTable(db *sql.DB, streamID *StreamIDStatements) (t
|
|||
}
|
||||
r := ¬ificationDataStatements{
|
||||
streamIDStatements: streamID,
|
||||
db: db,
|
||||
}
|
||||
return r, sqlutil.StatementList{
|
||||
{&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL},
|
||||
{&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL},
|
||||
{&r.selectMaxID, selectMaxNotificationIDSQL},
|
||||
// {&r.selectUserUnreadCountsForRooms, selectUserUnreadNotificationsForRooms}, // used at runtime
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
type notificationDataStatements struct {
|
||||
streamIDStatements *StreamIDStatements
|
||||
upsertRoomUnreadCounts *sql.Stmt
|
||||
selectUserUnreadCounts *sql.Stmt
|
||||
selectMaxID *sql.Stmt
|
||||
db *sql.DB
|
||||
streamIDStatements *StreamIDStatements
|
||||
upsertRoomUnreadCounts *sql.Stmt
|
||||
selectUserUnreadCountsForRooms *sql.Stmt
|
||||
selectMaxID *sql.Stmt
|
||||
}
|
||||
|
||||
const notificationDataSchema = `
|
||||
|
|
@ -63,12 +66,10 @@ const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_
|
|||
ON CONFLICT (user_id, room_id)
|
||||
DO UPDATE SET id = $5, notification_count = $6, highlight_count = $7`
|
||||
|
||||
const selectUserUnreadNotificationCountsSQL = `SELECT
|
||||
id, room_id, notification_count, highlight_count
|
||||
FROM syncapi_notification_data
|
||||
WHERE
|
||||
user_id = $1 AND
|
||||
id BETWEEN $2 + 1 AND $3`
|
||||
const selectUserUnreadNotificationsForRooms = `SELECT room_id, notification_count, highlight_count
|
||||
FROM syncapi_notification_data
|
||||
WHERE user_id = $1 AND
|
||||
room_id IN ($2)`
|
||||
|
||||
const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data`
|
||||
|
||||
|
|
@ -81,20 +82,26 @@ func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context,
|
|||
return
|
||||
}
|
||||
|
||||
func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, txn *sql.Tx, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) {
|
||||
rows, err := sqlutil.TxStmt(txn, r.selectUserUnreadCounts).QueryContext(ctx, userID, fromExcl, toIncl)
|
||||
func (r *notificationDataStatements) SelectUnserUnreadCountsForRooms(
|
||||
ctx context.Context, txn *sql.Tx, userID string, roomIDs []string,
|
||||
) (map[string]*eventutil.NotificationData, error) {
|
||||
params := make([]interface{}, len(roomIDs)+1)
|
||||
params[0] = userID
|
||||
for i := range roomIDs {
|
||||
params[i+1] = roomIDs[i]
|
||||
}
|
||||
sql := strings.Replace(selectUserUnreadNotificationsForRooms, "($1)", sqlutil.QueryVariadic(len(params)), 1)
|
||||
rows, err := r.db.QueryContext(ctx, sql, params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCounts: rows.close() failed")
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "SelectUnserUnreadCountsForRooms: rows.close() failed")
|
||||
|
||||
roomCounts := map[string]*eventutil.NotificationData{}
|
||||
var roomID string
|
||||
var notificationCount, highlightCount int
|
||||
for rows.Next() {
|
||||
var id types.StreamPosition
|
||||
var roomID string
|
||||
var notificationCount, highlightCount int
|
||||
|
||||
if err = rows.Scan(&id, &roomID, ¬ificationCount, &highlightCount); err != nil {
|
||||
if err = rows.Scan(&roomID, ¬ificationCount, &highlightCount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -190,7 +190,7 @@ type Memberships interface {
|
|||
|
||||
type NotificationData interface {
|
||||
UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
|
||||
SelectUserUnreadCounts(ctx context.Context, txn *sql.Tx, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error)
|
||||
SelectUnserUnreadCountsForRooms(ctx context.Context, txn *sql.Tx, userID string, roomIDs []string) (map[string]*eventutil.NotificationData, error)
|
||||
SelectMaxID(ctx context.Context, txn *sql.Tx) (int64, error)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -32,12 +32,12 @@ func (p *NotificationDataStreamProvider) IncrementalSync(
|
|||
req *types.SyncRequest,
|
||||
from, _ types.StreamPosition,
|
||||
) types.StreamPosition {
|
||||
// Always get the latest data, as this might have advanced while waiting
|
||||
// for other streams to prepare their responses and add/updated notifications.
|
||||
to := p.LatestPosition(ctx)
|
||||
countsByRoom, err := p.DB.GetUserUnreadNotificationCounts(ctx, req.Device.UserID, from, to)
|
||||
// Get the unread notifications for rooms in our join response.
|
||||
// This is to ensure clients always have an unread notification section
|
||||
// and can display the correct numbers.
|
||||
countsByRoom, err := p.DB.GetUserUnreadNotificationCountsForRooms(ctx, req.Device.UserID, req.Rooms)
|
||||
if err != nil {
|
||||
req.Log.WithError(err).Error("GetUserUnreadNotificationCounts failed")
|
||||
req.Log.WithError(err).Error("GetUserUnreadNotificationCountsForRooms failed")
|
||||
return from
|
||||
}
|
||||
|
||||
|
|
@ -53,5 +53,6 @@ func (p *NotificationDataStreamProvider) IncrementalSync(
|
|||
}
|
||||
req.Response.Rooms.Join[roomID] = jr
|
||||
}
|
||||
return to
|
||||
|
||||
return p.LatestPosition(ctx)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue