diff --git a/roomserver/api/query.go b/roomserver/api/query.go index 92636198e..b62907f3c 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -258,8 +258,9 @@ type QueryRoomVersionForRoomResponse struct { type QueryPublishedRoomsRequest struct { // Optional. If specified, returns whether this room is published or not. - RoomID string - NetworkdID string + RoomID string + NetworkID string + IncludeAllNetworks bool } type QueryPublishedRoomsResponse struct { diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 35a25a379..0db046a86 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -702,7 +702,7 @@ func (r *Queryer) QueryPublishedRooms( } return err } - rooms, err := r.DB.GetPublishedRooms(ctx, req.NetworkdID) + rooms, err := r.DB.GetPublishedRooms(ctx, req.NetworkID, req.IncludeAllNetworks) if err != nil { return err } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 4b9c990ae..094537948 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -141,7 +141,7 @@ type Database interface { // Publish or unpublish a room from the room directory. PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error // Returns a list of room IDs for rooms which are published. - GetPublishedRooms(ctx context.Context, networkID string) ([]string, error) + GetPublishedRooms(ctx context.Context, networkID string, includeAllNetworks bool) ([]string, error) // Returns whether a given room is published or not. GetPublishedRoom(ctx context.Context, roomID string) (bool, error) diff --git a/roomserver/storage/postgres/published_table.go b/roomserver/storage/postgres/published_table.go index d9e02b611..61caccb0e 100644 --- a/roomserver/storage/postgres/published_table.go +++ b/roomserver/storage/postgres/published_table.go @@ -44,7 +44,7 @@ const upsertPublishedSQL = "" + "ON CONFLICT (room_id, appservice_id, network_id) DO UPDATE SET published=$4" const selectAllPublishedSQL = "" + - "SELECT room_id FROM roomserver_published WHERE published = $1 ORDER BY room_id ASC" + "SELECT room_id FROM roomserver_published WHERE published = $1 AND CASE WHEN $2 THEN 1=1 ELSE network_id = '' END ORDER BY room_id ASC" const selectNetworkPublishedSQL = "" + "SELECT room_id FROM roomserver_published WHERE published = $1 AND network_id = $2 ORDER BY room_id ASC" @@ -103,7 +103,7 @@ func (s *publishedStatements) SelectPublishedFromRoomID( } func (s *publishedStatements) SelectAllPublishedRooms( - ctx context.Context, txn *sql.Tx, networkID string, published bool, + ctx context.Context, txn *sql.Tx, networkID string, published, includeAllNetworks bool, ) ([]string, error) { var rows *sql.Rows var err error @@ -112,7 +112,7 @@ func (s *publishedStatements) SelectAllPublishedRooms( rows, err = stmt.QueryContext(ctx, published, networkID) } else { stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt) - rows, err = stmt.QueryContext(ctx, published) + rows, err = stmt.QueryContext(ctx, published, includeAllNetworks) } if err != nil { diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index b9047204b..ed86280bf 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -732,8 +732,8 @@ func (d *Database) GetPublishedRoom(ctx context.Context, roomID string) (bool, e return d.PublishedTable.SelectPublishedFromRoomID(ctx, nil, roomID) } -func (d *Database) GetPublishedRooms(ctx context.Context, networkID string) ([]string, error) { - return d.PublishedTable.SelectAllPublishedRooms(ctx, nil, networkID, true) +func (d *Database) GetPublishedRooms(ctx context.Context, networkID string, includeAllNetworks bool) ([]string, error) { + return d.PublishedTable.SelectAllPublishedRooms(ctx, nil, networkID, true, includeAllNetworks) } func (d *Database) MissingAuthPrevEvents( diff --git a/roomserver/storage/sqlite3/published_table.go b/roomserver/storage/sqlite3/published_table.go index bbeb96bb9..34666552e 100644 --- a/roomserver/storage/sqlite3/published_table.go +++ b/roomserver/storage/sqlite3/published_table.go @@ -44,7 +44,7 @@ const upsertPublishedSQL = "" + " ON CONFLICT (room_id, appservice_id, network_id) DO UPDATE SET published = $4" const selectAllPublishedSQL = "" + - "SELECT room_id FROM roomserver_published WHERE published = $1 ORDER BY room_id ASC" + "SELECT room_id FROM roomserver_published WHERE published = $1 AND CASE WHEN $2 THEN 1=1 ELSE network_id = '' END ORDER BY room_id ASC" const selectNetworkPublishedSQL = "" + "SELECT room_id FROM roomserver_published WHERE published = $1 AND network_id = $2 ORDER BY room_id ASC" @@ -106,7 +106,7 @@ func (s *publishedStatements) SelectPublishedFromRoomID( } func (s *publishedStatements) SelectAllPublishedRooms( - ctx context.Context, txn *sql.Tx, networkID string, published bool, + ctx context.Context, txn *sql.Tx, networkID string, published, includeAllNetworks bool, ) ([]string, error) { var rows *sql.Rows var err error @@ -115,7 +115,7 @@ func (s *publishedStatements) SelectAllPublishedRooms( rows, err = stmt.QueryContext(ctx, published, networkID) } else { stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt) - rows, err = stmt.QueryContext(ctx, published) + rows, err = stmt.QueryContext(ctx, published, includeAllNetworks) } if err != nil { return nil, err diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 203a4dcf9..8d6ca324c 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -148,7 +148,7 @@ type Membership interface { type Published interface { UpsertRoomPublished(ctx context.Context, txn *sql.Tx, roomID, appserviceID, networkID string, published bool) (err error) SelectPublishedFromRoomID(ctx context.Context, txn *sql.Tx, roomID string) (published bool, err error) - SelectAllPublishedRooms(ctx context.Context, txn *sql.Tx, networkdID string, published bool) ([]string, error) + SelectAllPublishedRooms(ctx context.Context, txn *sql.Tx, networkdID string, published, includeAllNetworks bool) ([]string, error) } type RedactionInfo struct { diff --git a/roomserver/storage/tables/published_table_test.go b/roomserver/storage/tables/published_table_test.go index 4211249dc..e6289e9b1 100644 --- a/roomserver/storage/tables/published_table_test.go +++ b/roomserver/storage/tables/published_table_test.go @@ -65,7 +65,7 @@ func TestPublishedTable(t *testing.T) { sort.Strings(publishedRooms) // check that we get the expected published rooms - roomIDs, err := tab.SelectAllPublishedRooms(ctx, nil, "", true) + roomIDs, err := tab.SelectAllPublishedRooms(ctx, nil, "", true, true) assert.NoError(t, err) assert.Equal(t, publishedRooms, roomIDs) @@ -88,12 +88,12 @@ func TestPublishedTable(t *testing.T) { publishedRooms = append(publishedRooms, room.ID) sort.Strings(publishedRooms) // should only return the room for network "irc" - allNWPublished, err := tab.SelectAllPublishedRooms(ctx, nil, nwID, true) + allNWPublished, err := tab.SelectAllPublishedRooms(ctx, nil, nwID, true, true) assert.NoError(t, err) assert.Equal(t, []string{room.ID}, allNWPublished) // check that we still get all published rooms regardless networkID - roomIDs, err = tab.SelectAllPublishedRooms(ctx, nil, "", true) + roomIDs, err = tab.SelectAllPublishedRooms(ctx, nil, "", true, true) assert.NoError(t, err) assert.Equal(t, publishedRooms, roomIDs) })