diff --git a/syncapi/consumers/eduserver_receipts.go b/syncapi/consumers/eduserver_receipts.go index 88334b654..01e89b8a9 100644 --- a/syncapi/consumers/eduserver_receipts.go +++ b/syncapi/consumers/eduserver_receipts.go @@ -18,8 +18,6 @@ import ( "context" "encoding/json" - "github.com/matrix-org/dendrite/syncapi/types" - "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/internal" @@ -87,8 +85,8 @@ func (s *OutputReceiptEventConsumer) onMessage(msg *sarama.ConsumerMessage) erro if err != nil { return err } - // update stream position - s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos}) + + s.db.TypingStream().StreamAdvance(streamPos) return nil } diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index db11837df..adf534cf4 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -34,6 +34,7 @@ type Database interface { PDUStream() types.StreamProvider PDUTopology() types.TopologyProvider TypingStream() types.StreamProvider + ReceiptStream() types.StreamProvider // AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs. AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) diff --git a/syncapi/storage/shared/stream_pdu.go b/syncapi/storage/shared/stream_pdu.go index ed6631563..e2796376e 100644 --- a/syncapi/storage/shared/stream_pdu.go +++ b/syncapi/storage/shared/stream_pdu.go @@ -35,8 +35,10 @@ func (p *PDUStreamProvider) StreamAdvance( p.latestMutex.Lock() defer p.latestMutex.Unlock() - p.latest = latest - p.update.Broadcast() + if latest > p.latest { + p.latest = latest + p.update.Broadcast() + } } func (p *PDUStreamProvider) StreamRange( @@ -50,7 +52,7 @@ func (p *PDUStreamProvider) StreamRange( Backwards: from.IsAfter(to), } newPos = types.StreamingToken{ - PDUPosition: from.PDUPosition, + PDUPosition: to.PDUPosition, } var err error @@ -72,7 +74,7 @@ func (p *PDUStreamProvider) StreamRange( } for _, roomID := range joinedRooms { - req.Rooms[roomID] = "join" + req.Rooms[roomID] = gomatrixserverlib.Join } for _, stateDelta := range stateDeltas { @@ -110,7 +112,13 @@ func (p *PDUStreamProvider) StreamRange( gomatrixserverlib.FormatSync, ) - // TODO: fill in prev_batch + if len(events) > 0 { + prevBatch, err := p.DB.getBackwardTopologyPos(ctx, nil, events) + if err != nil { + return + } + room.Timeline.PrevBatch = &prevBatch + } req.Response.Rooms.Join[roomID] = room } diff --git a/syncapi/storage/shared/stream_receipt.go b/syncapi/storage/shared/stream_receipt.go new file mode 100644 index 000000000..edaafd173 --- /dev/null +++ b/syncapi/storage/shared/stream_receipt.go @@ -0,0 +1,166 @@ +package shared + +import ( + "context" + "encoding/json" + "sync" + + eduAPI "github.com/matrix-org/dendrite/eduserver/api" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +type ReceiptStreamProvider struct { + DB *Database + latest types.StreamPosition + latestMutex sync.RWMutex + update *sync.Cond +} + +func (p *ReceiptStreamProvider) StreamSetup() { + locker := &sync.Mutex{} + p.update = sync.NewCond(locker) + + latest, err := p.DB.Receipts.SelectMaxReceiptID(context.Background(), nil) + if err != nil { + return + } + + p.latest = types.StreamPosition(latest) +} + +func (p *ReceiptStreamProvider) StreamAdvance( + latest types.StreamPosition, +) { + p.latestMutex.Lock() + defer p.latestMutex.Unlock() + + if latest > p.latest { + p.latest = latest + p.update.Broadcast() + } +} + +func (p *ReceiptStreamProvider) StreamRange( + ctx context.Context, + req *types.StreamRangeRequest, + from, to types.StreamingToken, +) types.StreamingToken { + var joinedRooms []string + for roomID, membership := range req.Rooms { + if membership == gomatrixserverlib.Join { + joinedRooms = append(joinedRooms, roomID) + } + } + + lastPos, receipts, err := p.DB.Receipts.SelectRoomReceiptsAfter(context.TODO(), joinedRooms, from.ReceiptPosition) + if err != nil { + return types.StreamingToken{} //fmt.Errorf("unable to select receipts for rooms: %w", err) + } + + // Group receipts by room, so we can create one ClientEvent for every room + receiptsByRoom := make(map[string][]eduAPI.OutputReceiptEvent) + for _, receipt := range receipts { + receiptsByRoom[receipt.RoomID] = append(receiptsByRoom[receipt.RoomID], receipt) + } + + for roomID, receipts := range receiptsByRoom { + jr := req.Response.Rooms.Join[roomID] + var ok bool + + ev := gomatrixserverlib.ClientEvent{ + Type: gomatrixserverlib.MReceipt, + RoomID: roomID, + } + content := make(map[string]eduAPI.ReceiptMRead) + for _, receipt := range receipts { + var read eduAPI.ReceiptMRead + if read, ok = content[receipt.EventID]; !ok { + read = eduAPI.ReceiptMRead{ + User: make(map[string]eduAPI.ReceiptTS), + } + } + read.User[receipt.UserID] = eduAPI.ReceiptTS{TS: receipt.Timestamp} + content[receipt.EventID] = read + } + ev.Content, err = json.Marshal(content) + if err != nil { + return types.StreamingToken{} // err + } + + jr.Ephemeral.Events = append(jr.Ephemeral.Events, ev) + req.Response.Rooms.Join[roomID] = jr + } + + if lastPos > 0 { + return types.StreamingToken{ + ReceiptPosition: lastPos, + } + } else { + return types.StreamingToken{ + ReceiptPosition: to.ReceiptPosition, + } + } +} + +func (p *ReceiptStreamProvider) StreamNotifyAfter( + ctx context.Context, + from types.StreamingToken, +) chan struct{} { + ch := make(chan struct{}) + + check := func() bool { + p.latestMutex.RLock() + defer p.latestMutex.RUnlock() + if p.latest > from.ReceiptPosition { + close(ch) + return true + } + return false + } + + // If we've already advanced past the specified position + // then return straight away. + if check() { + return ch + } + + // If we haven't, then we'll subscribe to updates. The + // sync.Cond will fire every time the latest position + // updates, so we can check and see if we've advanced + // past it. + go func(p *ReceiptStreamProvider) { + p.update.L.Lock() + defer p.update.L.Unlock() + + for { + select { + case <-ctx.Done(): + // The context has expired, so there's no point + // in continuing to wait for the update. + return + default: + // The latest position has been advanced. Let's + // see if it's advanced to the position we care + // about. If it has then we'll return. + p.update.Wait() + if check() { + return + } + } + } + }(p) + + return ch +} + +func (p *ReceiptStreamProvider) StreamLatestPosition( + ctx context.Context, +) types.StreamingToken { + p.latestMutex.RLock() + defer p.latestMutex.RUnlock() + + return types.StreamingToken{ + ReceiptPosition: p.latest, + } +} diff --git a/syncapi/storage/shared/stream_typing.go b/syncapi/storage/shared/stream_typing.go index 5c28712ce..2f304176a 100644 --- a/syncapi/storage/shared/stream_typing.go +++ b/syncapi/storage/shared/stream_typing.go @@ -3,7 +3,6 @@ package shared import ( "context" "encoding/json" - "fmt" "sync" "github.com/matrix-org/dendrite/syncapi/types" @@ -28,8 +27,10 @@ func (p *TypingStreamProvider) StreamAdvance( p.latestMutex.Lock() defer p.latestMutex.Unlock() - p.latest = latest - p.update.Broadcast() + if latest > p.latest { + p.latest = latest + p.update.Broadcast() + } } func (p *TypingStreamProvider) StreamRange( @@ -56,12 +57,8 @@ func (p *TypingStreamProvider) StreamRange( return types.StreamingToken{} } - fmt.Println("Typing", roomID, "users", users) - jr.Ephemeral.Events = append(jr.Ephemeral.Events, ev) req.Response.Rooms.Join[roomID] = jr - } else { - fmt.Println("Typing", roomID, "not updated") } } diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 89a078099..07e753544 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -51,9 +51,10 @@ type Database struct { Receipts tables.Receipts EDUCache *cache.EDUCache - PDUStreamProvider types.StreamProvider - PDUTopologyProvider types.TopologyProvider - TypingStreamProvider types.StreamProvider + PDUStreamProvider types.StreamProvider + PDUTopologyProvider types.TopologyProvider + TypingStreamProvider types.StreamProvider + ReceiptStreamProvider types.StreamProvider } // ConfigureProviders creates instances of the various @@ -62,9 +63,11 @@ type Database struct { func (d *Database) ConfigureProviders() { d.PDUStreamProvider = &PDUStreamProvider{DB: d} d.TypingStreamProvider = &TypingStreamProvider{DB: d} + d.ReceiptStreamProvider = &ReceiptStreamProvider{DB: d} d.PDUStreamProvider.StreamSetup() d.TypingStreamProvider.StreamSetup() + d.ReceiptStreamProvider.StreamSetup() d.PDUTopologyProvider = &PDUTopologyProvider{DB: d} } @@ -81,6 +84,10 @@ func (d *Database) TypingStream() types.StreamProvider { return d.TypingStreamProvider } +func (d *Database) ReceiptStream() types.StreamProvider { + return d.ReceiptStreamProvider +} + // Events lookups a list of event by their event ID. // Returns a list of events matching the requested IDs found in the database. // If an event is not found in the database then it will be omitted from the list. diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 032031193..79aa8b56f 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -70,7 +70,7 @@ func NewRequestPool( lastseen: sync.Map{}, pduStream: db.PDUStream(), typingStream: db.TypingStream(), - receiptStream: nil, // TODO + receiptStream: db.ReceiptStream(), sendToDeviceStream: nil, // TODO inviteStream: nil, // TODO deviceListStream: nil, // TODO @@ -188,7 +188,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. case <-rp.pduStream.StreamNotifyAfter(syncReq.ctx, syncReq.since): case <-rp.typingStream.StreamNotifyAfter(syncReq.ctx, syncReq.since): - // case <-rp.receiptStream.StreamNotifyAfter(syncReq.ctx, syncReq.since): + case <-rp.receiptStream.StreamNotifyAfter(syncReq.ctx, syncReq.since): // case <-rp.sendToDeviceStream.StreamNotifyAfter(syncReq.ctx, syncReq.since): // case <-rp.inviteStream.StreamNotifyAfter(syncReq.ctx, syncReq.since): // case <-rp.deviceListStream.StreamNotifyAfter(syncReq.ctx, syncReq.since): @@ -198,7 +198,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. var latest types.StreamingToken latest.ApplyUpdates(rp.pduStream.StreamLatestPosition(syncReq.ctx)) latest.ApplyUpdates(rp.typingStream.StreamLatestPosition(syncReq.ctx)) - // latest.ApplyUpdates(rp.receiptStream.StreamLatestPosition(syncReq.ctx)) + latest.ApplyUpdates(rp.receiptStream.StreamLatestPosition(syncReq.ctx)) // latest.ApplyUpdates(rp.sendToDeviceStream.StreamLatestPosition(syncReq.ctx)) // latest.ApplyUpdates(rp.inviteStream.StreamLatestPosition(syncReq.ctx)) // latest.ApplyUpdates(rp.deviceListStream.StreamLatestPosition(syncReq.ctx)) @@ -212,7 +212,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. sr.Response.NextBatch.ApplyUpdates(rp.pduStream.StreamRange(syncReq.ctx, sr, syncReq.since, latest)) sr.Response.NextBatch.ApplyUpdates(rp.typingStream.StreamRange(syncReq.ctx, sr, syncReq.since, latest)) - // sr.Response.NextBatch.ApplyUpdates(rp.receiptStream.StreamRange(syncReq.ctx, sr, syncReq.since, latest)) + sr.Response.NextBatch.ApplyUpdates(rp.receiptStream.StreamRange(syncReq.ctx, sr, syncReq.since, latest)) // sr.Response.NextBatch.ApplyUpdates(rp.sendToDeviceStream.StreamRange(syncReq.ctx, sr, syncReq.since, latest)) // sr.Response.NextBatch.ApplyUpdates(rp.inviteStream.StreamRange(syncReq.ctx, sr, syncReq.since, latest)) // sr.Response.NextBatch.ApplyUpdates(rp.inviteStream.StreamRange(syncReq.ctx, sr, syncReq.since, latest))