diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index a93368701..e9a57d8da 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -132,12 +132,18 @@ type authEvents struct { events EventMap } +// Valid verifies that all auth events are from the same room. func (ae *authEvents) Valid() bool { - roomIDs := make(map[string]struct{}) - for _, ev := range ae.events { - roomIDs[ev.RoomID()] = struct{}{} + roomID := "" + for i := range ae.events { + if i == 0 { + roomID = ae.events[i].RoomID() + } + if roomID != ae.events[i].RoomID() { + return false + } } - return len(roomIDs) <= 1 + return true } // Create implements gomatrixserverlib.AuthEventProvider diff --git a/roomserver/internal/input/input_events_test.go b/roomserver/internal/input/input_events_test.go index 60119b2ed..818e7715c 100644 --- a/roomserver/internal/input/input_events_test.go +++ b/roomserver/internal/input/input_events_test.go @@ -8,7 +8,7 @@ import ( "github.com/matrix-org/dendrite/test" ) -func Test_1(t *testing.T) { +func Test_EventAuth(t *testing.T) { alice := test.NewUser(t) bob := test.NewUser(t) @@ -35,7 +35,7 @@ func Test_1(t *testing.T) { } } - // Add the illegal auth event from room1 + // Add the illegal auth event from room1 (rooms are different) for _, x := range room1.Events() { if x.Type() == gomatrixserverlib.MRoomMember { authEventIDs = append(authEventIDs, x.EventID()) @@ -43,7 +43,7 @@ func Test_1(t *testing.T) { } } - // Craft the illegal join event + // Craft the illegal join event, with auth events from different rooms ev := room2.CreateEvent(t, bob, "m.room.member", map[string]interface{}{ "membership": "join", }, test.WithStateKey(bob.ID), test.WithAuthIDs(authEventIDs))