diff --git a/userapi/consumers/clientapi.go b/userapi/consumers/clientapi.go index 17d19f5f3..c220d35cb 100644 --- a/userapi/consumers/clientapi.go +++ b/userapi/consumers/clientapi.go @@ -99,7 +99,12 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats return true } - updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, readPos, true) + metadata, err := msg.Metadata() + if err != nil { + return false + } + + updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp)), true) if err != nil { log.WithError(err).Error("userapi EDU consumer") return false diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index 52338b9b7..fbcb3856a 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -92,7 +92,12 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms "event_type": event.Type(), }).Tracef("Received message from roomserver: %#v", output) - if err := s.processMessage(ctx, event); err != nil { + metadata, err := msg.Metadata() + if err != nil { + return true + } + + if err := s.processMessage(ctx, event, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp))); err != nil { log.WithFields(log.Fields{ "event_id": event.EventID(), }).WithError(err).Errorf("userapi consumer: process room event failure") @@ -101,7 +106,7 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms return true } -func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error { +func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, streamPos uint64) error { members, roomSize, err := s.localRoomMembers(ctx, event.RoomID()) if err != nil { return fmt.Errorf("s.localRoomMembers: %w", err) @@ -141,7 +146,7 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gom // removing it means we can send all notifications to // e.g. Element's Push gateway in one go. for _, mem := range members { - if err := s.notifyLocal(ctx, event, mem, roomSize, roomName); err != nil { + if err := s.notifyLocal(ctx, event, mem, roomSize, roomName, streamPos); err != nil { log.WithFields(log.Fields{ "localpart": mem.Localpart, }).WithError(err).Debugf("Unable to push to local user") @@ -290,7 +295,7 @@ func unmarshalCanonicalAlias(event *gomatrixserverlib.HeaderedEvent) (string, er } // notifyLocal finds the right push actions for a local user, given an event. -func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int, roomName string) error { +func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int, roomName string, streamPos uint64) error { actions, err := s.evaluatePushRules(ctx, event, mem, roomSize) if err != nil { return err @@ -328,7 +333,7 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr RoomID: event.RoomID(), TS: gomatrixserverlib.AsTimestamp(time.Now()), } - if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), tweaks, n); err != nil { + if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), streamPos, tweaks, n); err != nil { return err } diff --git a/userapi/internal/api.go b/userapi/internal/api.go index a9a8c30ac..3e761a886 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -107,7 +107,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun return nil } - deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, req.RoomID, output.FullyRead) + deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now()))) if err != nil { logrus.WithError(err).Errorf("UserInternalAPI.setFullyRead: DeleteNotificationsUpTo failed") return err diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 42245638c..02efe7afe 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -119,9 +119,9 @@ type ThreePID interface { } type Notification interface { - InsertNotification(ctx context.Context, localpart, eventID string, tweaks map[string]interface{}, n *api.Notification) error - DeleteNotificationsUpTo(ctx context.Context, localpart, roomID, eventID string) (affected bool, err error) - SetNotificationsRead(ctx context.Context, localpart, roomID, eventID string, read bool) (affected bool, err error) + InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error + DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error) + SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, read bool) (affected bool, err error) GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) diff --git a/userapi/storage/postgres/notifications_table.go b/userapi/storage/postgres/notifications_table.go index 59b404258..24a30b2f5 100644 --- a/userapi/storage/postgres/notifications_table.go +++ b/userapi/storage/postgres/notifications_table.go @@ -58,13 +58,13 @@ CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_noti ` const insertNotificationSQL = "" + - "INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, 0, $4, $5, $6)" + "INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)" const deleteNotificationsUpToSQL = "" + - "DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND id <= (SELECT id FROM userapi_notifications WHERE event_id = $3 AND stream_pos = 0)" + "DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3" const updateNotificationReadSQL = "" + - "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND id <= (SELECT id FROM userapi_notifications WHERE event_id = $4 AND stream_pos = 0) AND read <> $1" + "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1" const selectNotificationSQL = "" + "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" + @@ -111,7 +111,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error } // Insert inserts a notification into the database. -func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, highlight bool, n *api.Notification) error { +func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error { roomID, tsMS := n.RoomID, n.TS nn := *n // Clears out fields that have their own columns to (1) shrink the @@ -122,13 +122,13 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local if err != nil { return err } - _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, tsMS, highlight, string(bs)) + _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs)) return err } // DeleteUpTo deletes all previous notifications, up to and including the event. -func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string) (affected bool, _ error) { - res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, eventID) +func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos) if err != nil { return false, err } @@ -136,13 +136,13 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l if err != nil { return true, err } - log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "event_id": eventID}).Tracef("DeleteUpTo: %d rows affected", nrows) + log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("DeleteUpTo: %d rows affected", nrows) return nrows > 0, nil } // UpdateRead updates the "read" value for an event. -func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string, v bool) (affected bool, _ error) { - res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, eventID) +func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos) if err != nil { return false, err } @@ -150,7 +150,7 @@ func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, l if err != nil { return true, err } - log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "event_id": eventID}).Tracef("UpdateRead: %d rows affected", nrows) + log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("UpdateRead: %d rows affected", nrows) return nrows > 0, nil } diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 36e31fe46..3ff299f1b 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -700,23 +700,23 @@ func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) ( return d.LoginTokens.SelectLoginToken(ctx, token) } -func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, tweaks map[string]interface{}, n *api.Notification) error { +func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.Notifications.Insert(ctx, txn, localpart, eventID, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n) + return d.Notifications.Insert(ctx, txn, localpart, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n) }) } -func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID, eventID string) (affected bool, err error) { +func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, eventID) + affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, pos) return err }) return } -func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID, eventID string, b bool) (affected bool, err error) { +func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, b bool) (affected bool, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, eventID, b) + affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, pos, b) return err }) return diff --git a/userapi/storage/sqlite3/notifications_table.go b/userapi/storage/sqlite3/notifications_table.go index 0ad9a59bf..a35ec7be5 100644 --- a/userapi/storage/sqlite3/notifications_table.go +++ b/userapi/storage/sqlite3/notifications_table.go @@ -58,13 +58,13 @@ CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_noti ` const insertNotificationSQL = "" + - "INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, 0, $4, $5, $6)" + "INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)" const deleteNotificationsUpToSQL = "" + - "DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND id <= (SELECT id FROM userapi_notifications WHERE event_id = $3 AND stream_pos = 0)" + "DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3" const updateNotificationReadSQL = "" + - "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND id <= (SELECT id FROM userapi_notifications WHERE event_id = $4 AND stream_pos = 0) AND read <> $1" + "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1" const selectNotificationSQL = "" + "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" + @@ -111,7 +111,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error } // Insert inserts a notification into the database. -func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, highlight bool, n *api.Notification) error { +func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error { roomID, tsMS := n.RoomID, n.TS nn := *n // Clears out fields that have their own columns to (1) shrink the @@ -122,13 +122,13 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local if err != nil { return err } - _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, tsMS, highlight, string(bs)) + _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs)) return err } // DeleteUpTo deletes all previous notifications, up to and including the event. -func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string) (affected bool, _ error) { - res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, eventID) +func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos) if err != nil { return false, err } @@ -136,13 +136,13 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l if err != nil { return true, err } - log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "event_id": eventID}).Tracef("DeleteUpTo: %d rows affected", nrows) + log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("DeleteUpTo: %d rows affected", nrows) return nrows > 0, nil } // UpdateRead updates the "read" value for an event. -func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string, v bool) (affected bool, _ error) { - res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, eventID) +func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos) if err != nil { return false, err } @@ -150,7 +150,7 @@ func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, l if err != nil { return true, err } - log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "event_id": eventID}).Tracef("UpdateRead: %d rows affected", nrows) + log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("UpdateRead: %d rows affected", nrows) return nrows > 0, nil } diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index c4d843baf..ca7c1bfd2 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -494,7 +494,6 @@ func Test_Notification(t *testing.T) { db, close := mustCreateDatabase(t, dbType) defer close() // generate some dummy notifications - eventIDs := make([]string, 0, 10) for i := 0; i < 10; i++ { eventID := util.RandomString(16) roomID := room.ID @@ -515,9 +514,8 @@ func Test_Notification(t *testing.T) { RoomID: roomID, TS: gomatrixserverlib.AsTimestamp(ts), } - err = db.InsertNotification(ctx, aliceLocalpart, eventID, nil, notification) + err = db.InsertNotification(ctx, aliceLocalpart, eventID, uint64(i+1), nil, notification) assert.NoError(t, err, "unable to insert notification") - eventIDs = append(eventIDs, eventID) } // get notifications @@ -534,12 +532,12 @@ func Test_Notification(t *testing.T) { assert.Equal(t, int64(4), total) // mark notification as read - affected, err := db.SetNotificationsRead(ctx, aliceLocalpart, room2.ID, eventIDs[6], true) + affected, err := db.SetNotificationsRead(ctx, aliceLocalpart, room2.ID, 7, true) assert.NoError(t, err, "unable to set notifications read") assert.True(t, affected) // this should delete 2 notifications - affected, err = db.DeleteNotificationsUpTo(ctx, aliceLocalpart, room2.ID, eventIDs[7]) + affected, err = db.DeleteNotificationsUpTo(ctx, aliceLocalpart, room2.ID, 8) assert.NoError(t, err, "unable to set notifications read") assert.True(t, affected) diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 3e761f806..cc4287997 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -105,9 +105,9 @@ type PusherTable interface { type NotificationTable interface { Clean(ctx context.Context, txn *sql.Tx) error - Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, highlight bool, n *api.Notification) error - DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string) (affected bool, _ error) - UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string, v bool) (affected bool, _ error) + Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error + DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) + UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter NotificationFilter) (int64, error) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error)