diff --git a/eduserver/cache/cache.go b/eduserver/cache/cache.go index 46f7a2b13..61e2cbc87 100644 --- a/eduserver/cache/cache.go +++ b/eduserver/cache/cache.go @@ -109,6 +109,17 @@ func (t *EDUCache) AddTypingUser( return t.GetLatestSyncPosition() } +// AddSendToDeviceMessage increases the sync position for +// send-to-device updates. +// Returns the latest sync position for typing after update. +func (t *EDUCache) AddSendToDeviceMessage() int64 { + t.Lock() + defer t.Unlock() + + t.latestSyncPosition++ + return t.latestSyncPosition +} + // addUser with mutex lock & replace the previous timer. // Returns the latest typing sync position after update. func (t *EDUCache) addUser( diff --git a/syncapi/consumers/eduserver_sendtodevice.go b/syncapi/consumers/eduserver_sendtodevice.go index 0382c6083..08221a7dd 100644 --- a/syncapi/consumers/eduserver_sendtodevice.go +++ b/syncapi/consumers/eduserver_sendtodevice.go @@ -15,6 +15,7 @@ package consumers import ( + "context" "encoding/json" "github.com/Shopify/sarama" @@ -23,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/sync" + "github.com/matrix-org/dendrite/syncapi/types" log "github.com/sirupsen/logrus" ) @@ -78,5 +80,17 @@ func (s *OutputSendToDeviceEventConsumer) onMessage(msg *sarama.ConsumerMessage) "event_type": output.EventType, }).Debug("received send-to-device event from EDU server") + newPos, err := s.db.StoreNewSendForDeviceMessage(context.TODO(), output.SendToDeviceEvent) + if err != nil { + log.WithError(err).Errorf("failed to store send-to-device message") + return err + } + + s.notifier.OnNewSendToDevice( + output.UserID, + []string{output.DeviceID}, // TODO: support wildcard here as per spec + types.NewStreamToken(0, newPos), + ) + return nil } diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index b6e2b195b..bfe1c53cf 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -108,4 +108,6 @@ type Database interface { // updates and deletions for previous events. The sync token should be supplied to this function so // that we can clean up old events properly. SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, token types.StreamingToken) ([]types.SendToDeviceEvent, error) + // StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device. + StoreNewSendForDeviceMessage(ctx context.Context, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error) } diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index a989cc5a9..023fe8bf9 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -1038,6 +1038,18 @@ func (d *Database) AddSendToDeviceEvent( ) } +func (d *Database) StoreNewSendForDeviceMessage( + ctx context.Context, event gomatrixserverlib.SendToDeviceEvent, +) (types.StreamPosition, error) { + err := d.AddSendToDeviceEvent( + ctx, nil, event.UserID, event.DeviceID, event.EventType, string(event.Message), + ) + if err != nil { + return 0, err + } + return types.StreamPosition(d.EDUCache.AddSendToDeviceMessage()), nil +} + func (d *Database) SendToDeviceUpdatesForSync( ctx context.Context, userID, deviceID string, diff --git a/syncapi/sync/notifier.go b/syncapi/sync/notifier.go index 9b410a0c4..325e75351 100644 --- a/syncapi/sync/notifier.go +++ b/syncapi/sync/notifier.go @@ -120,6 +120,18 @@ func (n *Notifier) OnNewEvent( } } +func (n *Notifier) OnNewSendToDevice( + userID string, deviceIDs []string, + posUpdate types.StreamingToken, +) { + n.streamLock.Lock() + defer n.streamLock.Unlock() + latestPos := n.currPos.WithUpdates(posUpdate) + n.currPos = latestPos + + n.wakeupUserDevice(userID, deviceIDs, latestPos) +} + // GetListener returns a UserStreamListener that can be used to wait for // updates for a user. Must be closed. // notify for anything before sincePos @@ -189,8 +201,8 @@ func (n *Notifier) wakeupUsers(userIDs []string, newPos types.StreamingToken) { // wakeupUserDevice will wake up the sync stream for a specific user device. Other // device streams will be left alone. // nolint:unused -func (n *Notifier) wakeupUserDevice(userDevices map[string]string, newPos types.StreamingToken) { - for userID, deviceID := range userDevices { +func (n *Notifier) wakeupUserDevice(userID string, deviceIDs []string, newPos types.StreamingToken) { + for _, deviceID := range deviceIDs { if stream := n.fetchUserDeviceStream(userID, deviceID, false); stream != nil { stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream } diff --git a/syncapi/sync/notifier_test.go b/syncapi/sync/notifier_test.go index 14ddef20a..132315573 100644 --- a/syncapi/sync/notifier_test.go +++ b/syncapi/sync/notifier_test.go @@ -172,7 +172,7 @@ func TestCorrectStreamWakeup(t *testing.T) { time.Sleep(1 * time.Second) wake := "two" - n.wakeupUserDevice(map[string]string{alice: wake}, syncPositionAfter) + n.wakeupUserDevice(alice, []string{wake}, syncPositionAfter) if result := <-awoken; result != wake { t.Fatalf("expected to wake %q, got %q", wake, result) diff --git a/syncapi/types/types.go b/syncapi/types/types.go index 3d9d31553..c3af31994 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -247,7 +247,7 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) { err = fmt.Errorf("token %s is not a streaming token", tok) return } - if len(t.Positions) != 2 { + if len(t.Positions) != 3 { err = fmt.Errorf("token %s wrong number of values, got %d want 2", tok, len(t.Positions)) return } diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go index 1e27a8e32..f3d4f6b7e 100644 --- a/syncapi/types/types_test.go +++ b/syncapi/types/types_test.go @@ -4,9 +4,9 @@ import "testing" func TestNewSyncTokenFromString(t *testing.T) { shouldPass := map[string]syncToken{ - "s4_0": NewStreamToken(4, 0).syncToken, - "s3_1": NewStreamToken(3, 1).syncToken, - "t3_1": NewTopologyToken(3, 1).syncToken, + "s4_0_0": NewStreamToken(4, 0).syncToken, + "s3_1_0": NewStreamToken(3, 1).syncToken, + "t3_1": NewTopologyToken(3, 1).syncToken, } shouldFail := []string{