Handle incoming send-to-device messages, count them with EDU stream pos

This commit is contained in:
Neil Alexander 2020-05-28 15:35:49 +01:00
parent d3bf9cb31b
commit 4d6347b21a
8 changed files with 58 additions and 7 deletions

View file

@ -109,6 +109,17 @@ func (t *EDUCache) AddTypingUser(
return t.GetLatestSyncPosition()
}
// AddSendToDeviceMessage increases the sync position for
// send-to-device updates.
// Returns the latest sync position for typing after update.
func (t *EDUCache) AddSendToDeviceMessage() int64 {
t.Lock()
defer t.Unlock()
t.latestSyncPosition++
return t.latestSyncPosition
}
// addUser with mutex lock & replace the previous timer.
// Returns the latest typing sync position after update.
func (t *EDUCache) addUser(

View file

@ -15,6 +15,7 @@
package consumers
import (
"context"
"encoding/json"
"github.com/Shopify/sarama"
@ -23,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/sync"
"github.com/matrix-org/dendrite/syncapi/types"
log "github.com/sirupsen/logrus"
)
@ -78,5 +80,17 @@ func (s *OutputSendToDeviceEventConsumer) onMessage(msg *sarama.ConsumerMessage)
"event_type": output.EventType,
}).Debug("received send-to-device event from EDU server")
newPos, err := s.db.StoreNewSendForDeviceMessage(context.TODO(), output.SendToDeviceEvent)
if err != nil {
log.WithError(err).Errorf("failed to store send-to-device message")
return err
}
s.notifier.OnNewSendToDevice(
output.UserID,
[]string{output.DeviceID}, // TODO: support wildcard here as per spec
types.NewStreamToken(0, newPos),
)
return nil
}

View file

@ -108,4 +108,6 @@ type Database interface {
// updates and deletions for previous events. The sync token should be supplied to this function so
// that we can clean up old events properly.
SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, token types.StreamingToken) ([]types.SendToDeviceEvent, error)
// StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device.
StoreNewSendForDeviceMessage(ctx context.Context, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error)
}

View file

@ -1038,6 +1038,18 @@ func (d *Database) AddSendToDeviceEvent(
)
}
func (d *Database) StoreNewSendForDeviceMessage(
ctx context.Context, event gomatrixserverlib.SendToDeviceEvent,
) (types.StreamPosition, error) {
err := d.AddSendToDeviceEvent(
ctx, nil, event.UserID, event.DeviceID, event.EventType, string(event.Message),
)
if err != nil {
return 0, err
}
return types.StreamPosition(d.EDUCache.AddSendToDeviceMessage()), nil
}
func (d *Database) SendToDeviceUpdatesForSync(
ctx context.Context,
userID, deviceID string,

View file

@ -120,6 +120,18 @@ func (n *Notifier) OnNewEvent(
}
}
func (n *Notifier) OnNewSendToDevice(
userID string, deviceIDs []string,
posUpdate types.StreamingToken,
) {
n.streamLock.Lock()
defer n.streamLock.Unlock()
latestPos := n.currPos.WithUpdates(posUpdate)
n.currPos = latestPos
n.wakeupUserDevice(userID, deviceIDs, latestPos)
}
// GetListener returns a UserStreamListener that can be used to wait for
// updates for a user. Must be closed.
// notify for anything before sincePos
@ -189,8 +201,8 @@ func (n *Notifier) wakeupUsers(userIDs []string, newPos types.StreamingToken) {
// wakeupUserDevice will wake up the sync stream for a specific user device. Other
// device streams will be left alone.
// nolint:unused
func (n *Notifier) wakeupUserDevice(userDevices map[string]string, newPos types.StreamingToken) {
for userID, deviceID := range userDevices {
func (n *Notifier) wakeupUserDevice(userID string, deviceIDs []string, newPos types.StreamingToken) {
for _, deviceID := range deviceIDs {
if stream := n.fetchUserDeviceStream(userID, deviceID, false); stream != nil {
stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream
}

View file

@ -172,7 +172,7 @@ func TestCorrectStreamWakeup(t *testing.T) {
time.Sleep(1 * time.Second)
wake := "two"
n.wakeupUserDevice(map[string]string{alice: wake}, syncPositionAfter)
n.wakeupUserDevice(alice, []string{wake}, syncPositionAfter)
if result := <-awoken; result != wake {
t.Fatalf("expected to wake %q, got %q", wake, result)

View file

@ -247,7 +247,7 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) {
err = fmt.Errorf("token %s is not a streaming token", tok)
return
}
if len(t.Positions) != 2 {
if len(t.Positions) != 3 {
err = fmt.Errorf("token %s wrong number of values, got %d want 2", tok, len(t.Positions))
return
}

View file

@ -4,9 +4,9 @@ import "testing"
func TestNewSyncTokenFromString(t *testing.T) {
shouldPass := map[string]syncToken{
"s4_0": NewStreamToken(4, 0).syncToken,
"s3_1": NewStreamToken(3, 1).syncToken,
"t3_1": NewTopologyToken(3, 1).syncToken,
"s4_0_0": NewStreamToken(4, 0).syncToken,
"s3_1_0": NewStreamToken(3, 1).syncToken,
"t3_1": NewTopologyToken(3, 1).syncToken,
}
shouldFail := []string{