diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index 0dc4d5605..0f78f4a24 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -1,4 +1,4 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -20,10 +20,12 @@ import ( "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" + currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/internal/config" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/userapi/api" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -232,3 +234,89 @@ func RemoveLocalAlias( JSON: struct{}{}, } } + +type roomVisibility struct { + Visibility string `json:"visibility"` +} + +// GetVisibility implements GET /directory/list/room/{roomID} +func GetVisibility( + req *http.Request, rsAPI roomserverAPI.RoomserverInternalAPI, + roomID string, +) util.JSONResponse { + var res roomserverAPI.QueryPublishedRoomsResponse + err := rsAPI.QueryPublishedRooms(req.Context(), &roomserverAPI.QueryPublishedRoomsRequest{ + RoomID: roomID, + }, &res) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("QueryPublishedRooms failed") + return jsonerror.InternalServerError() + } + + var v roomVisibility + if len(res.RoomIDs) == 1 { + v.Visibility = gomatrixserverlib.Public + } else { + v.Visibility = "private" + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: v, + } +} + +// SetVisibility implements PUT /directory/list/room/{roomID} +// TODO: Allow admin users to edit the room visibility +func SetVisibility( + req *http.Request, stateAPI currentstateAPI.CurrentStateInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI, dev *userapi.Device, + roomID string, +) util.JSONResponse { + resErr := checkMemberInRoom(req.Context(), stateAPI, dev.UserID, roomID) + if resErr != nil { + return *resErr + } + + queryEventsReq := roomserverAPI.QueryLatestEventsAndStateRequest{ + RoomID: roomID, + StateToFetch: []gomatrixserverlib.StateKeyTuple{{ + EventType: gomatrixserverlib.MRoomPowerLevels, + StateKey: "", + }}, + } + var queryEventsRes roomserverAPI.QueryLatestEventsAndStateResponse + err := rsAPI.QueryLatestEventsAndState(req.Context(), &queryEventsReq, &queryEventsRes) + if err != nil || len(queryEventsRes.StateEvents) == 0 { + util.GetLogger(req.Context()).WithError(err).Error("could not query events from room") + return jsonerror.InternalServerError() + } + + // NOTSPEC: Check if the user's power is greater than power required to change m.room.aliases event + power, _ := gomatrixserverlib.NewPowerLevelContentFromEvent(queryEventsRes.StateEvents[0].Event) + if power.UserLevel(dev.UserID) < power.EventLevel(gomatrixserverlib.MRoomAliases, true) { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("userID doesn't have power level to change visibility"), + } + } + + var v roomVisibility + if reqErr := httputil.UnmarshalJSONRequest(req, &v); reqErr != nil { + return *reqErr + } + + var publishRes roomserverAPI.PerformPublishResponse + rsAPI.PerformPublish(req.Context(), &roomserverAPI.PerformPublishRequest{ + RoomID: roomID, + Visibility: v.Visibility, + }, &publishRes) + if publishRes.Error != nil { + util.GetLogger(req.Context()).WithError(publishRes.Error).Error("PerformPublish failed") + return publishRes.Error.JSONResponse() + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } +} diff --git a/clientapi/routing/directory_public.go b/clientapi/routing/directory_public.go new file mode 100644 index 000000000..c0c4cd930 --- /dev/null +++ b/clientapi/routing/directory_public.go @@ -0,0 +1,332 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "context" + "math/rand" + "net/http" + "sort" + "strconv" + "sync" + "time" + + "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api" + "github.com/matrix-org/dendrite/publicroomsapi/types" + roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +type PublicRoomReq struct { + Since string `json:"since,omitempty"` + Limit int16 `json:"limit,omitempty"` + Filter filter `json:"filter,omitempty"` +} + +type filter struct { + SearchTerms string `json:"generic_search_term,omitempty"` +} + +// GetPostPublicRooms implements GET and POST /publicRooms +func GetPostPublicRooms( + req *http.Request, rsAPI roomserverAPI.RoomserverInternalAPI, stateAPI currentstateAPI.CurrentStateInternalAPI, +) util.JSONResponse { + var request PublicRoomReq + if fillErr := fillPublicRoomsReq(req, &request); fillErr != nil { + return *fillErr + } + response, err := publicRooms(req.Context(), request, rsAPI, stateAPI) + if err != nil { + return jsonerror.InternalServerError() + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: response, + } +} + +// GetPostPublicRoomsWithExternal is the same as GetPostPublicRooms but also mixes in public rooms from the provider supplied. +func GetPostPublicRoomsWithExternal( + req *http.Request, rsAPI roomserverAPI.RoomserverInternalAPI, stateAPI currentstateAPI.CurrentStateInternalAPI, + fedClient *gomatrixserverlib.FederationClient, extRoomsProvider types.ExternalPublicRoomsProvider, +) util.JSONResponse { + var request PublicRoomReq + if fillErr := fillPublicRoomsReq(req, &request); fillErr != nil { + return *fillErr + } + response, err := publicRooms(req.Context(), request, rsAPI, stateAPI) + if err != nil { + return jsonerror.InternalServerError() + } + + if request.Since != "" { + // TODO: handle pagination tokens sensibly rather than ignoring them. + // ignore paginated requests since we don't handle them yet over federation. + // Only the initial request will contain federated rooms. + return util.JSONResponse{ + Code: http.StatusOK, + JSON: response, + } + } + + // If we have already hit the limit on the number of rooms, bail. + var limit int + if request.Limit > 0 { + limit = int(request.Limit) - len(response.Chunk) + if limit <= 0 { + return util.JSONResponse{ + Code: http.StatusOK, + JSON: response, + } + } + } + + // downcasting `limit` is safe as we know it isn't bigger than request.Limit which is int16 + fedRooms := bulkFetchPublicRoomsFromServers(req.Context(), fedClient, extRoomsProvider.Homeservers(), int16(limit)) + response.Chunk = append(response.Chunk, fedRooms...) + + // de-duplicate rooms with the same room ID. We can join the room via any of these aliases as we know these servers + // are alive and well, so we arbitrarily pick one (purposefully shuffling them to spread the load a bit) + var publicRooms []gomatrixserverlib.PublicRoom + haveRoomIDs := make(map[string]bool) + rand.Shuffle(len(response.Chunk), func(i, j int) { + response.Chunk[i], response.Chunk[j] = response.Chunk[j], response.Chunk[i] + }) + for _, r := range response.Chunk { + if haveRoomIDs[r.RoomID] { + continue + } + haveRoomIDs[r.RoomID] = true + publicRooms = append(publicRooms, r) + } + // sort by member count + sort.SliceStable(publicRooms, func(i, j int) bool { + return publicRooms[i].JoinedMembersCount > publicRooms[j].JoinedMembersCount + }) + + response.Chunk = publicRooms + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: response, + } +} + +// bulkFetchPublicRoomsFromServers fetches public rooms from the list of homeservers. +// Returns a list of public rooms up to the limit specified. +func bulkFetchPublicRoomsFromServers( + ctx context.Context, fedClient *gomatrixserverlib.FederationClient, homeservers []string, limit int16, +) (publicRooms []gomatrixserverlib.PublicRoom) { + // follow pipeline semantics, see https://blog.golang.org/pipelines for more info. + // goroutines send rooms to this channel + roomCh := make(chan gomatrixserverlib.PublicRoom, int(limit)) + // signalling channel to tell goroutines to stop sending rooms and quit + done := make(chan bool) + // signalling to say when we can close the room channel + var wg sync.WaitGroup + wg.Add(len(homeservers)) + // concurrently query for public rooms + for _, hs := range homeservers { + go func(homeserverDomain string) { + defer wg.Done() + util.GetLogger(ctx).WithField("hs", homeserverDomain).Info("Querying HS for public rooms") + fres, err := fedClient.GetPublicRooms(ctx, gomatrixserverlib.ServerName(homeserverDomain), int(limit), "", false, "") + if err != nil { + util.GetLogger(ctx).WithError(err).WithField("hs", homeserverDomain).Warn( + "bulkFetchPublicRoomsFromServers: failed to query hs", + ) + return + } + for _, room := range fres.Chunk { + // atomically send a room or stop + select { + case roomCh <- room: + case <-done: + util.GetLogger(ctx).WithError(err).WithField("hs", homeserverDomain).Info("Interrupted whilst sending rooms") + return + } + } + }(hs) + } + + // Close the room channel when the goroutines have quit so we don't leak, but don't let it stop the in-flight request. + // This also allows the request to fail fast if all HSes experience errors as it will cause the room channel to be + // closed. + go func() { + wg.Wait() + util.GetLogger(ctx).Info("Cleaning up resources") + close(roomCh) + }() + + // fan-in results with timeout. We stop when we reach the limit. +FanIn: + for len(publicRooms) < int(limit) || limit == 0 { + // add a room or timeout + select { + case room, ok := <-roomCh: + if !ok { + util.GetLogger(ctx).Info("All homeservers have been queried, returning results.") + break FanIn + } + publicRooms = append(publicRooms, room) + case <-time.After(15 * time.Second): // we've waited long enough, let's tell the client what we got. + util.GetLogger(ctx).Info("Waited 15s for federated public rooms, returning early") + break FanIn + case <-ctx.Done(): // the client hung up on us, let's stop. + util.GetLogger(ctx).Info("Client hung up, returning early") + break FanIn + } + } + // tell goroutines to stop + close(done) + + return publicRooms +} + +func publicRooms(ctx context.Context, request PublicRoomReq, rsAPI roomserverAPI.RoomserverInternalAPI, + stateAPI currentstateAPI.CurrentStateInternalAPI) (*gomatrixserverlib.RespPublicRooms, error) { + + var response gomatrixserverlib.RespPublicRooms + var limit int16 + var offset int64 + limit = request.Limit + offset, err := strconv.ParseInt(request.Since, 10, 64) + // ParseInt returns 0 and an error when trying to parse an empty string + // In that case, we want to assign 0 so we ignore the error + if err != nil && len(request.Since) > 0 { + util.GetLogger(ctx).WithError(err).Error("strconv.ParseInt failed") + return nil, err + } + + var queryRes roomserverAPI.QueryPublishedRoomsResponse + err = rsAPI.QueryPublishedRooms(ctx, &roomserverAPI.QueryPublishedRoomsRequest{}, &queryRes) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("QueryPublishedRooms failed") + return nil, err + } + response.TotalRoomCountEstimate = len(queryRes.RoomIDs) + + if offset > 0 { + response.PrevBatch = strconv.Itoa(int(offset) - 1) + } + nextIndex := int(offset) + int(limit) + if response.TotalRoomCountEstimate > nextIndex { + response.NextBatch = strconv.Itoa(nextIndex) + } + + if offset < 0 { + offset = 0 + } + if nextIndex > len(queryRes.RoomIDs) { + nextIndex = len(queryRes.RoomIDs) + } + roomIDs := queryRes.RoomIDs[offset:nextIndex] + response.Chunk, err = fillInRooms(ctx, roomIDs, stateAPI) + return &response, err +} + +// fillPublicRoomsReq fills the Limit, Since and Filter attributes of a GET or POST request +// on /publicRooms by parsing the incoming HTTP request +// Filter is only filled for POST requests +func fillPublicRoomsReq(httpReq *http.Request, request *PublicRoomReq) *util.JSONResponse { + if httpReq.Method == http.MethodGet { + limit, err := strconv.Atoi(httpReq.FormValue("limit")) + // Atoi returns 0 and an error when trying to parse an empty string + // In that case, we want to assign 0 so we ignore the error + if err != nil && len(httpReq.FormValue("limit")) > 0 { + util.GetLogger(httpReq.Context()).WithError(err).Error("strconv.Atoi failed") + reqErr := jsonerror.InternalServerError() + return &reqErr + } + request.Limit = int16(limit) + request.Since = httpReq.FormValue("since") + return nil + } else if httpReq.Method == http.MethodPost { + return httputil.UnmarshalJSONRequest(httpReq, request) + } + + return &util.JSONResponse{ + Code: http.StatusMethodNotAllowed, + JSON: jsonerror.NotFound("Bad method"), + } +} + +// due to lots of switches +// nolint:gocyclo +func fillInRooms(ctx context.Context, roomIDs []string, stateAPI currentstateAPI.CurrentStateInternalAPI) ([]gomatrixserverlib.PublicRoom, error) { + avatarTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.avatar", StateKey: ""} + nameTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.name", StateKey: ""} + canonicalTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomCanonicalAlias, StateKey: ""} + topicTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.topic", StateKey: ""} + guestTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.guest_access", StateKey: ""} + visibilityTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomHistoryVisibility, StateKey: ""} + joinRuleTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomJoinRules, StateKey: ""} + + var stateRes currentstateAPI.QueryBulkStateContentResponse + err := stateAPI.QueryBulkStateContent(ctx, ¤tstateAPI.QueryBulkStateContentRequest{ + RoomIDs: roomIDs, + AllowWildcards: true, + StateTuples: []gomatrixserverlib.StateKeyTuple{ + nameTuple, canonicalTuple, topicTuple, guestTuple, visibilityTuple, joinRuleTuple, avatarTuple, + {EventType: gomatrixserverlib.MRoomMember, StateKey: "*"}, + }, + }, &stateRes) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("QueryBulkStateContent failed") + return nil, err + } + chunk := make([]gomatrixserverlib.PublicRoom, len(roomIDs)) + i := 0 + for roomID, data := range stateRes.Rooms { + pub := gomatrixserverlib.PublicRoom{ + RoomID: roomID, + } + joinCount := 0 + var joinRule, guestAccess string + for tuple, contentVal := range data { + if tuple.EventType == gomatrixserverlib.MRoomMember && contentVal == "join" { + joinCount++ + continue + } + switch tuple { + case avatarTuple: + pub.AvatarURL = contentVal + case nameTuple: + pub.Name = contentVal + case topicTuple: + pub.Topic = contentVal + case canonicalTuple: + pub.CanonicalAlias = contentVal + case visibilityTuple: + pub.WorldReadable = contentVal == "world_readable" + // need both of these to determine whether guests can join + case joinRuleTuple: + joinRule = contentVal + case guestTuple: + guestAccess = contentVal + } + } + if joinRule == gomatrixserverlib.Public && guestAccess == "can_join" { + pub.GuestCanJoin = true + } + pub.JoinedMembersCount = joinCount + chunk[i] = pub + i++ + } + return chunk, nil +} diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 1f316384b..c2145159a 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/threepid" + currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api" "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" @@ -358,3 +359,35 @@ func checkAndProcessThreepid( } return } + +func checkMemberInRoom(ctx context.Context, stateAPI currentstateAPI.CurrentStateInternalAPI, userID, roomID string) *util.JSONResponse { + tuple := gomatrixserverlib.StateKeyTuple{ + EventType: gomatrixserverlib.MRoomMember, + StateKey: userID, + } + var membershipRes currentstateAPI.QueryCurrentStateResponse + err := stateAPI.QueryCurrentState(ctx, ¤tstateAPI.QueryCurrentStateRequest{ + RoomID: roomID, + StateTuples: []gomatrixserverlib.StateKeyTuple{tuple}, + }, &membershipRes) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("QueryCurrentState: could not query membership for user") + e := jsonerror.InternalServerError() + return &e + } + ev, ok := membershipRes.StateEvents[tuple] + if !ok { + return &util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("user does not belong to room"), + } + } + membership, err := ev.Membership() + if err != nil || membership != "join" { + return &util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("user does not belong to room"), + } + } + return nil +} diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index deaa7b329..57bb921d9 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -1,4 +1,4 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -31,6 +31,7 @@ import ( "github.com/matrix-org/dendrite/internal/transactions" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/userapi/api" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" @@ -290,6 +291,34 @@ func Setup( return RemoveLocalAlias(req, device, vars["roomAlias"], rsAPI) }), ).Methods(http.MethodDelete, http.MethodOptions) + r0mux.Handle("/directory/list/room/{roomID}", + httputil.MakeExternalAPI("directory_list", func(req *http.Request) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return GetVisibility(req, rsAPI, vars["roomID"]) + }), + ).Methods(http.MethodGet, http.MethodOptions) + // TODO: Add AS support + r0mux.Handle("/directory/list/room/{roomID}", + httputil.MakeAuthAPI("directory_list", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return SetVisibility(req, stateAPI, rsAPI, device, vars["roomID"]) + }), + ).Methods(http.MethodPut, http.MethodOptions) + r0mux.Handle("/publicRooms", + httputil.MakeExternalAPI("public_rooms", func(req *http.Request) util.JSONResponse { + /* TODO: + if extRoomsProvider != nil { + return GetPostPublicRoomsWithExternal(req, stateAPI, fedClient, extRoomsProvider) + } */ + return GetPostPublicRooms(req, rsAPI, stateAPI) + }), + ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) r0mux.Handle("/logout", httputil.MakeAuthAPI("logout", userAPI, func(req *http.Request, device *api.Device) util.JSONResponse { diff --git a/currentstateserver/api/api.go b/currentstateserver/api/api.go index b16306ab0..729a66baf 100644 --- a/currentstateserver/api/api.go +++ b/currentstateserver/api/api.go @@ -29,6 +29,8 @@ type CurrentStateInternalAPI interface { QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error // QueryRoomsForUser retrieves a list of room IDs matching the given query. QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error + // QueryBulkStateContent does a bulk query for state event content in the given rooms. + QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error } type QueryRoomsForUserRequest struct { @@ -41,6 +43,30 @@ type QueryRoomsForUserResponse struct { RoomIDs []string } +type QueryBulkStateContentRequest struct { + // Returns state events in these rooms + RoomIDs []string + // If true, treats the '*' StateKey as "all state events of this type" rather than a literal value of '*' + AllowWildcards bool + // The state events to return. Only a small subset of tuples are allowed in this request as only certain events + // have their content fields extracted. Specifically, the tuple Type must be one of: + // m.room.avatar + // m.room.create + // m.room.canonical_alias + // m.room.guest_access + // m.room.history_visibility + // m.room.join_rules + // m.room.member + // m.room.name + // m.room.topic + // Any other tuple type will result in the query failing. + StateTuples []gomatrixserverlib.StateKeyTuple +} +type QueryBulkStateContentResponse struct { + // map of room ID -> tuple -> content_value + Rooms map[string]map[gomatrixserverlib.StateKeyTuple]string +} + type QueryCurrentStateRequest struct { RoomID string StateTuples []gomatrixserverlib.StateKeyTuple diff --git a/currentstateserver/internal/api.go b/currentstateserver/internal/api.go index 85fbf51ef..c28760477 100644 --- a/currentstateserver/internal/api.go +++ b/currentstateserver/internal/api.go @@ -48,3 +48,23 @@ func (a *CurrentStateInternalAPI) QueryRoomsForUser(ctx context.Context, req *ap res.RoomIDs = roomIDs return nil } + +func (a *CurrentStateInternalAPI) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkStateContentRequest, res *api.QueryBulkStateContentResponse) error { + events, err := a.DB.GetBulkStateContent(ctx, req.RoomIDs, req.StateTuples, req.AllowWildcards) + if err != nil { + return err + } + res.Rooms = make(map[string]map[gomatrixserverlib.StateKeyTuple]string) + for _, ev := range events { + if res.Rooms[ev.RoomID] == nil { + res.Rooms[ev.RoomID] = make(map[gomatrixserverlib.StateKeyTuple]string) + } + room := res.Rooms[ev.RoomID] + room[gomatrixserverlib.StateKeyTuple{ + EventType: ev.EventType, + StateKey: ev.StateKey, + }] = ev.ContentValue + res.Rooms[ev.RoomID] = room + } + return nil +} diff --git a/currentstateserver/inthttp/client.go b/currentstateserver/inthttp/client.go index 6fd9907bd..b8c6a1198 100644 --- a/currentstateserver/inthttp/client.go +++ b/currentstateserver/inthttp/client.go @@ -26,8 +26,9 @@ import ( // HTTP paths for the internal HTTP APIs const ( - QueryCurrentStatePath = "/currentstateserver/queryCurrentState" - QueryRoomsForUserPath = "/currentstateserver/queryRoomsForUser" + QueryCurrentStatePath = "/currentstateserver/queryCurrentState" + QueryRoomsForUserPath = "/currentstateserver/queryRoomsForUser" + QueryBulkStateContentPath = "/currentstateserver/queryBulkStateContent" ) // NewCurrentStateAPIClient creates a CurrentStateInternalAPI implemented by talking to a HTTP POST API. @@ -73,3 +74,15 @@ func (h *httpCurrentStateInternalAPI) QueryRoomsForUser( apiURL := h.apiURL + QueryRoomsForUserPath return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } + +func (h *httpCurrentStateInternalAPI) QueryBulkStateContent( + ctx context.Context, + request *api.QueryBulkStateContentRequest, + response *api.QueryBulkStateContentResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryBulkStateContent") + defer span.Finish() + + apiURL := h.apiURL + QueryBulkStateContentPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} diff --git a/currentstateserver/inthttp/server.go b/currentstateserver/inthttp/server.go index fa7ecb22e..dafb9f643 100644 --- a/currentstateserver/inthttp/server.go +++ b/currentstateserver/inthttp/server.go @@ -51,4 +51,17 @@ func AddRoutes(internalAPIMux *mux.Router, intAPI api.CurrentStateInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(QueryBulkStateContentPath, + httputil.MakeInternalAPI("queryBulkStateContent", func(req *http.Request) util.JSONResponse { + request := api.QueryBulkStateContentRequest{} + response := api.QueryBulkStateContentResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := intAPI.QueryBulkStateContent(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) } diff --git a/currentstateserver/storage/interface.go b/currentstateserver/storage/interface.go index dbf223f33..04636bafb 100644 --- a/currentstateserver/storage/interface.go +++ b/currentstateserver/storage/interface.go @@ -17,6 +17,7 @@ package storage import ( "context" + "github.com/matrix-org/dendrite/currentstateserver/storage/tables" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/gomatrixserverlib" ) @@ -31,4 +32,7 @@ type Database interface { GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) // GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key). GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) + // GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match. + // If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned. + GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) } diff --git a/currentstateserver/storage/postgres/current_room_state_table.go b/currentstateserver/storage/postgres/current_room_state_table.go index f423cc194..79a8d40f8 100644 --- a/currentstateserver/storage/postgres/current_room_state_table.go +++ b/currentstateserver/storage/postgres/current_room_state_table.go @@ -203,3 +203,9 @@ func (s *currentRoomStateStatements) SelectStateEvent( } return &ev, err } + +func (s *currentRoomStateStatements) SelectBulkStateContent( + ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool, +) ([]tables.StrippedEvent, error) { + return nil, nil +} diff --git a/currentstateserver/storage/shared/storage.go b/currentstateserver/storage/shared/storage.go index cc5e7cda6..cd59ac129 100644 --- a/currentstateserver/storage/shared/storage.go +++ b/currentstateserver/storage/shared/storage.go @@ -32,6 +32,10 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s return d.CurrentRoomState.SelectStateEvent(ctx, roomID, evType, stateKey) } +func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) { + return d.CurrentRoomState.SelectBulkStateContent(ctx, roomIDs, tuples, allowWildcards) +} + func (d *Database) StoreStateEvents(ctx context.Context, addStateEvents []gomatrixserverlib.HeaderedEvent, removeStateEventIDs []string) error { return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { diff --git a/currentstateserver/storage/sqlite3/current_room_state_table.go b/currentstateserver/storage/sqlite3/current_room_state_table.go index 7030fd099..672aaff26 100644 --- a/currentstateserver/storage/sqlite3/current_room_state_table.go +++ b/currentstateserver/storage/sqlite3/current_room_state_table.go @@ -192,3 +192,9 @@ func (s *currentRoomStateStatements) SelectStateEvent( } return &ev, err } + +func (s *currentRoomStateStatements) SelectBulkStateContent( + ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool, +) ([]tables.StrippedEvent, error) { + return nil, nil +} diff --git a/currentstateserver/storage/tables/interface.go b/currentstateserver/storage/tables/interface.go index 42dc58ccc..4cd547ff7 100644 --- a/currentstateserver/storage/tables/interface.go +++ b/currentstateserver/storage/tables/interface.go @@ -33,6 +33,15 @@ type CurrentRoomState interface { DeleteRoomStateByEventID(ctx context.Context, txn *sql.Tx, eventID string) error // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. SelectRoomIDsWithMembership(ctx context.Context, txn *sql.Tx, userID string, membership string) ([]string, error) + SelectBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]StrippedEvent, error) +} + +// StrippedEvent represents a stripped event for returning extracted content values. +type StrippedEvent struct { + RoomID string + EventType string + StateKey string + ContentValue string } // ExtractContentValue from the given state event. For example, given an m.room.name event with: diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index 3f5d5f4e0..bfbdaa5ff 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -111,6 +111,13 @@ func (t *testRoomserverAPI) PerformJoin( ) { } +func (t *testRoomserverAPI) PerformPublish( + ctx context.Context, + req *api.PerformPublishRequest, + res *api.PerformPublishResponse, +) { +} + func (t *testRoomserverAPI) PerformLeave( ctx context.Context, req *api.PerformLeaveRequest, @@ -168,6 +175,14 @@ func (t *testRoomserverAPI) QueryMembershipForUser( return fmt.Errorf("not implemented") } +func (t *testRoomserverAPI) QueryPublishedRooms( + ctx context.Context, + request *api.QueryPublishedRoomsRequest, + response *api.QueryPublishedRoomsResponse, +) error { + return fmt.Errorf("not implemented") +} + // Query a list of membership events for a room func (t *testRoomserverAPI) QueryMembershipsForRoom( ctx context.Context, diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 26ec8ca1d..0a5845dd6 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -36,6 +36,18 @@ type RoomserverInternalAPI interface { res *PerformLeaveResponse, ) error + PerformPublish( + ctx context.Context, + req *PerformPublishRequest, + res *PerformPublishResponse, + ) + + QueryPublishedRooms( + ctx context.Context, + req *QueryPublishedRoomsRequest, + res *QueryPublishedRoomsResponse, + ) error + // Query the latest events and state for a room from the room server. QueryLatestEventsAndState( ctx context.Context, diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go index 8645b6f28..bdebc57b0 100644 --- a/roomserver/api/api_trace.go +++ b/roomserver/api/api_trace.go @@ -57,6 +57,25 @@ func (t *RoomserverInternalAPITrace) PerformLeave( return err } +func (t *RoomserverInternalAPITrace) PerformPublish( + ctx context.Context, + req *PerformPublishRequest, + res *PerformPublishResponse, +) { + t.Impl.PerformPublish(ctx, req, res) + util.GetLogger(ctx).Infof("PerformPublish req=%+v res=%+v", js(req), js(res)) +} + +func (t *RoomserverInternalAPITrace) QueryPublishedRooms( + ctx context.Context, + req *QueryPublishedRoomsRequest, + res *QueryPublishedRoomsResponse, +) error { + err := t.Impl.QueryPublishedRooms(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryPublishedRooms req=%+v res=%+v", js(req), js(res)) + return err +} + func (t *RoomserverInternalAPITrace) QueryLatestEventsAndState( ctx context.Context, req *QueryLatestEventsAndStateRequest, diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go index 5d8d88a5a..9e8447339 100644 --- a/roomserver/api/perform.go +++ b/roomserver/api/perform.go @@ -136,3 +136,13 @@ type PerformBackfillResponse struct { // Missing events, arbritrary order. Events []gomatrixserverlib.HeaderedEvent `json:"events"` } + +type PerformPublishRequest struct { + RoomID string + Visibility string +} + +type PerformPublishResponse struct { + // If non-nil, the publish request failed. Contains more information why it failed. + Error *PerformError +} diff --git a/roomserver/api/query.go b/roomserver/api/query.go index f0cb9374b..4e1d09c30 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -215,3 +215,13 @@ type QueryRoomVersionForRoomRequest struct { type QueryRoomVersionForRoomResponse struct { RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` } + +type QueryPublishedRoomsRequest struct { + // Optional. If specified, returns whether this room is published or not. + RoomID string +} + +type QueryPublishedRoomsResponse struct { + // The list of published rooms. + RoomIDs []string +} diff --git a/roomserver/internal/perform_publish.go b/roomserver/internal/perform_publish.go new file mode 100644 index 000000000..d7863620a --- /dev/null +++ b/roomserver/internal/perform_publish.go @@ -0,0 +1,20 @@ +package internal + +import ( + "context" + + "github.com/matrix-org/dendrite/roomserver/api" +) + +func (r *RoomserverInternalAPI) PerformPublish( + ctx context.Context, + req *api.PerformPublishRequest, + res *api.PerformPublishResponse, +) { + err := r.DB.PublishRoom(ctx, req.RoomID, req.Visibility == "public") + if err != nil { + res.Error = &api.PerformError{ + Msg: err.Error(), + } + } +} diff --git a/roomserver/internal/query.go b/roomserver/internal/query.go index 19236bfbd..7fa3247a6 100644 --- a/roomserver/internal/query.go +++ b/roomserver/internal/query.go @@ -930,3 +930,16 @@ func (r *RoomserverInternalAPI) QueryRoomVersionForRoom( r.Cache.StoreRoomVersion(request.RoomID, response.RoomVersion) return nil } + +func (r *RoomserverInternalAPI) QueryPublishedRooms( + ctx context.Context, + req *api.QueryPublishedRoomsRequest, + res *api.QueryPublishedRoomsResponse, +) error { + rooms, err := r.DB.GetPublishedRooms(ctx) + if err != nil { + return err + } + res.RoomIDs = rooms + return nil +} diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index 8a2b1204c..ad24af4ad 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -29,6 +29,7 @@ const ( RoomserverPerformJoinPath = "/roomserver/performJoin" RoomserverPerformLeavePath = "/roomserver/performLeave" RoomserverPerformBackfillPath = "/roomserver/performBackfill" + RoomserverPerformPublishPath = "/roomserver/performPublish" // Query operations RoomserverQueryLatestEventsAndStatePath = "/roomserver/queryLatestEventsAndState" @@ -41,6 +42,7 @@ const ( RoomserverQueryStateAndAuthChainPath = "/roomserver/queryStateAndAuthChain" RoomserverQueryRoomVersionCapabilitiesPath = "/roomserver/queryRoomVersionCapabilities" RoomserverQueryRoomVersionForRoomPath = "/roomserver/queryRoomVersionForRoom" + RoomserverQueryPublishedRoomsPath = "/roomserver/queryPublishedRooms" ) type httpRoomserverInternalAPI struct { @@ -194,6 +196,23 @@ func (h *httpRoomserverInternalAPI) PerformLeave( return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } +func (h *httpRoomserverInternalAPI) PerformPublish( + ctx context.Context, + req *api.PerformPublishRequest, + res *api.PerformPublishResponse, +) { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPublish") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverPerformPublishPath + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) + if err != nil { + res.Error = &api.PerformError{ + Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err), + } + } +} + // QueryLatestEventsAndState implements RoomserverQueryAPI func (h *httpRoomserverInternalAPI) QueryLatestEventsAndState( ctx context.Context, @@ -233,6 +252,18 @@ func (h *httpRoomserverInternalAPI) QueryEventsByID( return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } +func (h *httpRoomserverInternalAPI) QueryPublishedRooms( + ctx context.Context, + request *api.QueryPublishedRoomsRequest, + response *api.QueryPublishedRoomsResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPublishedRooms") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryPublishedRoomsPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + // QueryMembershipForUser implements RoomserverQueryAPI func (h *httpRoomserverInternalAPI) QueryMembershipForUser( ctx context.Context, diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go index 1c47e87e2..bb54abf9c 100644 --- a/roomserver/inthttp/server.go +++ b/roomserver/inthttp/server.go @@ -61,6 +61,31 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(RoomserverPerformPublishPath, + httputil.MakeInternalAPI("performPublish", func(req *http.Request) util.JSONResponse { + var request api.PerformPublishRequest + var response api.PerformPublishResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + r.PerformPublish(req.Context(), &request, &response) + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle( + RoomserverQueryPublishedRoomsPath, + httputil.MakeInternalAPI("queryPublishedRooms", func(req *http.Request) util.JSONResponse { + var request api.QueryPublishedRoomsRequest + var response api.QueryPublishedRoomsResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.QueryPublishedRooms(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) internalAPIMux.Handle( RoomserverQueryLatestEventsAndStatePath, httputil.MakeInternalAPI("queryLatestEventsAndState", func(req *http.Request) util.JSONResponse { diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 0c4e2e0b5..5c916f294 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -139,4 +139,8 @@ type Database interface { EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) // Look up the room version for a given room. GetRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error) + // Publish or unpublish a room from the room directory. + PublishRoom(ctx context.Context, roomID string, publish bool) error + // Returns a list of room IDs for rooms which are published. + GetPublishedRooms(ctx context.Context) ([]string, error) } diff --git a/roomserver/storage/postgres/published_table.go b/roomserver/storage/postgres/published_table.go new file mode 100644 index 000000000..55df2ef28 --- /dev/null +++ b/roomserver/storage/postgres/published_table.go @@ -0,0 +1,101 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" +) + +const publishedSchema = ` +-- Stores which rooms are published in the room directory +CREATE TABLE IF NOT EXISTS roomserver_published ( + -- The room ID of the room + room_id TEXT NOT NULL PRIMARY KEY, + -- Whether it is published or not + published BOOLEAN NOT NULL DEFAULT false +); +` + +const upsertPublishedSQL = "" + + "INSERT INTO roomserver_published (room_id, published) VALUES ($1, $2) " + + "ON CONFLICT room_id DO UPDATE SET published=$2" + +const selectAllPublishedSQL = "" + + "SELECT room_id FROM roomserver_published WHERE published = $1" + +const selectPublishedSQL = "" + + "SELECT published FROM roomserver_published WHERE room_id = $1" + +type publishedStatements struct { + upsertPublishedStmt *sql.Stmt + selectAllPublishedStmt *sql.Stmt + selectPublishedStmt *sql.Stmt +} + +func NewPostgresPublishedTable(db *sql.DB) (tables.Published, error) { + s := &publishedStatements{} + _, err := db.Exec(publishedSchema) + if err != nil { + return nil, err + } + return s, shared.StatementList{ + {&s.upsertPublishedStmt, upsertPublishedSQL}, + {&s.selectAllPublishedStmt, selectAllPublishedSQL}, + {&s.selectPublishedStmt, selectPublishedSQL}, + }.Prepare(db) +} + +func (s *publishedStatements) UpsertRoomPublished( + ctx context.Context, roomID string, published bool, +) (err error) { + _, err = s.upsertPublishedStmt.ExecContext(ctx, roomID, published) + return +} + +func (s *publishedStatements) SelectPublishedFromRoomID( + ctx context.Context, roomID string, +) (published bool, err error) { + err = s.selectPublishedStmt.QueryRowContext(ctx, roomID).Scan(&published) + if err == sql.ErrNoRows { + return false, nil + } + return +} + +func (s *publishedStatements) SelectAllPublishedRooms( + ctx context.Context, published bool, +) ([]string, error) { + rows, err := s.selectAllPublishedStmt.QueryContext(ctx, published) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectAllPublishedStmt: rows.close() failed") + + var roomIDs []string + for rows.Next() { + var roomID string + if err = rows.Scan(&roomID); err != nil { + return nil, err + } + + roomIDs = append(roomIDs, roomID) + } + return roomIDs, rows.Err() +} diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index d76ee0a92..23d078e4a 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -87,6 +87,10 @@ func Open(dataSourceName string, dbProperties sqlutil.DbProperties) (*Database, if err != nil { return nil, err } + published, err := NewPostgresPublishedTable(db) + if err != nil { + return nil, err + } d.Database = shared.Database{ DB: db, EventTypesTable: eventTypes, @@ -101,6 +105,7 @@ func Open(dataSourceName string, dbProperties sqlutil.DbProperties) (*Database, RoomAliasesTable: roomAliases, InvitesTable: invites, MembershipTable: membership, + PublishedTable: published, } return &d, nil } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index e6d0e34e2..166822d0c 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -26,6 +26,7 @@ type Database struct { PrevEventsTable tables.PreviousEvents InvitesTable tables.Invites MembershipTable tables.Membership + PublishedTable tables.Published } func (d *Database) EventTypeNIDs( @@ -420,6 +421,14 @@ func (d *Database) StoreEvent( }, nil } +func (d *Database) PublishRoom(ctx context.Context, roomID string, publish bool) error { + return d.PublishedTable.UpsertRoomPublished(ctx, roomID, publish) +} + +func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) { + return d.PublishedTable.SelectAllPublishedRooms(ctx, true) +} + func (d *Database) assignRoomNID( ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion, diff --git a/roomserver/storage/sqlite3/published_table.go b/roomserver/storage/sqlite3/published_table.go new file mode 100644 index 000000000..a94e60719 --- /dev/null +++ b/roomserver/storage/sqlite3/published_table.go @@ -0,0 +1,100 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" +) + +const publishedSchema = ` +-- Stores which rooms are published in the room directory +CREATE TABLE IF NOT EXISTS roomserver_published ( + -- The room ID of the room + room_id TEXT NOT NULL PRIMARY KEY, + -- Whether it is published or not + published BOOLEAN NOT NULL DEFAULT false +); +` + +const upsertPublishedSQL = "" + + "INSERT OR REPLACE INTO roomserver_published (room_id, published) VALUES ($1, $2)" + +const selectAllPublishedSQL = "" + + "SELECT room_id FROM roomserver_published WHERE published = $1" + +const selectPublishedSQL = "" + + "SELECT published FROM roomserver_published WHERE room_id = $1" + +type publishedStatements struct { + upsertPublishedStmt *sql.Stmt + selectAllPublishedStmt *sql.Stmt + selectPublishedStmt *sql.Stmt +} + +func NewSqlitePublishedTable(db *sql.DB) (tables.Published, error) { + s := &publishedStatements{} + _, err := db.Exec(publishedSchema) + if err != nil { + return nil, err + } + return s, shared.StatementList{ + {&s.upsertPublishedStmt, upsertPublishedSQL}, + {&s.selectAllPublishedStmt, selectAllPublishedSQL}, + {&s.selectPublishedStmt, selectPublishedSQL}, + }.Prepare(db) +} + +func (s *publishedStatements) UpsertRoomPublished( + ctx context.Context, roomID string, published bool, +) (err error) { + _, err = s.upsertPublishedStmt.ExecContext(ctx, roomID, published) + return +} + +func (s *publishedStatements) SelectPublishedFromRoomID( + ctx context.Context, roomID string, +) (published bool, err error) { + err = s.selectPublishedStmt.QueryRowContext(ctx, roomID).Scan(&published) + if err == sql.ErrNoRows { + return false, nil + } + return +} + +func (s *publishedStatements) SelectAllPublishedRooms( + ctx context.Context, published bool, +) ([]string, error) { + rows, err := s.selectAllPublishedStmt.QueryContext(ctx, published) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectAllPublishedStmt: rows.close() failed") + + var roomIDs []string + for rows.Next() { + var roomID string + if err = rows.Scan(&roomID); err != nil { + return nil, err + } + + roomIDs = append(roomIDs, roomID) + } + return roomIDs, rows.Err() +} diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 8e9352192..767b13ce0 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -110,6 +110,10 @@ func Open(dataSourceName string) (*Database, error) { if err != nil { return nil, err } + published, err := NewSqlitePublishedTable(d.db) + if err != nil { + return nil, err + } d.Database = shared.Database{ DB: d.db, EventsTable: d.events, @@ -124,6 +128,7 @@ func Open(dataSourceName string) (*Database, error) { RoomAliasesTable: roomAliases, InvitesTable: d.invites, MembershipTable: d.membership, + PublishedTable: published, } return &d, nil } diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 3aa8c538c..7499089ca 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -120,3 +120,9 @@ type Membership interface { SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID) error } + +type Published interface { + UpsertRoomPublished(ctx context.Context, roomID string, published bool) (err error) + SelectPublishedFromRoomID(ctx context.Context, roomID string) (published bool, err error) + SelectAllPublishedRooms(ctx context.Context, published bool) ([]string, error) +}