Refactor notifications

This commit is contained in:
Till Faelligen 2022-08-31 17:19:14 +02:00
parent a84b273955
commit e4f0d5bd92
No known key found for this signature in database
GPG key ID: 3DF82D8AB9211D4E
18 changed files with 135 additions and 250 deletions

View file

@ -16,7 +16,6 @@ package consumers
import ( import (
"context" "context"
"database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -113,7 +112,7 @@ func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msgs []*nats.M
return false return false
} }
if err = s.sendReadUpdate(ctx, userID, output); err != nil { if err = s.sendReadUpdate(userID, output); err != nil {
log.WithError(err).WithFields(logrus.Fields{ log.WithError(err).WithFields(logrus.Fields{
"user_id": userID, "user_id": userID,
"room_id": output.RoomID, "room_id": output.RoomID,
@ -137,7 +136,7 @@ func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msgs []*nats.M
return true return true
} }
func (s *OutputClientDataConsumer) sendReadUpdate(ctx context.Context, userID string, output eventutil.AccountData) error { func (s *OutputClientDataConsumer) sendReadUpdate(userID string, output eventutil.AccountData) error {
if output.Type != "m.fully_read" || output.ReadMarker == nil { if output.Type != "m.fully_read" || output.ReadMarker == nil {
return nil return nil
} }
@ -148,20 +147,8 @@ func (s *OutputClientDataConsumer) sendReadUpdate(ctx context.Context, userID st
if serverName != s.serverName { if serverName != s.serverName {
return nil return nil
} }
var readPos types.StreamPosition if output.ReadMarker.Read != "" || output.ReadMarker.FullyRead != "" {
var fullyReadPos types.StreamPosition if err := s.producer.SendReadUpdate(userID, output.RoomID, output.ReadMarker.Read, output.ReadMarker.FullyRead); err != nil {
if output.ReadMarker.Read != "" {
if _, readPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.Read); err != nil && err != sql.ErrNoRows {
return fmt.Errorf("s.db.PositionInTopology (Read): %w", err)
}
}
if output.ReadMarker.FullyRead != "" {
if _, fullyReadPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.FullyRead); err != nil && err != sql.ErrNoRows {
return fmt.Errorf("s.db.PositionInTopology (FullyRead): %w", err)
}
}
if readPos > 0 || fullyReadPos > 0 {
if err := s.producer.SendReadUpdate(userID, output.RoomID, readPos, fullyReadPos); err != nil {
return fmt.Errorf("s.producer.SendReadUpdate: %w", err) return fmt.Errorf("s.producer.SendReadUpdate: %w", err)
} }
} }

View file

@ -16,11 +16,15 @@ package consumers
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
"strconv" "strconv"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
@ -28,10 +32,6 @@ import (
"github.com/matrix-org/dendrite/syncapi/producers" "github.com/matrix-org/dendrite/syncapi/producers"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
) )
// OutputReceiptEventConsumer consumes events that originated in the EDU server. // OutputReceiptEventConsumer consumes events that originated in the EDU server.
@ -137,14 +137,8 @@ func (s *OutputReceiptEventConsumer) sendReadUpdate(ctx context.Context, output
if serverName != s.serverName { if serverName != s.serverName {
return nil return nil
} }
var readPos types.StreamPosition
if output.EventID != "" { if output.EventID != "" {
if _, readPos, err = s.db.PositionInTopology(ctx, output.EventID); err != nil && err != sql.ErrNoRows { if err := s.producer.SendReadUpdate(output.UserID, output.RoomID, output.EventID, ""); err != nil {
return fmt.Errorf("s.db.PositionInTopology (Read): %w", err)
}
}
if readPos > 0 {
if err := s.producer.SendReadUpdate(output.UserID, output.RoomID, readPos, 0); err != nil {
return fmt.Errorf("s.producer.SendReadUpdate: %w", err) return fmt.Errorf("s.producer.SendReadUpdate: %w", err)
} }
} }

View file

@ -21,17 +21,17 @@ import (
"fmt" "fmt"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/producers"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
) )
// OutputRoomEventConsumer consumes events that originated in the room server. // OutputRoomEventConsumer consumes events that originated in the room server.
@ -46,7 +46,6 @@ type OutputRoomEventConsumer struct {
pduStream types.StreamProvider pduStream types.StreamProvider
inviteStream types.StreamProvider inviteStream types.StreamProvider
notifier *notifier.Notifier notifier *notifier.Notifier
producer *producers.UserAPIStreamEventProducer
} }
// NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers. // NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers.
@ -59,7 +58,6 @@ func NewOutputRoomEventConsumer(
pduStream types.StreamProvider, pduStream types.StreamProvider,
inviteStream types.StreamProvider, inviteStream types.StreamProvider,
rsAPI api.SyncRoomserverAPI, rsAPI api.SyncRoomserverAPI,
producer *producers.UserAPIStreamEventProducer,
) *OutputRoomEventConsumer { ) *OutputRoomEventConsumer {
return &OutputRoomEventConsumer{ return &OutputRoomEventConsumer{
ctx: process.Context(), ctx: process.Context(),
@ -72,7 +70,6 @@ func NewOutputRoomEventConsumer(
pduStream: pduStream, pduStream: pduStream,
inviteStream: inviteStream, inviteStream: inviteStream,
rsAPI: rsAPI, rsAPI: rsAPI,
producer: producer,
} }
} }
@ -255,12 +252,6 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
return nil return nil
} }
if err = s.producer.SendStreamEvent(ev.RoomID(), ev, pduPos); err != nil {
log.WithError(err).Errorf("Failed to send stream output event for event %s", ev.EventID())
sentry.CaptureException(err)
return err
}
if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil { if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil {
log.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos) log.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos)
sentry.CaptureException(err) sentry.CaptureException(err)

View file

@ -19,6 +19,9 @@ import (
"encoding/json" "encoding/json"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
@ -26,8 +29,6 @@ import (
"github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
) )
// OutputNotificationDataConsumer consumes events that originated in // OutputNotificationDataConsumer consumes events that originated in

View file

@ -15,12 +15,10 @@
package producers package producers
import ( import (
"encoding/json"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/setup/jetstream"
) )
// UserAPIProducer produces events for the user API server to consume // UserAPIProducer produces events for the user API server to consume
@ -30,25 +28,12 @@ type UserAPIReadProducer struct {
} }
// SendData sends account data to the user API server // SendData sends account data to the user API server
func (p *UserAPIReadProducer) SendReadUpdate(userID, roomID string, readPos, fullyReadPos types.StreamPosition) error { func (p *UserAPIReadProducer) SendReadUpdate(userID, roomID string, readPos, fullyReadPos string) error {
m := &nats.Msg{ m := nats.NewMsg(p.Topic)
Subject: p.Topic,
Header: nats.Header{},
}
m.Header.Set(jetstream.UserID, userID) m.Header.Set(jetstream.UserID, userID)
m.Header.Set(jetstream.RoomID, roomID) m.Header.Set(jetstream.RoomID, roomID)
m.Header.Set("read", readPos)
data := types.ReadUpdate{ m.Header.Set("fully_read", fullyReadPos)
UserID: userID,
RoomID: roomID,
Read: readPos,
FullyRead: fullyReadPos,
}
var err error
m.Data, err = json.Marshal(data)
if err != nil {
return err
}
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"user_id": userID, "user_id": userID,
@ -57,6 +42,6 @@ func (p *UserAPIReadProducer) SendReadUpdate(userID, roomID string, readPos, ful
"fully_read_pos": fullyReadPos, "fully_read_pos": fullyReadPos,
}).Tracef("Producing to topic '%s'", p.Topic) }).Tracef("Producing to topic '%s'", p.Topic)
_, err = p.JetStream.PublishMsg(m) _, err := p.JetStream.PublishMsg(m)
return err return err
} }

View file

@ -1,60 +0,0 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package producers
import (
"encoding/json"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
)
// UserAPIProducer produces events for the user API server to consume
type UserAPIStreamEventProducer struct {
Topic string
JetStream nats.JetStreamContext
}
// SendData sends account data to the user API server
func (p *UserAPIStreamEventProducer) SendStreamEvent(roomID string, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition) error {
m := &nats.Msg{
Subject: p.Topic,
Header: nats.Header{},
}
m.Header.Set(jetstream.RoomID, roomID)
data := types.StreamedEvent{
Event: event,
StreamPosition: pos,
}
var err error
m.Data, err = json.Marshal(data)
if err != nil {
return err
}
log.WithFields(log.Fields{
"room_id": roomID,
"event_id": event.EventID(),
"event_type": event.Type(),
"stream_pos": pos,
}).Tracef("Producing to topic '%s'", p.Topic)
_, err = p.JetStream.PublishMsg(m)
return err
}

View file

@ -30,9 +30,10 @@ func (p *NotificationDataStreamProvider) CompleteSync(
func (p *NotificationDataStreamProvider) IncrementalSync( func (p *NotificationDataStreamProvider) IncrementalSync(
ctx context.Context, ctx context.Context,
req *types.SyncRequest, req *types.SyncRequest,
from, to types.StreamPosition, from, _ types.StreamPosition,
) types.StreamPosition { ) types.StreamPosition {
// We want counts for all possible rooms, so always start from zero. // We want counts for all possible rooms, so always start from zero.
to := p.LatestPosition(ctx)
countsByRoom, err := p.DB.GetUserUnreadNotificationCounts(ctx, req.Device.UserID, from, to) countsByRoom, err := p.DB.GetUserUnreadNotificationCounts(ctx, req.Device.UserID, from, to)
if err != nil { if err != nil {
req.Log.WithError(err).Error("GetUserUnreadNotificationCounts failed") req.Log.WithError(err).Error("GetUserUnreadNotificationCounts failed")
@ -41,14 +42,14 @@ func (p *NotificationDataStreamProvider) IncrementalSync(
// Add notification data to rooms. // Add notification data to rooms.
// Create an empty JoinResponse if the room isn't in this response // Create an empty JoinResponse if the room isn't in this response
for roomID, notificationData := range countsByRoom { for roomID, jr := range req.Response.Rooms.Join {
jr, exists := req.Response.Rooms.Join[roomID] counts := countsByRoom[roomID]
if !exists { if counts == nil {
jr = *types.NewJoinResponse() continue
} }
jr.UnreadNotifications = &types.UnreadNotifications{ jr.UnreadNotifications = &types.UnreadNotifications{
HighlightCount: notificationData.UnreadHighlightCount, HighlightCount: counts.UnreadHighlightCount,
NotificationCount: notificationData.UnreadNotificationCount, NotificationCount: counts.UnreadNotificationCount,
} }
req.Response.Rooms.Join[roomID] = jr req.Response.Rooms.Join[roomID] = jr
} }

View file

@ -17,9 +17,10 @@ package syncapi
import ( import (
"context" "context"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal/caching"
keyapi "github.com/matrix-org/dendrite/keyserver/api" keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
@ -76,11 +77,6 @@ func AddPublicRoutes(
logrus.WithError(err).Panicf("failed to start presence consumer") logrus.WithError(err).Panicf("failed to start presence consumer")
} }
userAPIStreamEventProducer := &producers.UserAPIStreamEventProducer{
JetStream: js,
Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputStreamEvent),
}
userAPIReadUpdateProducer := &producers.UserAPIReadProducer{ userAPIReadUpdateProducer := &producers.UserAPIReadProducer{
JetStream: js, JetStream: js,
Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReadUpdate), Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReadUpdate),
@ -97,7 +93,7 @@ func AddPublicRoutes(
roomConsumer := consumers.NewOutputRoomEventConsumer( roomConsumer := consumers.NewOutputRoomEventConsumer(
base.ProcessContext, cfg, js, syncDB, notifier, streams.PDUStreamProvider, base.ProcessContext, cfg, js, syncDB, notifier, streams.PDUStreamProvider,
streams.InviteStreamProvider, rsAPI, userAPIStreamEventProducer, streams.InviteStreamProvider, rsAPI,
) )
if err = roomConsumer.Start(); err != nil { if err = roomConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start room server consumer") logrus.WithError(err).Panicf("failed to start room server consumer")

View file

@ -505,19 +505,6 @@ type Peek struct {
Deleted bool Deleted bool
} }
type ReadUpdate struct {
UserID string `json:"user_id"`
RoomID string `json:"room_id"`
Read StreamPosition `json:"read,omitempty"`
FullyRead StreamPosition `json:"fully_read,omitempty"`
}
// StreamEvent is the same as gomatrixserverlib.Event but also has the PDU stream position for this event.
type StreamedEvent struct {
Event *gomatrixserverlib.HeaderedEvent `json:"event"`
StreamPosition StreamPosition `json:"stream_position"`
}
// OutputReceiptEvent is an entry in the receipt output kafka log // OutputReceiptEvent is an entry in the receipt output kafka log
type OutputReceiptEvent struct { type OutputReceiptEvent struct {
UserID string `json:"user_id"` UserID string `json:"user_id"`

View file

@ -26,7 +26,7 @@ import (
"github.com/matrix-org/dendrite/userapi/util" "github.com/matrix-org/dendrite/userapi/util"
) )
type OutputStreamEventConsumer struct { type OutputRoomEventConsumer struct {
ctx context.Context ctx context.Context
cfg *config.UserAPI cfg *config.UserAPI
userAPI api.UserInternalAPI userAPI api.UserInternalAPI
@ -39,7 +39,7 @@ type OutputStreamEventConsumer struct {
syncProducer *producers.SyncAPI syncProducer *producers.SyncAPI
} }
func NewOutputStreamEventConsumer( func NewOutputRoomEventConsumer(
process *process.ProcessContext, process *process.ProcessContext,
cfg *config.UserAPI, cfg *config.UserAPI,
js nats.JetStreamContext, js nats.JetStreamContext,
@ -48,14 +48,14 @@ func NewOutputStreamEventConsumer(
userAPI api.UserInternalAPI, userAPI api.UserInternalAPI,
rsAPI rsapi.UserRoomserverAPI, rsAPI rsapi.UserRoomserverAPI,
syncProducer *producers.SyncAPI, syncProducer *producers.SyncAPI,
) *OutputStreamEventConsumer { ) *OutputRoomEventConsumer {
return &OutputStreamEventConsumer{ return &OutputRoomEventConsumer{
ctx: process.Context(), ctx: process.Context(),
cfg: cfg, cfg: cfg,
jetstream: js, jetstream: js,
db: store, db: store,
durable: cfg.Matrix.JetStream.Durable("UserAPISyncAPIStreamEventConsumer"), durable: cfg.Matrix.JetStream.Durable("UserAPIRoomServerConsumer"),
topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputStreamEvent), topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputRoomEvent),
pgClient: pgClient, pgClient: pgClient,
userAPI: userAPI, userAPI: userAPI,
rsAPI: rsAPI, rsAPI: rsAPI,
@ -63,7 +63,7 @@ func NewOutputStreamEventConsumer(
} }
} }
func (s *OutputStreamEventConsumer) Start() error { func (s *OutputRoomEventConsumer) Start() error {
if err := jetstream.JetStreamConsumer( if err := jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, 1, s.ctx, s.jetstream, s.topic, s.durable, 1,
s.onMessage, nats.DeliverAll(), nats.ManualAck(), s.onMessage, nats.DeliverAll(), nats.ManualAck(),
@ -73,35 +73,38 @@ func (s *OutputStreamEventConsumer) Start() error {
return nil return nil
} }
func (s *OutputStreamEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called msg := msgs[0] // Guaranteed to exist if onMessage is called
var output types.StreamedEvent var output rsapi.OutputEvent
output.Event = &gomatrixserverlib.HeaderedEvent{}
if err := json.Unmarshal(msg.Data, &output); err != nil { if err := json.Unmarshal(msg.Data, &output); err != nil {
log.WithError(err).Errorf("userapi consumer: message parse failure") // If the message was invalid, log it and move on to the next message in the stream
log.WithError(err).Errorf("roomserver output log: message parse failure")
return true return true
} }
if output.Event.Event == nil { if output.Type != rsapi.OutputTypeNewRoomEvent {
return true
}
event := output.NewRoomEvent.Event
if event == nil {
log.Errorf("userapi consumer: expected event") log.Errorf("userapi consumer: expected event")
return true return true
} }
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"event_id": output.Event.EventID(), "event_id": event.EventID(),
"event_type": output.Event.Type(), "event_type": event.Type(),
"stream_pos": output.StreamPosition, }).Tracef("Received message from roomserver: %#v", output)
}).Tracef("Received message from sync API: %#v", output)
if err := s.processMessage(ctx, output.Event, int64(output.StreamPosition)); err != nil { if err := s.processMessage(ctx, event); err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"event_id": output.Event.EventID(), "event_id": event.EventID(),
}).WithError(err).Errorf("userapi consumer: process room event failure") }).WithError(err).Errorf("userapi consumer: process room event failure")
} }
return true return true
} }
func (s *OutputStreamEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos int64) error { func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error {
members, roomSize, err := s.localRoomMembers(ctx, event.RoomID()) members, roomSize, err := s.localRoomMembers(ctx, event.RoomID())
if err != nil { if err != nil {
return fmt.Errorf("s.localRoomMembers: %w", err) return fmt.Errorf("s.localRoomMembers: %w", err)
@ -132,7 +135,7 @@ func (s *OutputStreamEventConsumer) processMessage(ctx context.Context, event *g
"room_id": event.RoomID(), "room_id": event.RoomID(),
"num_members": len(members), "num_members": len(members),
"room_size": roomSize, "room_size": roomSize,
}).Tracef("Notifying members") }).Debugf("Notifying members")
// Notification.UserIsTarget is a per-member field, so we // Notification.UserIsTarget is a per-member field, so we
// cannot group all users in a single request. // cannot group all users in a single request.
@ -141,7 +144,7 @@ func (s *OutputStreamEventConsumer) processMessage(ctx context.Context, event *g
// removing it means we can send all notifications to // removing it means we can send all notifications to
// e.g. Element's Push gateway in one go. // e.g. Element's Push gateway in one go.
for _, mem := range members { for _, mem := range members {
if err := s.notifyLocal(ctx, event, pos, mem, roomSize, roomName); err != nil { if err := s.notifyLocal(ctx, event, mem, roomSize, roomName); err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"localpart": mem.Localpart, "localpart": mem.Localpart,
}).WithError(err).Debugf("Unable to push to local user") }).WithError(err).Debugf("Unable to push to local user")
@ -182,7 +185,7 @@ func newLocalMembership(event *gomatrixserverlib.ClientEvent) (*localMembership,
// localRoomMembers fetches the current local members of a room, and // localRoomMembers fetches the current local members of a room, and
// the total number of members. // the total number of members.
func (s *OutputStreamEventConsumer) localRoomMembers(ctx context.Context, roomID string) ([]*localMembership, int, error) { func (s *OutputRoomEventConsumer) localRoomMembers(ctx context.Context, roomID string) ([]*localMembership, int, error) {
req := &rsapi.QueryMembershipsForRoomRequest{ req := &rsapi.QueryMembershipsForRoomRequest{
RoomID: roomID, RoomID: roomID,
JoinedOnly: true, JoinedOnly: true,
@ -222,7 +225,7 @@ func (s *OutputStreamEventConsumer) localRoomMembers(ctx context.Context, roomID
// looks it up in roomserver. If there is no name, // looks it up in roomserver. If there is no name,
// m.room.canonical_alias is consulted. Returns an empty string if the // m.room.canonical_alias is consulted. Returns an empty string if the
// room has no name. // room has no name.
func (s *OutputStreamEventConsumer) roomName(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) (string, error) { func (s *OutputRoomEventConsumer) roomName(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) (string, error) {
if event.Type() == gomatrixserverlib.MRoomName { if event.Type() == gomatrixserverlib.MRoomName {
name, err := unmarshalRoomName(event) name, err := unmarshalRoomName(event)
if err != nil { if err != nil {
@ -290,7 +293,7 @@ func unmarshalCanonicalAlias(event *gomatrixserverlib.HeaderedEvent) (string, er
} }
// notifyLocal finds the right push actions for a local user, given an event. // notifyLocal finds the right push actions for a local user, given an event.
func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos int64, mem *localMembership, roomSize int, roomName string) error { func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int, roomName string) error {
actions, err := s.evaluatePushRules(ctx, event, mem, roomSize) actions, err := s.evaluatePushRules(ctx, event, mem, roomSize)
if err != nil { if err != nil {
return err return err
@ -328,7 +331,7 @@ func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *goma
RoomID: event.RoomID(), RoomID: event.RoomID(),
TS: gomatrixserverlib.AsTimestamp(time.Now()), TS: gomatrixserverlib.AsTimestamp(time.Now()),
} }
if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), pos, tweaks, n); err != nil { if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), tweaks, n); err != nil {
return err return err
} }
@ -348,7 +351,7 @@ func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *goma
"localpart": mem.Localpart, "localpart": mem.Localpart,
"num_urls": len(devicesByURLAndFormat), "num_urls": len(devicesByURLAndFormat),
"num_unread": userNumUnreadNotifs, "num_unread": userNumUnreadNotifs,
}).Tracef("Notifying single member") }).Debugf("Notifying single member")
// Push gateways are out of our control, and we cannot risk // Push gateways are out of our control, and we cannot risk
// looking up the server on a misbehaving push gateway. Each user // looking up the server on a misbehaving push gateway. Each user
@ -399,7 +402,7 @@ func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *goma
// evaluatePushRules fetches and evaluates the push rules of a local // evaluatePushRules fetches and evaluates the push rules of a local
// user. Returns actions (including dont_notify). // user. Returns actions (including dont_notify).
func (s *OutputStreamEventConsumer) evaluatePushRules(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) { func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) {
if event.Sender() == mem.UserID { if event.Sender() == mem.UserID {
// SPEC: Homeservers MUST NOT notify the Push Gateway for // SPEC: Homeservers MUST NOT notify the Push Gateway for
// events that the user has sent themselves. // events that the user has sent themselves.
@ -450,7 +453,7 @@ func (s *OutputStreamEventConsumer) evaluatePushRules(ctx context.Context, event
"room_id": event.RoomID(), "room_id": event.RoomID(),
"localpart": mem.Localpart, "localpart": mem.Localpart,
"rule_id": rule.RuleID, "rule_id": rule.RuleID,
}).Tracef("Matched a push rule") }).Debugf("Matched a push rule")
return rule.Actions, nil return rule.Actions, nil
} }
@ -494,7 +497,7 @@ func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, err
// localPushDevices pushes to the configured devices of a local // localPushDevices pushes to the configured devices of a local
// user. The map keys are [url][format]. // user. The map keys are [url][format].
func (s *OutputStreamEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) { func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) {
pusherDevices, err := util.GetPushDevices(ctx, localpart, tweaks, s.db) pusherDevices, err := util.GetPushDevices(ctx, localpart, tweaks, s.db)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
@ -518,7 +521,7 @@ func (s *OutputStreamEventConsumer) localPushDevices(ctx context.Context, localp
} }
// notifyHTTP performs a notificatation to a Push Gateway. // notifyHTTP performs a notificatation to a Push Gateway.
func (s *OutputStreamEventConsumer) notifyHTTP(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, url, format string, devices []*pushgateway.Device, localpart, roomName string, userNumUnreadNotifs int) ([]*pushgateway.Device, error) { func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, url, format string, devices []*pushgateway.Device, localpart, roomName string, userNumUnreadNotifs int) ([]*pushgateway.Device, error) {
logger := log.WithFields(log.Fields{ logger := log.WithFields(log.Fields{
"event_id": event.EventID(), "event_id": event.EventID(),
"url": url, "url": url,
@ -570,7 +573,7 @@ func (s *OutputStreamEventConsumer) notifyHTTP(ctx context.Context, event *gomat
logger.WithError(err).Errorf("Failed to notify push gateway %s", url) logger.WithError(err).Errorf("Failed to notify push gateway %s", url)
return nil, err return nil, err
} }
logger.WithField("num_rejected", len(res.Rejected)).Tracef("Push gateway result") logger.WithField("num_rejected", len(res.Rejected)).Debugf("Push gateway result")
if len(res.Rejected) == 0 { if len(res.Rejected) == 0 {
return nil, nil return nil, nil
@ -592,7 +595,7 @@ func (s *OutputStreamEventConsumer) notifyHTTP(ctx context.Context, event *gomat
} }
// deleteRejectedPushers deletes the pushers associated with the given devices. // deleteRejectedPushers deletes the pushers associated with the given devices.
func (s *OutputStreamEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) { func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"localpart": localpart, "localpart": localpart,
"app_id0": devices[0].AppID, "app_id0": devices[0].AppID,

View file

@ -2,20 +2,19 @@ package consumers
import ( import (
"context" "context"
"encoding/json"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/types"
uapi "github.com/matrix-org/dendrite/userapi/api" uapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/producers" "github.com/matrix-org/dendrite/userapi/producers"
"github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/util" "github.com/matrix-org/dendrite/userapi/util"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
) )
type OutputReadUpdateConsumer struct { type OutputReadUpdateConsumer struct {
@ -66,17 +65,14 @@ func (s *OutputReadUpdateConsumer) Start() error {
func (s *OutputReadUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { func (s *OutputReadUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called msg := msgs[0] // Guaranteed to exist if onMessage is called
var read types.ReadUpdate userID := msg.Header.Get(jetstream.UserID)
if err := json.Unmarshal(msg.Data, &read); err != nil { roomID := msg.Header.Get(jetstream.RoomID)
log.WithError(err).Error("userapi clientapi consumer: message parse failure") readPos := msg.Header.Get("read")
return true fullyReadPos := msg.Header.Get("fully_read")
}
if read.FullyRead == 0 && read.Read == 0 {
return true
}
userID := string(msg.Header.Get(jetstream.UserID)) if readPos == "" && fullyReadPos == "" {
roomID := string(msg.Header.Get(jetstream.RoomID)) return true
}
localpart, domain, err := gomatrixserverlib.SplitID('@', userID) localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
@ -92,10 +88,9 @@ func (s *OutputReadUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.M
"room_id": roomID, "room_id": roomID,
"user_id": userID, "user_id": userID,
}) })
log.Tracef("Received read update from sync API: %#v", read)
if read.Read > 0 { if readPos != "" {
updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, int64(read.Read), true) updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, readPos, true)
if err != nil { if err != nil {
log.WithError(err).Error("userapi EDU consumer") log.WithError(err).Error("userapi EDU consumer")
return false return false
@ -113,8 +108,8 @@ func (s *OutputReadUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.M
} }
} }
if read.FullyRead > 0 { if fullyReadPos != "" {
deleted, err := s.db.DeleteNotificationsUpTo(ctx, localpart, roomID, int64(read.FullyRead)) deleted, err := s.db.DeleteNotificationsUpTo(ctx, localpart, roomID, fullyReadPos)
if err != nil { if err != nil {
log.WithError(err).Errorf("userapi clientapi consumer: DeleteNotificationsUpTo failed") log.WithError(err).Errorf("userapi clientapi consumer: DeleteNotificationsUpTo failed")
return false return false
@ -126,7 +121,7 @@ func (s *OutputReadUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.M
return false return false
} }
if err := s.syncProducer.GetAndSendNotificationData(ctx, userID, read.RoomID); err != nil { if err := s.syncProducer.GetAndSendNotificationData(ctx, userID, roomID); err != nil {
log.WithError(err).Errorf("userapi clientapi consumer: GetAndSendNotificationData failed") log.WithError(err).Errorf("userapi clientapi consumer: GetAndSendNotificationData failed")
return false return false
} }

View file

@ -117,9 +117,9 @@ type ThreePID interface {
} }
type Notification interface { type Notification interface {
InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error InsertNotification(ctx context.Context, localpart, eventID string, tweaks map[string]interface{}, n *api.Notification) error
DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID, eventID string) (affected bool, err error)
SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, read bool) (affected bool, err error) SetNotificationsRead(ctx context.Context, localpart, roomID, eventID string, read bool) (affected bool, err error)
GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error)
GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error)
GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error)

View file

@ -58,13 +58,13 @@ CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_noti
` `
const insertNotificationSQL = "" + const insertNotificationSQL = "" +
"INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)" "INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, 0, $4, $5, $6)"
const deleteNotificationsUpToSQL = "" + const deleteNotificationsUpToSQL = "" +
"DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3" "DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND id <= (SELECT id FROM userapi_notifications WHERE event_id = $3)"
const updateNotificationReadSQL = "" + const updateNotificationReadSQL = "" +
"UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1" "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND id <= (SELECT id FROM userapi_notifications WHERE event_id = $4) AND read <> $1"
const selectNotificationSQL = "" + const selectNotificationSQL = "" +
"SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" + "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" +
@ -111,7 +111,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error
} }
// Insert inserts a notification into the database. // Insert inserts a notification into the database.
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error { func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, highlight bool, n *api.Notification) error {
roomID, tsMS := n.RoomID, n.TS roomID, tsMS := n.RoomID, n.TS
nn := *n nn := *n
// Clears out fields that have their own columns to (1) shrink the // Clears out fields that have their own columns to (1) shrink the
@ -122,13 +122,13 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local
if err != nil { if err != nil {
return err return err
} }
_, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs)) _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, tsMS, highlight, string(bs))
return err return err
} }
// DeleteUpTo deletes all previous notifications, up to and including the event. // DeleteUpTo deletes all previous notifications, up to and including the event.
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) { func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos) res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, eventID)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -136,13 +136,13 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l
if err != nil { if err != nil {
return true, err return true, err
} }
log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("DeleteUpTo: %d rows affected", nrows) log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "event_id": eventID}).Tracef("DeleteUpTo: %d rows affected", nrows)
return nrows > 0, nil return nrows > 0, nil
} }
// UpdateRead updates the "read" value for an event. // UpdateRead updates the "read" value for an event.
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) { func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string, v bool) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos) res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, eventID)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -150,7 +150,7 @@ func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, l
if err != nil { if err != nil {
return true, err return true, err
} }
log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("UpdateRead: %d rows affected", nrows) log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "event_id": eventID}).Tracef("UpdateRead: %d rows affected", nrows)
return nrows > 0, nil return nrows > 0, nil
} }

View file

@ -26,10 +26,11 @@ import (
"strings" "strings"
"time" "time"
"github.com/matrix-org/dendrite/userapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"github.com/matrix-org/dendrite/userapi/types"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
@ -664,23 +665,23 @@ func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (
return d.LoginTokens.SelectLoginToken(ctx, token) return d.LoginTokens.SelectLoginToken(ctx, token)
} }
func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error { func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, tweaks map[string]interface{}, n *api.Notification) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Notifications.Insert(ctx, txn, localpart, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n) return d.Notifications.Insert(ctx, txn, localpart, eventID, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n)
}) })
} }
func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error) { func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID, eventID string) (affected bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, pos) affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, eventID)
return err return err
}) })
return return
} }
func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, b bool) (affected bool, err error) { func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID, eventID string, b bool) (affected bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, pos, b) affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, eventID, b)
return err return err
}) })
return return

View file

@ -58,13 +58,13 @@ CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_noti
` `
const insertNotificationSQL = "" + const insertNotificationSQL = "" +
"INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)" "INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, 0, $4, $5, $6)"
const deleteNotificationsUpToSQL = "" + const deleteNotificationsUpToSQL = "" +
"DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3" "DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND id <= (SELECT id FROM userapi_notifications WHERE event_id = $3)"
const updateNotificationReadSQL = "" + const updateNotificationReadSQL = "" +
"UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1" "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND id <= (SELECT id FROM userapi_notifications WHERE event_id = $4) AND read <> $1"
const selectNotificationSQL = "" + const selectNotificationSQL = "" +
"SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" + "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" +
@ -111,7 +111,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error
} }
// Insert inserts a notification into the database. // Insert inserts a notification into the database.
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error { func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, highlight bool, n *api.Notification) error {
roomID, tsMS := n.RoomID, n.TS roomID, tsMS := n.RoomID, n.TS
nn := *n nn := *n
// Clears out fields that have their own columns to (1) shrink the // Clears out fields that have their own columns to (1) shrink the
@ -122,13 +122,13 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local
if err != nil { if err != nil {
return err return err
} }
_, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs)) _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, tsMS, highlight, string(bs))
return err return err
} }
// DeleteUpTo deletes all previous notifications, up to and including the event. // DeleteUpTo deletes all previous notifications, up to and including the event.
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) { func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos) res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, eventID)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -136,13 +136,13 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l
if err != nil { if err != nil {
return true, err return true, err
} }
log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("DeleteUpTo: %d rows affected", nrows) log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "event_id": eventID}).Tracef("DeleteUpTo: %d rows affected", nrows)
return nrows > 0, nil return nrows > 0, nil
} }
// UpdateRead updates the "read" value for an event. // UpdateRead updates the "read" value for an event.
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) { func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string, v bool) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos) res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, eventID)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -150,7 +150,7 @@ func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, l
if err != nil { if err != nil {
return true, err return true, err
} }
log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("UpdateRead: %d rows affected", nrows) log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "event_id": eventID}).Tracef("UpdateRead: %d rows affected", nrows)
return nrows > 0, nil return nrows > 0, nil
} }

View file

@ -7,6 +7,11 @@ import (
"testing" "testing"
"time" "time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/bcrypt"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
@ -14,10 +19,6 @@ import (
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/bcrypt"
) )
const loginTokenLifetime = time.Minute const loginTokenLifetime = time.Minute
@ -493,6 +494,7 @@ func Test_Notification(t *testing.T) {
db, close := mustCreateDatabase(t, dbType) db, close := mustCreateDatabase(t, dbType)
defer close() defer close()
// generate some dummy notifications // generate some dummy notifications
eventIDs := make([]string, 0, 10)
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
eventID := util.RandomString(16) eventID := util.RandomString(16)
roomID := room.ID roomID := room.ID
@ -513,8 +515,9 @@ func Test_Notification(t *testing.T) {
RoomID: roomID, RoomID: roomID,
TS: gomatrixserverlib.AsTimestamp(ts), TS: gomatrixserverlib.AsTimestamp(ts),
} }
err = db.InsertNotification(ctx, aliceLocalpart, eventID, int64(i+1), nil, notification) err = db.InsertNotification(ctx, aliceLocalpart, eventID, nil, notification)
assert.NoError(t, err, "unable to insert notification") assert.NoError(t, err, "unable to insert notification")
eventIDs = append(eventIDs, eventID)
} }
// get notifications // get notifications
@ -531,12 +534,12 @@ func Test_Notification(t *testing.T) {
assert.Equal(t, int64(4), total) assert.Equal(t, int64(4), total)
// mark notification as read // mark notification as read
affected, err := db.SetNotificationsRead(ctx, aliceLocalpart, room2.ID, 7, true) affected, err := db.SetNotificationsRead(ctx, aliceLocalpart, room2.ID, eventIDs[6], true)
assert.NoError(t, err, "unable to set notifications read") assert.NoError(t, err, "unable to set notifications read")
assert.True(t, affected) assert.True(t, affected)
// this should delete 2 notifications // this should delete 2 notifications
affected, err = db.DeleteNotificationsUpTo(ctx, aliceLocalpart, room2.ID, 8) affected, err = db.DeleteNotificationsUpTo(ctx, aliceLocalpart, room2.ID, eventIDs[7])
assert.NoError(t, err, "unable to set notifications read") assert.NoError(t, err, "unable to set notifications read")
assert.True(t, affected) assert.True(t, affected)

View file

@ -105,9 +105,9 @@ type PusherTable interface {
type NotificationTable interface { type NotificationTable interface {
Clean(ctx context.Context, txn *sql.Tx) error Clean(ctx context.Context, txn *sql.Tx) error
Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, highlight bool, n *api.Notification) error
DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string) (affected bool, _ error)
UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string, v bool) (affected bool, _ error)
Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error)
SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter NotificationFilter) (int64, error) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter NotificationFilter) (int64, error)
SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error)

View file

@ -18,6 +18,8 @@ import (
"time" "time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/internal/pushgateway"
keyapi "github.com/matrix-org/dendrite/keyserver/api" keyapi "github.com/matrix-org/dendrite/keyserver/api"
rsapi "github.com/matrix-org/dendrite/roomserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api"
@ -31,7 +33,6 @@ import (
"github.com/matrix-org/dendrite/userapi/producers" "github.com/matrix-org/dendrite/userapi/producers"
"github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/util" "github.com/matrix-org/dendrite/userapi/util"
"github.com/sirupsen/logrus"
) )
// AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions // AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions
@ -89,7 +90,7 @@ func NewInternalAPI(
logrus.WithError(err).Panic("failed to start user API read update consumer") logrus.WithError(err).Panic("failed to start user API read update consumer")
} }
eventConsumer := consumers.NewOutputStreamEventConsumer( eventConsumer := consumers.NewOutputRoomEventConsumer(
base.ProcessContext, cfg, js, db, pgClient, userAPI, rsAPI, syncProducer, base.ProcessContext, cfg, js, db, pgClient, userAPI, rsAPI, syncProducer,
) )
if err := eventConsumer.Start(); err != nil { if err := eventConsumer.Start(); err != nil {