Use timestamp as streamPos

This commit is contained in:
Till Faelligen 2022-09-26 13:27:08 +02:00
parent 76e78cb6e4
commit b66cf465e5
No known key found for this signature in database
GPG key ID: 3DF82D8AB9211D4E
9 changed files with 54 additions and 46 deletions

View file

@ -99,7 +99,12 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats
return true return true
} }
updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, readPos, true) metadata, err := msg.Metadata()
if err != nil {
return false
}
updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp)), true)
if err != nil { if err != nil {
log.WithError(err).Error("userapi EDU consumer") log.WithError(err).Error("userapi EDU consumer")
return false return false

View file

@ -92,7 +92,12 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms
"event_type": event.Type(), "event_type": event.Type(),
}).Tracef("Received message from roomserver: %#v", output) }).Tracef("Received message from roomserver: %#v", output)
if err := s.processMessage(ctx, event); err != nil { metadata, err := msg.Metadata()
if err != nil {
return true
}
if err := s.processMessage(ctx, event, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp))); err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"event_id": event.EventID(), "event_id": event.EventID(),
}).WithError(err).Errorf("userapi consumer: process room event failure") }).WithError(err).Errorf("userapi consumer: process room event failure")
@ -101,7 +106,7 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms
return true return true
} }
func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error { func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, streamPos uint64) error {
members, roomSize, err := s.localRoomMembers(ctx, event.RoomID()) members, roomSize, err := s.localRoomMembers(ctx, event.RoomID())
if err != nil { if err != nil {
return fmt.Errorf("s.localRoomMembers: %w", err) return fmt.Errorf("s.localRoomMembers: %w", err)
@ -141,7 +146,7 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gom
// removing it means we can send all notifications to // removing it means we can send all notifications to
// e.g. Element's Push gateway in one go. // e.g. Element's Push gateway in one go.
for _, mem := range members { for _, mem := range members {
if err := s.notifyLocal(ctx, event, mem, roomSize, roomName); err != nil { if err := s.notifyLocal(ctx, event, mem, roomSize, roomName, streamPos); err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"localpart": mem.Localpart, "localpart": mem.Localpart,
}).WithError(err).Debugf("Unable to push to local user") }).WithError(err).Debugf("Unable to push to local user")
@ -290,7 +295,7 @@ func unmarshalCanonicalAlias(event *gomatrixserverlib.HeaderedEvent) (string, er
} }
// notifyLocal finds the right push actions for a local user, given an event. // notifyLocal finds the right push actions for a local user, given an event.
func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int, roomName string) error { func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int, roomName string, streamPos uint64) error {
actions, err := s.evaluatePushRules(ctx, event, mem, roomSize) actions, err := s.evaluatePushRules(ctx, event, mem, roomSize)
if err != nil { if err != nil {
return err return err
@ -328,7 +333,7 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr
RoomID: event.RoomID(), RoomID: event.RoomID(),
TS: gomatrixserverlib.AsTimestamp(time.Now()), TS: gomatrixserverlib.AsTimestamp(time.Now()),
} }
if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), tweaks, n); err != nil { if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), streamPos, tweaks, n); err != nil {
return err return err
} }

View file

@ -107,7 +107,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun
return nil return nil
} }
deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, req.RoomID, output.FullyRead) deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now())))
if err != nil { if err != nil {
logrus.WithError(err).Errorf("UserInternalAPI.setFullyRead: DeleteNotificationsUpTo failed") logrus.WithError(err).Errorf("UserInternalAPI.setFullyRead: DeleteNotificationsUpTo failed")
return err return err

View file

@ -119,9 +119,9 @@ type ThreePID interface {
} }
type Notification interface { type Notification interface {
InsertNotification(ctx context.Context, localpart, eventID string, tweaks map[string]interface{}, n *api.Notification) error InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error
DeleteNotificationsUpTo(ctx context.Context, localpart, roomID, eventID string) (affected bool, err error) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error)
SetNotificationsRead(ctx context.Context, localpart, roomID, eventID string, read bool) (affected bool, err error) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, read bool) (affected bool, err error)
GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error)
GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error)
GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error)

View file

@ -58,13 +58,13 @@ CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_noti
` `
const insertNotificationSQL = "" + const insertNotificationSQL = "" +
"INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, 0, $4, $5, $6)" "INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)"
const deleteNotificationsUpToSQL = "" + const deleteNotificationsUpToSQL = "" +
"DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND id <= (SELECT id FROM userapi_notifications WHERE event_id = $3 AND stream_pos = 0)" "DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3"
const updateNotificationReadSQL = "" + const updateNotificationReadSQL = "" +
"UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND id <= (SELECT id FROM userapi_notifications WHERE event_id = $4 AND stream_pos = 0) AND read <> $1" "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1"
const selectNotificationSQL = "" + const selectNotificationSQL = "" +
"SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" + "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" +
@ -111,7 +111,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error
} }
// Insert inserts a notification into the database. // Insert inserts a notification into the database.
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, highlight bool, n *api.Notification) error { func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error {
roomID, tsMS := n.RoomID, n.TS roomID, tsMS := n.RoomID, n.TS
nn := *n nn := *n
// Clears out fields that have their own columns to (1) shrink the // Clears out fields that have their own columns to (1) shrink the
@ -122,13 +122,13 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local
if err != nil { if err != nil {
return err return err
} }
_, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, tsMS, highlight, string(bs)) _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs))
return err return err
} }
// DeleteUpTo deletes all previous notifications, up to and including the event. // DeleteUpTo deletes all previous notifications, up to and including the event.
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string) (affected bool, _ error) { func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, eventID) res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -136,13 +136,13 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l
if err != nil { if err != nil {
return true, err return true, err
} }
log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "event_id": eventID}).Tracef("DeleteUpTo: %d rows affected", nrows) log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("DeleteUpTo: %d rows affected", nrows)
return nrows > 0, nil return nrows > 0, nil
} }
// UpdateRead updates the "read" value for an event. // UpdateRead updates the "read" value for an event.
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string, v bool) (affected bool, _ error) { func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, eventID) res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -150,7 +150,7 @@ func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, l
if err != nil { if err != nil {
return true, err return true, err
} }
log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "event_id": eventID}).Tracef("UpdateRead: %d rows affected", nrows) log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("UpdateRead: %d rows affected", nrows)
return nrows > 0, nil return nrows > 0, nil
} }

View file

@ -700,23 +700,23 @@ func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (
return d.LoginTokens.SelectLoginToken(ctx, token) return d.LoginTokens.SelectLoginToken(ctx, token)
} }
func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, tweaks map[string]interface{}, n *api.Notification) error { func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Notifications.Insert(ctx, txn, localpart, eventID, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n) return d.Notifications.Insert(ctx, txn, localpart, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n)
}) })
} }
func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID, eventID string) (affected bool, err error) { func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, eventID) affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, pos)
return err return err
}) })
return return
} }
func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID, eventID string, b bool) (affected bool, err error) { func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, b bool) (affected bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, eventID, b) affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, pos, b)
return err return err
}) })
return return

View file

@ -58,13 +58,13 @@ CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_noti
` `
const insertNotificationSQL = "" + const insertNotificationSQL = "" +
"INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, 0, $4, $5, $6)" "INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)"
const deleteNotificationsUpToSQL = "" + const deleteNotificationsUpToSQL = "" +
"DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND id <= (SELECT id FROM userapi_notifications WHERE event_id = $3 AND stream_pos = 0)" "DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3"
const updateNotificationReadSQL = "" + const updateNotificationReadSQL = "" +
"UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND id <= (SELECT id FROM userapi_notifications WHERE event_id = $4 AND stream_pos = 0) AND read <> $1" "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1"
const selectNotificationSQL = "" + const selectNotificationSQL = "" +
"SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" + "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" +
@ -111,7 +111,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error
} }
// Insert inserts a notification into the database. // Insert inserts a notification into the database.
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, highlight bool, n *api.Notification) error { func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error {
roomID, tsMS := n.RoomID, n.TS roomID, tsMS := n.RoomID, n.TS
nn := *n nn := *n
// Clears out fields that have their own columns to (1) shrink the // Clears out fields that have their own columns to (1) shrink the
@ -122,13 +122,13 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local
if err != nil { if err != nil {
return err return err
} }
_, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, tsMS, highlight, string(bs)) _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs))
return err return err
} }
// DeleteUpTo deletes all previous notifications, up to and including the event. // DeleteUpTo deletes all previous notifications, up to and including the event.
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string) (affected bool, _ error) { func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, eventID) res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -136,13 +136,13 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l
if err != nil { if err != nil {
return true, err return true, err
} }
log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "event_id": eventID}).Tracef("DeleteUpTo: %d rows affected", nrows) log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("DeleteUpTo: %d rows affected", nrows)
return nrows > 0, nil return nrows > 0, nil
} }
// UpdateRead updates the "read" value for an event. // UpdateRead updates the "read" value for an event.
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string, v bool) (affected bool, _ error) { func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, eventID) res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -150,7 +150,7 @@ func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, l
if err != nil { if err != nil {
return true, err return true, err
} }
log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "event_id": eventID}).Tracef("UpdateRead: %d rows affected", nrows) log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("UpdateRead: %d rows affected", nrows)
return nrows > 0, nil return nrows > 0, nil
} }

View file

@ -494,7 +494,6 @@ func Test_Notification(t *testing.T) {
db, close := mustCreateDatabase(t, dbType) db, close := mustCreateDatabase(t, dbType)
defer close() defer close()
// generate some dummy notifications // generate some dummy notifications
eventIDs := make([]string, 0, 10)
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
eventID := util.RandomString(16) eventID := util.RandomString(16)
roomID := room.ID roomID := room.ID
@ -515,9 +514,8 @@ func Test_Notification(t *testing.T) {
RoomID: roomID, RoomID: roomID,
TS: gomatrixserverlib.AsTimestamp(ts), TS: gomatrixserverlib.AsTimestamp(ts),
} }
err = db.InsertNotification(ctx, aliceLocalpart, eventID, nil, notification) err = db.InsertNotification(ctx, aliceLocalpart, eventID, uint64(i+1), nil, notification)
assert.NoError(t, err, "unable to insert notification") assert.NoError(t, err, "unable to insert notification")
eventIDs = append(eventIDs, eventID)
} }
// get notifications // get notifications
@ -534,12 +532,12 @@ func Test_Notification(t *testing.T) {
assert.Equal(t, int64(4), total) assert.Equal(t, int64(4), total)
// mark notification as read // mark notification as read
affected, err := db.SetNotificationsRead(ctx, aliceLocalpart, room2.ID, eventIDs[6], true) affected, err := db.SetNotificationsRead(ctx, aliceLocalpart, room2.ID, 7, true)
assert.NoError(t, err, "unable to set notifications read") assert.NoError(t, err, "unable to set notifications read")
assert.True(t, affected) assert.True(t, affected)
// this should delete 2 notifications // this should delete 2 notifications
affected, err = db.DeleteNotificationsUpTo(ctx, aliceLocalpart, room2.ID, eventIDs[7]) affected, err = db.DeleteNotificationsUpTo(ctx, aliceLocalpart, room2.ID, 8)
assert.NoError(t, err, "unable to set notifications read") assert.NoError(t, err, "unable to set notifications read")
assert.True(t, affected) assert.True(t, affected)

View file

@ -105,9 +105,9 @@ type PusherTable interface {
type NotificationTable interface { type NotificationTable interface {
Clean(ctx context.Context, txn *sql.Tx) error Clean(ctx context.Context, txn *sql.Tx) error
Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, highlight bool, n *api.Notification) error Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error
DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string) (affected bool, _ error) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error)
UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string, v bool) (affected bool, _ error) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error)
Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error)
SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter NotificationFilter) (int64, error) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter NotificationFilter) (int64, error)
SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error)