diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index 33bc63d18..ce14745aa 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -18,14 +18,15 @@ import ( "fmt" "net/http" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" federationAPI "github.com/matrix-org/dendrite/federationapi/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" ) type roomDirectoryResponse struct { @@ -318,3 +319,43 @@ func SetVisibility( JSON: struct{}{}, } } + +func SetVisibilityAS( + req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI, dev *userapi.Device, + networkID, roomID string, +) util.JSONResponse { + if dev.AccountType != userapi.AccountTypeAppService { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("Only appservice may use this endpoint"), + } + } + var v roomVisibility + + // If the method is delete, we simply mark the visibility as private + if req.Method == http.MethodDelete { + v.Visibility = "private" + } else { + if reqErr := httputil.UnmarshalJSONRequest(req, &v); reqErr != nil { + return *reqErr + } + } + var publishRes roomserverAPI.PerformPublishResponse + if err := rsAPI.PerformPublish(req.Context(), &roomserverAPI.PerformPublishRequest{ + RoomID: roomID, + Visibility: v.Visibility, + NetworkID: networkID, + AppserviceID: dev.AppserviceID, + }, &publishRes); err != nil { + return jsonerror.InternalAPIError(req.Context(), err) + } + if publishRes.Error != nil { + util.GetLogger(req.Context()).WithError(publishRes.Error).Error("PerformPublish failed") + return publishRes.Error.JSONResponse() + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } +} diff --git a/clientapi/routing/directory_public.go b/clientapi/routing/directory_public.go index 4ebf2295a..b1043e994 100644 --- a/clientapi/routing/directory_public.go +++ b/clientapi/routing/directory_public.go @@ -39,14 +39,17 @@ var ( ) type PublicRoomReq struct { - Since string `json:"since,omitempty"` - Limit int16 `json:"limit,omitempty"` - Filter filter `json:"filter,omitempty"` - Server string `json:"server,omitempty"` + Since string `json:"since,omitempty"` + Limit int64 `json:"limit,omitempty"` + Filter filter `json:"filter,omitempty"` + Server string `json:"server,omitempty"` + IncludeAllNetworks bool `json:"include_all_networks,omitempty"` + NetworkID string `json:"third_party_instance_id,omitempty"` } type filter struct { - SearchTerms string `json:"generic_search_term,omitempty"` + SearchTerms string `json:"generic_search_term,omitempty"` + RoomTypes []string `json:"room_types,omitempty"` // TODO: Implement filter on this } // GetPostPublicRooms implements GET and POST /publicRooms @@ -61,6 +64,13 @@ func GetPostPublicRooms( return *fillErr } + if request.IncludeAllNetworks && request.NetworkID != "" { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidParam("include_all_networks and third_party_instance_id can not be used together"), + } + } + serverName := gomatrixserverlib.ServerName(request.Server) if serverName != "" && !cfg.Matrix.IsLocalServerName(serverName) { res, err := federation.GetPublicRoomsFiltered( @@ -97,7 +107,7 @@ func publicRooms( response := gomatrixserverlib.RespPublicRooms{ Chunk: []gomatrixserverlib.PublicRoom{}, } - var limit int16 + var limit int64 var offset int64 limit = request.Limit if limit == 0 { @@ -114,7 +124,7 @@ func publicRooms( var rooms []gomatrixserverlib.PublicRoom if request.Since == "" { - rooms = refreshPublicRoomCache(ctx, rsAPI, extRoomsProvider) + rooms = refreshPublicRoomCache(ctx, rsAPI, extRoomsProvider, request) } else { rooms = getPublicRoomsFromCache() } @@ -176,7 +186,7 @@ func fillPublicRoomsReq(httpReq *http.Request, request *PublicRoomReq) *util.JSO JSON: jsonerror.BadJSON("limit param is not a number"), } } - request.Limit = int16(limit) + request.Limit = int64(limit) request.Since = httpReq.FormValue("since") request.Server = httpReq.FormValue("server") } else { @@ -204,7 +214,7 @@ func fillPublicRoomsReq(httpReq *http.Request, request *PublicRoomReq) *util.JSO // limit=3&since=6 => G (prev='3', next='') // // A value of '-1' for prev/next indicates no position. -func sliceInto(slice []gomatrixserverlib.PublicRoom, since int64, limit int16) (subset []gomatrixserverlib.PublicRoom, prev, next int) { +func sliceInto(slice []gomatrixserverlib.PublicRoom, since int64, limit int64) (subset []gomatrixserverlib.PublicRoom, prev, next int) { prev = -1 next = -1 @@ -230,6 +240,7 @@ func sliceInto(slice []gomatrixserverlib.PublicRoom, since int64, limit int16) ( func refreshPublicRoomCache( ctx context.Context, rsAPI roomserverAPI.ClientRoomserverAPI, extRoomsProvider api.ExtraPublicRoomsProvider, + request PublicRoomReq, ) []gomatrixserverlib.PublicRoom { cacheMu.Lock() defer cacheMu.Unlock() @@ -238,8 +249,17 @@ func refreshPublicRoomCache( extraRooms = extRoomsProvider.Rooms() } + // TODO: this is only here to make Sytest happy, for now. + ns := strings.Split(request.NetworkID, "|") + if len(ns) == 2 { + request.NetworkID = ns[1] + } + var queryRes roomserverAPI.QueryPublishedRoomsResponse - err := rsAPI.QueryPublishedRooms(ctx, &roomserverAPI.QueryPublishedRoomsRequest{}, &queryRes) + err := rsAPI.QueryPublishedRooms(ctx, &roomserverAPI.QueryPublishedRoomsRequest{ + NetworkID: request.NetworkID, + IncludeAllNetworks: request.IncludeAllNetworks, + }, &queryRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("QueryPublishedRooms failed") return publicRoomsCache diff --git a/clientapi/routing/directory_public_test.go b/clientapi/routing/directory_public_test.go index bb3912b8c..65ad392c2 100644 --- a/clientapi/routing/directory_public_test.go +++ b/clientapi/routing/directory_public_test.go @@ -17,7 +17,7 @@ func TestSliceInto(t *testing.T) { slice := []gomatrixserverlib.PublicRoom{ pubRoom("a"), pubRoom("b"), pubRoom("c"), pubRoom("d"), pubRoom("e"), pubRoom("f"), pubRoom("g"), } - limit := int16(3) + limit := int64(3) testCases := []struct { since int64 wantPrev int diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index e0e3e33d4..22bc77a0b 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -480,7 +480,7 @@ func Setup( return GetVisibility(req, rsAPI, vars["roomID"]) }), ).Methods(http.MethodGet, http.MethodOptions) - // TODO: Add AS support + v3mux.Handle("/directory/list/room/{roomID}", httputil.MakeAuthAPI("directory_list", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) @@ -490,6 +490,27 @@ func Setup( return SetVisibility(req, rsAPI, device, vars["roomID"]) }), ).Methods(http.MethodPut, http.MethodOptions) + v3mux.Handle("/directory/list/appservice/{networkID}/{roomID}", + httputil.MakeAuthAPI("directory_list", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return SetVisibilityAS(req, rsAPI, device, vars["networkID"], vars["roomID"]) + }), + ).Methods(http.MethodPut, http.MethodOptions) + + // Undocumented endpoint + v3mux.Handle("/directory/list/appservice/{networkID}/{roomID}", + httputil.MakeAuthAPI("directory_list", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return SetVisibilityAS(req, rsAPI, device, vars["networkID"], vars["roomID"]) + }), + ).Methods(http.MethodDelete, http.MethodOptions) + v3mux.Handle("/publicRooms", httputil.MakeExternalAPI("public_rooms", func(req *http.Request) util.JSONResponse { return GetPostPublicRooms(req, rsAPI, extRoomsProvider, federation, cfg) diff --git a/federationapi/routing/publicrooms.go b/federationapi/routing/publicrooms.go index 1a54f5a7d..34025932a 100644 --- a/federationapi/routing/publicrooms.go +++ b/federationapi/routing/publicrooms.go @@ -2,24 +2,29 @@ package routing import ( "context" + "fmt" "net/http" "strconv" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" ) type PublicRoomReq struct { - Since string `json:"since,omitempty"` - Limit int16 `json:"limit,omitempty"` - Filter filter `json:"filter,omitempty"` + Since string `json:"since,omitempty"` + Limit int16 `json:"limit,omitempty"` + Filter filter `json:"filter,omitempty"` + IncludeAllNetworks bool `json:"include_all_networks,omitempty"` + NetworkID string `json:"third_party_instance_id,omitempty"` } type filter struct { - SearchTerms string `json:"generic_search_term,omitempty"` + SearchTerms string `json:"generic_search_term,omitempty"` + RoomTypes []string `json:"room_types,omitempty"` } // GetPostPublicRooms implements GET and POST /publicRooms @@ -57,8 +62,14 @@ func publicRooms( return nil, err } + if request.IncludeAllNetworks && request.NetworkID != "" { + return nil, fmt.Errorf("include_all_networks and third_party_instance_id can not be used together") + } + var queryRes roomserverAPI.QueryPublishedRoomsResponse - err = rsAPI.QueryPublishedRooms(ctx, &roomserverAPI.QueryPublishedRoomsRequest{}, &queryRes) + err = rsAPI.QueryPublishedRooms(ctx, &roomserverAPI.QueryPublishedRoomsRequest{ + NetworkID: request.NetworkID, + }, &queryRes) if err != nil { util.GetLogger(ctx).WithError(err).Error("QueryPublishedRooms failed") return nil, err diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go index 7a362f969..1442a4b09 100644 --- a/roomserver/api/perform.go +++ b/roomserver/api/perform.go @@ -168,8 +168,10 @@ type PerformBackfillResponse struct { } type PerformPublishRequest struct { - RoomID string - Visibility string + RoomID string + Visibility string + AppserviceID string + NetworkID string } type PerformPublishResponse struct { diff --git a/roomserver/api/query.go b/roomserver/api/query.go index d63c24785..b62907f3c 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -21,8 +21,9 @@ import ( "fmt" "strings" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) // QueryLatestEventsAndStateRequest is a request to QueryLatestEventsAndState @@ -257,7 +258,9 @@ type QueryRoomVersionForRoomResponse struct { type QueryPublishedRoomsRequest struct { // Optional. If specified, returns whether this room is published or not. - RoomID string + RoomID string + NetworkID string + IncludeAllNetworks bool } type QueryPublishedRoomsResponse struct { diff --git a/roomserver/internal/perform/perform_publish.go b/roomserver/internal/perform/perform_publish.go index 1631fc657..fbbfc3219 100644 --- a/roomserver/internal/perform/perform_publish.go +++ b/roomserver/internal/perform/perform_publish.go @@ -30,7 +30,7 @@ func (r *Publisher) PerformPublish( req *api.PerformPublishRequest, res *api.PerformPublishResponse, ) error { - err := r.DB.PublishRoom(ctx, req.RoomID, req.Visibility == "public") + err := r.DB.PublishRoom(ctx, req.RoomID, req.AppserviceID, req.NetworkID, req.Visibility == "public") if err != nil { res.Error = &api.PerformError{ Msg: err.Error(), diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index f41132403..0db046a86 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -702,7 +702,7 @@ func (r *Queryer) QueryPublishedRooms( } return err } - rooms, err := r.DB.GetPublishedRooms(ctx) + rooms, err := r.DB.GetPublishedRooms(ctx, req.NetworkID, req.IncludeAllNetworks) if err != nil { return err } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index ee0624b21..094537948 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -139,9 +139,9 @@ type Database interface { // Returns an error if the retrieval went wrong. EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) // Publish or unpublish a room from the room directory. - PublishRoom(ctx context.Context, roomID string, publish bool) error + PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error // Returns a list of room IDs for rooms which are published. - GetPublishedRooms(ctx context.Context) ([]string, error) + GetPublishedRooms(ctx context.Context, networkID string, includeAllNetworks bool) ([]string, error) // Returns whether a given room is published or not. GetPublishedRoom(ctx context.Context, roomID string) (bool, error) diff --git a/roomserver/storage/postgres/deltas/20221027084407_published_appservice.go b/roomserver/storage/postgres/deltas/20221027084407_published_appservice.go new file mode 100644 index 000000000..be046545a --- /dev/null +++ b/roomserver/storage/postgres/deltas/20221027084407_published_appservice.go @@ -0,0 +1,45 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "context" + "database/sql" + "fmt" +) + +func UpPulishedAppservice(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, `ALTER TABLE roomserver_published ADD COLUMN IF NOT EXISTS appservice_id TEXT NOT NULL;`) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + _, err = tx.ExecContext(ctx, `ALTER TABLE roomserver_published ADD COLUMN IF NOT EXISTS network_id TEXT NOT NULL;`) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownPublishedAppservice(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, `ALTER TABLE roomserver_published DROP COLUMN IF EXISTS appservice_id;`) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + _, err = tx.ExecContext(ctx, `ALTER TABLE roomserver_published DROP COLUMN IF EXISTS network_id;`) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/roomserver/storage/postgres/published_table.go b/roomserver/storage/postgres/published_table.go index 56fa02f7b..61caccb0e 100644 --- a/roomserver/storage/postgres/published_table.go +++ b/roomserver/storage/postgres/published_table.go @@ -20,6 +20,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/postgres/deltas" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) @@ -27,31 +28,48 @@ const publishedSchema = ` -- Stores which rooms are published in the room directory CREATE TABLE IF NOT EXISTS roomserver_published ( -- The room ID of the room - room_id TEXT NOT NULL PRIMARY KEY, + room_id TEXT NOT NULL, + -- The appservice ID of the room + appservice_id TEXT NOT NULL, + -- The network_id of the room + network_id TEXT NOT NULL, -- Whether it is published or not - published BOOLEAN NOT NULL DEFAULT false + published BOOLEAN NOT NULL DEFAULT false, + PRIMARY KEY (room_id, appservice_id, network_id) ); ` const upsertPublishedSQL = "" + - "INSERT INTO roomserver_published (room_id, published) VALUES ($1, $2) " + - "ON CONFLICT (room_id) DO UPDATE SET published=$2" + "INSERT INTO roomserver_published (room_id, appservice_id, network_id, published) VALUES ($1, $2, $3, $4) " + + "ON CONFLICT (room_id, appservice_id, network_id) DO UPDATE SET published=$4" const selectAllPublishedSQL = "" + - "SELECT room_id FROM roomserver_published WHERE published = $1 ORDER BY room_id ASC" + "SELECT room_id FROM roomserver_published WHERE published = $1 AND CASE WHEN $2 THEN 1=1 ELSE network_id = '' END ORDER BY room_id ASC" + +const selectNetworkPublishedSQL = "" + + "SELECT room_id FROM roomserver_published WHERE published = $1 AND network_id = $2 ORDER BY room_id ASC" const selectPublishedSQL = "" + "SELECT published FROM roomserver_published WHERE room_id = $1" type publishedStatements struct { - upsertPublishedStmt *sql.Stmt - selectAllPublishedStmt *sql.Stmt - selectPublishedStmt *sql.Stmt + upsertPublishedStmt *sql.Stmt + selectAllPublishedStmt *sql.Stmt + selectPublishedStmt *sql.Stmt + selectNetworkPublishedStmt *sql.Stmt } func CreatePublishedTable(db *sql.DB) error { _, err := db.Exec(publishedSchema) - return err + if err != nil { + return err + } + m := sqlutil.NewMigrator(db) + m.AddMigrations(sqlutil.Migration{ + Version: "roomserver: published appservice", + Up: deltas.UpPulishedAppservice, + }) + return m.Up(context.Background()) } func PreparePublishedTable(db *sql.DB) (tables.Published, error) { @@ -61,14 +79,15 @@ func PreparePublishedTable(db *sql.DB) (tables.Published, error) { {&s.upsertPublishedStmt, upsertPublishedSQL}, {&s.selectAllPublishedStmt, selectAllPublishedSQL}, {&s.selectPublishedStmt, selectPublishedSQL}, + {&s.selectNetworkPublishedStmt, selectNetworkPublishedSQL}, }.Prepare(db) } func (s *publishedStatements) UpsertRoomPublished( - ctx context.Context, txn *sql.Tx, roomID string, published bool, + ctx context.Context, txn *sql.Tx, roomID, appserviceID, networkID string, published bool, ) (err error) { stmt := sqlutil.TxStmt(txn, s.upsertPublishedStmt) - _, err = stmt.ExecContext(ctx, roomID, published) + _, err = stmt.ExecContext(ctx, roomID, appserviceID, networkID, published) return } @@ -84,10 +103,18 @@ func (s *publishedStatements) SelectPublishedFromRoomID( } func (s *publishedStatements) SelectAllPublishedRooms( - ctx context.Context, txn *sql.Tx, published bool, + ctx context.Context, txn *sql.Tx, networkID string, published, includeAllNetworks bool, ) ([]string, error) { - stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt) - rows, err := stmt.QueryContext(ctx, published) + var rows *sql.Rows + var err error + if networkID != "" { + stmt := sqlutil.TxStmt(txn, s.selectNetworkPublishedStmt) + rows, err = stmt.QueryContext(ctx, published, networkID) + } else { + stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt) + rows, err = stmt.QueryContext(ctx, published, includeAllNetworks) + + } if err != nil { return nil, err } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index e401f17dc..ed86280bf 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -722,9 +722,9 @@ func (d *Database) storeEvent( }, redactionEvent, redactedEventID, err } -func (d *Database) PublishRoom(ctx context.Context, roomID string, publish bool) error { +func (d *Database) PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.PublishedTable.UpsertRoomPublished(ctx, txn, roomID, publish) + return d.PublishedTable.UpsertRoomPublished(ctx, txn, roomID, appserviceID, networkID, publish) }) } @@ -732,8 +732,8 @@ func (d *Database) GetPublishedRoom(ctx context.Context, roomID string) (bool, e return d.PublishedTable.SelectPublishedFromRoomID(ctx, nil, roomID) } -func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) { - return d.PublishedTable.SelectAllPublishedRooms(ctx, nil, true) +func (d *Database) GetPublishedRooms(ctx context.Context, networkID string, includeAllNetworks bool) ([]string, error) { + return d.PublishedTable.SelectAllPublishedRooms(ctx, nil, networkID, true, includeAllNetworks) } func (d *Database) MissingAuthPrevEvents( diff --git a/roomserver/storage/sqlite3/deltas/20221027084407_published_appservice.go b/roomserver/storage/sqlite3/deltas/20221027084407_published_appservice.go new file mode 100644 index 000000000..cd923b1c1 --- /dev/null +++ b/roomserver/storage/sqlite3/deltas/20221027084407_published_appservice.go @@ -0,0 +1,64 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "context" + "database/sql" + "fmt" +) + +func UpPulishedAppservice(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, ` ALTER TABLE roomserver_published RENAME TO roomserver_published_tmp; +CREATE TABLE IF NOT EXISTS roomserver_published ( + room_id TEXT NOT NULL, + appservice_id TEXT NOT NULL, + network_id TEXT NOT NULL, + published BOOLEAN NOT NULL DEFAULT false, + CONSTRAINT unique_published_idx PRIMARY KEY (room_id, appservice_id, network_id) +); +INSERT + INTO roomserver_published ( + room_id, published + ) SELECT + room_id, published + FROM roomserver_published_tmp +; +DROP TABLE roomserver_published_tmp;`) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownPublishedAppservice(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, ` ALTER TABLE roomserver_published RENAME TO roomserver_published_tmp; +CREATE TABLE IF NOT EXISTS roomserver_published ( + room_id TEXT NOT NULL PRIMARY KEY, + published BOOLEAN NOT NULL DEFAULT false +); +INSERT + INTO roomserver_published ( + room_id, published + ) SELECT + room_id, published + FROM roomserver_published_tmp +; +DROP TABLE roomserver_published_tmp;`) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} diff --git a/roomserver/storage/sqlite3/published_table.go b/roomserver/storage/sqlite3/published_table.go index 50dfa5492..34666552e 100644 --- a/roomserver/storage/sqlite3/published_table.go +++ b/roomserver/storage/sqlite3/published_table.go @@ -20,6 +20,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) @@ -27,31 +28,49 @@ const publishedSchema = ` -- Stores which rooms are published in the room directory CREATE TABLE IF NOT EXISTS roomserver_published ( -- The room ID of the room - room_id TEXT NOT NULL PRIMARY KEY, + room_id TEXT NOT NULL, + -- The appservice ID of the room + appservice_id TEXT NOT NULL, + -- The network_id of the room + network_id TEXT NOT NULL, -- Whether it is published or not - published BOOLEAN NOT NULL DEFAULT false + published BOOLEAN NOT NULL DEFAULT false, + PRIMARY KEY (room_id, appservice_id, network_id) ); ` const upsertPublishedSQL = "" + - "INSERT OR REPLACE INTO roomserver_published (room_id, published) VALUES ($1, $2)" + "INSERT INTO roomserver_published (room_id, appservice_id, network_id, published) VALUES ($1, $2, $3, $4)" + + " ON CONFLICT (room_id, appservice_id, network_id) DO UPDATE SET published = $4" const selectAllPublishedSQL = "" + - "SELECT room_id FROM roomserver_published WHERE published = $1 ORDER BY room_id ASC" + "SELECT room_id FROM roomserver_published WHERE published = $1 AND CASE WHEN $2 THEN 1=1 ELSE network_id = '' END ORDER BY room_id ASC" + +const selectNetworkPublishedSQL = "" + + "SELECT room_id FROM roomserver_published WHERE published = $1 AND network_id = $2 ORDER BY room_id ASC" const selectPublishedSQL = "" + "SELECT published FROM roomserver_published WHERE room_id = $1" type publishedStatements struct { - db *sql.DB - upsertPublishedStmt *sql.Stmt - selectAllPublishedStmt *sql.Stmt - selectPublishedStmt *sql.Stmt + db *sql.DB + upsertPublishedStmt *sql.Stmt + selectAllPublishedStmt *sql.Stmt + selectPublishedStmt *sql.Stmt + selectNetworkPublishedStmt *sql.Stmt } func CreatePublishedTable(db *sql.DB) error { _, err := db.Exec(publishedSchema) - return err + if err != nil { + return err + } + m := sqlutil.NewMigrator(db) + m.AddMigrations(sqlutil.Migration{ + Version: "roomserver: published appservice", + Up: deltas.UpPulishedAppservice, + }) + return m.Up(context.Background()) } func PreparePublishedTable(db *sql.DB) (tables.Published, error) { @@ -63,14 +82,15 @@ func PreparePublishedTable(db *sql.DB) (tables.Published, error) { {&s.upsertPublishedStmt, upsertPublishedSQL}, {&s.selectAllPublishedStmt, selectAllPublishedSQL}, {&s.selectPublishedStmt, selectPublishedSQL}, + {&s.selectNetworkPublishedStmt, selectNetworkPublishedSQL}, }.Prepare(db) } func (s *publishedStatements) UpsertRoomPublished( - ctx context.Context, txn *sql.Tx, roomID string, published bool, + ctx context.Context, txn *sql.Tx, roomID, appserviceID, networkID string, published bool, ) error { stmt := sqlutil.TxStmt(txn, s.upsertPublishedStmt) - _, err := stmt.ExecContext(ctx, roomID, published) + _, err := stmt.ExecContext(ctx, roomID, appserviceID, networkID, published) return err } @@ -86,10 +106,17 @@ func (s *publishedStatements) SelectPublishedFromRoomID( } func (s *publishedStatements) SelectAllPublishedRooms( - ctx context.Context, txn *sql.Tx, published bool, + ctx context.Context, txn *sql.Tx, networkID string, published, includeAllNetworks bool, ) ([]string, error) { - stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt) - rows, err := stmt.QueryContext(ctx, published) + var rows *sql.Rows + var err error + if networkID != "" { + stmt := sqlutil.TxStmt(txn, s.selectNetworkPublishedStmt) + rows, err = stmt.QueryContext(ctx, published, networkID) + } else { + stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt) + rows, err = stmt.QueryContext(ctx, published, includeAllNetworks) + } if err != nil { return nil, err } diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 8be47855f..8d6ca324c 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -146,9 +146,9 @@ type Membership interface { } type Published interface { - UpsertRoomPublished(ctx context.Context, txn *sql.Tx, roomID string, published bool) (err error) + UpsertRoomPublished(ctx context.Context, txn *sql.Tx, roomID, appserviceID, networkID string, published bool) (err error) SelectPublishedFromRoomID(ctx context.Context, txn *sql.Tx, roomID string) (published bool, err error) - SelectAllPublishedRooms(ctx context.Context, txn *sql.Tx, published bool) ([]string, error) + SelectAllPublishedRooms(ctx context.Context, txn *sql.Tx, networkdID string, published, includeAllNetworks bool) ([]string, error) } type RedactionInfo struct { diff --git a/roomserver/storage/tables/published_table_test.go b/roomserver/storage/tables/published_table_test.go index fff6dc186..e6289e9b1 100644 --- a/roomserver/storage/tables/published_table_test.go +++ b/roomserver/storage/tables/published_table_test.go @@ -2,16 +2,18 @@ package tables_test import ( "context" + "fmt" "sort" "testing" + "github.com/stretchr/testify/assert" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/postgres" "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" - "github.com/stretchr/testify/assert" ) func mustCreatePublishedTable(t *testing.T, dbType test.DBType) (tab tables.Published, close func()) { @@ -46,10 +48,12 @@ func TestPublishedTable(t *testing.T) { // Publish some rooms publishedRooms := []string{} + asID := "" + nwID := "" for i := 0; i < 10; i++ { room := test.NewRoom(t, alice) published := i%2 == 0 - err := tab.UpsertRoomPublished(ctx, nil, room.ID, published) + err := tab.UpsertRoomPublished(ctx, nil, room.ID, asID, nwID, published) assert.NoError(t, err) if published { publishedRooms = append(publishedRooms, room.ID) @@ -61,19 +65,36 @@ func TestPublishedTable(t *testing.T) { sort.Strings(publishedRooms) // check that we get the expected published rooms - roomIDs, err := tab.SelectAllPublishedRooms(ctx, nil, true) + roomIDs, err := tab.SelectAllPublishedRooms(ctx, nil, "", true, true) assert.NoError(t, err) assert.Equal(t, publishedRooms, roomIDs) // test an actual upsert room := test.NewRoom(t, alice) - err = tab.UpsertRoomPublished(ctx, nil, room.ID, true) + err = tab.UpsertRoomPublished(ctx, nil, room.ID, asID, nwID, true) assert.NoError(t, err) - err = tab.UpsertRoomPublished(ctx, nil, room.ID, false) + err = tab.UpsertRoomPublished(ctx, nil, room.ID, asID, nwID, false) assert.NoError(t, err) // should now be false, due to the upsert publishedRes, err := tab.SelectPublishedFromRoomID(ctx, nil, room.ID) assert.NoError(t, err) - assert.False(t, publishedRes) + assert.False(t, publishedRes, fmt.Sprintf("expected room %s to be unpublished", room.ID)) + + // network specific test + nwID = "irc" + room = test.NewRoom(t, alice) + err = tab.UpsertRoomPublished(ctx, nil, room.ID, asID, nwID, true) + assert.NoError(t, err) + publishedRooms = append(publishedRooms, room.ID) + sort.Strings(publishedRooms) + // should only return the room for network "irc" + allNWPublished, err := tab.SelectAllPublishedRooms(ctx, nil, nwID, true, true) + assert.NoError(t, err) + assert.Equal(t, []string{room.ID}, allNWPublished) + + // check that we still get all published rooms regardless networkID + roomIDs, err = tab.SelectAllPublishedRooms(ctx, nil, "", true, true) + assert.NoError(t, err) + assert.Equal(t, publishedRooms, roomIDs) }) } diff --git a/sytest-whitelist b/sytest-whitelist index 60610929a..6e4500d06 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -757,4 +757,6 @@ Can get rooms/{roomId}/messages for a departed room (SPEC-216) Local device key changes appear in /keys/changes Can get rooms/{roomId}/members at a given point Can filter rooms/{roomId}/members -Current state appears in timeline in private history with many messages after \ No newline at end of file +Current state appears in timeline in private history with many messages after +AS can publish rooms in their own list +AS and main public room lists are separate \ No newline at end of file