diff --git a/roomserver/auth/auth.go b/roomserver/auth/auth.go index b6168d38b..ba10a4332 100644 --- a/roomserver/auth/auth.go +++ b/roomserver/auth/auth.go @@ -13,6 +13,9 @@ package auth import ( + "context" + + "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" ) @@ -22,6 +25,7 @@ import ( // IsServerAllowed returns true if the server is allowed to see events in the room // at this particular state. This function implements https://matrix.org/docs/spec/client_server/r0.6.0#id87 func IsServerAllowed( + ctx context.Context, db storage.RoomDatabase, serverName spec.ServerName, serverCurrentlyInRoom bool, authEvents []gomatrixserverlib.PDU, @@ -37,7 +41,7 @@ func IsServerAllowed( return true } // 2. If the user's membership was join, allow. - joinedUserExists := IsAnyUserOnServerWithMembership(serverName, authEvents, spec.Join) + joinedUserExists := IsAnyUserOnServerWithMembership(ctx, db, serverName, authEvents, spec.Join) if joinedUserExists { return true } @@ -46,7 +50,7 @@ func IsServerAllowed( return true } // 4. If the user's membership was invite, and the history_visibility was set to invited, allow. - invitedUserExists := IsAnyUserOnServerWithMembership(serverName, authEvents, spec.Invite) + invitedUserExists := IsAnyUserOnServerWithMembership(ctx, db, serverName, authEvents, spec.Invite) if invitedUserExists && historyVisibility == gomatrixserverlib.HistoryVisibilityInvited { return true } @@ -70,7 +74,7 @@ func HistoryVisibilityForRoom(authEvents []gomatrixserverlib.PDU) gomatrixserver return visibility } -func IsAnyUserOnServerWithMembership(serverName spec.ServerName, authEvents []gomatrixserverlib.PDU, wantMembership string) bool { +func IsAnyUserOnServerWithMembership(ctx context.Context, db storage.RoomDatabase, serverName spec.ServerName, authEvents []gomatrixserverlib.PDU, wantMembership string) bool { for _, ev := range authEvents { if ev.Type() != spec.MRoomMember { continue @@ -85,12 +89,12 @@ func IsAnyUserOnServerWithMembership(serverName spec.ServerName, authEvents []go continue } - _, domain, err := gomatrixserverlib.SplitID('@', *stateKey) + userID, err := db.GetUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*stateKey)) if err != nil { continue } - if domain == serverName { + if userID.Domain() == serverName { return true } } diff --git a/roomserver/auth/auth_test.go b/roomserver/auth/auth_test.go index e3eea5d8b..192d9e5da 100644 --- a/roomserver/auth/auth_test.go +++ b/roomserver/auth/auth_test.go @@ -1,13 +1,23 @@ package auth import ( + "context" "testing" + "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" ) +type FakeStorageDB struct { + storage.RoomDatabase +} + +func (f *FakeStorageDB) GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) +} + func TestIsServerAllowed(t *testing.T) { alice := test.NewUser(t) @@ -77,7 +87,7 @@ func TestIsServerAllowed(t *testing.T) { authEvents = append(authEvents, ev.PDU) } - if got := IsServerAllowed(tt.serverName, tt.serverCurrentlyInRoom, authEvents); got != tt.want { + if got := IsServerAllowed(context.Background(), &FakeStorageDB{}, tt.serverName, tt.serverCurrentlyInRoom, authEvents); got != tt.want { t.Errorf("IsServerAllowed() = %v, want %v", got, tt.want) } }) diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index eee8e1a55..d263353e3 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -95,7 +95,7 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam for i := range events { gmslEvents[i] = events[i].PDU } - return auth.IsAnyUserOnServerWithMembership(serverName, gmslEvents, spec.Join), nil + return auth.IsAnyUserOnServerWithMembership(ctx, db, serverName, gmslEvents, spec.Join), nil } func IsInvitePending( @@ -289,7 +289,7 @@ func CheckServerAllowedToSeeEvent( return false, err } } - return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil + return auth.IsServerAllowed(ctx, db, serverName, isServerInRoom, stateAtEvent), nil } func slowGetHistoryVisibilityState( diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index 388150936..8e87359a3 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -582,7 +582,7 @@ func joinEventsFromHistoryVisibility( } // Can we see events in the room? - canSeeEvents := auth.IsServerAllowed(thisServer, true, events) + canSeeEvents := auth.IsServerAllowed(ctx, db, thisServer, true, events) visibility := auth.HistoryVisibilityForRoom(events) if !canSeeEvents { logrus.Infof("ServersAtEvent history not visible to us: %s", visibility)