Only wake up /sync requests which the event is for (#101)
This commit is contained in:
parent
0a3d44a80a
commit
d5a44fd3e8
|
@ -78,6 +78,9 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
n := sync.NewNotifier(types.StreamPosition(pos))
|
n := sync.NewNotifier(types.StreamPosition(pos))
|
||||||
|
if err := n.Load(db); err != nil {
|
||||||
|
log.Panicf("startup: failed to set up notifier: %s", err)
|
||||||
|
}
|
||||||
server, err := consumers.NewServer(cfg, n, db)
|
server, err := consumers.NewServer(cfg, n, db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Panicf("startup: failed to create sync server: %s", err)
|
log.Panicf("startup: failed to create sync server: %s", err)
|
||||||
|
|
|
@ -61,11 +61,15 @@ const selectRoomIDsWithMembershipSQL = "" +
|
||||||
const selectCurrentStateSQL = "" +
|
const selectCurrentStateSQL = "" +
|
||||||
"SELECT event_json FROM current_room_state WHERE room_id = $1"
|
"SELECT event_json FROM current_room_state WHERE room_id = $1"
|
||||||
|
|
||||||
|
const selectJoinedUsersSQL = "" +
|
||||||
|
"SELECT room_id, state_key FROM current_room_state WHERE type = 'm.room.member' AND membership = 'join'"
|
||||||
|
|
||||||
type currentRoomStateStatements struct {
|
type currentRoomStateStatements struct {
|
||||||
upsertRoomStateStmt *sql.Stmt
|
upsertRoomStateStmt *sql.Stmt
|
||||||
deleteRoomStateByEventIDStmt *sql.Stmt
|
deleteRoomStateByEventIDStmt *sql.Stmt
|
||||||
selectRoomIDsWithMembershipStmt *sql.Stmt
|
selectRoomIDsWithMembershipStmt *sql.Stmt
|
||||||
selectCurrentStateStmt *sql.Stmt
|
selectCurrentStateStmt *sql.Stmt
|
||||||
|
selectJoinedUsersStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *currentRoomStateStatements) prepare(db *sql.DB) (err error) {
|
func (s *currentRoomStateStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
@ -85,8 +89,33 @@ func (s *currentRoomStateStatements) prepare(db *sql.DB) (err error) {
|
||||||
if s.selectCurrentStateStmt, err = db.Prepare(selectCurrentStateSQL); err != nil {
|
if s.selectCurrentStateStmt, err = db.Prepare(selectCurrentStateSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// JoinedMemberLists returns a map of room ID to a list of joined user IDs.
|
||||||
|
func (s *currentRoomStateStatements) JoinedMemberLists() (map[string][]string, error) {
|
||||||
|
rows, err := s.selectJoinedUsersStmt.Query()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
result := make(map[string][]string)
|
||||||
|
for rows.Next() {
|
||||||
|
var roomID string
|
||||||
|
var userID string
|
||||||
|
if err := rows.Scan(&roomID, &userID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
users := result[roomID]
|
||||||
|
users = append(users, userID)
|
||||||
|
result[roomID] = users
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
|
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
|
||||||
func (s *currentRoomStateStatements) SelectRoomIDsWithMembership(txn *sql.Tx, userID, membership string) ([]string, error) {
|
func (s *currentRoomStateStatements) SelectRoomIDsWithMembership(txn *sql.Tx, userID, membership string) ([]string, error) {
|
||||||
|
|
|
@ -61,6 +61,11 @@ func NewSyncServerDatabase(dataSourceName string) (*SyncServerDatabase, error) {
|
||||||
return &SyncServerDatabase{db, partitions, events, state}, nil
|
return &SyncServerDatabase{db, partitions, events, state}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs.
|
||||||
|
func (d *SyncServerDatabase) AllJoinedUsersInRooms() (map[string][]string, error) {
|
||||||
|
return d.roomstate.JoinedMemberLists()
|
||||||
|
}
|
||||||
|
|
||||||
// WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races
|
// WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races
|
||||||
// when generating the stream position for this event. Returns the sync stream position for the inserted event.
|
// when generating the stream position for this event. Returns the sync stream position for the inserted event.
|
||||||
// Returns an error if there was a problem inserting this event.
|
// Returns an error if there was a problem inserting this event.
|
||||||
|
|
|
@ -15,27 +15,42 @@
|
||||||
package sync
|
package sync
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/Sirupsen/logrus"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/events"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Notifier will wake up sleeping requests in the request pool when there
|
// Notifier will wake up sleeping requests when there is some new data.
|
||||||
// is some new data. It does not tell requests what that data is, only the
|
// It does not tell requests what that data is, only the stream position which
|
||||||
// stream position which they can use to get at it.
|
// they can use to get at it. This is done to prevent races whereby we tell the caller
|
||||||
|
// the event, but the token has already advanced by the time they fetch it, resulting
|
||||||
|
// in missed events.
|
||||||
type Notifier struct {
|
type Notifier struct {
|
||||||
// The latest sync stream position: guarded by 'cond'.
|
// A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine
|
||||||
|
roomIDToJoinedUsers map[string]set
|
||||||
|
// Protects currPos and userStreams.
|
||||||
|
streamLock *sync.Mutex
|
||||||
|
// The latest sync stream position: guarded by 'currPosMutex' which is RW to allow
|
||||||
|
// for concurrent reads on /sync requests
|
||||||
currPos types.StreamPosition
|
currPos types.StreamPosition
|
||||||
// A condition variable to notify all waiting goroutines of a new sync stream position
|
// A map of user_id => UserStream which can be used to wake a given user's /sync request.
|
||||||
cond *sync.Cond
|
userStreams map[string]*UserStream
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewNotifier creates a new notifier set to the given stream position.
|
// NewNotifier creates a new notifier set to the given stream position.
|
||||||
|
// In order for this to be of any use, the Notifier needs to be told all rooms and
|
||||||
|
// the joined users within each of them by calling Notifier.LoadFromDatabase().
|
||||||
func NewNotifier(pos types.StreamPosition) *Notifier {
|
func NewNotifier(pos types.StreamPosition) *Notifier {
|
||||||
return &Notifier{
|
return &Notifier{
|
||||||
pos,
|
currPos: pos,
|
||||||
sync.NewCond(&sync.Mutex{}),
|
roomIDToJoinedUsers: make(map[string]set),
|
||||||
|
userStreams: make(map[string]*UserStream),
|
||||||
|
streamLock: &sync.Mutex{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -43,25 +58,157 @@ func NewNotifier(pos types.StreamPosition) *Notifier {
|
||||||
// called from a single goroutine, to avoid races between updates which could set the
|
// called from a single goroutine, to avoid races between updates which could set the
|
||||||
// current position in the stream incorrectly.
|
// current position in the stream incorrectly.
|
||||||
func (n *Notifier) OnNewEvent(ev *gomatrixserverlib.Event, pos types.StreamPosition) {
|
func (n *Notifier) OnNewEvent(ev *gomatrixserverlib.Event, pos types.StreamPosition) {
|
||||||
// update the current position in a guard and then notify all /sync streams
|
// update the current position then notify relevant /sync streams.
|
||||||
n.cond.L.Lock()
|
// This needs to be done PRIOR to waking up users as they will read this value.
|
||||||
|
n.streamLock.Lock()
|
||||||
|
defer n.streamLock.Unlock()
|
||||||
n.currPos = pos
|
n.currPos = pos
|
||||||
n.cond.L.Unlock()
|
|
||||||
|
|
||||||
n.cond.Broadcast() // notify ALL waiting goroutines
|
// Map this event's room_id to a list of joined users, and wake them up.
|
||||||
|
userIDs := n.joinedUsers(ev.RoomID())
|
||||||
|
// If this is an invite, also add in the invitee to this list.
|
||||||
|
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
|
||||||
|
userID := *ev.StateKey()
|
||||||
|
var memberContent events.MemberContent
|
||||||
|
if err := json.Unmarshal(ev.Content(), &memberContent); err != nil {
|
||||||
|
log.WithError(err).WithField("event_id", ev.EventID()).Errorf(
|
||||||
|
"Notifier.OnNewEvent: Failed to unmarshal member event",
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
// Keep the joined user map up-to-date
|
||||||
|
switch memberContent.Membership {
|
||||||
|
case "invite":
|
||||||
|
userIDs = append(userIDs, userID)
|
||||||
|
case "join":
|
||||||
|
n.addJoinedUser(ev.RoomID(), userID)
|
||||||
|
case "leave":
|
||||||
|
fallthrough
|
||||||
|
case "ban":
|
||||||
|
n.removeJoinedUser(ev.RoomID(), userID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, userID := range userIDs {
|
||||||
|
n.wakeupUser(userID, pos)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WaitForEvents blocks until there are new events for this request.
|
// WaitForEvents blocks until there are new events for this request.
|
||||||
func (n *Notifier) WaitForEvents(req syncRequest) types.StreamPosition {
|
func (n *Notifier) WaitForEvents(req syncRequest) types.StreamPosition {
|
||||||
// In a guard, check if the /sync request should block, and block it until we get a new position
|
// Do what synapse does: https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/notifier.py#L298
|
||||||
n.cond.L.Lock()
|
// - Bucket request into a lookup map keyed off a list of joined room IDs and separately a user ID
|
||||||
|
// - Incoming events wake requests for a matching room ID
|
||||||
|
// - Incoming events wake requests for a matching user ID (needed for invites)
|
||||||
|
|
||||||
|
// TODO: v1 /events 'peeking' has an 'explicit room ID' which is also tracked,
|
||||||
|
// but given we don't do /events, let's pretend it doesn't exist.
|
||||||
|
|
||||||
|
// In a guard, check if the /sync request should block, and block it until we get woken up
|
||||||
|
n.streamLock.Lock()
|
||||||
currentPos := n.currPos
|
currentPos := n.currPos
|
||||||
for req.since == currentPos {
|
|
||||||
// we need to wait for a new event.
|
// TODO: We increment the stream position for any event, so it's possible that we return immediately
|
||||||
// TODO: This waits for ANY new event, we need to only wait for events which we care about.
|
// with a pos which contains no new events for this user. We should probably re-wait for events
|
||||||
n.cond.Wait() // atomically unlocks and blocks goroutine, then re-acquires lock on unblock
|
// automatically in this case.
|
||||||
currentPos = n.currPos
|
if req.since != currentPos {
|
||||||
}
|
n.streamLock.Unlock()
|
||||||
n.cond.L.Unlock()
|
|
||||||
return currentPos
|
return currentPos
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// wait to be woken up, and then re-check the stream position
|
||||||
|
req.log.WithField("user_id", req.userID).Info("Waiting for event")
|
||||||
|
|
||||||
|
// give up the stream lock prior to waiting on the user lock
|
||||||
|
stream := n.fetchUserStream(req.userID, true)
|
||||||
|
n.streamLock.Unlock()
|
||||||
|
return stream.Wait(currentPos)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load the membership states required to notify users correctly.
|
||||||
|
func (n *Notifier) Load(db *storage.SyncServerDatabase) error {
|
||||||
|
roomToUsers, err := db.AllJoinedUsersInRooms()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
n.setUsersJoinedToRooms(roomToUsers)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setUsersJoinedToRooms marks the given users as 'joined' to the given rooms, such that new events from
|
||||||
|
// these rooms will wake the given users /sync requests. This should be called prior to ANY calls to
|
||||||
|
// OnNewEvent (eg on startup) to prevent racing.
|
||||||
|
func (n *Notifier) setUsersJoinedToRooms(roomIDToUserIDs map[string][]string) {
|
||||||
|
// This is just the bulk form of addJoinedUser
|
||||||
|
for roomID, userIDs := range roomIDToUserIDs {
|
||||||
|
if _, ok := n.roomIDToJoinedUsers[roomID]; !ok {
|
||||||
|
n.roomIDToJoinedUsers[roomID] = make(set)
|
||||||
|
}
|
||||||
|
for _, userID := range userIDs {
|
||||||
|
n.roomIDToJoinedUsers[roomID].add(userID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) wakeupUser(userID string, newPos types.StreamPosition) {
|
||||||
|
stream := n.fetchUserStream(userID, false)
|
||||||
|
if stream == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
stream.Broadcast(newPos) // wakeup all goroutines Wait()ing on this stream
|
||||||
|
}
|
||||||
|
|
||||||
|
// fetchUserStream retrieves a stream unique to the given user. If makeIfNotExists is true,
|
||||||
|
// a stream will be made for this user if one doesn't exist and it will be returned. This
|
||||||
|
// function does not wait for data to be available on the stream.
|
||||||
|
func (n *Notifier) fetchUserStream(userID string, makeIfNotExists bool) *UserStream {
|
||||||
|
stream, ok := n.userStreams[userID]
|
||||||
|
if !ok {
|
||||||
|
// TODO: Unbounded growth of streams (1 per user)
|
||||||
|
stream = NewUserStream(userID)
|
||||||
|
n.userStreams[userID] = stream
|
||||||
|
}
|
||||||
|
return stream
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not thread-safe: must be called on the OnNewEvent goroutine only
|
||||||
|
func (n *Notifier) addJoinedUser(roomID, userID string) {
|
||||||
|
if _, ok := n.roomIDToJoinedUsers[roomID]; !ok {
|
||||||
|
n.roomIDToJoinedUsers[roomID] = make(set)
|
||||||
|
}
|
||||||
|
n.roomIDToJoinedUsers[roomID].add(userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not thread-safe: must be called on the OnNewEvent goroutine only
|
||||||
|
func (n *Notifier) removeJoinedUser(roomID, userID string) {
|
||||||
|
if _, ok := n.roomIDToJoinedUsers[roomID]; !ok {
|
||||||
|
n.roomIDToJoinedUsers[roomID] = make(set)
|
||||||
|
}
|
||||||
|
n.roomIDToJoinedUsers[roomID].remove(userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not thread-safe: must be called on the OnNewEvent goroutine only
|
||||||
|
func (n *Notifier) joinedUsers(roomID string) (userIDs []string) {
|
||||||
|
if _, ok := n.roomIDToJoinedUsers[roomID]; !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return n.roomIDToJoinedUsers[roomID].values()
|
||||||
|
}
|
||||||
|
|
||||||
|
// A string set, mainly existing for improving clarity of structs in this file.
|
||||||
|
type set map[string]bool
|
||||||
|
|
||||||
|
func (s set) add(str string) {
|
||||||
|
s[str] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s set) remove(str string) {
|
||||||
|
delete(s, str)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s set) values() (vals []string) {
|
||||||
|
for str := range s {
|
||||||
|
vals = append(vals, str)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
292
src/github.com/matrix-org/dendrite/syncapi/sync/notifier_test.go
Normal file
292
src/github.com/matrix-org/dendrite/syncapi/sync/notifier_test.go
Normal file
|
@ -0,0 +1,292 @@
|
||||||
|
// Copyright 2017 Vector Creations Ltd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package sync
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
randomMessageEvent gomatrixserverlib.Event
|
||||||
|
aliceInviteBobEvent gomatrixserverlib.Event
|
||||||
|
bobLeaveEvent gomatrixserverlib.Event
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
streamPositionVeryOld = types.StreamPosition(5)
|
||||||
|
streamPositionBefore = types.StreamPosition(11)
|
||||||
|
streamPositionAfter = types.StreamPosition(12)
|
||||||
|
streamPositionAfter2 = types.StreamPosition(13)
|
||||||
|
roomID = "!test:localhost"
|
||||||
|
alice = "@alice:localhost"
|
||||||
|
bob = "@bob:localhost"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
var err error
|
||||||
|
randomMessageEvent, err = gomatrixserverlib.NewEventFromTrustedJSON([]byte(`{
|
||||||
|
"type": "m.room.message",
|
||||||
|
"content": {
|
||||||
|
"body": "Hello World",
|
||||||
|
"msgtype": "m.text"
|
||||||
|
},
|
||||||
|
"sender": "@noone:localhost",
|
||||||
|
"room_id": "`+roomID+`",
|
||||||
|
"origin_server_ts": 12345,
|
||||||
|
"event_id": "$randomMessageEvent:localhost"
|
||||||
|
}`), false)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
aliceInviteBobEvent, err = gomatrixserverlib.NewEventFromTrustedJSON([]byte(`{
|
||||||
|
"type": "m.room.member",
|
||||||
|
"state_key": "`+bob+`",
|
||||||
|
"content": {
|
||||||
|
"membership": "invite"
|
||||||
|
},
|
||||||
|
"sender": "`+alice+`",
|
||||||
|
"room_id": "`+roomID+`",
|
||||||
|
"origin_server_ts": 12345,
|
||||||
|
"event_id": "$aliceInviteBobEvent:localhost"
|
||||||
|
}`), false)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
bobLeaveEvent, err = gomatrixserverlib.NewEventFromTrustedJSON([]byte(`{
|
||||||
|
"type": "m.room.member",
|
||||||
|
"state_key": "`+bob+`",
|
||||||
|
"content": {
|
||||||
|
"membership": "leave"
|
||||||
|
},
|
||||||
|
"sender": "`+bob+`",
|
||||||
|
"room_id": "`+roomID+`",
|
||||||
|
"origin_server_ts": 12345,
|
||||||
|
"event_id": "$bobLeaveEvent:localhost"
|
||||||
|
}`), false)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that the current position is returned if a request is already behind.
|
||||||
|
func TestImmediateNotification(t *testing.T) {
|
||||||
|
n := NewNotifier(streamPositionBefore)
|
||||||
|
pos, err := waitForEvents(n, newTestSyncRequest(alice, streamPositionVeryOld))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("TestImmediateNotification error: %s", err)
|
||||||
|
}
|
||||||
|
if pos != streamPositionBefore {
|
||||||
|
t.Fatalf("TestImmediateNotification want %d, got %d", streamPositionBefore, pos)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that new events to a joined room unblocks the request.
|
||||||
|
func TestNewEventAndJoinedToRoom(t *testing.T) {
|
||||||
|
n := NewNotifier(streamPositionBefore)
|
||||||
|
n.setUsersJoinedToRooms(map[string][]string{
|
||||||
|
roomID: []string{alice, bob},
|
||||||
|
})
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
pos, err := waitForEvents(n, newTestSyncRequest(bob, streamPositionBefore))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("TestNewEventAndJoinedToRoom error: %s", err)
|
||||||
|
}
|
||||||
|
if pos != streamPositionAfter {
|
||||||
|
t.Errorf("TestNewEventAndJoinedToRoom want %d, got %d", streamPositionAfter, pos)
|
||||||
|
}
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
|
stream := n.fetchUserStream(bob, true)
|
||||||
|
waitForBlocking(stream, 1)
|
||||||
|
|
||||||
|
n.OnNewEvent(&randomMessageEvent, streamPositionAfter)
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that an invite unblocks the request
|
||||||
|
func TestNewInviteEventForUser(t *testing.T) {
|
||||||
|
n := NewNotifier(streamPositionBefore)
|
||||||
|
n.setUsersJoinedToRooms(map[string][]string{
|
||||||
|
roomID: []string{alice, bob},
|
||||||
|
})
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
pos, err := waitForEvents(n, newTestSyncRequest(bob, streamPositionBefore))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("TestNewInviteEventForUser error: %s", err)
|
||||||
|
}
|
||||||
|
if pos != streamPositionAfter {
|
||||||
|
t.Errorf("TestNewInviteEventForUser want %d, got %d", streamPositionAfter, pos)
|
||||||
|
}
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
|
stream := n.fetchUserStream(bob, true)
|
||||||
|
waitForBlocking(stream, 1)
|
||||||
|
|
||||||
|
n.OnNewEvent(&aliceInviteBobEvent, streamPositionAfter)
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that all blocked requests get woken up on a new event.
|
||||||
|
func TestMultipleRequestWakeup(t *testing.T) {
|
||||||
|
n := NewNotifier(streamPositionBefore)
|
||||||
|
n.setUsersJoinedToRooms(map[string][]string{
|
||||||
|
roomID: []string{alice, bob},
|
||||||
|
})
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(3)
|
||||||
|
poll := func() {
|
||||||
|
pos, err := waitForEvents(n, newTestSyncRequest(bob, streamPositionBefore))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("TestMultipleRequestWakeup error: %s", err)
|
||||||
|
}
|
||||||
|
if pos != streamPositionAfter {
|
||||||
|
t.Errorf("TestMultipleRequestWakeup want %d, got %d", streamPositionAfter, pos)
|
||||||
|
}
|
||||||
|
wg.Done()
|
||||||
|
}
|
||||||
|
go poll()
|
||||||
|
go poll()
|
||||||
|
go poll()
|
||||||
|
|
||||||
|
stream := n.fetchUserStream(bob, true)
|
||||||
|
waitForBlocking(stream, 3)
|
||||||
|
|
||||||
|
n.OnNewEvent(&randomMessageEvent, streamPositionAfter)
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
numWaiting := stream.NumWaiting()
|
||||||
|
if numWaiting != 0 {
|
||||||
|
t.Errorf("TestMultipleRequestWakeup NumWaiting() want 0, got %d", numWaiting)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that you stop getting woken up when you leave a room.
|
||||||
|
func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) {
|
||||||
|
// listen as bob. Make bob leave room. Make alice send event to room.
|
||||||
|
// Make sure alice gets woken up only and not bob as well.
|
||||||
|
n := NewNotifier(streamPositionBefore)
|
||||||
|
n.setUsersJoinedToRooms(map[string][]string{
|
||||||
|
roomID: []string{alice, bob},
|
||||||
|
})
|
||||||
|
|
||||||
|
var leaveWG sync.WaitGroup
|
||||||
|
|
||||||
|
// Make bob leave the room
|
||||||
|
leaveWG.Add(1)
|
||||||
|
go func() {
|
||||||
|
pos, err := waitForEvents(n, newTestSyncRequest(bob, streamPositionBefore))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %s", err)
|
||||||
|
}
|
||||||
|
if pos != streamPositionAfter {
|
||||||
|
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %d, got %d", streamPositionAfter, pos)
|
||||||
|
}
|
||||||
|
leaveWG.Done()
|
||||||
|
}()
|
||||||
|
bobStream := n.fetchUserStream(bob, true)
|
||||||
|
waitForBlocking(bobStream, 1)
|
||||||
|
n.OnNewEvent(&bobLeaveEvent, streamPositionAfter)
|
||||||
|
leaveWG.Wait()
|
||||||
|
|
||||||
|
// send an event into the room. Make sure alice gets it. Bob should not.
|
||||||
|
var aliceWG sync.WaitGroup
|
||||||
|
aliceStream := n.fetchUserStream(alice, true)
|
||||||
|
aliceWG.Add(1)
|
||||||
|
go func() {
|
||||||
|
pos, err := waitForEvents(n, newTestSyncRequest(alice, streamPositionAfter))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %s", err)
|
||||||
|
}
|
||||||
|
if pos != streamPositionAfter2 {
|
||||||
|
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %d, got %d", streamPositionAfter2, pos)
|
||||||
|
}
|
||||||
|
aliceWG.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
// this should timeout with an error (but the main goroutine won't wait for the timeout explicitly)
|
||||||
|
_, err := waitForEvents(n, newTestSyncRequest(bob, streamPositionAfter))
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom expect error but got nil")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
waitForBlocking(aliceStream, 1)
|
||||||
|
waitForBlocking(bobStream, 1)
|
||||||
|
|
||||||
|
n.OnNewEvent(&randomMessageEvent, streamPositionAfter2)
|
||||||
|
aliceWG.Wait()
|
||||||
|
|
||||||
|
// it's possible that at this point alice has been informed and bob is about to be informed, so wait
|
||||||
|
// for a fraction of a second to account for this race
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
// same as Notifier.WaitForEvents but with a timeout.
|
||||||
|
func waitForEvents(n *Notifier, req syncRequest) (types.StreamPosition, error) {
|
||||||
|
done := make(chan types.StreamPosition, 1)
|
||||||
|
go func() {
|
||||||
|
newPos := n.WaitForEvents(req)
|
||||||
|
done <- newPos
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
return types.StreamPosition(0), fmt.Errorf(
|
||||||
|
"waitForEvents timed out waiting for %s (pos=%d)", req.userID, req.since,
|
||||||
|
)
|
||||||
|
case p := <-done:
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait until something is Wait()ing on the user stream.
|
||||||
|
func waitForBlocking(s *UserStream, numBlocking int) {
|
||||||
|
for numBlocking != s.NumWaiting() {
|
||||||
|
// This is horrible but I don't want to add a signalling mechanism JUST for testing.
|
||||||
|
time.Sleep(1 * time.Microsecond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestSyncRequest(userID string, since types.StreamPosition) syncRequest {
|
||||||
|
return syncRequest{
|
||||||
|
userID: userID,
|
||||||
|
timeout: 1 * time.Minute,
|
||||||
|
since: since,
|
||||||
|
wantFullState: false,
|
||||||
|
limit: defaultTimelineLimit,
|
||||||
|
log: util.GetLogger(context.TODO()),
|
||||||
|
}
|
||||||
|
}
|
|
@ -15,10 +15,13 @@
|
||||||
package sync
|
package sync
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
log "github.com/Sirupsen/logrus"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultSyncTimeout = time.Duration(30) * time.Second
|
const defaultSyncTimeout = time.Duration(30) * time.Second
|
||||||
|
@ -31,6 +34,7 @@ type syncRequest struct {
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
since types.StreamPosition
|
since types.StreamPosition
|
||||||
wantFullState bool
|
wantFullState bool
|
||||||
|
log *log.Entry
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSyncRequest(req *http.Request, userID string) (*syncRequest, error) {
|
func newSyncRequest(req *http.Request, userID string) (*syncRequest, error) {
|
||||||
|
@ -48,6 +52,7 @@ func newSyncRequest(req *http.Request, userID string) (*syncRequest, error) {
|
||||||
since: since,
|
since: since,
|
||||||
wantFullState: wantFullState,
|
wantFullState: wantFullState,
|
||||||
limit: defaultTimelineLimit, // TODO: read from filter
|
limit: defaultTimelineLimit, // TODO: read from filter
|
||||||
|
log: util.GetLogger(req.Context()),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,79 @@
|
||||||
|
// Copyright 2017 Vector Creations Ltd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package sync
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserStream represents a communication mechanism between the /sync request goroutine
|
||||||
|
// and the underlying sync server goroutines. Goroutines can Wait() for a stream position and
|
||||||
|
// goroutines can Broadcast(streamPosition) to other goroutines.
|
||||||
|
type UserStream struct {
|
||||||
|
UserID string
|
||||||
|
// Because this is a Cond, we can notify all waiting goroutines so this works
|
||||||
|
// across devices for the same user. Protects pos.
|
||||||
|
cond *sync.Cond
|
||||||
|
// The position to broadcast to callers of Wait().
|
||||||
|
pos types.StreamPosition
|
||||||
|
// The number of goroutines blocked on Wait() - used for testing and metrics
|
||||||
|
numWaiting int
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUserStream creates a new user stream
|
||||||
|
func NewUserStream(userID string) *UserStream {
|
||||||
|
return &UserStream{
|
||||||
|
UserID: userID,
|
||||||
|
cond: sync.NewCond(&sync.Mutex{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait blocks until there is a new stream position for this user, which is then returned.
|
||||||
|
// waitAtPos should be the position the stream thinks it should be waiting at.
|
||||||
|
func (s *UserStream) Wait(waitAtPos types.StreamPosition) (pos types.StreamPosition) {
|
||||||
|
s.cond.L.Lock()
|
||||||
|
// Before we start blocking, we need to make sure that we didn't race with a call
|
||||||
|
// to Broadcast() between calling Wait() and actually sleeping. We check the last
|
||||||
|
// broadcast pos to see if it is newer than the pos we are meant to wait at. If it
|
||||||
|
// is newer, something has Broadcast to this stream more recently so return immediately.
|
||||||
|
if s.pos > waitAtPos {
|
||||||
|
pos = s.pos
|
||||||
|
s.cond.L.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.numWaiting++
|
||||||
|
s.cond.Wait()
|
||||||
|
pos = s.pos
|
||||||
|
s.numWaiting--
|
||||||
|
s.cond.L.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Broadcast a new stream position for this user.
|
||||||
|
func (s *UserStream) Broadcast(pos types.StreamPosition) {
|
||||||
|
s.cond.L.Lock()
|
||||||
|
s.pos = pos
|
||||||
|
s.cond.L.Unlock()
|
||||||
|
s.cond.Broadcast()
|
||||||
|
}
|
||||||
|
|
||||||
|
// NumWaiting returns the number of goroutines waiting for Wait() to return. Used for metrics and testing.
|
||||||
|
func (s *UserStream) NumWaiting() int {
|
||||||
|
s.cond.L.Lock()
|
||||||
|
defer s.cond.L.Unlock()
|
||||||
|
return s.numWaiting
|
||||||
|
}
|
Loading…
Reference in a new issue