Use a single lock for protecting currPos and userStreams

This commit is contained in:
Kegan Dougal 2017-05-17 13:44:17 +01:00
parent 8e7297990c
commit f641e7af14

View file

@ -31,16 +31,15 @@ import (
// the event, but the token has already advanced by the time they fetch it, resulting // the event, but the token has already advanced by the time they fetch it, resulting
// in missed events. // in missed events.
type Notifier struct { type Notifier struct {
// A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine
roomIDToJoinedUsers map[string]set
// Protects currPos and userStreams.
streamLock *sync.Mutex
// The latest sync stream position: guarded by 'currPosMutex' which is RW to allow // The latest sync stream position: guarded by 'currPosMutex' which is RW to allow
// for concurrent reads on /sync requests // for concurrent reads on /sync requests
currPos types.StreamPosition currPos types.StreamPosition
currPosMutex *sync.RWMutex
// A map of RoomID => Set<UserID>
roomIDToJoinedUsers map[string]set
// A map of user_id => UserStream which can be used to wake a given user's /sync request. // A map of user_id => UserStream which can be used to wake a given user's /sync request.
// Map access is guarded by userStreamsMutex. userStreams map[string]*UserStream
userStreams map[string]*UserStream
userStreamsMutex *sync.Mutex
} }
// NewNotifier creates a new notifier set to the given stream position. // NewNotifier creates a new notifier set to the given stream position.
@ -49,10 +48,9 @@ type Notifier struct {
func NewNotifier(pos types.StreamPosition) *Notifier { func NewNotifier(pos types.StreamPosition) *Notifier {
return &Notifier{ return &Notifier{
currPos: pos, currPos: pos,
currPosMutex: &sync.RWMutex{},
roomIDToJoinedUsers: make(map[string]set), roomIDToJoinedUsers: make(map[string]set),
userStreams: make(map[string]*UserStream), userStreams: make(map[string]*UserStream),
userStreamsMutex: &sync.Mutex{}, streamLock: &sync.Mutex{},
} }
} }
@ -60,11 +58,11 @@ func NewNotifier(pos types.StreamPosition) *Notifier {
// called from a single goroutine, to avoid races between updates which could set the // called from a single goroutine, to avoid races between updates which could set the
// current position in the stream incorrectly. // current position in the stream incorrectly.
func (n *Notifier) OnNewEvent(ev *gomatrixserverlib.Event, pos types.StreamPosition) { func (n *Notifier) OnNewEvent(ev *gomatrixserverlib.Event, pos types.StreamPosition) {
// update the current position in a guard and then notify relevant /sync streams. // update the current position then notify relevant /sync streams.
// This needs to be done PRIOR to waking up users as they will read this value. // This needs to be done PRIOR to waking up users as they will read this value.
n.currPosMutex.Lock() n.streamLock.Lock()
defer n.streamLock.Unlock()
n.currPos = pos n.currPos = pos
n.currPosMutex.Unlock()
// Map this event's room_id to a list of joined users, and wake them up. // Map this event's room_id to a list of joined users, and wake them up.
userIDs := n.joinedUsers(ev.RoomID()) userIDs := n.joinedUsers(ev.RoomID())
@ -107,20 +105,24 @@ func (n *Notifier) WaitForEvents(req syncRequest) types.StreamPosition {
// but given we don't do /events, let's pretend it doesn't exist. // but given we don't do /events, let's pretend it doesn't exist.
// In a guard, check if the /sync request should block, and block it until we get woken up // In a guard, check if the /sync request should block, and block it until we get woken up
n.currPosMutex.RLock() n.streamLock.Lock()
currentPos := n.currPos currentPos := n.currPos
n.currPosMutex.RUnlock()
// TODO: We increment the stream position for any event, so it's possible that we return immediately // TODO: We increment the stream position for any event, so it's possible that we return immediately
// with a pos which contains no new events for this user. We should probably re-wait for events // with a pos which contains no new events for this user. We should probably re-wait for events
// automatically in this case. // automatically in this case.
if req.since != currentPos { if req.since != currentPos {
n.streamLock.Unlock()
return currentPos return currentPos
} }
// wait to be woken up, and then re-check the stream position // wait to be woken up, and then re-check the stream position
req.log.WithField("user_id", req.userID).Info("Waiting for event") req.log.WithField("user_id", req.userID).Info("Waiting for event")
return n.blockUser(req.userID)
// give up the stream lock prior to waiting on the user lock
stream := n.fetchUserStream(req.userID, true)
n.streamLock.Unlock()
return stream.Wait()
} }
// Load the membership states required to notify users correctly. // Load the membership states required to notify users correctly.
@ -156,17 +158,10 @@ func (n *Notifier) wakeupUser(userID string, newPos types.StreamPosition) {
stream.Broadcast(newPos) // wakeup all goroutines Wait()ing on this stream stream.Broadcast(newPos) // wakeup all goroutines Wait()ing on this stream
} }
func (n *Notifier) blockUser(userID string) types.StreamPosition {
stream := n.fetchUserStream(userID, true)
return stream.Wait()
}
// fetchUserStream retrieves a stream unique to the given user. If makeIfNotExists is true, // fetchUserStream retrieves a stream unique to the given user. If makeIfNotExists is true,
// a stream will be made for this user if one doesn't exist and it will be returned. This // a stream will be made for this user if one doesn't exist and it will be returned. This
// function does not wait for data to be available on the stream. // function does not wait for data to be available on the stream.
func (n *Notifier) fetchUserStream(userID string, makeIfNotExists bool) *UserStream { func (n *Notifier) fetchUserStream(userID string, makeIfNotExists bool) *UserStream {
n.userStreamsMutex.Lock()
defer n.userStreamsMutex.Unlock()
stream, ok := n.userStreams[userID] stream, ok := n.userStreams[userID]
if !ok { if !ok {
// TODO: Unbounded growth of streams (1 per user) // TODO: Unbounded growth of streams (1 per user)