diff --git a/syncapi/consumers/clientapi.go b/syncapi/consumers/clientapi.go index b6d3ab391..4939aaf31 100644 --- a/syncapi/consumers/clientapi.go +++ b/syncapi/consumers/clientapi.go @@ -22,8 +22,10 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/streams" + "github.com/matrix-org/dendrite/syncapi/types" log "github.com/sirupsen/logrus" ) @@ -32,6 +34,7 @@ type OutputClientDataConsumer struct { clientAPIConsumer *internal.ContinualConsumer db storage.Database streams *streams.Streams + notifier *notifier.Notifier } // NewOutputClientDataConsumer creates a new OutputClientData consumer. Call Start() to begin consuming from room servers. @@ -39,6 +42,7 @@ func NewOutputClientDataConsumer( cfg *config.SyncAPI, kafkaConsumer sarama.Consumer, store storage.Database, + notifier *notifier.Notifier, streams *streams.Streams, ) *OutputClientDataConsumer { @@ -51,6 +55,7 @@ func NewOutputClientDataConsumer( s := &OutputClientDataConsumer{ clientAPIConsumer: &consumer, db: store, + notifier: notifier, streams: streams, } consumer.ProcessMessage = s.onMessage @@ -92,6 +97,7 @@ func (s *OutputClientDataConsumer) onMessage(msg *sarama.ConsumerMessage) error } s.streams.AccountDataStreamProvider.Advance(pduPos) + s.notifier.OnNewEvent(nil, "", []string{string(msg.Key)}, types.StreamingToken{AccountDataPosition: pduPos}) return nil } diff --git a/syncapi/consumers/eduserver_receipts.go b/syncapi/consumers/eduserver_receipts.go index f52ea4360..bd3ec9793 100644 --- a/syncapi/consumers/eduserver_receipts.go +++ b/syncapi/consumers/eduserver_receipts.go @@ -22,8 +22,10 @@ import ( "github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/streams" + "github.com/matrix-org/dendrite/syncapi/types" log "github.com/sirupsen/logrus" ) @@ -32,6 +34,7 @@ type OutputReceiptEventConsumer struct { receiptConsumer *internal.ContinualConsumer db storage.Database streams *streams.Streams + notifier *notifier.Notifier } // NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer. @@ -40,6 +43,7 @@ func NewOutputReceiptEventConsumer( cfg *config.SyncAPI, kafkaConsumer sarama.Consumer, store storage.Database, + notifier *notifier.Notifier, streams *streams.Streams, ) *OutputReceiptEventConsumer { @@ -53,6 +57,7 @@ func NewOutputReceiptEventConsumer( s := &OutputReceiptEventConsumer{ receiptConsumer: &consumer, db: store, + notifier: notifier, streams: streams, } @@ -87,6 +92,7 @@ func (s *OutputReceiptEventConsumer) onMessage(msg *sarama.ConsumerMessage) erro } s.streams.ReceiptStreamProvider.Advance(streamPos) + s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos}) return nil } diff --git a/syncapi/consumers/eduserver_sendtodevice.go b/syncapi/consumers/eduserver_sendtodevice.go index 10b31f10a..a22b66795 100644 --- a/syncapi/consumers/eduserver_sendtodevice.go +++ b/syncapi/consumers/eduserver_sendtodevice.go @@ -22,8 +22,10 @@ import ( "github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/streams" + "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" log "github.com/sirupsen/logrus" @@ -35,6 +37,7 @@ type OutputSendToDeviceEventConsumer struct { db storage.Database serverName gomatrixserverlib.ServerName // our server name streams *streams.Streams + notifier *notifier.Notifier } // NewOutputSendToDeviceEventConsumer creates a new OutputSendToDeviceEventConsumer. @@ -43,6 +46,7 @@ func NewOutputSendToDeviceEventConsumer( cfg *config.SyncAPI, kafkaConsumer sarama.Consumer, store storage.Database, + notifier *notifier.Notifier, streams *streams.Streams, ) *OutputSendToDeviceEventConsumer { @@ -57,6 +61,7 @@ func NewOutputSendToDeviceEventConsumer( sendToDeviceConsumer: &consumer, db: store, serverName: cfg.Matrix.ServerName, + notifier: notifier, streams: streams, } @@ -102,6 +107,11 @@ func (s *OutputSendToDeviceEventConsumer) onMessage(msg *sarama.ConsumerMessage) } s.streams.SendToDeviceStreamProvider.Advance(streamPos) + s.notifier.OnNewSendToDevice( + output.UserID, + []string{output.DeviceID}, + types.StreamingToken{SendToDevicePosition: streamPos}, + ) return nil } diff --git a/syncapi/consumers/eduserver_typing.go b/syncapi/consumers/eduserver_typing.go index ead6c7a88..a8a67d653 100644 --- a/syncapi/consumers/eduserver_typing.go +++ b/syncapi/consumers/eduserver_typing.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/eduserver/cache" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" @@ -33,6 +34,7 @@ type OutputTypingEventConsumer struct { typingConsumer *internal.ContinualConsumer eduCache *cache.EDUCache streams *streams.Streams + notifier *notifier.Notifier } // NewOutputTypingEventConsumer creates a new OutputTypingEventConsumer. @@ -42,6 +44,7 @@ func NewOutputTypingEventConsumer( kafkaConsumer sarama.Consumer, store storage.Database, eduCache *cache.EDUCache, + notifier *notifier.Notifier, streams *streams.Streams, ) *OutputTypingEventConsumer { @@ -55,6 +58,7 @@ func NewOutputTypingEventConsumer( s := &OutputTypingEventConsumer{ typingConsumer: &consumer, eduCache: eduCache, + notifier: notifier, streams: streams, } @@ -66,7 +70,8 @@ func NewOutputTypingEventConsumer( // Start consuming from EDU api func (s *OutputTypingEventConsumer) Start() error { s.eduCache.SetTimeoutCallback(func(userID, roomID string, latestSyncPosition int64) { - s.streams.TypingStreamProvider.Advance(types.StreamPosition(latestSyncPosition)) + pos := types.StreamPosition(latestSyncPosition) + s.notifier.OnNewTyping(roomID, types.StreamingToken{TypingPosition: pos}) }) return s.typingConsumer.Start() } @@ -98,6 +103,7 @@ func (s *OutputTypingEventConsumer) onMessage(msg *sarama.ConsumerMessage) error } s.streams.TypingStreamProvider.Advance(typingPos) + s.notifier.OnNewTyping(output.Event.RoomID, types.StreamingToken{TypingPosition: typingPos}) return nil } diff --git a/syncapi/consumers/keychange.go b/syncapi/consumers/keychange.go index 75b2ba721..919b7c94d 100644 --- a/syncapi/consumers/keychange.go +++ b/syncapi/consumers/keychange.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" @@ -34,6 +35,7 @@ import ( type OutputKeyChangeEventConsumer struct { keyChangeConsumer *internal.ContinualConsumer db storage.Database + notifier *notifier.Notifier streams *streams.Streams serverName gomatrixserverlib.ServerName // our server name rsAPI roomserverAPI.RoomserverInternalAPI @@ -51,6 +53,7 @@ func NewOutputKeyChangeEventConsumer( keyAPI api.KeyInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI, store storage.Database, + notifier *notifier.Notifier, streams *streams.Streams, ) *OutputKeyChangeEventConsumer { @@ -69,6 +72,7 @@ func NewOutputKeyChangeEventConsumer( rsAPI: rsAPI, partitionToOffset: make(map[int32]int64), partitionToOffsetMu: sync.Mutex{}, + notifier: notifier, streams: streams, } @@ -120,6 +124,9 @@ func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) er } s.streams.DeviceListStreamProvider.Advance(posUpdate) + for userID := range queryRes.UserIDsToCount { + s.notifier.OnNewKeyChange(types.StreamingToken{DeviceListPosition: posUpdate}, userID, output.UserID) + } return nil } diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index ac63f0b0c..69d0d8f65 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" @@ -37,6 +38,7 @@ type OutputRoomEventConsumer struct { rsConsumer *internal.ContinualConsumer db storage.Database streams *streams.Streams + notifier *notifier.Notifier } // NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers. @@ -44,6 +46,7 @@ func NewOutputRoomEventConsumer( cfg *config.SyncAPI, kafkaConsumer sarama.Consumer, store storage.Database, + notifier *notifier.Notifier, streams *streams.Streams, rsAPI api.RoomserverInternalAPI, ) *OutputRoomEventConsumer { @@ -58,6 +61,7 @@ func NewOutputRoomEventConsumer( cfg: cfg, rsConsumer: &consumer, db: store, + notifier: notifier, streams: streams, rsAPI: rsAPI, } @@ -181,6 +185,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( } s.streams.PDUStreamProvider.Advance(pduPos) + s.notifier.OnNewEvent(ev, ev.RoomID(), nil, types.StreamingToken{PDUPosition: pduPos}) return nil } @@ -220,6 +225,7 @@ func (s *OutputRoomEventConsumer) onOldRoomEvent( } s.streams.PDUStreamProvider.Advance(pduPos) + s.notifier.OnNewEvent(ev, ev.RoomID(), nil, types.StreamingToken{PDUPosition: pduPos}) return nil } @@ -276,6 +282,7 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent( } s.streams.InviteStreamProvider.Advance(pduPos) + s.notifier.OnNewInvite(types.StreamingToken{PDUPosition: pduPos}, *msg.Event.StateKey()) return nil } @@ -296,6 +303,7 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent( // Notify any active sync requests that the invite has been retired. // Invites share the same stream counter as PDUs s.streams.InviteStreamProvider.Advance(pduPos) + s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, msg.TargetUserID) return nil } @@ -316,6 +324,7 @@ func (s *OutputRoomEventConsumer) onNewPeek( // TODO: This only works because the peeks table is reusing the same // index as PDUs, but we should fix this s.streams.PDUStreamProvider.Advance(sp) + s.notifier.OnNewEvent(nil, msg.RoomID, nil, types.StreamingToken{PDUPosition: sp}) return nil } @@ -336,6 +345,7 @@ func (s *OutputRoomEventConsumer) onRetirePeek( // TODO: This only works because the peeks table is reusing the same // index as PDUs, but we should fix this s.streams.PDUStreamProvider.Advance(sp) + s.notifier.OnNewEvent(nil, msg.RoomID, nil, types.StreamingToken{PDUPosition: sp}) return nil } diff --git a/syncapi/notifier/notifier.go b/syncapi/notifier/notifier.go new file mode 100644 index 000000000..561c6f0c0 --- /dev/null +++ b/syncapi/notifier/notifier.go @@ -0,0 +1,469 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package notifier + +import ( + "context" + "sync" + "time" + + "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" + log "github.com/sirupsen/logrus" +) + +// Notifier will wake up sleeping requests when there is some new data. +// It does not tell requests what that data is, only the sync position which +// they can use to get at it. This is done to prevent races whereby we tell the caller +// the event, but the token has already advanced by the time they fetch it, resulting +// in missed events. +type Notifier struct { + // A map of RoomID => Set : Must only be accessed by the OnNewEvent goroutine + roomIDToJoinedUsers map[string]userIDSet + // A map of RoomID => Set : Must only be accessed by the OnNewEvent goroutine + roomIDToPeekingDevices map[string]peekingDeviceSet + // Protects currPos and userStreams. + streamLock *sync.Mutex + // The latest sync position + currPos types.StreamingToken + // A map of user_id => device_id => UserStream which can be used to wake a given user's /sync request. + userDeviceStreams map[string]map[string]*UserDeviceStream + // The last time we cleaned out stale entries from the userStreams map + lastCleanUpTime time.Time +} + +// NewNotifier creates a new notifier set to the given sync position. +// In order for this to be of any use, the Notifier needs to be told all rooms and +// the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase). +func NewNotifier(currPos types.StreamingToken) *Notifier { + return &Notifier{ + currPos: currPos, + roomIDToJoinedUsers: make(map[string]userIDSet), + roomIDToPeekingDevices: make(map[string]peekingDeviceSet), + userDeviceStreams: make(map[string]map[string]*UserDeviceStream), + streamLock: &sync.Mutex{}, + lastCleanUpTime: time.Now(), + } +} + +// OnNewEvent is called when a new event is received from the room server. Must only be +// called from a single goroutine, to avoid races between updates which could set the +// current sync position incorrectly. +// Chooses which user sync streams to update by a provided *gomatrixserverlib.Event +// (based on the users in the event's room), +// a roomID directly, or a list of user IDs, prioritised by parameter ordering. +// posUpdate contains the latest position(s) for one or more types of events. +// If a position in posUpdate is 0, it means no updates are available of that type. +// Typically a consumer supplies a posUpdate with the latest sync position for the +// event type it handles, leaving other fields as 0. +func (n *Notifier) OnNewEvent( + ev *gomatrixserverlib.HeaderedEvent, roomID string, userIDs []string, + posUpdate types.StreamingToken, +) { + // update the current position then notify relevant /sync streams. + // This needs to be done PRIOR to waking up users as they will read this value. + n.streamLock.Lock() + defer n.streamLock.Unlock() + + n.currPos.ApplyUpdates(posUpdate) + n.removeEmptyUserStreams() + + if ev != nil { + // Map this event's room_id to a list of joined users, and wake them up. + usersToNotify := n.joinedUsers(ev.RoomID()) + // Map this event's room_id to a list of peeking devices, and wake them up. + peekingDevicesToNotify := n.PeekingDevices(ev.RoomID()) + // If this is an invite, also add in the invitee to this list. + if ev.Type() == "m.room.member" && ev.StateKey() != nil { + targetUserID := *ev.StateKey() + membership, err := ev.Membership() + if err != nil { + log.WithError(err).WithField("event_id", ev.EventID()).Errorf( + "Notifier.OnNewEvent: Failed to unmarshal member event", + ) + } else { + // Keep the joined user map up-to-date + switch membership { + case gomatrixserverlib.Invite: + usersToNotify = append(usersToNotify, targetUserID) + case gomatrixserverlib.Join: + // Manually append the new user's ID so they get notified + // along all members in the room + usersToNotify = append(usersToNotify, targetUserID) + n.addJoinedUser(ev.RoomID(), targetUserID) + case gomatrixserverlib.Leave: + fallthrough + case gomatrixserverlib.Ban: + n.removeJoinedUser(ev.RoomID(), targetUserID) + } + } + } + + n.wakeupUsers(usersToNotify, peekingDevicesToNotify, n.currPos) + } else if roomID != "" { + n.wakeupUsers(n.joinedUsers(roomID), n.PeekingDevices(roomID), n.currPos) + } else if len(userIDs) > 0 { + n.wakeupUsers(userIDs, nil, n.currPos) + } else { + log.WithFields(log.Fields{ + "posUpdate": posUpdate.String, + }).Warn("Notifier.OnNewEvent called but caller supplied no user to wake up") + } +} + +func (n *Notifier) OnNewPeek( + roomID, userID, deviceID string, +) { + n.streamLock.Lock() + defer n.streamLock.Unlock() + + n.addPeekingDevice(roomID, userID, deviceID) + //n.streams.PDUStreamProvider.Advance(posUpdate.PDUPosition) + + // we don't wake up devices here given the roomserver consumer will do this shortly afterwards + // by calling OnNewEvent. +} + +func (n *Notifier) OnRetirePeek( + roomID, userID, deviceID string, +) { + n.streamLock.Lock() + defer n.streamLock.Unlock() + + n.removePeekingDevice(roomID, userID, deviceID) + //n.streams.PDUStreamProvider.Advance(posUpdate.PDUPosition) + + // we don't wake up devices here given the roomserver consumer will do this shortly afterwards + // by calling OnRetireEvent. +} + +func (n *Notifier) OnNewSendToDevice( + userID string, deviceIDs []string, + posUpdate types.StreamingToken, +) { + n.streamLock.Lock() + defer n.streamLock.Unlock() + + n.currPos.ApplyUpdates(posUpdate) + n.wakeupUserDevice(userID, deviceIDs, n.currPos) +} + +// OnNewReceipt updates the current position +func (n *Notifier) OnNewTyping( + roomID string, + posUpdate types.StreamingToken, +) { + n.streamLock.Lock() + defer n.streamLock.Unlock() + + n.currPos.ApplyUpdates(posUpdate) + n.wakeupUsers(n.joinedUsers(roomID), nil, n.currPos) +} + +// OnNewReceipt updates the current position +func (n *Notifier) OnNewReceipt( + roomID string, + posUpdate types.StreamingToken, +) { + n.streamLock.Lock() + defer n.streamLock.Unlock() + + n.currPos.ApplyUpdates(posUpdate) + n.wakeupUsers(n.joinedUsers(roomID), nil, n.currPos) +} + +func (n *Notifier) OnNewKeyChange( + posUpdate types.StreamingToken, wakeUserID, keyChangeUserID string, +) { + n.streamLock.Lock() + defer n.streamLock.Unlock() + + n.currPos.ApplyUpdates(posUpdate) + n.wakeupUsers([]string{wakeUserID}, nil, n.currPos) +} + +func (n *Notifier) OnNewInvite( + posUpdate types.StreamingToken, wakeUserID string, +) { + n.streamLock.Lock() + defer n.streamLock.Unlock() + + n.currPos.ApplyUpdates(posUpdate) + n.wakeupUsers([]string{wakeUserID}, nil, n.currPos) +} + +// GetListener returns a UserStreamListener that can be used to wait for +// updates for a user. Must be closed. +// notify for anything before sincePos +func (n *Notifier) GetListener(req types.SyncRequest) UserDeviceStreamListener { + // Do what synapse does: https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/notifier.py#L298 + // - Bucket request into a lookup map keyed off a list of joined room IDs and separately a user ID + // - Incoming events wake requests for a matching room ID + // - Incoming events wake requests for a matching user ID (needed for invites) + + // TODO: v1 /events 'peeking' has an 'explicit room ID' which is also tracked, + // but given we don't do /events, let's pretend it doesn't exist. + + n.streamLock.Lock() + defer n.streamLock.Unlock() + + n.removeEmptyUserStreams() + + return n.fetchUserDeviceStream(req.Device.UserID, req.Device.ID, true).GetListener(req.Context) +} + +// Load the membership states required to notify users correctly. +func (n *Notifier) Load(ctx context.Context, db storage.Database) error { + roomToUsers, err := db.AllJoinedUsersInRooms(ctx) + if err != nil { + return err + } + n.setUsersJoinedToRooms(roomToUsers) + + roomToPeekingDevices, err := db.AllPeekingDevicesInRooms(ctx) + if err != nil { + return err + } + n.setPeekingDevices(roomToPeekingDevices) + + return nil +} + +// CurrentPosition returns the current sync position +func (n *Notifier) CurrentPosition() types.StreamingToken { + n.streamLock.Lock() + defer n.streamLock.Unlock() + + return n.currPos +} + +// setUsersJoinedToRooms marks the given users as 'joined' to the given rooms, such that new events from +// these rooms will wake the given users /sync requests. This should be called prior to ANY calls to +// OnNewEvent (eg on startup) to prevent racing. +func (n *Notifier) setUsersJoinedToRooms(roomIDToUserIDs map[string][]string) { + // This is just the bulk form of addJoinedUser + for roomID, userIDs := range roomIDToUserIDs { + if _, ok := n.roomIDToJoinedUsers[roomID]; !ok { + n.roomIDToJoinedUsers[roomID] = make(userIDSet) + } + for _, userID := range userIDs { + n.roomIDToJoinedUsers[roomID].add(userID) + } + } +} + +// setPeekingDevices marks the given devices as peeking in the given rooms, such that new events from +// these rooms will wake the given devices' /sync requests. This should be called prior to ANY calls to +// OnNewEvent (eg on startup) to prevent racing. +func (n *Notifier) setPeekingDevices(roomIDToPeekingDevices map[string][]types.PeekingDevice) { + // This is just the bulk form of addPeekingDevice + for roomID, peekingDevices := range roomIDToPeekingDevices { + if _, ok := n.roomIDToPeekingDevices[roomID]; !ok { + n.roomIDToPeekingDevices[roomID] = make(peekingDeviceSet) + } + for _, peekingDevice := range peekingDevices { + n.roomIDToPeekingDevices[roomID].add(peekingDevice) + } + } +} + +// wakeupUsers will wake up the sync strems for all of the devices for all of the +// specified user IDs, and also the specified peekingDevices +func (n *Notifier) wakeupUsers(userIDs []string, peekingDevices []types.PeekingDevice, newPos types.StreamingToken) { + for _, userID := range userIDs { + for _, stream := range n.fetchUserStreams(userID) { + if stream == nil { + continue + } + stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream + } + } + + for _, peekingDevice := range peekingDevices { + // TODO: don't bother waking up for devices whose users we already woke up + if stream := n.fetchUserDeviceStream(peekingDevice.UserID, peekingDevice.DeviceID, false); stream != nil { + stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream + } + } +} + +// wakeupUserDevice will wake up the sync stream for a specific user device. Other +// device streams will be left alone. +// nolint:unused +func (n *Notifier) wakeupUserDevice(userID string, deviceIDs []string, newPos types.StreamingToken) { + for _, deviceID := range deviceIDs { + if stream := n.fetchUserDeviceStream(userID, deviceID, false); stream != nil { + stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream + } + } +} + +// fetchUserDeviceStream retrieves a stream unique to the given device. If makeIfNotExists is true, +// a stream will be made for this device if one doesn't exist and it will be returned. This +// function does not wait for data to be available on the stream. +// NB: Callers should have locked the mutex before calling this function. +func (n *Notifier) fetchUserDeviceStream(userID, deviceID string, makeIfNotExists bool) *UserDeviceStream { + _, ok := n.userDeviceStreams[userID] + if !ok { + if !makeIfNotExists { + return nil + } + n.userDeviceStreams[userID] = map[string]*UserDeviceStream{} + } + stream, ok := n.userDeviceStreams[userID][deviceID] + if !ok { + if !makeIfNotExists { + return nil + } + // TODO: Unbounded growth of streams (1 per user) + if stream = NewUserDeviceStream(userID, deviceID, n.currPos); stream != nil { + n.userDeviceStreams[userID][deviceID] = stream + } + } + return stream +} + +// fetchUserStreams retrieves all streams for the given user. If makeIfNotExists is true, +// a stream will be made for this user if one doesn't exist and it will be returned. This +// function does not wait for data to be available on the stream. +// NB: Callers should have locked the mutex before calling this function. +func (n *Notifier) fetchUserStreams(userID string) []*UserDeviceStream { + user, ok := n.userDeviceStreams[userID] + if !ok { + return []*UserDeviceStream{} + } + streams := []*UserDeviceStream{} + for _, stream := range user { + streams = append(streams, stream) + } + return streams +} + +// Not thread-safe: must be called on the OnNewEvent goroutine only +func (n *Notifier) addJoinedUser(roomID, userID string) { + if _, ok := n.roomIDToJoinedUsers[roomID]; !ok { + n.roomIDToJoinedUsers[roomID] = make(userIDSet) + } + n.roomIDToJoinedUsers[roomID].add(userID) +} + +// Not thread-safe: must be called on the OnNewEvent goroutine only +func (n *Notifier) removeJoinedUser(roomID, userID string) { + if _, ok := n.roomIDToJoinedUsers[roomID]; !ok { + n.roomIDToJoinedUsers[roomID] = make(userIDSet) + } + n.roomIDToJoinedUsers[roomID].remove(userID) +} + +// Not thread-safe: must be called on the OnNewEvent goroutine only +func (n *Notifier) joinedUsers(roomID string) (userIDs []string) { + if _, ok := n.roomIDToJoinedUsers[roomID]; !ok { + return + } + return n.roomIDToJoinedUsers[roomID].values() +} + +// Not thread-safe: must be called on the OnNewEvent goroutine only +func (n *Notifier) addPeekingDevice(roomID, userID, deviceID string) { + if _, ok := n.roomIDToPeekingDevices[roomID]; !ok { + n.roomIDToPeekingDevices[roomID] = make(peekingDeviceSet) + } + n.roomIDToPeekingDevices[roomID].add(types.PeekingDevice{UserID: userID, DeviceID: deviceID}) +} + +// Not thread-safe: must be called on the OnNewEvent goroutine only +// nolint:unused +func (n *Notifier) removePeekingDevice(roomID, userID, deviceID string) { + if _, ok := n.roomIDToPeekingDevices[roomID]; !ok { + n.roomIDToPeekingDevices[roomID] = make(peekingDeviceSet) + } + // XXX: is this going to work as a key? + n.roomIDToPeekingDevices[roomID].remove(types.PeekingDevice{UserID: userID, DeviceID: deviceID}) +} + +// Not thread-safe: must be called on the OnNewEvent goroutine only +func (n *Notifier) PeekingDevices(roomID string) (peekingDevices []types.PeekingDevice) { + if _, ok := n.roomIDToPeekingDevices[roomID]; !ok { + return + } + return n.roomIDToPeekingDevices[roomID].values() +} + +// removeEmptyUserStreams iterates through the user stream map and removes any +// that have been empty for a certain amount of time. This is a crude way of +// ensuring that the userStreams map doesn't grow forver. +// This should be called when the notifier gets called for whatever reason, +// the function itself is responsible for ensuring it doesn't iterate too +// often. +// NB: Callers should have locked the mutex before calling this function. +func (n *Notifier) removeEmptyUserStreams() { + // Only clean up now and again + now := time.Now() + if n.lastCleanUpTime.Add(time.Minute).After(now) { + return + } + n.lastCleanUpTime = now + + deleteBefore := now.Add(-5 * time.Minute) + for user, byUser := range n.userDeviceStreams { + for device, stream := range byUser { + if stream.TimeOfLastNonEmpty().Before(deleteBefore) { + delete(n.userDeviceStreams[user], device) + } + if len(n.userDeviceStreams[user]) == 0 { + delete(n.userDeviceStreams, user) + } + } + } +} + +// A string set, mainly existing for improving clarity of structs in this file. +type userIDSet map[string]bool + +func (s userIDSet) add(str string) { + s[str] = true +} + +func (s userIDSet) remove(str string) { + delete(s, str) +} + +func (s userIDSet) values() (vals []string) { + for str := range s { + vals = append(vals, str) + } + return +} + +// A set of PeekingDevices, similar to userIDSet + +type peekingDeviceSet map[types.PeekingDevice]bool + +func (s peekingDeviceSet) add(d types.PeekingDevice) { + s[d] = true +} + +// nolint:unused +func (s peekingDeviceSet) remove(d types.PeekingDevice) { + delete(s, d) +} + +func (s peekingDeviceSet) values() (vals []types.PeekingDevice) { + for d := range s { + vals = append(vals, d) + } + return +} diff --git a/syncapi/notifier/notifier_test.go b/syncapi/notifier/notifier_test.go new file mode 100644 index 000000000..8b9425e37 --- /dev/null +++ b/syncapi/notifier/notifier_test.go @@ -0,0 +1,374 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package notifier + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "testing" + "time" + + "github.com/matrix-org/dendrite/syncapi/types" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +var ( + randomMessageEvent gomatrixserverlib.HeaderedEvent + aliceInviteBobEvent gomatrixserverlib.HeaderedEvent + bobLeaveEvent gomatrixserverlib.HeaderedEvent + syncPositionVeryOld = types.StreamingToken{PDUPosition: 5} + syncPositionBefore = types.StreamingToken{PDUPosition: 11} + syncPositionAfter = types.StreamingToken{PDUPosition: 12} + //syncPositionNewEDU = types.NewStreamToken(syncPositionAfter.PDUPosition, 1, 0, 0, nil) + syncPositionAfter2 = types.StreamingToken{PDUPosition: 13} +) + +var ( + roomID = "!test:localhost" + alice = "@alice:localhost" + aliceDev = "alicedevice" + bob = "@bob:localhost" + bobDev = "bobdev" +) + +func init() { + var err error + err = json.Unmarshal([]byte(`{ + "_room_version": "1", + "type": "m.room.message", + "content": { + "body": "Hello World", + "msgtype": "m.text" + }, + "sender": "@noone:localhost", + "room_id": "`+roomID+`", + "origin": "localhost", + "origin_server_ts": 12345, + "event_id": "$randomMessageEvent:localhost" + }`), &randomMessageEvent) + if err != nil { + panic(err) + } + err = json.Unmarshal([]byte(`{ + "_room_version": "1", + "type": "m.room.member", + "state_key": "`+bob+`", + "content": { + "membership": "invite" + }, + "sender": "`+alice+`", + "room_id": "`+roomID+`", + "origin": "localhost", + "origin_server_ts": 12345, + "event_id": "$aliceInviteBobEvent:localhost" + }`), &aliceInviteBobEvent) + if err != nil { + panic(err) + } + err = json.Unmarshal([]byte(`{ + "_room_version": "1", + "type": "m.room.member", + "state_key": "`+bob+`", + "content": { + "membership": "leave" + }, + "sender": "`+bob+`", + "room_id": "`+roomID+`", + "origin": "localhost", + "origin_server_ts": 12345, + "event_id": "$bobLeaveEvent:localhost" + }`), &bobLeaveEvent) + if err != nil { + panic(err) + } +} + +func mustEqualPositions(t *testing.T, got, want types.StreamingToken) { + if got.String() != want.String() { + t.Fatalf("mustEqualPositions got %s want %s", got.String(), want.String()) + } +} + +// Test that the current position is returned if a request is already behind. +func TestImmediateNotification(t *testing.T) { + n := NewNotifier(syncPositionBefore) + pos, err := waitForEvents(n, newTestSyncRequest(alice, aliceDev, syncPositionVeryOld)) + if err != nil { + t.Fatalf("TestImmediateNotification error: %s", err) + } + mustEqualPositions(t, pos, syncPositionBefore) +} + +// Test that new events to a joined room unblocks the request. +func TestNewEventAndJoinedToRoom(t *testing.T) { + n := NewNotifier(syncPositionBefore) + n.setUsersJoinedToRooms(map[string][]string{ + roomID: {alice, bob}, + }) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionBefore)) + if err != nil { + t.Errorf("TestNewEventAndJoinedToRoom error: %w", err) + } + mustEqualPositions(t, pos, syncPositionAfter) + wg.Done() + }() + + stream := lockedFetchUserStream(n, bob, bobDev) + waitForBlocking(stream, 1) + + n.OnNewEvent(&randomMessageEvent, "", nil, syncPositionAfter) + + wg.Wait() +} + +func TestCorrectStream(t *testing.T) { + n := NewNotifier(syncPositionBefore) + stream := lockedFetchUserStream(n, bob, bobDev) + if stream.UserID != bob { + t.Fatalf("expected user %q, got %q", bob, stream.UserID) + } + if stream.DeviceID != bobDev { + t.Fatalf("expected device %q, got %q", bobDev, stream.DeviceID) + } +} + +func TestCorrectStreamWakeup(t *testing.T) { + n := NewNotifier(syncPositionBefore) + awoken := make(chan string) + + streamone := lockedFetchUserStream(n, alice, "one") + streamtwo := lockedFetchUserStream(n, alice, "two") + + go func() { + select { + case <-streamone.signalChannel: + awoken <- "one" + case <-streamtwo.signalChannel: + awoken <- "two" + } + }() + + time.Sleep(1 * time.Second) + + wake := "two" + n.wakeupUserDevice(alice, []string{wake}, syncPositionAfter) + + if result := <-awoken; result != wake { + t.Fatalf("expected to wake %q, got %q", wake, result) + } +} + +// Test that an invite unblocks the request +func TestNewInviteEventForUser(t *testing.T) { + n := NewNotifier(syncPositionBefore) + n.setUsersJoinedToRooms(map[string][]string{ + roomID: {alice, bob}, + }) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionBefore)) + if err != nil { + t.Errorf("TestNewInviteEventForUser error: %w", err) + } + mustEqualPositions(t, pos, syncPositionAfter) + wg.Done() + }() + + stream := lockedFetchUserStream(n, bob, bobDev) + waitForBlocking(stream, 1) + + n.OnNewEvent(&aliceInviteBobEvent, "", nil, syncPositionAfter) + + wg.Wait() +} + +// Test an EDU-only update wakes up the request. +// TODO: Fix this test, invites wake up with an incremented +// PDU position, not EDU position +/* +func TestEDUWakeup(t *testing.T) { + n := NewNotifier(syncPositionAfter) + n.setUsersJoinedToRooms(map[string][]string{ + roomID: {alice, bob}, + }) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionAfter)) + if err != nil { + t.Errorf("TestNewInviteEventForUser error: %w", err) + } + mustEqualPositions(t, pos, syncPositionNewEDU) + wg.Done() + }() + + stream := lockedFetchUserStream(n, bob, bobDev) + waitForBlocking(stream, 1) + + n.OnNewEvent(&aliceInviteBobEvent, "", nil, syncPositionNewEDU) + + wg.Wait() +} +*/ + +// Test that all blocked requests get woken up on a new event. +func TestMultipleRequestWakeup(t *testing.T) { + n := NewNotifier(syncPositionBefore) + n.setUsersJoinedToRooms(map[string][]string{ + roomID: {alice, bob}, + }) + + var wg sync.WaitGroup + wg.Add(3) + poll := func() { + pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionBefore)) + if err != nil { + t.Errorf("TestMultipleRequestWakeup error: %w", err) + } + mustEqualPositions(t, pos, syncPositionAfter) + wg.Done() + } + go poll() + go poll() + go poll() + + stream := lockedFetchUserStream(n, bob, bobDev) + waitForBlocking(stream, 3) + + n.OnNewEvent(&randomMessageEvent, "", nil, syncPositionAfter) + + wg.Wait() + + numWaiting := stream.NumWaiting() + if numWaiting != 0 { + t.Errorf("TestMultipleRequestWakeup NumWaiting() want 0, got %d", numWaiting) + } +} + +// Test that you stop getting woken up when you leave a room. +func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) { + // listen as bob. Make bob leave room. Make alice send event to room. + // Make sure alice gets woken up only and not bob as well. + n := NewNotifier(syncPositionBefore) + n.setUsersJoinedToRooms(map[string][]string{ + roomID: {alice, bob}, + }) + + var leaveWG sync.WaitGroup + + // Make bob leave the room + leaveWG.Add(1) + go func() { + pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionBefore)) + if err != nil { + t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %w", err) + } + mustEqualPositions(t, pos, syncPositionAfter) + leaveWG.Done() + }() + bobStream := lockedFetchUserStream(n, bob, bobDev) + waitForBlocking(bobStream, 1) + n.OnNewEvent(&bobLeaveEvent, "", nil, syncPositionAfter) + leaveWG.Wait() + + // send an event into the room. Make sure alice gets it. Bob should not. + var aliceWG sync.WaitGroup + aliceStream := lockedFetchUserStream(n, alice, aliceDev) + aliceWG.Add(1) + go func() { + pos, err := waitForEvents(n, newTestSyncRequest(alice, aliceDev, syncPositionAfter)) + if err != nil { + t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %w", err) + } + mustEqualPositions(t, pos, syncPositionAfter2) + aliceWG.Done() + }() + + go func() { + // this should timeout with an error (but the main goroutine won't wait for the timeout explicitly) + _, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionAfter)) + if err == nil { + t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom expect error but got nil") + } + }() + + waitForBlocking(aliceStream, 1) + waitForBlocking(bobStream, 1) + + n.OnNewEvent(&randomMessageEvent, "", nil, syncPositionAfter2) + aliceWG.Wait() + + // it's possible that at this point alice has been informed and bob is about to be informed, so wait + // for a fraction of a second to account for this race + time.Sleep(1 * time.Millisecond) +} + +func waitForEvents(n *Notifier, req types.SyncRequest) (types.StreamingToken, error) { + listener := n.GetListener(req) + defer listener.Close() + + select { + case <-time.After(5 * time.Second): + return types.StreamingToken{}, fmt.Errorf( + "waitForEvents timed out waiting for %s (pos=%v)", req.Device.UserID, req.Since, + ) + case <-listener.GetNotifyChannel(req.Since): + p := listener.GetSyncPosition() + return p, nil + } +} + +// Wait until something is Wait()ing on the user stream. +func waitForBlocking(s *UserDeviceStream, numBlocking uint) { + for numBlocking != s.NumWaiting() { + // This is horrible but I don't want to add a signalling mechanism JUST for testing. + time.Sleep(1 * time.Microsecond) + } +} + +// lockedFetchUserStream invokes Notifier.fetchUserStream, respecting Notifier.streamLock. +// A new stream is made if it doesn't exist already. +func lockedFetchUserStream(n *Notifier, userID, deviceID string) *UserDeviceStream { + n.streamLock.Lock() + defer n.streamLock.Unlock() + + return n.fetchUserDeviceStream(userID, deviceID, true) +} + +func newTestSyncRequest(userID, deviceID string, since types.StreamingToken) types.SyncRequest { + return types.SyncRequest{ + Device: &userapi.Device{ + UserID: userID, + ID: deviceID, + }, + Timeout: 1 * time.Minute, + Since: since, + WantFullState: false, + Limit: 20, + Log: util.GetLogger(context.TODO()), + Context: context.TODO(), + } +} diff --git a/syncapi/sync/userstream.go b/syncapi/notifier/userstream.go similarity index 99% rename from syncapi/sync/userstream.go rename to syncapi/notifier/userstream.go index ff9a4d003..720185d52 100644 --- a/syncapi/sync/userstream.go +++ b/syncapi/notifier/userstream.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package sync +package notifier import ( "context" diff --git a/syncapi/streams/stream_accountdata.go b/syncapi/streams/stream_accountdata.go index 3de6ed15b..2e85e9e9b 100644 --- a/syncapi/streams/stream_accountdata.go +++ b/syncapi/streams/stream_accountdata.go @@ -13,10 +13,6 @@ type AccountDataStreamProvider struct { userAPI userapi.UserInternalAPI } -func (p *AccountDataStreamProvider) Setup() { - p.StreamProvider.Setup() -} - func (p *AccountDataStreamProvider) CompleteSync( ctx context.Context, req *types.SyncRequest, diff --git a/syncapi/streams/streams.go b/syncapi/streams/streams.go index 07cb724cd..ba4118df5 100644 --- a/syncapi/streams/streams.go +++ b/syncapi/streams/streams.go @@ -1,6 +1,8 @@ package streams import ( + "context" + "github.com/matrix-org/dendrite/eduserver/cache" keyapi "github.com/matrix-org/dendrite/keyserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api" @@ -62,3 +64,15 @@ func NewSyncStreamProviders( return streams } + +func (s *Streams) Latest(ctx context.Context) types.StreamingToken { + return types.StreamingToken{ + PDUPosition: s.PDUStreamProvider.LatestPosition(ctx), + TypingPosition: s.TypingStreamProvider.LatestPosition(ctx), + ReceiptPosition: s.PDUStreamProvider.LatestPosition(ctx), + InvitePosition: s.InviteStreamProvider.LatestPosition(ctx), + SendToDevicePosition: s.SendToDeviceStreamProvider.LatestPosition(ctx), + AccountDataPosition: s.AccountDataStreamProvider.LatestPosition(ctx), + DeviceListPosition: s.DeviceListStreamProvider.LatestPosition(ctx), + } +} diff --git a/syncapi/streams/template_pstream.go b/syncapi/streams/template_pstream.go index 1c249b767..265e22a20 100644 --- a/syncapi/streams/template_pstream.go +++ b/syncapi/streams/template_pstream.go @@ -6,51 +6,26 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" - userapi "github.com/matrix-org/dendrite/userapi/api" ) type PartitionedStreamProvider struct { - DB storage.Database - latest types.LogPosition - latestMutex sync.RWMutex - subscriptions map[string]*partitionedStreamSubscription // userid+deviceid - subscriptionsMutex sync.Mutex -} - -type partitionedStreamSubscription struct { - ctx context.Context - from types.LogPosition - ch chan struct{} + DB storage.Database + latest types.LogPosition + latestMutex sync.RWMutex } func (p *PartitionedStreamProvider) Setup() { - p.subscriptions = make(map[string]*partitionedStreamSubscription) } func (p *PartitionedStreamProvider) Advance( latest types.LogPosition, ) { p.latestMutex.Lock() + defer p.latestMutex.Unlock() + if latest.IsAfter(&p.latest) { p.latest = latest } - p.latestMutex.Unlock() - - p.subscriptionsMutex.Lock() - defer p.subscriptionsMutex.Unlock() - - for id, s := range p.subscriptions { - select { - case <-s.ctx.Done(): - close(s.ch) - delete(p.subscriptions, id) - default: - if latest.IsAfter(&s.from) { - close(s.ch) - delete(p.subscriptions, id) - } - } - } } func (p *PartitionedStreamProvider) LatestPosition( @@ -61,37 +36,3 @@ func (p *PartitionedStreamProvider) LatestPosition( return p.latest } - -func (p *PartitionedStreamProvider) NotifyAfter( - ctx context.Context, - device *userapi.Device, - from types.LogPosition, -) chan struct{} { - ch := make(chan struct{}) - - check := func() bool { - p.latestMutex.RLock() - defer p.latestMutex.RUnlock() - if p.latest.IsAfter(&from) { - close(ch) - return true - } - return false - } - - // If we've already advanced past the specified position - // then return straight away. - if check() { - return ch - } - - id := device.UserID + device.ID - p.subscriptionsMutex.Lock() - if s, ok := p.subscriptions[id]; ok { - close(s.ch) - } - p.subscriptions[id] = &partitionedStreamSubscription{ctx, from, ch} - p.subscriptionsMutex.Unlock() - - return ch -} diff --git a/syncapi/streams/template_stream.go b/syncapi/streams/template_stream.go index 84a59d315..15074cc10 100644 --- a/syncapi/streams/template_stream.go +++ b/syncapi/streams/template_stream.go @@ -6,51 +6,26 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" - userapi "github.com/matrix-org/dendrite/userapi/api" ) type StreamProvider struct { - DB storage.Database - latest types.StreamPosition - latestMutex sync.RWMutex - subscriptions map[string]*streamSubscription // userid+deviceid - subscriptionsMutex sync.Mutex -} - -type streamSubscription struct { - ctx context.Context - from types.StreamPosition - ch chan struct{} + DB storage.Database + latest types.StreamPosition + latestMutex sync.RWMutex } func (p *StreamProvider) Setup() { - p.subscriptions = make(map[string]*streamSubscription) } func (p *StreamProvider) Advance( latest types.StreamPosition, ) { p.latestMutex.Lock() + defer p.latestMutex.Unlock() + if latest > p.latest { p.latest = latest } - p.latestMutex.Unlock() - - p.subscriptionsMutex.Lock() - defer p.subscriptionsMutex.Unlock() - - for id, s := range p.subscriptions { - select { - case <-s.ctx.Done(): - close(s.ch) - delete(p.subscriptions, id) - default: - if latest > s.from { - close(s.ch) - delete(p.subscriptions, id) - } - } - } } func (p *StreamProvider) LatestPosition( @@ -61,37 +36,3 @@ func (p *StreamProvider) LatestPosition( return p.latest } - -func (p *StreamProvider) NotifyAfter( - ctx context.Context, - device *userapi.Device, - from types.StreamPosition, -) chan struct{} { - ch := make(chan struct{}) - - check := func() bool { - p.latestMutex.RLock() - defer p.latestMutex.RUnlock() - if p.latest > from { - close(ch) - return true - } - return false - } - - // If we've already advanced past the specified position - // then return straight away. - if check() { - return ch - } - - id := device.UserID + device.ID - p.subscriptionsMutex.Lock() - if s, ok := p.subscriptions[id]; ok { - close(s.ch) - } - p.subscriptions[id] = &streamSubscription{ctx, from, ch} - p.subscriptionsMutex.Unlock() - - return ch -} diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 43ac01ec5..785a3af82 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -17,7 +17,6 @@ package sync import ( - "context" "net" "net/http" "strings" @@ -29,6 +28,7 @@ import ( roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/internal" + "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" @@ -46,6 +46,7 @@ type RequestPool struct { rsAPI roomserverAPI.RoomserverInternalAPI lastseen sync.Map streams *streams.Streams + notifier *notifier.Notifier } // NewRequestPool makes a new RequestPool @@ -53,7 +54,7 @@ func NewRequestPool( db storage.Database, cfg *config.SyncAPI, userAPI userapi.UserInternalAPI, keyAPI keyapi.KeyInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI, - streams *streams.Streams, + streams *streams.Streams, notifier *notifier.Notifier, ) *RequestPool { rp := &RequestPool{ db: db, @@ -63,6 +64,7 @@ func NewRequestPool( rsAPI: rsAPI, lastseen: sync.Map{}, streams: streams, + notifier: notifier, } go rp.cleanLastSeen() return rp @@ -152,15 +154,16 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. waitingSyncRequests.Inc() defer waitingSyncRequests.Dec() + currentPos := rp.notifier.CurrentPosition() + if !rp.shouldReturnImmediately(syncReq) { timer := time.NewTimer(syncReq.Timeout) // case of timeout=0 is handled above defer timer.Stop() - // Use a subcontext so that we don't keep the StreamNotifyAfter - // goroutines alive any longer than they really need to be. - waitctx, waitcancel := context.WithCancel(syncReq.Context) + userStreamListener := rp.notifier.GetListener(*syncReq) + defer userStreamListener.Close() + giveup := func() util.JSONResponse { - waitcancel() syncReq.Response.NextBatch = syncReq.Since return util.JSONResponse{ Code: http.StatusOK, @@ -169,29 +172,16 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. } select { - case <-waitctx.Done(): // Caller gave up + case <-syncReq.Context.Done(): // Caller gave up return giveup() case <-timer.C: // Timeout reached return giveup() - case <-rp.streams.PDUStreamProvider.NotifyAfter(waitctx, device, syncReq.Since.PDUPosition): - syncReq.Log.Debugln("Responding to sync after PDU") - case <-rp.streams.TypingStreamProvider.NotifyAfter(waitctx, device, syncReq.Since.TypingPosition): - syncReq.Log.Debugln("Responding to sync after typing notification") - case <-rp.streams.ReceiptStreamProvider.NotifyAfter(waitctx, device, syncReq.Since.ReceiptPosition): - syncReq.Log.Debugln("Responding to sync after read receipt") - case <-rp.streams.InviteStreamProvider.NotifyAfter(waitctx, device, syncReq.Since.InvitePosition): - syncReq.Log.Debugln("Responding to sync after invite") - case <-rp.streams.SendToDeviceStreamProvider.NotifyAfter(waitctx, device, syncReq.Since.SendToDevicePosition): - syncReq.Log.Debugln("Responding to sync after send-to-device message") - case <-rp.streams.AccountDataStreamProvider.NotifyAfter(waitctx, device, syncReq.Since.AccountDataPosition): - syncReq.Log.Debugln("Responding to sync after account data") - case <-rp.streams.DeviceListStreamProvider.NotifyAfter(waitctx, device, syncReq.Since.DeviceListPosition): - syncReq.Log.Debugln("Responding to sync after device list update") + case <-userStreamListener.GetNotifyChannel(syncReq.Since): + syncReq.Log.Debugln("Responding to sync after wake-up") + currentPos = userStreamListener.GetSyncPosition() } - - waitcancel() } else { syncReq.Log.Debugln("Responding to sync immediately") } @@ -226,31 +216,31 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. syncReq.Response.NextBatch = types.StreamingToken{ PDUPosition: rp.streams.PDUStreamProvider.IncrementalSync( syncReq.Context, syncReq, - syncReq.Since.PDUPosition, rp.streams.PDUStreamProvider.LatestPosition(syncReq.Context), + syncReq.Since.PDUPosition, currentPos.PDUPosition, ), TypingPosition: rp.streams.TypingStreamProvider.IncrementalSync( syncReq.Context, syncReq, - syncReq.Since.TypingPosition, rp.streams.TypingStreamProvider.LatestPosition(syncReq.Context), + syncReq.Since.TypingPosition, currentPos.TypingPosition, ), ReceiptPosition: rp.streams.ReceiptStreamProvider.IncrementalSync( syncReq.Context, syncReq, - syncReq.Since.ReceiptPosition, rp.streams.ReceiptStreamProvider.LatestPosition(syncReq.Context), + syncReq.Since.ReceiptPosition, currentPos.ReceiptPosition, ), InvitePosition: rp.streams.InviteStreamProvider.IncrementalSync( syncReq.Context, syncReq, - syncReq.Since.InvitePosition, rp.streams.InviteStreamProvider.LatestPosition(syncReq.Context), + syncReq.Since.InvitePosition, currentPos.InvitePosition, ), SendToDevicePosition: rp.streams.SendToDeviceStreamProvider.IncrementalSync( syncReq.Context, syncReq, - syncReq.Since.SendToDevicePosition, rp.streams.SendToDeviceStreamProvider.LatestPosition(syncReq.Context), + syncReq.Since.SendToDevicePosition, currentPos.SendToDevicePosition, ), AccountDataPosition: rp.streams.AccountDataStreamProvider.IncrementalSync( syncReq.Context, syncReq, - syncReq.Since.AccountDataPosition, rp.streams.AccountDataStreamProvider.LatestPosition(syncReq.Context), + syncReq.Since.AccountDataPosition, currentPos.AccountDataPosition, ), DeviceListPosition: rp.streams.DeviceListStreamProvider.IncrementalSync( syncReq.Context, syncReq, - syncReq.Since.DeviceListPosition, rp.streams.DeviceListStreamProvider.LatestPosition(syncReq.Context), + syncReq.Since.DeviceListPosition, currentPos.DeviceListPosition, ), } } diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index b58f7e5f4..7addcb9bb 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -15,6 +15,8 @@ package syncapi import ( + "context" + "github.com/gorilla/mux" "github.com/sirupsen/logrus" @@ -27,6 +29,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/syncapi/consumers" + "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/routing" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/streams" @@ -52,47 +55,51 @@ func AddPublicRoutes( eduCache := cache.New() streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, keyAPI, eduCache) + notifier := notifier.NewNotifier(streams.Latest(context.Background())) + if err = notifier.Load(context.Background(), syncDB); err != nil { + logrus.WithError(err).Panicf("failed to load notifier") + } - requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, keyAPI, rsAPI, streams) + requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, keyAPI, rsAPI, streams, notifier) keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer( cfg.Matrix.ServerName, string(cfg.Matrix.Kafka.TopicFor(config.TopicOutputKeyChangeEvent)), - consumer, keyAPI, rsAPI, syncDB, streams, + consumer, keyAPI, rsAPI, syncDB, notifier, streams, ) if err = keyChangeConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start key change consumer") } roomConsumer := consumers.NewOutputRoomEventConsumer( - cfg, consumer, syncDB, streams, rsAPI, + cfg, consumer, syncDB, notifier, streams, rsAPI, ) if err = roomConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start room server consumer") } clientConsumer := consumers.NewOutputClientDataConsumer( - cfg, consumer, syncDB, streams, + cfg, consumer, syncDB, notifier, streams, ) if err = clientConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start client data consumer") } typingConsumer := consumers.NewOutputTypingEventConsumer( - cfg, consumer, syncDB, eduCache, streams, + cfg, consumer, syncDB, eduCache, notifier, streams, ) if err = typingConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start typing consumer") } sendToDeviceConsumer := consumers.NewOutputSendToDeviceEventConsumer( - cfg, consumer, syncDB, streams, + cfg, consumer, syncDB, notifier, streams, ) if err = sendToDeviceConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start send-to-device consumer") } receiptConsumer := consumers.NewOutputReceiptEventConsumer( - cfg, consumer, syncDB, streams, + cfg, consumer, syncDB, notifier, streams, ) if err = receiptConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start receipts consumer") diff --git a/syncapi/types/provider.go b/syncapi/types/provider.go index eb6087905..24b453a80 100644 --- a/syncapi/types/provider.go +++ b/syncapi/types/provider.go @@ -40,10 +40,6 @@ type StreamProvider interface { // making no changes if the range contains no updates. IncrementalSync(ctx context.Context, req *SyncRequest, from, to StreamPosition) StreamPosition - // NotifyAfter returns a channel which will be closed once the - // stream advances past the "from" position. - NotifyAfter(ctx context.Context, device *userapi.Device, from StreamPosition) chan struct{} - // LatestPosition returns the latest stream position for this stream. LatestPosition(ctx context.Context) StreamPosition } @@ -53,6 +49,5 @@ type PartitionedStreamProvider interface { Advance(latest LogPosition) CompleteSync(ctx context.Context, req *SyncRequest) LogPosition IncrementalSync(ctx context.Context, req *SyncRequest, from, to LogPosition) LogPosition - NotifyAfter(ctx context.Context, device *userapi.Device, from LogPosition) chan struct{} LatestPosition(ctx context.Context) LogPosition }