Always include notification data for rooms in our reponse

This is to avoid issues where we return notification data in one sync,
but not the other. Which could result in clients simply displaying a
"dot" as the count. (Observerd in Element Android, where the count would
disappear on subsequent syncs where no notification data would be
present, e.g. we only returned ephemeral data). It's also something Synapse is doing.
This commit is contained in:
Till Faelligen 2022-09-27 07:43:32 +02:00
parent db07e9b365
commit 472887fc32
No known key found for this signature in database
GPG key ID: 3DF82D8AB9211D4E
6 changed files with 70 additions and 52 deletions

View file

@ -29,6 +29,7 @@ import (
type Database interface { type Database interface {
Presence Presence
SharedUsers SharedUsers
Notifications
MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error)
@ -148,12 +149,6 @@ type Database interface {
// GetRoomReceipts gets all receipts for a given roomID // GetRoomReceipts gets all receipts for a given roomID
GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error)
// UpsertRoomUnreadNotificationCounts updates the notification statistics about a (user, room) key.
UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
// GetUserUnreadNotificationCounts returns statistics per room a user is interested in.
GetUserUnreadNotificationCounts(ctx context.Context, userID string, from, to types.StreamPosition) (map[string]*eventutil.NotificationData, error)
SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error)
SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error)
SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error)
@ -179,3 +174,11 @@ type SharedUsers interface {
// SharedUsers returns a subset of otherUserIDs that share a room with userID. // SharedUsers returns a subset of otherUserIDs that share a room with userID.
SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error) SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error)
} }
type Notifications interface {
// UpsertRoomUnreadNotificationCounts updates the notification statistics about a (user, room) key.
UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
// getUserUnreadNotificationCountsForRooms returns the unread notifications for the given rooms
GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, roomIDs map[string]string) (map[string]*eventutil.NotificationData, error)
}

View file

@ -18,6 +18,8 @@ import (
"context" "context"
"database/sql" "database/sql"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
@ -33,15 +35,15 @@ func NewPostgresNotificationDataTable(db *sql.DB) (tables.NotificationData, erro
r := &notificationDataStatements{} r := &notificationDataStatements{}
return r, sqlutil.StatementList{ return r, sqlutil.StatementList{
{&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL}, {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL},
{&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL}, {&r.selectUserUnreadCountsForRooms, selectUserUnreadNotificationsForRooms},
{&r.selectMaxID, selectMaxNotificationIDSQL}, {&r.selectMaxID, selectMaxNotificationIDSQL},
}.Prepare(db) }.Prepare(db)
} }
type notificationDataStatements struct { type notificationDataStatements struct {
upsertRoomUnreadCounts *sql.Stmt upsertRoomUnreadCounts *sql.Stmt
selectUserUnreadCounts *sql.Stmt selectUserUnreadCountsForRooms *sql.Stmt
selectMaxID *sql.Stmt selectMaxID *sql.Stmt
} }
const notificationDataSchema = ` const notificationDataSchema = `
@ -61,12 +63,10 @@ const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_
DO UPDATE SET id = nextval('syncapi_notification_data_id_seq'), notification_count = $3, highlight_count = $4 DO UPDATE SET id = nextval('syncapi_notification_data_id_seq'), notification_count = $3, highlight_count = $4
RETURNING id` RETURNING id`
const selectUserUnreadNotificationCountsSQL = `SELECT const selectUserUnreadNotificationsForRooms = `SELECT room_id, notification_count, highlight_count
id, room_id, notification_count, highlight_count FROM syncapi_notification_data
FROM syncapi_notification_data WHERE user_id = $1 AND
WHERE room_id = ANY($2)`
user_id = $1 AND
id BETWEEN $2 + 1 AND $3`
const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data` const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data`
@ -75,20 +75,20 @@ func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context,
return return
} }
func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, txn *sql.Tx, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) { func (r *notificationDataStatements) SelectUnserUnreadCountsForRooms(
rows, err := sqlutil.TxStmt(txn, r.selectUserUnreadCounts).QueryContext(ctx, userID, fromExcl, toIncl) ctx context.Context, txn *sql.Tx, userID string, roomIDs []string,
) (map[string]*eventutil.NotificationData, error) {
rows, err := sqlutil.TxStmt(txn, r.selectUserUnreadCountsForRooms).QueryContext(ctx, userID, pq.Array(roomIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCounts: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "SelectUnserUnreadCountsForRooms: rows.close() failed")
roomCounts := map[string]*eventutil.NotificationData{} roomCounts := map[string]*eventutil.NotificationData{}
var roomID string
var notificationCount, highlightCount int
for rows.Next() { for rows.Next() {
var id types.StreamPosition if err = rows.Scan(&roomID, &notificationCount, &highlightCount); err != nil {
var roomID string
var notificationCount, highlightCount int
if err = rows.Scan(&id, &roomID, &notificationCount, &highlightCount); err != nil {
return nil, err return nil, err
} }

View file

@ -1036,8 +1036,15 @@ func (d *Database) UpsertRoomUnreadNotificationCounts(ctx context.Context, userI
return return
} }
func (d *Database) GetUserUnreadNotificationCounts(ctx context.Context, userID string, from, to types.StreamPosition) (map[string]*eventutil.NotificationData, error) { func (d *Database) GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, rooms map[string]string) (map[string]*eventutil.NotificationData, error) {
return d.NotificationData.SelectUserUnreadCounts(ctx, nil, userID, from, to) roomIDs := make([]string, 0, len(rooms))
for roomID, membership := range rooms {
if membership != gomatrixserverlib.Join {
continue
}
roomIDs = append(roomIDs, roomID)
}
return d.NotificationData.SelectUnserUnreadCountsForRooms(ctx, nil, userID, roomIDs)
} }
func (d *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) { func (d *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) {

View file

@ -17,6 +17,7 @@ package sqlite3
import ( import (
"context" "context"
"database/sql" "database/sql"
"strings"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
@ -32,19 +33,21 @@ func NewSqliteNotificationDataTable(db *sql.DB, streamID *StreamIDStatements) (t
} }
r := &notificationDataStatements{ r := &notificationDataStatements{
streamIDStatements: streamID, streamIDStatements: streamID,
db: db,
} }
return r, sqlutil.StatementList{ return r, sqlutil.StatementList{
{&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL}, {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL},
{&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL},
{&r.selectMaxID, selectMaxNotificationIDSQL}, {&r.selectMaxID, selectMaxNotificationIDSQL},
// {&r.selectUserUnreadCountsForRooms, selectUserUnreadNotificationsForRooms}, // used at runtime
}.Prepare(db) }.Prepare(db)
} }
type notificationDataStatements struct { type notificationDataStatements struct {
streamIDStatements *StreamIDStatements db *sql.DB
upsertRoomUnreadCounts *sql.Stmt streamIDStatements *StreamIDStatements
selectUserUnreadCounts *sql.Stmt upsertRoomUnreadCounts *sql.Stmt
selectMaxID *sql.Stmt selectUserUnreadCountsForRooms *sql.Stmt
selectMaxID *sql.Stmt
} }
const notificationDataSchema = ` const notificationDataSchema = `
@ -63,12 +66,10 @@ const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_
ON CONFLICT (user_id, room_id) ON CONFLICT (user_id, room_id)
DO UPDATE SET id = $5, notification_count = $6, highlight_count = $7` DO UPDATE SET id = $5, notification_count = $6, highlight_count = $7`
const selectUserUnreadNotificationCountsSQL = `SELECT const selectUserUnreadNotificationsForRooms = `SELECT room_id, notification_count, highlight_count
id, room_id, notification_count, highlight_count FROM syncapi_notification_data
FROM syncapi_notification_data WHERE user_id = $1 AND
WHERE room_id IN ($2)`
user_id = $1 AND
id BETWEEN $2 + 1 AND $3`
const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data` const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data`
@ -81,20 +82,26 @@ func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context,
return return
} }
func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, txn *sql.Tx, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) { func (r *notificationDataStatements) SelectUnserUnreadCountsForRooms(
rows, err := sqlutil.TxStmt(txn, r.selectUserUnreadCounts).QueryContext(ctx, userID, fromExcl, toIncl) ctx context.Context, txn *sql.Tx, userID string, roomIDs []string,
) (map[string]*eventutil.NotificationData, error) {
params := make([]interface{}, len(roomIDs)+1)
params[0] = userID
for i := range roomIDs {
params[i+1] = roomIDs[i]
}
sql := strings.Replace(selectUserUnreadNotificationsForRooms, "($1)", sqlutil.QueryVariadic(len(params)), 1)
rows, err := r.db.QueryContext(ctx, sql, params)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCounts: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "SelectUnserUnreadCountsForRooms: rows.close() failed")
roomCounts := map[string]*eventutil.NotificationData{} roomCounts := map[string]*eventutil.NotificationData{}
var roomID string
var notificationCount, highlightCount int
for rows.Next() { for rows.Next() {
var id types.StreamPosition if err = rows.Scan(&roomID, &notificationCount, &highlightCount); err != nil {
var roomID string
var notificationCount, highlightCount int
if err = rows.Scan(&id, &roomID, &notificationCount, &highlightCount); err != nil {
return nil, err return nil, err
} }

View file

@ -190,7 +190,7 @@ type Memberships interface {
type NotificationData interface { type NotificationData interface {
UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error) UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
SelectUserUnreadCounts(ctx context.Context, txn *sql.Tx, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) SelectUnserUnreadCountsForRooms(ctx context.Context, txn *sql.Tx, userID string, roomIDs []string) (map[string]*eventutil.NotificationData, error)
SelectMaxID(ctx context.Context, txn *sql.Tx) (int64, error) SelectMaxID(ctx context.Context, txn *sql.Tx) (int64, error)
} }

View file

@ -32,12 +32,12 @@ func (p *NotificationDataStreamProvider) IncrementalSync(
req *types.SyncRequest, req *types.SyncRequest,
from, _ types.StreamPosition, from, _ types.StreamPosition,
) types.StreamPosition { ) types.StreamPosition {
// Always get the latest data, as this might have advanced while waiting // Get the unread notifications for rooms in our join response.
// for other streams to prepare their responses and add/updated notifications. // This is to ensure clients always have an unread notification section
to := p.LatestPosition(ctx) // and can display the correct numbers.
countsByRoom, err := p.DB.GetUserUnreadNotificationCounts(ctx, req.Device.UserID, from, to) countsByRoom, err := p.DB.GetUserUnreadNotificationCountsForRooms(ctx, req.Device.UserID, req.Rooms)
if err != nil { if err != nil {
req.Log.WithError(err).Error("GetUserUnreadNotificationCounts failed") req.Log.WithError(err).Error("GetUserUnreadNotificationCountsForRooms failed")
return from return from
} }
@ -53,5 +53,6 @@ func (p *NotificationDataStreamProvider) IncrementalSync(
} }
req.Response.Rooms.Join[roomID] = jr req.Response.Rooms.Join[roomID] = jr
} }
return to
return p.LatestPosition(ctx)
} }