diff --git a/syncapi/notifier/notifier.go b/syncapi/notifier/notifier.go index f76456859..4ee7c8605 100644 --- a/syncapi/notifier/notifier.go +++ b/syncapi/notifier/notifier.go @@ -20,6 +20,7 @@ import ( "time" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/api" rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" @@ -36,7 +37,8 @@ import ( // the event, but the token has already advanced by the time they fetch it, resulting // in missed events. type Notifier struct { - lock *sync.RWMutex + lock *sync.RWMutex + rsAPI api.SyncRoomserverAPI // A map of RoomID => Set : Must only be accessed by the OnNewEvent goroutine roomIDToJoinedUsers map[string]*userIDSet // A map of RoomID => Set : Must only be accessed by the OnNewEvent goroutine @@ -55,8 +57,9 @@ type Notifier struct { // NewNotifier creates a new notifier set to the given sync position. // In order for this to be of any use, the Notifier needs to be told all rooms and // the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase). -func NewNotifier() *Notifier { +func NewNotifier(rsAPI api.SyncRoomserverAPI) *Notifier { return &Notifier{ + rsAPI: rsAPI, roomIDToJoinedUsers: make(map[string]*userIDSet), roomIDToPeekingDevices: make(map[string]peekingDeviceSet), userDeviceStreams: make(map[string]map[string]*UserDeviceStream), @@ -104,26 +107,32 @@ func (n *Notifier) OnNewEvent( peekingDevicesToNotify := n._peekingDevices(ev.RoomID()) // If this is an invite, also add in the invitee to this list. if ev.Type() == "m.room.member" && ev.StateKey() != nil { - targetUserID := *ev.StateKey() - membership, err := ev.Membership() + targetUserID, err := n.rsAPI.QueryUserIDForSender(context.Background(), ev.RoomID(), spec.SenderID(*ev.StateKey())) if err != nil { log.WithError(err).WithField("event_id", ev.EventID()).Errorf( - "Notifier.OnNewEvent: Failed to unmarshal member event", + "Notifier.OnNewEvent: Failed to find the userID for this event", ) } else { - // Keep the joined user map up-to-date - switch membership { - case spec.Invite: - usersToNotify = append(usersToNotify, targetUserID) - case spec.Join: - // Manually append the new user's ID so they get notified - // along all members in the room - usersToNotify = append(usersToNotify, targetUserID) - n._addJoinedUser(ev.RoomID(), targetUserID) - case spec.Leave: - fallthrough - case spec.Ban: - n._removeJoinedUser(ev.RoomID(), targetUserID) + membership, err := ev.Membership() + if err != nil { + log.WithError(err).WithField("event_id", ev.EventID()).Errorf( + "Notifier.OnNewEvent: Failed to unmarshal member event", + ) + } else { + // Keep the joined user map up-to-date + switch membership { + case spec.Invite: + usersToNotify = append(usersToNotify, targetUserID.String()) + case spec.Join: + // Manually append the new user's ID so they get notified + // along all members in the room + usersToNotify = append(usersToNotify, targetUserID.String()) + n._addJoinedUser(ev.RoomID(), targetUserID.String()) + case spec.Leave: + fallthrough + case spec.Ban: + n._removeJoinedUser(ev.RoomID(), targetUserID.String()) + } } } } diff --git a/syncapi/notifier/notifier_test.go b/syncapi/notifier/notifier_test.go index 36577a0ee..7076f7134 100644 --- a/syncapi/notifier/notifier_test.go +++ b/syncapi/notifier/notifier_test.go @@ -22,9 +22,11 @@ import ( "testing" "time" + "github.com/matrix-org/dendrite/roomserver/api" rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" ) @@ -105,9 +107,15 @@ func mustEqualPositions(t *testing.T, got, want types.StreamingToken) { } } +type TestRoomServer struct{ api.SyncRoomserverAPI } + +func (t *TestRoomServer) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) +} + // Test that the current position is returned if a request is already behind. func TestImmediateNotification(t *testing.T) { - n := NewNotifier() + n := NewNotifier(&TestRoomServer{}) n.SetCurrentPosition(syncPositionBefore) pos, err := waitForEvents(n, newTestSyncRequest(alice, aliceDev, syncPositionVeryOld)) if err != nil { @@ -118,7 +126,7 @@ func TestImmediateNotification(t *testing.T) { // Test that new events to a joined room unblocks the request. func TestNewEventAndJoinedToRoom(t *testing.T) { - n := NewNotifier() + n := NewNotifier(&TestRoomServer{}) n.SetCurrentPosition(syncPositionBefore) n.setUsersJoinedToRooms(map[string][]string{ roomID: {alice, bob}, @@ -144,7 +152,7 @@ func TestNewEventAndJoinedToRoom(t *testing.T) { } func TestCorrectStream(t *testing.T) { - n := NewNotifier() + n := NewNotifier(&TestRoomServer{}) n.SetCurrentPosition(syncPositionBefore) stream := lockedFetchUserStream(n, bob, bobDev) if stream.UserID != bob { @@ -156,7 +164,7 @@ func TestCorrectStream(t *testing.T) { } func TestCorrectStreamWakeup(t *testing.T) { - n := NewNotifier() + n := NewNotifier(&TestRoomServer{}) n.SetCurrentPosition(syncPositionBefore) awoken := make(chan string) @@ -184,7 +192,7 @@ func TestCorrectStreamWakeup(t *testing.T) { // Test that an invite unblocks the request func TestNewInviteEventForUser(t *testing.T) { - n := NewNotifier() + n := NewNotifier(&TestRoomServer{}) n.SetCurrentPosition(syncPositionBefore) n.setUsersJoinedToRooms(map[string][]string{ roomID: {alice, bob}, @@ -241,7 +249,7 @@ func TestEDUWakeup(t *testing.T) { // Test that all blocked requests get woken up on a new event. func TestMultipleRequestWakeup(t *testing.T) { - n := NewNotifier() + n := NewNotifier(&TestRoomServer{}) n.SetCurrentPosition(syncPositionBefore) n.setUsersJoinedToRooms(map[string][]string{ roomID: {alice, bob}, @@ -278,7 +286,7 @@ func TestMultipleRequestWakeup(t *testing.T) { func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) { // listen as bob. Make bob leave room. Make alice send event to room. // Make sure alice gets woken up only and not bob as well. - n := NewNotifier() + n := NewNotifier(&TestRoomServer{}) n.SetCurrentPosition(syncPositionBefore) n.setUsersJoinedToRooms(map[string][]string{ roomID: {alice, bob}, diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index ecbe05dd8..64a4af757 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -60,7 +60,7 @@ func AddPublicRoutes( } eduCache := caching.NewTypingCache() - notifier := notifier.NewNotifier() + notifier := notifier.NewNotifier(rsAPI) streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, eduCache, caches, notifier) notifier.SetCurrentPosition(streams.Latest(context.Background())) if err = notifier.Load(context.Background(), syncDB); err != nil {