diff --git a/roomserver/acls/acls.go b/roomserver/acls/acls.go index 017682e05..4950e6231 100644 --- a/roomserver/acls/acls.go +++ b/roomserver/acls/acls.go @@ -32,8 +32,8 @@ import ( const MRoomServerACL = "m.room.server_acl" type ServerACLDatabase interface { - // GetKnownRooms returns a list of all rooms we know about. - GetKnownRooms(ctx context.Context) ([]string, error) + // RoomsWithACLs returns all room IDs for rooms with ACLs + RoomsWithACLs(ctx context.Context) ([]string, error) // GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match. // If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned. @@ -57,7 +57,7 @@ func NewServerACLs(db ServerACLDatabase) *ServerACLs { } // Look up all of the rooms that the current state server knows about. - rooms, err := db.GetKnownRooms(ctx) + rooms, err := db.RoomsWithACLs(ctx) if err != nil { logrus.WithError(err).Fatalf("Failed to get known rooms") } diff --git a/roomserver/acls/acls_test.go b/roomserver/acls/acls_test.go index efe1d2093..09920308c 100644 --- a/roomserver/acls/acls_test.go +++ b/roomserver/acls/acls_test.go @@ -116,7 +116,7 @@ var ( type dummyACLDB struct{} -func (d dummyACLDB) GetKnownRooms(ctx context.Context) ([]string, error) { +func (d dummyACLDB) RoomsWithACLs(ctx context.Context) ([]string, error) { return []string{"1", "2"}, nil } diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 88e335711..85312efd9 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -1284,3 +1284,38 @@ func TestRoomConsumerRecreation(t *testing.T) { wantAckWait := input.MaximumMissingProcessingTime + (time.Second * 10) assert.Equal(t, wantAckWait, info.Config.AckWait) } + +func TestRoomsWithACLs(t *testing.T) { + ctx := context.Background() + alice := test.NewUser(t) + noACLRoom := test.NewRoom(t, alice) + aclRoom := test.NewRoom(t, alice) + + aclRoom.CreateAndInsert(t, alice, "m.room.server_acl", map[string]any{ + "deny": []string{"evilhost.test"}, + "allow": []string{"*"}, + }, test.WithStateKey("")) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, closeDB := testrig.CreateConfig(t, dbType) + defer closeDB() + + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + natsInstance := &jetstream.NATSInstance{} + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + // start JetStream listeners + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + + for _, room := range []*test.Room{noACLRoom, aclRoom} { + // Create the rooms + err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false) + assert.NoError(t, err) + } + + // Validate that we only have one ACLd room. + roomsWithACLs, err := rsAPI.RoomsWithACLs(ctx) + assert.NoError(t, err) + assert.Equal(t, []string{aclRoom.ID}, roomsWithACLs) + }) +}