Always include notification data for rooms in our reponse

This is to avoid issues where we return notification data in one sync,
but not the other. Which could result in clients simply displaying a
"dot" as the count. (Observerd in Element Android, where the count would
disappear on subsequent syncs where no notification data would be
present, e.g. we only returned ephemeral data). It's also something Synapse is doing.
This commit is contained in:
Till Faelligen 2022-09-27 07:43:32 +02:00
parent db07e9b365
commit 472887fc32
No known key found for this signature in database
GPG key ID: 3DF82D8AB9211D4E
6 changed files with 70 additions and 52 deletions

View file

@ -29,6 +29,7 @@ import (
type Database interface { type Database interface {
Presence Presence
SharedUsers SharedUsers
Notifications
MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error)
@ -148,12 +149,6 @@ type Database interface {
// GetRoomReceipts gets all receipts for a given roomID // GetRoomReceipts gets all receipts for a given roomID
GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error)
// UpsertRoomUnreadNotificationCounts updates the notification statistics about a (user, room) key.
UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
// GetUserUnreadNotificationCounts returns statistics per room a user is interested in.
GetUserUnreadNotificationCounts(ctx context.Context, userID string, from, to types.StreamPosition) (map[string]*eventutil.NotificationData, error)
SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error)
SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error)
SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error)
@ -179,3 +174,11 @@ type SharedUsers interface {
// SharedUsers returns a subset of otherUserIDs that share a room with userID. // SharedUsers returns a subset of otherUserIDs that share a room with userID.
SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error) SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error)
} }
type Notifications interface {
// UpsertRoomUnreadNotificationCounts updates the notification statistics about a (user, room) key.
UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
// getUserUnreadNotificationCountsForRooms returns the unread notifications for the given rooms
GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, roomIDs map[string]string) (map[string]*eventutil.NotificationData, error)
}

View file

@ -18,6 +18,8 @@ import (
"context" "context"
"database/sql" "database/sql"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
@ -33,14 +35,14 @@ func NewPostgresNotificationDataTable(db *sql.DB) (tables.NotificationData, erro
r := &notificationDataStatements{} r := &notificationDataStatements{}
return r, sqlutil.StatementList{ return r, sqlutil.StatementList{
{&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL}, {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL},
{&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL}, {&r.selectUserUnreadCountsForRooms, selectUserUnreadNotificationsForRooms},
{&r.selectMaxID, selectMaxNotificationIDSQL}, {&r.selectMaxID, selectMaxNotificationIDSQL},
}.Prepare(db) }.Prepare(db)
} }
type notificationDataStatements struct { type notificationDataStatements struct {
upsertRoomUnreadCounts *sql.Stmt upsertRoomUnreadCounts *sql.Stmt
selectUserUnreadCounts *sql.Stmt selectUserUnreadCountsForRooms *sql.Stmt
selectMaxID *sql.Stmt selectMaxID *sql.Stmt
} }
@ -61,12 +63,10 @@ const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_
DO UPDATE SET id = nextval('syncapi_notification_data_id_seq'), notification_count = $3, highlight_count = $4 DO UPDATE SET id = nextval('syncapi_notification_data_id_seq'), notification_count = $3, highlight_count = $4
RETURNING id` RETURNING id`
const selectUserUnreadNotificationCountsSQL = `SELECT const selectUserUnreadNotificationsForRooms = `SELECT room_id, notification_count, highlight_count
id, room_id, notification_count, highlight_count
FROM syncapi_notification_data FROM syncapi_notification_data
WHERE WHERE user_id = $1 AND
user_id = $1 AND room_id = ANY($2)`
id BETWEEN $2 + 1 AND $3`
const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data` const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data`
@ -75,20 +75,20 @@ func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context,
return return
} }
func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, txn *sql.Tx, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) { func (r *notificationDataStatements) SelectUnserUnreadCountsForRooms(
rows, err := sqlutil.TxStmt(txn, r.selectUserUnreadCounts).QueryContext(ctx, userID, fromExcl, toIncl) ctx context.Context, txn *sql.Tx, userID string, roomIDs []string,
) (map[string]*eventutil.NotificationData, error) {
rows, err := sqlutil.TxStmt(txn, r.selectUserUnreadCountsForRooms).QueryContext(ctx, userID, pq.Array(roomIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCounts: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "SelectUnserUnreadCountsForRooms: rows.close() failed")
roomCounts := map[string]*eventutil.NotificationData{} roomCounts := map[string]*eventutil.NotificationData{}
for rows.Next() {
var id types.StreamPosition
var roomID string var roomID string
var notificationCount, highlightCount int var notificationCount, highlightCount int
for rows.Next() {
if err = rows.Scan(&id, &roomID, &notificationCount, &highlightCount); err != nil { if err = rows.Scan(&roomID, &notificationCount, &highlightCount); err != nil {
return nil, err return nil, err
} }

View file

@ -1036,8 +1036,15 @@ func (d *Database) UpsertRoomUnreadNotificationCounts(ctx context.Context, userI
return return
} }
func (d *Database) GetUserUnreadNotificationCounts(ctx context.Context, userID string, from, to types.StreamPosition) (map[string]*eventutil.NotificationData, error) { func (d *Database) GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, rooms map[string]string) (map[string]*eventutil.NotificationData, error) {
return d.NotificationData.SelectUserUnreadCounts(ctx, nil, userID, from, to) roomIDs := make([]string, 0, len(rooms))
for roomID, membership := range rooms {
if membership != gomatrixserverlib.Join {
continue
}
roomIDs = append(roomIDs, roomID)
}
return d.NotificationData.SelectUnserUnreadCountsForRooms(ctx, nil, userID, roomIDs)
} }
func (d *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) { func (d *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) {

View file

@ -17,6 +17,7 @@ package sqlite3
import ( import (
"context" "context"
"database/sql" "database/sql"
"strings"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
@ -32,18 +33,20 @@ func NewSqliteNotificationDataTable(db *sql.DB, streamID *StreamIDStatements) (t
} }
r := &notificationDataStatements{ r := &notificationDataStatements{
streamIDStatements: streamID, streamIDStatements: streamID,
db: db,
} }
return r, sqlutil.StatementList{ return r, sqlutil.StatementList{
{&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL}, {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL},
{&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL},
{&r.selectMaxID, selectMaxNotificationIDSQL}, {&r.selectMaxID, selectMaxNotificationIDSQL},
// {&r.selectUserUnreadCountsForRooms, selectUserUnreadNotificationsForRooms}, // used at runtime
}.Prepare(db) }.Prepare(db)
} }
type notificationDataStatements struct { type notificationDataStatements struct {
db *sql.DB
streamIDStatements *StreamIDStatements streamIDStatements *StreamIDStatements
upsertRoomUnreadCounts *sql.Stmt upsertRoomUnreadCounts *sql.Stmt
selectUserUnreadCounts *sql.Stmt selectUserUnreadCountsForRooms *sql.Stmt
selectMaxID *sql.Stmt selectMaxID *sql.Stmt
} }
@ -63,12 +66,10 @@ const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_
ON CONFLICT (user_id, room_id) ON CONFLICT (user_id, room_id)
DO UPDATE SET id = $5, notification_count = $6, highlight_count = $7` DO UPDATE SET id = $5, notification_count = $6, highlight_count = $7`
const selectUserUnreadNotificationCountsSQL = `SELECT const selectUserUnreadNotificationsForRooms = `SELECT room_id, notification_count, highlight_count
id, room_id, notification_count, highlight_count
FROM syncapi_notification_data FROM syncapi_notification_data
WHERE WHERE user_id = $1 AND
user_id = $1 AND room_id IN ($2)`
id BETWEEN $2 + 1 AND $3`
const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data` const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data`
@ -81,20 +82,26 @@ func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context,
return return
} }
func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, txn *sql.Tx, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) { func (r *notificationDataStatements) SelectUnserUnreadCountsForRooms(
rows, err := sqlutil.TxStmt(txn, r.selectUserUnreadCounts).QueryContext(ctx, userID, fromExcl, toIncl) ctx context.Context, txn *sql.Tx, userID string, roomIDs []string,
) (map[string]*eventutil.NotificationData, error) {
params := make([]interface{}, len(roomIDs)+1)
params[0] = userID
for i := range roomIDs {
params[i+1] = roomIDs[i]
}
sql := strings.Replace(selectUserUnreadNotificationsForRooms, "($1)", sqlutil.QueryVariadic(len(params)), 1)
rows, err := r.db.QueryContext(ctx, sql, params)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCounts: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "SelectUnserUnreadCountsForRooms: rows.close() failed")
roomCounts := map[string]*eventutil.NotificationData{} roomCounts := map[string]*eventutil.NotificationData{}
for rows.Next() {
var id types.StreamPosition
var roomID string var roomID string
var notificationCount, highlightCount int var notificationCount, highlightCount int
for rows.Next() {
if err = rows.Scan(&id, &roomID, &notificationCount, &highlightCount); err != nil { if err = rows.Scan(&roomID, &notificationCount, &highlightCount); err != nil {
return nil, err return nil, err
} }

View file

@ -190,7 +190,7 @@ type Memberships interface {
type NotificationData interface { type NotificationData interface {
UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error) UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
SelectUserUnreadCounts(ctx context.Context, txn *sql.Tx, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) SelectUnserUnreadCountsForRooms(ctx context.Context, txn *sql.Tx, userID string, roomIDs []string) (map[string]*eventutil.NotificationData, error)
SelectMaxID(ctx context.Context, txn *sql.Tx) (int64, error) SelectMaxID(ctx context.Context, txn *sql.Tx) (int64, error)
} }

View file

@ -32,12 +32,12 @@ func (p *NotificationDataStreamProvider) IncrementalSync(
req *types.SyncRequest, req *types.SyncRequest,
from, _ types.StreamPosition, from, _ types.StreamPosition,
) types.StreamPosition { ) types.StreamPosition {
// Always get the latest data, as this might have advanced while waiting // Get the unread notifications for rooms in our join response.
// for other streams to prepare their responses and add/updated notifications. // This is to ensure clients always have an unread notification section
to := p.LatestPosition(ctx) // and can display the correct numbers.
countsByRoom, err := p.DB.GetUserUnreadNotificationCounts(ctx, req.Device.UserID, from, to) countsByRoom, err := p.DB.GetUserUnreadNotificationCountsForRooms(ctx, req.Device.UserID, req.Rooms)
if err != nil { if err != nil {
req.Log.WithError(err).Error("GetUserUnreadNotificationCounts failed") req.Log.WithError(err).Error("GetUserUnreadNotificationCountsForRooms failed")
return from return from
} }
@ -53,5 +53,6 @@ func (p *NotificationDataStreamProvider) IncrementalSync(
} }
req.Response.Rooms.Join[roomID] = jr req.Response.Rooms.Join[roomID] = jr
} }
return to
return p.LatestPosition(ctx)
} }