diff --git a/syncapi/sync/notifier.go b/syncapi/sync/notifier.go index b3ed5cd03..806b7d328 100644 --- a/syncapi/sync/notifier.go +++ b/syncapi/sync/notifier.go @@ -37,8 +37,8 @@ type Notifier struct { streamLock *sync.Mutex // The latest sync position currPos types.StreamingToken - // A map of user_id => UserStream which can be used to wake a given user's /sync request. - userStreams map[string]*UserStream + // A map of user_id => device_id => UserStream which can be used to wake a given user's /sync request. + userDeviceStreams map[string]map[string]*UserDeviceStream // The last time we cleaned out stale entries from the userStreams map lastCleanUpTime time.Time } @@ -50,7 +50,7 @@ func NewNotifier(pos types.StreamingToken) *Notifier { return &Notifier{ currPos: pos, roomIDToJoinedUsers: make(map[string]userIDSet), - userStreams: make(map[string]*UserStream), + userDeviceStreams: make(map[string]map[string]*UserDeviceStream), streamLock: &sync.Mutex{}, lastCleanUpTime: time.Now(), } @@ -123,7 +123,7 @@ func (n *Notifier) OnNewEvent( // GetListener returns a UserStreamListener that can be used to wait for // updates for a user. Must be closed. // notify for anything before sincePos -func (n *Notifier) GetListener(req syncRequest) UserStreamListener { +func (n *Notifier) GetListener(req syncRequest) UserDeviceStreamListener { // Do what synapse does: https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/notifier.py#L298 // - Bucket request into a lookup map keyed off a list of joined room IDs and separately a user ID // - Incoming events wake requests for a matching room ID @@ -137,7 +137,7 @@ func (n *Notifier) GetListener(req syncRequest) UserStreamListener { n.removeEmptyUserStreams() - return n.fetchUserStream(req.device.UserID, true).GetListener(req.ctx) + return n.fetchUserDeviceStream(req.device.UserID, req.device.ID, true).GetListener(req.ctx) } // Load the membership states required to notify users correctly. @@ -175,27 +175,56 @@ func (n *Notifier) setUsersJoinedToRooms(roomIDToUserIDs map[string][]string) { func (n *Notifier) wakeupUsers(userIDs []string, newPos types.StreamingToken) { for _, userID := range userIDs { - stream := n.fetchUserStream(userID, false) - if stream != nil { + for _, stream := range n.fetchUserStreams(userID, false) { + if stream != nil { + stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream + } + } + } +} + +func (n *Notifier) wakeupUserDevice(userDevices map[string]string, newPos types.StreamingToken) { + for userID, deviceID := range userDevices { + if stream := n.fetchUserDeviceStream(userID, deviceID, false); stream != nil { stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream } } } -// fetchUserStream retrieves a stream unique to the given user. If makeIfNotExists is true, +// fetchUserDeviceStream retrieves a stream unique to the given user. If makeIfNotExists is true, // a stream will be made for this user if one doesn't exist and it will be returned. This // function does not wait for data to be available on the stream. // NB: Callers should have locked the mutex before calling this function. -func (n *Notifier) fetchUserStream(userID string, makeIfNotExists bool) *UserStream { - stream, ok := n.userStreams[userID] +func (n *Notifier) fetchUserDeviceStream(userID, deviceID string, makeIfNotExists bool) *UserDeviceStream { + user, ok := n.userDeviceStreams[userID] + if !ok && makeIfNotExists { + n.userDeviceStreams[userID] = map[string]*UserDeviceStream{} + } + stream, ok := user[deviceID] if !ok && makeIfNotExists { // TODO: Unbounded growth of streams (1 per user) - stream = NewUserStream(userID, n.currPos) - n.userStreams[userID] = stream + stream = NewUserDeviceStream(userID, deviceID, n.currPos) + n.userDeviceStreams[userID][deviceID] = stream } return stream } +// fetchUserDeviceStreams retrieves all streams for the given user. If makeIfNotExists is true, +// a stream will be made for this user if one doesn't exist and it will be returned. This +// function does not wait for data to be available on the stream. +// NB: Callers should have locked the mutex before calling this function. +func (n *Notifier) fetchUserStreams(userID string, makeIfNotExists bool) []*UserDeviceStream { + user, ok := n.userDeviceStreams[userID] + if !ok && makeIfNotExists { + return []*UserDeviceStream{} + } + streams := []*UserDeviceStream{} + for _, stream := range user { + streams = append(streams, stream) + } + return streams +} + // Not thread-safe: must be called on the OnNewEvent goroutine only func (n *Notifier) addJoinedUser(roomID, userID string) { if _, ok := n.roomIDToJoinedUsers[roomID]; !ok { @@ -236,9 +265,14 @@ func (n *Notifier) removeEmptyUserStreams() { n.lastCleanUpTime = now deleteBefore := now.Add(-5 * time.Minute) - for key, value := range n.userStreams { - if value.TimeOfLastNonEmpty().Before(deleteBefore) { - delete(n.userStreams, key) + for user, byUser := range n.userDeviceStreams { + for device, value := range byUser { + if value.TimeOfLastNonEmpty().Before(deleteBefore) { + delete(n.userDeviceStreams[user], device) + } + if len(n.userDeviceStreams[user]) == 0 { + delete(n.userDeviceStreams, user) + } } } } diff --git a/syncapi/sync/notifier_test.go b/syncapi/sync/notifier_test.go index 7d979fcc9..7cf9b8448 100644 --- a/syncapi/sync/notifier_test.go +++ b/syncapi/sync/notifier_test.go @@ -132,7 +132,7 @@ func TestNewEventAndJoinedToRoom(t *testing.T) { wg.Done() }() - stream := lockedFetchUserStream(n, bob) + stream := lockedFetchUserStream(n, "", bob) waitForBlocking(stream, 1) n.OnNewEvent(&randomMessageEvent, "", nil, syncPositionAfter) @@ -158,7 +158,7 @@ func TestNewInviteEventForUser(t *testing.T) { wg.Done() }() - stream := lockedFetchUserStream(n, bob) + stream := lockedFetchUserStream(n, "", bob) waitForBlocking(stream, 1) n.OnNewEvent(&aliceInviteBobEvent, "", nil, syncPositionAfter) @@ -184,7 +184,7 @@ func TestEDUWakeup(t *testing.T) { wg.Done() }() - stream := lockedFetchUserStream(n, bob) + stream := lockedFetchUserStream(n, "", bob) waitForBlocking(stream, 1) n.OnNewEvent(&aliceInviteBobEvent, "", nil, syncPositionNewEDU) @@ -213,7 +213,7 @@ func TestMultipleRequestWakeup(t *testing.T) { go poll() go poll() - stream := lockedFetchUserStream(n, bob) + stream := lockedFetchUserStream(n, "", bob) waitForBlocking(stream, 3) n.OnNewEvent(&randomMessageEvent, "", nil, syncPositionAfter) @@ -247,14 +247,14 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) { mustEqualPositions(t, pos, syncPositionAfter) leaveWG.Done() }() - bobStream := lockedFetchUserStream(n, bob) + bobStream := lockedFetchUserStream(n, "", bob) waitForBlocking(bobStream, 1) n.OnNewEvent(&bobLeaveEvent, "", nil, syncPositionAfter) leaveWG.Wait() // send an event into the room. Make sure alice gets it. Bob should not. var aliceWG sync.WaitGroup - aliceStream := lockedFetchUserStream(n, alice) + aliceStream := lockedFetchUserStream(n, "", alice) aliceWG.Add(1) go func() { pos, err := waitForEvents(n, newTestSyncRequest(alice, syncPositionAfter)) @@ -300,7 +300,7 @@ func waitForEvents(n *Notifier, req syncRequest) (types.StreamingToken, error) { } // Wait until something is Wait()ing on the user stream. -func waitForBlocking(s *UserStream, numBlocking uint) { +func waitForBlocking(s *UserDeviceStream, numBlocking uint) { for numBlocking != s.NumWaiting() { // This is horrible but I don't want to add a signalling mechanism JUST for testing. time.Sleep(1 * time.Microsecond) @@ -309,11 +309,11 @@ func waitForBlocking(s *UserStream, numBlocking uint) { // lockedFetchUserStream invokes Notifier.fetchUserStream, respecting Notifier.streamLock. // A new stream is made if it doesn't exist already. -func lockedFetchUserStream(n *Notifier, userID string) *UserStream { +func lockedFetchUserStream(n *Notifier, userID, deviceID string) *UserDeviceStream { n.streamLock.Lock() defer n.streamLock.Unlock() - return n.fetchUserStream(userID, true) + return n.fetchUserStream(userID, deviceID, true) } func newTestSyncRequest(userID string, since types.StreamingToken) syncRequest { diff --git a/syncapi/sync/userstream.go b/syncapi/sync/userstream.go index b2eafa3dc..ff9a4d003 100644 --- a/syncapi/sync/userstream.go +++ b/syncapi/sync/userstream.go @@ -23,12 +23,13 @@ import ( "github.com/matrix-org/dendrite/syncapi/types" ) -// UserStream represents a communication mechanism between the /sync request goroutine +// UserDeviceStream represents a communication mechanism between the /sync request goroutine // and the underlying sync server goroutines. // Goroutines can get a UserStreamListener to wait for updates, and can Broadcast() // updates. -type UserStream struct { - UserID string +type UserDeviceStream struct { + UserID string + DeviceID string // The lock that protects changes to this struct lock sync.Mutex // Closed when there is an update. @@ -41,18 +42,19 @@ type UserStream struct { numWaiting uint } -// UserStreamListener allows a sync request to wait for updates for a user. -type UserStreamListener struct { - userStream *UserStream +// UserDeviceStreamListener allows a sync request to wait for updates for a user. +type UserDeviceStreamListener struct { + userStream *UserDeviceStream // Whether the stream has been closed hasClosed bool } -// NewUserStream creates a new user stream -func NewUserStream(userID string, currPos types.StreamingToken) *UserStream { - return &UserStream{ +// NewUserDeviceStream creates a new user stream +func NewUserDeviceStream(userID, deviceID string, currPos types.StreamingToken) *UserDeviceStream { + return &UserDeviceStream{ UserID: userID, + DeviceID: deviceID, timeOfLastChannel: time.Now(), pos: currPos, signalChannel: make(chan struct{}), @@ -62,18 +64,18 @@ func NewUserStream(userID string, currPos types.StreamingToken) *UserStream { // GetListener returns UserStreamListener that a sync request can use to wait // for new updates with. // UserStreamListener must be closed -func (s *UserStream) GetListener(ctx context.Context) UserStreamListener { +func (s *UserDeviceStream) GetListener(ctx context.Context) UserDeviceStreamListener { s.lock.Lock() defer s.lock.Unlock() s.numWaiting++ // We decrement when UserStreamListener is closed - listener := UserStreamListener{ + listener := UserDeviceStreamListener{ userStream: s, } // Lets be a bit paranoid here and check that Close() is being called - runtime.SetFinalizer(&listener, func(l *UserStreamListener) { + runtime.SetFinalizer(&listener, func(l *UserDeviceStreamListener) { if !l.hasClosed { l.Close() } @@ -83,7 +85,7 @@ func (s *UserStream) GetListener(ctx context.Context) UserStreamListener { } // Broadcast a new sync position for this user. -func (s *UserStream) Broadcast(pos types.StreamingToken) { +func (s *UserDeviceStream) Broadcast(pos types.StreamingToken) { s.lock.Lock() defer s.lock.Unlock() @@ -96,7 +98,7 @@ func (s *UserStream) Broadcast(pos types.StreamingToken) { // NumWaiting returns the number of goroutines waiting for waiting for updates. // Used for metrics and testing. -func (s *UserStream) NumWaiting() uint { +func (s *UserDeviceStream) NumWaiting() uint { s.lock.Lock() defer s.lock.Unlock() return s.numWaiting @@ -105,7 +107,7 @@ func (s *UserStream) NumWaiting() uint { // TimeOfLastNonEmpty returns the last time that the number of waiting listeners // was non-empty, may be time.Now() if number of waiting listeners is currently // non-empty. -func (s *UserStream) TimeOfLastNonEmpty() time.Time { +func (s *UserDeviceStream) TimeOfLastNonEmpty() time.Time { s.lock.Lock() defer s.lock.Unlock() @@ -118,7 +120,7 @@ func (s *UserStream) TimeOfLastNonEmpty() time.Time { // GetSyncPosition returns last sync position which the UserStream was // notified about -func (s *UserStreamListener) GetSyncPosition() types.StreamingToken { +func (s *UserDeviceStreamListener) GetSyncPosition() types.StreamingToken { s.userStream.lock.Lock() defer s.userStream.lock.Unlock() @@ -130,7 +132,7 @@ func (s *UserStreamListener) GetSyncPosition() types.StreamingToken { // sincePos specifies from which point we want to be notified about. If there // has already been an update after sincePos we'll return a closed channel // immediately. -func (s *UserStreamListener) GetNotifyChannel(sincePos types.StreamingToken) <-chan struct{} { +func (s *UserDeviceStreamListener) GetNotifyChannel(sincePos types.StreamingToken) <-chan struct{} { s.userStream.lock.Lock() defer s.userStream.lock.Unlock() @@ -147,7 +149,7 @@ func (s *UserStreamListener) GetNotifyChannel(sincePos types.StreamingToken) <-c } // Close cleans up resources used -func (s *UserStreamListener) Close() { +func (s *UserDeviceStreamListener) Close() { s.userStream.lock.Lock() defer s.userStream.lock.Unlock()