Export SharedUsers/SharedUsers

This commit is contained in:
S7evinK 2022-04-01 20:26:40 +02:00
parent f2c82aaf74
commit 21d1ac8610
5 changed files with 66 additions and 46 deletions

View file

@ -48,9 +48,8 @@ type Notifier struct {
// NewNotifier creates a new notifier set to the given sync position. // NewNotifier creates a new notifier set to the given sync position.
// In order for this to be of any use, the Notifier needs to be told all rooms and // In order for this to be of any use, the Notifier needs to be told all rooms and
// the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase). // the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase).
func NewNotifier(currPos types.StreamingToken) *Notifier { func NewNotifier() *Notifier {
return &Notifier{ return &Notifier{
currPos: currPos,
roomIDToJoinedUsers: make(map[string]userIDSet), roomIDToJoinedUsers: make(map[string]userIDSet),
roomIDToPeekingDevices: make(map[string]peekingDeviceSet), roomIDToPeekingDevices: make(map[string]peekingDeviceSet),
userDeviceStreams: make(map[string]map[string]*UserDeviceStream), userDeviceStreams: make(map[string]map[string]*UserDeviceStream),
@ -59,6 +58,12 @@ func NewNotifier(currPos types.StreamingToken) *Notifier {
} }
} }
// SetCurrentPosition sets the current streaming positions.
// This must be called directly after NewNotifier and initialising the streams.
func (n *Notifier) SetCurrentPosition(currPos types.StreamingToken) {
n.currPos = currPos
}
// OnNewEvent is called when a new event is received from the room server. Must only be // OnNewEvent is called when a new event is received from the room server. Must only be
// called from a single goroutine, to avoid races between updates which could set the // called from a single goroutine, to avoid races between updates which could set the
// current sync position incorrectly. // current sync position incorrectly.
@ -83,7 +88,7 @@ func (n *Notifier) OnNewEvent(
if ev != nil { if ev != nil {
// Map this event's room_id to a list of joined users, and wake them up. // Map this event's room_id to a list of joined users, and wake them up.
usersToNotify := n.joinedUsers(ev.RoomID()) usersToNotify := n.JoinedUsers(ev.RoomID())
// Map this event's room_id to a list of peeking devices, and wake them up. // Map this event's room_id to a list of peeking devices, and wake them up.
peekingDevicesToNotify := n.PeekingDevices(ev.RoomID()) peekingDevicesToNotify := n.PeekingDevices(ev.RoomID())
// If this is an invite, also add in the invitee to this list. // If this is an invite, also add in the invitee to this list.
@ -114,7 +119,7 @@ func (n *Notifier) OnNewEvent(
n.wakeupUsers(usersToNotify, peekingDevicesToNotify, n.currPos) n.wakeupUsers(usersToNotify, peekingDevicesToNotify, n.currPos)
} else if roomID != "" { } else if roomID != "" {
n.wakeupUsers(n.joinedUsers(roomID), n.PeekingDevices(roomID), n.currPos) n.wakeupUsers(n.JoinedUsers(roomID), n.PeekingDevices(roomID), n.currPos)
} else if len(userIDs) > 0 { } else if len(userIDs) > 0 {
n.wakeupUsers(userIDs, nil, n.currPos) n.wakeupUsers(userIDs, nil, n.currPos)
} else { } else {
@ -182,7 +187,7 @@ func (n *Notifier) OnNewTyping(
defer n.streamLock.Unlock() defer n.streamLock.Unlock()
n.currPos.ApplyUpdates(posUpdate) n.currPos.ApplyUpdates(posUpdate)
n.wakeupUsers(n.joinedUsers(roomID), nil, n.currPos) n.wakeupUsers(n.JoinedUsers(roomID), nil, n.currPos)
} }
// OnNewReceipt updates the current position // OnNewReceipt updates the current position
@ -194,7 +199,7 @@ func (n *Notifier) OnNewReceipt(
defer n.streamLock.Unlock() defer n.streamLock.Unlock()
n.currPos.ApplyUpdates(posUpdate) n.currPos.ApplyUpdates(posUpdate)
n.wakeupUsers(n.joinedUsers(roomID), nil, n.currPos) n.wakeupUsers(n.JoinedUsers(roomID), nil, n.currPos)
} }
func (n *Notifier) OnNewKeyChange( func (n *Notifier) OnNewKeyChange(
@ -235,16 +240,16 @@ func (n *Notifier) OnNewPresence(
defer n.streamLock.Unlock() defer n.streamLock.Unlock()
n.currPos.ApplyUpdates(posUpdate) n.currPos.ApplyUpdates(posUpdate)
sharedUsers := n.sharedUsers(userID) sharedUsers := n.SharedUsers(userID)
sharedUsers = append(sharedUsers, userID) sharedUsers = append(sharedUsers, userID)
n.wakeupUsers(sharedUsers, nil, n.currPos) n.wakeupUsers(sharedUsers, nil, n.currPos)
} }
func (n *Notifier) sharedUsers(userID string) (sharedUsers []string) { func (n *Notifier) SharedUsers(userID string) (sharedUsers []string) {
for roomID, users := range n.roomIDToJoinedUsers { for roomID, users := range n.roomIDToJoinedUsers {
if _, ok := users[userID]; ok { if _, ok := users[userID]; ok {
sharedUsers = append(sharedUsers, n.joinedUsers(roomID)...) sharedUsers = append(sharedUsers, n.JoinedUsers(roomID)...)
} }
} }
return sharedUsers return sharedUsers
@ -414,7 +419,7 @@ func (n *Notifier) removeJoinedUser(roomID, userID string) {
} }
// Not thread-safe: must be called on the OnNewEvent goroutine only // Not thread-safe: must be called on the OnNewEvent goroutine only
func (n *Notifier) joinedUsers(roomID string) (userIDs []string) { func (n *Notifier) JoinedUsers(roomID string) (userIDs []string) {
if _, ok := n.roomIDToJoinedUsers[roomID]; !ok { if _, ok := n.roomIDToJoinedUsers[roomID]; !ok {
return return
} }

View file

@ -107,7 +107,8 @@ func mustEqualPositions(t *testing.T, got, want types.StreamingToken) {
// Test that the current position is returned if a request is already behind. // Test that the current position is returned if a request is already behind.
func TestImmediateNotification(t *testing.T) { func TestImmediateNotification(t *testing.T) {
n := NewNotifier(syncPositionBefore) n := NewNotifier()
n.SetCurrentPosition(syncPositionBefore)
pos, err := waitForEvents(n, newTestSyncRequest(alice, aliceDev, syncPositionVeryOld)) pos, err := waitForEvents(n, newTestSyncRequest(alice, aliceDev, syncPositionVeryOld))
if err != nil { if err != nil {
t.Fatalf("TestImmediateNotification error: %s", err) t.Fatalf("TestImmediateNotification error: %s", err)
@ -117,7 +118,8 @@ func TestImmediateNotification(t *testing.T) {
// Test that new events to a joined room unblocks the request. // Test that new events to a joined room unblocks the request.
func TestNewEventAndJoinedToRoom(t *testing.T) { func TestNewEventAndJoinedToRoom(t *testing.T) {
n := NewNotifier(syncPositionBefore) n := NewNotifier()
n.SetCurrentPosition(syncPositionBefore)
n.setUsersJoinedToRooms(map[string][]string{ n.setUsersJoinedToRooms(map[string][]string{
roomID: {alice, bob}, roomID: {alice, bob},
}) })
@ -142,7 +144,8 @@ func TestNewEventAndJoinedToRoom(t *testing.T) {
} }
func TestCorrectStream(t *testing.T) { func TestCorrectStream(t *testing.T) {
n := NewNotifier(syncPositionBefore) n := NewNotifier()
n.SetCurrentPosition(syncPositionBefore)
stream := lockedFetchUserStream(n, bob, bobDev) stream := lockedFetchUserStream(n, bob, bobDev)
if stream.UserID != bob { if stream.UserID != bob {
t.Fatalf("expected user %q, got %q", bob, stream.UserID) t.Fatalf("expected user %q, got %q", bob, stream.UserID)
@ -153,7 +156,8 @@ func TestCorrectStream(t *testing.T) {
} }
func TestCorrectStreamWakeup(t *testing.T) { func TestCorrectStreamWakeup(t *testing.T) {
n := NewNotifier(syncPositionBefore) n := NewNotifier()
n.SetCurrentPosition(syncPositionBefore)
awoken := make(chan string) awoken := make(chan string)
streamone := lockedFetchUserStream(n, alice, "one") streamone := lockedFetchUserStream(n, alice, "one")
@ -180,7 +184,8 @@ func TestCorrectStreamWakeup(t *testing.T) {
// Test that an invite unblocks the request // Test that an invite unblocks the request
func TestNewInviteEventForUser(t *testing.T) { func TestNewInviteEventForUser(t *testing.T) {
n := NewNotifier(syncPositionBefore) n := NewNotifier()
n.SetCurrentPosition(syncPositionBefore)
n.setUsersJoinedToRooms(map[string][]string{ n.setUsersJoinedToRooms(map[string][]string{
roomID: {alice, bob}, roomID: {alice, bob},
}) })
@ -236,7 +241,8 @@ func TestEDUWakeup(t *testing.T) {
// Test that all blocked requests get woken up on a new event. // Test that all blocked requests get woken up on a new event.
func TestMultipleRequestWakeup(t *testing.T) { func TestMultipleRequestWakeup(t *testing.T) {
n := NewNotifier(syncPositionBefore) n := NewNotifier()
n.SetCurrentPosition(syncPositionBefore)
n.setUsersJoinedToRooms(map[string][]string{ n.setUsersJoinedToRooms(map[string][]string{
roomID: {alice, bob}, roomID: {alice, bob},
}) })
@ -272,7 +278,8 @@ func TestMultipleRequestWakeup(t *testing.T) {
func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) { func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) {
// listen as bob. Make bob leave room. Make alice send event to room. // listen as bob. Make bob leave room. Make alice send event to room.
// Make sure alice gets woken up only and not bob as well. // Make sure alice gets woken up only and not bob as well.
n := NewNotifier(syncPositionBefore) n := NewNotifier()
n.SetCurrentPosition(syncPositionBefore)
n.setUsersJoinedToRooms(map[string][]string{ n.setUsersJoinedToRooms(map[string][]string{
roomID: {alice, bob}, roomID: {alice, bob},
}) })

View file

@ -20,6 +20,7 @@ import (
"encoding/json" "encoding/json"
"sync" "sync"
"github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -27,7 +28,8 @@ import (
type PresenceStreamProvider struct { type PresenceStreamProvider struct {
StreamProvider StreamProvider
// cache contains previously sent presence updates to avoid unneeded updates // cache contains previously sent presence updates to avoid unneeded updates
cache sync.Map cache sync.Map
notifier *notifier.Notifier
} }
func (p *PresenceStreamProvider) Setup() { func (p *PresenceStreamProvider) Setup() {
@ -63,39 +65,42 @@ func (p *PresenceStreamProvider) IncrementalSync(
} }
// get all joined users // get all joined users
rooms, err := p.DB.AllJoinedUsersInRooms(ctx) // TODO: SharedUsers might get out of syncf
if err != nil { sharedUsers := p.notifier.SharedUsers(req.Device.UserID)
req.Log.WithError(err).Error("unable to query joined users")
return from
}
sharedUsers := map[string]bool{ sharedUsersMap := map[string]bool{
req.Device.UserID: true, req.Device.UserID: true,
} }
for roomID := range req.Rooms { // convert array to a map for easier checking if a user exists
roomUsers := rooms[roomID] for i := range sharedUsers {
for i := range roomUsers { sharedUsersMap[sharedUsers[i]] = true
sharedUsers[roomUsers[i]] = true
}
} }
// add newly joined rooms user presences // add newly joined rooms user presences
newlyJoined := joinedRooms(req.Response, req.Device.UserID) newlyJoined := joinedRooms(req.Response, req.Device.UserID)
for _, roomID := range newlyJoined { if len(newlyJoined) > 0 {
roomUsers := rooms[roomID] // TODO: This refreshes all lists and is quite expensive
for i := range roomUsers { // The notifier should update the lists itself
sharedUsers[roomUsers[i]] = true if err = p.notifier.Load(ctx, p.DB); err != nil {
// we already got a presence from this user req.Log.WithError(err).Error("unable to refresh notifier lists")
if _, ok := presences[roomUsers[i]]; ok { return from
continue }
} for _, roomID := range newlyJoined {
presences[roomUsers[i]], err = p.DB.GetPresence(ctx, roomUsers[i]) roomUsers := p.notifier.JoinedUsers(roomID)
if err != nil { for i := range roomUsers {
if err == sql.ErrNoRows { sharedUsersMap[roomUsers[i]] = true
// we already got a presence from this user
if _, ok := presences[roomUsers[i]]; ok {
continue continue
} }
req.Log.WithError(err).Error("unable to query presence for user") presences[roomUsers[i]], err = p.DB.GetPresence(ctx, roomUsers[i])
return from if err != nil {
if err == sql.ErrNoRows {
continue
}
req.Log.WithError(err).Error("unable to query presence for user")
return from
}
} }
} }
} }
@ -104,7 +109,7 @@ func (p *PresenceStreamProvider) IncrementalSync(
for i := range presences { for i := range presences {
presence := presences[i] presence := presences[i]
// Ignore users we don't share a room with // Ignore users we don't share a room with
if !sharedUsers[presence.UserID] { if !sharedUsersMap[presence.UserID] {
continue continue
} }
cacheKey := req.Device.UserID + req.Device.ID + presence.UserID cacheKey := req.Device.UserID + req.Device.ID + presence.UserID

View file

@ -6,6 +6,7 @@ import (
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
keyapi "github.com/matrix-org/dendrite/keyserver/api" keyapi "github.com/matrix-org/dendrite/keyserver/api"
rsapi "github.com/matrix-org/dendrite/roomserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
@ -26,7 +27,7 @@ type Streams struct {
func NewSyncStreamProviders( func NewSyncStreamProviders(
d storage.Database, userAPI userapi.UserInternalAPI, d storage.Database, userAPI userapi.UserInternalAPI,
rsAPI rsapi.RoomserverInternalAPI, keyAPI keyapi.KeyInternalAPI, rsAPI rsapi.RoomserverInternalAPI, keyAPI keyapi.KeyInternalAPI,
eduCache *caching.EDUCache, eduCache *caching.EDUCache, notifier *notifier.Notifier,
) *Streams { ) *Streams {
streams := &Streams{ streams := &Streams{
PDUStreamProvider: &PDUStreamProvider{ PDUStreamProvider: &PDUStreamProvider{
@ -59,6 +60,7 @@ func NewSyncStreamProviders(
}, },
PresenceStreamProvider: &PresenceStreamProvider{ PresenceStreamProvider: &PresenceStreamProvider{
StreamProvider: StreamProvider{DB: d}, StreamProvider: StreamProvider{DB: d},
notifier: notifier,
}, },
} }

View file

@ -57,8 +57,9 @@ func AddPublicRoutes(
} }
eduCache := caching.NewTypingCache() eduCache := caching.NewTypingCache()
streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, keyAPI, eduCache) notifier := notifier.NewNotifier()
notifier := notifier.NewNotifier(streams.Latest(context.Background())) streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, keyAPI, eduCache, notifier)
notifier.SetCurrentPosition(streams.Latest(context.Background()))
if err = notifier.Load(context.Background(), syncDB); err != nil { if err = notifier.Load(context.Background(), syncDB); err != nil {
logrus.WithError(err).Panicf("failed to load notifier ") logrus.WithError(err).Panicf("failed to load notifier ")
} }