diff --git a/clientapi/routing/directory_public.go b/clientapi/routing/directory_public.go index 4ebf2295a..6c21af80d 100644 --- a/clientapi/routing/directory_public.go +++ b/clientapi/routing/directory_public.go @@ -39,14 +39,17 @@ var ( ) type PublicRoomReq struct { - Since string `json:"since,omitempty"` - Limit int16 `json:"limit,omitempty"` - Filter filter `json:"filter,omitempty"` - Server string `json:"server,omitempty"` + Since string `json:"since,omitempty"` + Limit int16 `json:"limit,omitempty"` + Filter filter `json:"filter,omitempty"` + Server string `json:"server,omitempty"` + IncludeAllNetworks bool `json:"include_all_networks,omitempty"` + NetworkID string `json:"third_party_instance_id,omitempty"` } type filter struct { - SearchTerms string `json:"generic_search_term,omitempty"` + SearchTerms string `json:"generic_search_term,omitempty"` + RoomTypes []string `json:"room_types,omitempty"` // TODO: Implement filter on this } // GetPostPublicRooms implements GET and POST /publicRooms @@ -61,6 +64,13 @@ func GetPostPublicRooms( return *fillErr } + if request.IncludeAllNetworks && request.NetworkID != "" { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidParam("include_all_networks and third_party_instance_id can not be used together"), + } + } + serverName := gomatrixserverlib.ServerName(request.Server) if serverName != "" && !cfg.Matrix.IsLocalServerName(serverName) { res, err := federation.GetPublicRoomsFiltered( diff --git a/federationapi/routing/publicrooms.go b/federationapi/routing/publicrooms.go index 1a54f5a7d..b4c4717eb 100644 --- a/federationapi/routing/publicrooms.go +++ b/federationapi/routing/publicrooms.go @@ -2,24 +2,29 @@ package routing import ( "context" + "fmt" "net/http" "strconv" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" ) type PublicRoomReq struct { - Since string `json:"since,omitempty"` - Limit int16 `json:"limit,omitempty"` - Filter filter `json:"filter,omitempty"` + Since string `json:"since,omitempty"` + Limit int16 `json:"limit,omitempty"` + Filter filter `json:"filter,omitempty"` + IncludeAllNetworks bool `json:"include_all_networks,omitempty"` + NetworkID string `json:"third_party_instance_id,omitempty"` } type filter struct { - SearchTerms string `json:"generic_search_term,omitempty"` + SearchTerms string `json:"generic_search_term,omitempty"` + RoomTypes []string `json:"room_types,omitempty"` } // GetPostPublicRooms implements GET and POST /publicRooms @@ -57,8 +62,14 @@ func publicRooms( return nil, err } + if request.IncludeAllNetworks && request.NetworkID != "" { + return nil, fmt.Errorf("include_all_networks and third_party_instance_id can not be used together") + } + var queryRes roomserverAPI.QueryPublishedRoomsResponse - err = rsAPI.QueryPublishedRooms(ctx, &roomserverAPI.QueryPublishedRoomsRequest{}, &queryRes) + err = rsAPI.QueryPublishedRooms(ctx, &roomserverAPI.QueryPublishedRoomsRequest{ + NetworkdID: request.NetworkID, + }, &queryRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("QueryPublishedRooms failed") return nil, err diff --git a/roomserver/api/query.go b/roomserver/api/query.go index d63c24785..92636198e 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -21,8 +21,9 @@ import ( "fmt" "strings" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) // QueryLatestEventsAndStateRequest is a request to QueryLatestEventsAndState @@ -257,7 +258,8 @@ type QueryRoomVersionForRoomResponse struct { type QueryPublishedRoomsRequest struct { // Optional. If specified, returns whether this room is published or not. - RoomID string + RoomID string + NetworkdID string } type QueryPublishedRoomsResponse struct { diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index f41132403..35a25a379 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -702,7 +702,7 @@ func (r *Queryer) QueryPublishedRooms( } return err } - rooms, err := r.DB.GetPublishedRooms(ctx) + rooms, err := r.DB.GetPublishedRooms(ctx, req.NetworkdID) if err != nil { return err } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 6b5062531..4b9c990ae 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -141,7 +141,7 @@ type Database interface { // Publish or unpublish a room from the room directory. PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error // Returns a list of room IDs for rooms which are published. - GetPublishedRooms(ctx context.Context) ([]string, error) + GetPublishedRooms(ctx context.Context, networkID string) ([]string, error) // Returns whether a given room is published or not. GetPublishedRoom(ctx context.Context, roomID string) (bool, error) diff --git a/roomserver/storage/postgres/published_table.go b/roomserver/storage/postgres/published_table.go index c0302c85c..d9e02b611 100644 --- a/roomserver/storage/postgres/published_table.go +++ b/roomserver/storage/postgres/published_table.go @@ -46,13 +46,17 @@ const upsertPublishedSQL = "" + const selectAllPublishedSQL = "" + "SELECT room_id FROM roomserver_published WHERE published = $1 ORDER BY room_id ASC" +const selectNetworkPublishedSQL = "" + + "SELECT room_id FROM roomserver_published WHERE published = $1 AND network_id = $2 ORDER BY room_id ASC" + const selectPublishedSQL = "" + "SELECT published FROM roomserver_published WHERE room_id = $1" type publishedStatements struct { - upsertPublishedStmt *sql.Stmt - selectAllPublishedStmt *sql.Stmt - selectPublishedStmt *sql.Stmt + upsertPublishedStmt *sql.Stmt + selectAllPublishedStmt *sql.Stmt + selectPublishedStmt *sql.Stmt + selectNetworkPublishedStmt *sql.Stmt } func CreatePublishedTable(db *sql.DB) error { @@ -75,6 +79,7 @@ func PreparePublishedTable(db *sql.DB) (tables.Published, error) { {&s.upsertPublishedStmt, upsertPublishedSQL}, {&s.selectAllPublishedStmt, selectAllPublishedSQL}, {&s.selectPublishedStmt, selectPublishedSQL}, + {&s.selectNetworkPublishedStmt, selectNetworkPublishedSQL}, }.Prepare(db) } @@ -98,10 +103,18 @@ func (s *publishedStatements) SelectPublishedFromRoomID( } func (s *publishedStatements) SelectAllPublishedRooms( - ctx context.Context, txn *sql.Tx, published bool, + ctx context.Context, txn *sql.Tx, networkID string, published bool, ) ([]string, error) { - stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt) - rows, err := stmt.QueryContext(ctx, published) + var rows *sql.Rows + var err error + if networkID != "" { + stmt := sqlutil.TxStmt(txn, s.selectNetworkPublishedStmt) + rows, err = stmt.QueryContext(ctx, published, networkID) + } else { + stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt) + rows, err = stmt.QueryContext(ctx, published) + + } if err != nil { return nil, err } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 07413a316..b9047204b 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -732,8 +732,8 @@ func (d *Database) GetPublishedRoom(ctx context.Context, roomID string) (bool, e return d.PublishedTable.SelectPublishedFromRoomID(ctx, nil, roomID) } -func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) { - return d.PublishedTable.SelectAllPublishedRooms(ctx, nil, true) +func (d *Database) GetPublishedRooms(ctx context.Context, networkID string) ([]string, error) { + return d.PublishedTable.SelectAllPublishedRooms(ctx, nil, networkID, true) } func (d *Database) MissingAuthPrevEvents( diff --git a/roomserver/storage/sqlite3/published_table.go b/roomserver/storage/sqlite3/published_table.go index c82266ff2..bbeb96bb9 100644 --- a/roomserver/storage/sqlite3/published_table.go +++ b/roomserver/storage/sqlite3/published_table.go @@ -46,14 +46,18 @@ const upsertPublishedSQL = "" + const selectAllPublishedSQL = "" + "SELECT room_id FROM roomserver_published WHERE published = $1 ORDER BY room_id ASC" +const selectNetworkPublishedSQL = "" + + "SELECT room_id FROM roomserver_published WHERE published = $1 AND network_id = $2 ORDER BY room_id ASC" + const selectPublishedSQL = "" + "SELECT published FROM roomserver_published WHERE room_id = $1" type publishedStatements struct { - db *sql.DB - upsertPublishedStmt *sql.Stmt - selectAllPublishedStmt *sql.Stmt - selectPublishedStmt *sql.Stmt + db *sql.DB + upsertPublishedStmt *sql.Stmt + selectAllPublishedStmt *sql.Stmt + selectPublishedStmt *sql.Stmt + selectNetworkPublishedStmt *sql.Stmt } func CreatePublishedTable(db *sql.DB) error { @@ -78,6 +82,7 @@ func PreparePublishedTable(db *sql.DB) (tables.Published, error) { {&s.upsertPublishedStmt, upsertPublishedSQL}, {&s.selectAllPublishedStmt, selectAllPublishedSQL}, {&s.selectPublishedStmt, selectPublishedSQL}, + {&s.selectNetworkPublishedStmt, selectNetworkPublishedSQL}, }.Prepare(db) } @@ -101,10 +106,17 @@ func (s *publishedStatements) SelectPublishedFromRoomID( } func (s *publishedStatements) SelectAllPublishedRooms( - ctx context.Context, txn *sql.Tx, published bool, + ctx context.Context, txn *sql.Tx, networkID string, published bool, ) ([]string, error) { - stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt) - rows, err := stmt.QueryContext(ctx, published) + var rows *sql.Rows + var err error + if networkID != "" { + stmt := sqlutil.TxStmt(txn, s.selectNetworkPublishedStmt) + rows, err = stmt.QueryContext(ctx, published, networkID) + } else { + stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt) + rows, err = stmt.QueryContext(ctx, published) + } if err != nil { return nil, err } diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 7fc68974c..203a4dcf9 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -148,7 +148,7 @@ type Membership interface { type Published interface { UpsertRoomPublished(ctx context.Context, txn *sql.Tx, roomID, appserviceID, networkID string, published bool) (err error) SelectPublishedFromRoomID(ctx context.Context, txn *sql.Tx, roomID string) (published bool, err error) - SelectAllPublishedRooms(ctx context.Context, txn *sql.Tx, published bool) ([]string, error) + SelectAllPublishedRooms(ctx context.Context, txn *sql.Tx, networkdID string, published bool) ([]string, error) } type RedactionInfo struct { diff --git a/roomserver/storage/tables/published_table_test.go b/roomserver/storage/tables/published_table_test.go index 5e6437004..4211249dc 100644 --- a/roomserver/storage/tables/published_table_test.go +++ b/roomserver/storage/tables/published_table_test.go @@ -65,7 +65,7 @@ func TestPublishedTable(t *testing.T) { sort.Strings(publishedRooms) // check that we get the expected published rooms - roomIDs, err := tab.SelectAllPublishedRooms(ctx, nil, true) + roomIDs, err := tab.SelectAllPublishedRooms(ctx, nil, "", true) assert.NoError(t, err) assert.Equal(t, publishedRooms, roomIDs) @@ -79,5 +79,22 @@ func TestPublishedTable(t *testing.T) { publishedRes, err := tab.SelectPublishedFromRoomID(ctx, nil, room.ID) assert.NoError(t, err) assert.False(t, publishedRes, fmt.Sprintf("expected room %s to be unpublished", room.ID)) + + // network specific test + nwID = "irc" + room = test.NewRoom(t, alice) + err = tab.UpsertRoomPublished(ctx, nil, room.ID, asID, nwID, true) + assert.NoError(t, err) + publishedRooms = append(publishedRooms, room.ID) + sort.Strings(publishedRooms) + // should only return the room for network "irc" + allNWPublished, err := tab.SelectAllPublishedRooms(ctx, nil, nwID, true) + assert.NoError(t, err) + assert.Equal(t, []string{room.ID}, allNWPublished) + + // check that we still get all published rooms regardless networkID + roomIDs, err = tab.SelectAllPublishedRooms(ctx, nil, "", true) + assert.NoError(t, err) + assert.Equal(t, publishedRooms, roomIDs) }) }