From 6f22cfa68f5548b8d07f6848160b2c41c9957e98 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Thu, 4 Aug 2022 18:10:08 +0200 Subject: [PATCH] Also for SQLite --- .../storage/sqlite3/notification_data_table.go | 16 +++++++++++----- syncapi/storage/sqlite3/stream_id_table.go | 8 ++++++++ syncapi/storage/sqlite3/syncserver.go | 2 +- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/syncapi/storage/sqlite3/notification_data_table.go b/syncapi/storage/sqlite3/notification_data_table.go index 4b3f074db..eaa11a8c0 100644 --- a/syncapi/storage/sqlite3/notification_data_table.go +++ b/syncapi/storage/sqlite3/notification_data_table.go @@ -25,12 +25,14 @@ import ( "github.com/matrix-org/dendrite/syncapi/types" ) -func NewSqliteNotificationDataTable(db *sql.DB) (tables.NotificationData, error) { +func NewSqliteNotificationDataTable(db *sql.DB, streamID *StreamIDStatements) (tables.NotificationData, error) { _, err := db.Exec(notificationDataSchema) if err != nil { return nil, err } - r := ¬ificationDataStatements{} + r := ¬ificationDataStatements{ + streamIDStatements: streamID, + } return r, sqlutil.StatementList{ {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL}, {&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL}, @@ -39,6 +41,7 @@ func NewSqliteNotificationDataTable(db *sql.DB) (tables.NotificationData, error) } type notificationDataStatements struct { + streamIDStatements *StreamIDStatements upsertRoomUnreadCounts *sql.Stmt selectUserUnreadCounts *sql.Stmt selectMaxID *sql.Stmt @@ -58,8 +61,7 @@ const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_ (user_id, room_id, notification_count, highlight_count) VALUES ($1, $2, $3, $4) ON CONFLICT (user_id, room_id) - DO UPDATE SET notification_count = $3, highlight_count = $4 - RETURNING id` + DO UPDATE SET id = $5, notification_count = $6, highlight_count = $7` const selectUserUnreadNotificationCountsSQL = `SELECT id, room_id, notification_count, highlight_count @@ -71,7 +73,11 @@ const selectUserUnreadNotificationCountsSQL = `SELECT const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data` func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) { - err = r.upsertRoomUnreadCounts.QueryRowContext(ctx, userID, roomID, notificationCount, highlightCount).Scan(&pos) + pos, err = r.streamIDStatements.nextNotificationID(ctx, nil) + if err != nil { + return + } + _, err = r.upsertRoomUnreadCounts.ExecContext(ctx, userID, roomID, notificationCount, highlightCount, pos, notificationCount, highlightCount) return } diff --git a/syncapi/storage/sqlite3/stream_id_table.go b/syncapi/storage/sqlite3/stream_id_table.go index 71980b806..1160a437e 100644 --- a/syncapi/storage/sqlite3/stream_id_table.go +++ b/syncapi/storage/sqlite3/stream_id_table.go @@ -26,6 +26,8 @@ INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("invite", 0) ON CONFLICT DO NOTHING; INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("presence", 0) ON CONFLICT DO NOTHING; +INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("notification", 0) + ON CONFLICT DO NOTHING; ` const increaseStreamIDStmt = "" + @@ -78,3 +80,9 @@ func (s *StreamIDStatements) nextPresenceID(ctx context.Context, txn *sql.Tx) (p err = increaseStmt.QueryRowContext(ctx, "presence").Scan(&pos) return } + +func (s *StreamIDStatements) nextNotificationID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { + increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) + err = increaseStmt.QueryRowContext(ctx, "notification").Scan(&pos) + return +} diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index 65b2bb38a..5c5eb0f55 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -95,7 +95,7 @@ func (d *SyncServerDatasource) prepare() (err error) { if err != nil { return err } - notificationData, err := NewSqliteNotificationDataTable(d.db) + notificationData, err := NewSqliteNotificationDataTable(d.db, &d.streamID) if err != nil { return err }