Allow querying all networks

This commit is contained in:
Till Faelligen 2022-10-27 12:00:18 +02:00
parent fea2fa49d1
commit e7808392cd
No known key found for this signature in database
GPG key ID: 3DF82D8AB9211D4E
8 changed files with 17 additions and 16 deletions

View file

@ -259,7 +259,8 @@ type QueryRoomVersionForRoomResponse struct {
type QueryPublishedRoomsRequest struct { type QueryPublishedRoomsRequest struct {
// Optional. If specified, returns whether this room is published or not. // Optional. If specified, returns whether this room is published or not.
RoomID string RoomID string
NetworkdID string NetworkID string
IncludeAllNetworks bool
} }
type QueryPublishedRoomsResponse struct { type QueryPublishedRoomsResponse struct {

View file

@ -702,7 +702,7 @@ func (r *Queryer) QueryPublishedRooms(
} }
return err return err
} }
rooms, err := r.DB.GetPublishedRooms(ctx, req.NetworkdID) rooms, err := r.DB.GetPublishedRooms(ctx, req.NetworkID, req.IncludeAllNetworks)
if err != nil { if err != nil {
return err return err
} }

View file

@ -141,7 +141,7 @@ type Database interface {
// Publish or unpublish a room from the room directory. // Publish or unpublish a room from the room directory.
PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error
// Returns a list of room IDs for rooms which are published. // Returns a list of room IDs for rooms which are published.
GetPublishedRooms(ctx context.Context, networkID string) ([]string, error) GetPublishedRooms(ctx context.Context, networkID string, includeAllNetworks bool) ([]string, error)
// Returns whether a given room is published or not. // Returns whether a given room is published or not.
GetPublishedRoom(ctx context.Context, roomID string) (bool, error) GetPublishedRoom(ctx context.Context, roomID string) (bool, error)

View file

@ -44,7 +44,7 @@ const upsertPublishedSQL = "" +
"ON CONFLICT (room_id, appservice_id, network_id) DO UPDATE SET published=$4" "ON CONFLICT (room_id, appservice_id, network_id) DO UPDATE SET published=$4"
const selectAllPublishedSQL = "" + const selectAllPublishedSQL = "" +
"SELECT room_id FROM roomserver_published WHERE published = $1 ORDER BY room_id ASC" "SELECT room_id FROM roomserver_published WHERE published = $1 AND CASE WHEN $2 THEN 1=1 ELSE network_id = '' END ORDER BY room_id ASC"
const selectNetworkPublishedSQL = "" + const selectNetworkPublishedSQL = "" +
"SELECT room_id FROM roomserver_published WHERE published = $1 AND network_id = $2 ORDER BY room_id ASC" "SELECT room_id FROM roomserver_published WHERE published = $1 AND network_id = $2 ORDER BY room_id ASC"
@ -103,7 +103,7 @@ func (s *publishedStatements) SelectPublishedFromRoomID(
} }
func (s *publishedStatements) SelectAllPublishedRooms( func (s *publishedStatements) SelectAllPublishedRooms(
ctx context.Context, txn *sql.Tx, networkID string, published bool, ctx context.Context, txn *sql.Tx, networkID string, published, includeAllNetworks bool,
) ([]string, error) { ) ([]string, error) {
var rows *sql.Rows var rows *sql.Rows
var err error var err error
@ -112,7 +112,7 @@ func (s *publishedStatements) SelectAllPublishedRooms(
rows, err = stmt.QueryContext(ctx, published, networkID) rows, err = stmt.QueryContext(ctx, published, networkID)
} else { } else {
stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt) stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt)
rows, err = stmt.QueryContext(ctx, published) rows, err = stmt.QueryContext(ctx, published, includeAllNetworks)
} }
if err != nil { if err != nil {

View file

@ -732,8 +732,8 @@ func (d *Database) GetPublishedRoom(ctx context.Context, roomID string) (bool, e
return d.PublishedTable.SelectPublishedFromRoomID(ctx, nil, roomID) return d.PublishedTable.SelectPublishedFromRoomID(ctx, nil, roomID)
} }
func (d *Database) GetPublishedRooms(ctx context.Context, networkID string) ([]string, error) { func (d *Database) GetPublishedRooms(ctx context.Context, networkID string, includeAllNetworks bool) ([]string, error) {
return d.PublishedTable.SelectAllPublishedRooms(ctx, nil, networkID, true) return d.PublishedTable.SelectAllPublishedRooms(ctx, nil, networkID, true, includeAllNetworks)
} }
func (d *Database) MissingAuthPrevEvents( func (d *Database) MissingAuthPrevEvents(

View file

@ -44,7 +44,7 @@ const upsertPublishedSQL = "" +
" ON CONFLICT (room_id, appservice_id, network_id) DO UPDATE SET published = $4" " ON CONFLICT (room_id, appservice_id, network_id) DO UPDATE SET published = $4"
const selectAllPublishedSQL = "" + const selectAllPublishedSQL = "" +
"SELECT room_id FROM roomserver_published WHERE published = $1 ORDER BY room_id ASC" "SELECT room_id FROM roomserver_published WHERE published = $1 AND CASE WHEN $2 THEN 1=1 ELSE network_id = '' END ORDER BY room_id ASC"
const selectNetworkPublishedSQL = "" + const selectNetworkPublishedSQL = "" +
"SELECT room_id FROM roomserver_published WHERE published = $1 AND network_id = $2 ORDER BY room_id ASC" "SELECT room_id FROM roomserver_published WHERE published = $1 AND network_id = $2 ORDER BY room_id ASC"
@ -106,7 +106,7 @@ func (s *publishedStatements) SelectPublishedFromRoomID(
} }
func (s *publishedStatements) SelectAllPublishedRooms( func (s *publishedStatements) SelectAllPublishedRooms(
ctx context.Context, txn *sql.Tx, networkID string, published bool, ctx context.Context, txn *sql.Tx, networkID string, published, includeAllNetworks bool,
) ([]string, error) { ) ([]string, error) {
var rows *sql.Rows var rows *sql.Rows
var err error var err error
@ -115,7 +115,7 @@ func (s *publishedStatements) SelectAllPublishedRooms(
rows, err = stmt.QueryContext(ctx, published, networkID) rows, err = stmt.QueryContext(ctx, published, networkID)
} else { } else {
stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt) stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt)
rows, err = stmt.QueryContext(ctx, published) rows, err = stmt.QueryContext(ctx, published, includeAllNetworks)
} }
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -148,7 +148,7 @@ type Membership interface {
type Published interface { type Published interface {
UpsertRoomPublished(ctx context.Context, txn *sql.Tx, roomID, appserviceID, networkID string, published bool) (err error) UpsertRoomPublished(ctx context.Context, txn *sql.Tx, roomID, appserviceID, networkID string, published bool) (err error)
SelectPublishedFromRoomID(ctx context.Context, txn *sql.Tx, roomID string) (published bool, err error) SelectPublishedFromRoomID(ctx context.Context, txn *sql.Tx, roomID string) (published bool, err error)
SelectAllPublishedRooms(ctx context.Context, txn *sql.Tx, networkdID string, published bool) ([]string, error) SelectAllPublishedRooms(ctx context.Context, txn *sql.Tx, networkdID string, published, includeAllNetworks bool) ([]string, error)
} }
type RedactionInfo struct { type RedactionInfo struct {

View file

@ -65,7 +65,7 @@ func TestPublishedTable(t *testing.T) {
sort.Strings(publishedRooms) sort.Strings(publishedRooms)
// check that we get the expected published rooms // check that we get the expected published rooms
roomIDs, err := tab.SelectAllPublishedRooms(ctx, nil, "", true) roomIDs, err := tab.SelectAllPublishedRooms(ctx, nil, "", true, true)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, publishedRooms, roomIDs) assert.Equal(t, publishedRooms, roomIDs)
@ -88,12 +88,12 @@ func TestPublishedTable(t *testing.T) {
publishedRooms = append(publishedRooms, room.ID) publishedRooms = append(publishedRooms, room.ID)
sort.Strings(publishedRooms) sort.Strings(publishedRooms)
// should only return the room for network "irc" // should only return the room for network "irc"
allNWPublished, err := tab.SelectAllPublishedRooms(ctx, nil, nwID, true) allNWPublished, err := tab.SelectAllPublishedRooms(ctx, nil, nwID, true, true)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, []string{room.ID}, allNWPublished) assert.Equal(t, []string{room.ID}, allNWPublished)
// check that we still get all published rooms regardless networkID // check that we still get all published rooms regardless networkID
roomIDs, err = tab.SelectAllPublishedRooms(ctx, nil, "", true) roomIDs, err = tab.SelectAllPublishedRooms(ctx, nil, "", true, true)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, publishedRooms, roomIDs) assert.Equal(t, publishedRooms, roomIDs)
}) })