diff --git a/build/docker/config/dendrite.yaml b/build/docker/config/dendrite.yaml index e3a0316dc..94dcf4558 100644 --- a/build/docker/config/dendrite.yaml +++ b/build/docker/config/dendrite.yaml @@ -140,7 +140,7 @@ client_api: # Prevents new users from being able to register on this homeserver, except when # using the registration shared secret below. - registration_disabled: false + registration_disabled: true # If set, allows registration by anyone who knows the shared secret, regardless of # whether registration is otherwise disabled. diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index 9cc94d650..d047f3fff 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -259,6 +259,8 @@ func (m *DendriteMonolith) Start() { cfg.MediaAPI.BasePath = config.Path(fmt.Sprintf("%s/media", m.CacheDirectory)) cfg.MediaAPI.AbsBasePath = config.Path(fmt.Sprintf("%s/media", m.CacheDirectory)) cfg.MSCs.MSCs = []string{"msc2836", "msc2946"} + cfg.ClientAPI.RegistrationDisabled = false + cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true if err := cfg.Derive(); err != nil { panic(err) } @@ -314,6 +316,7 @@ func (m *DendriteMonolith) Start() { base.PublicWellKnownAPIMux, base.PublicMediaAPIMux, base.SynapseAdminMux, + base.DendriteAdminMux, ) httpRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() diff --git a/build/gobind-yggdrasil/monolith.go b/build/gobind-yggdrasil/monolith.go index 87dcad2e8..4e95e3972 100644 --- a/build/gobind-yggdrasil/monolith.go +++ b/build/gobind-yggdrasil/monolith.go @@ -97,6 +97,8 @@ func (m *DendriteMonolith) Start() { cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-appservice.db", m.StorageDirectory)) cfg.MediaAPI.BasePath = config.Path(fmt.Sprintf("%s/tmp", m.StorageDirectory)) cfg.MediaAPI.AbsBasePath = config.Path(fmt.Sprintf("%s/tmp", m.StorageDirectory)) + cfg.ClientAPI.RegistrationDisabled = false + cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true if err = cfg.Derive(); err != nil { panic(err) } @@ -152,6 +154,7 @@ func (m *DendriteMonolith) Start() { base.PublicWellKnownAPIMux, base.PublicMediaAPIMux, base.SynapseAdminMux, + base.DendriteAdminMux, ) httpRouter := mux.NewRouter() diff --git a/build/scripts/Complement.Dockerfile b/build/scripts/Complement.Dockerfile index 6b2942d97..63e3890ee 100644 --- a/build/scripts/Complement.Dockerfile +++ b/build/scripts/Complement.Dockerfile @@ -29,4 +29,4 @@ EXPOSE 8008 8448 CMD ./generate-keys --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key && \ ./generate-config -server $SERVER_NAME --ci > dendrite.yaml && \ cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \ - ./dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0} + ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0} diff --git a/build/scripts/ComplementLocal.Dockerfile b/build/scripts/ComplementLocal.Dockerfile index 60b4d983a..a9feb4cd1 100644 --- a/build/scripts/ComplementLocal.Dockerfile +++ b/build/scripts/ComplementLocal.Dockerfile @@ -32,7 +32,7 @@ RUN echo '\ ./generate-keys --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key \n\ ./generate-config -server $SERVER_NAME --ci > dendrite.yaml \n\ cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates \n\ -./dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml \n\ +./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml \n\ ' > run.sh && chmod +x run.sh diff --git a/build/scripts/ComplementPostgres.Dockerfile b/build/scripts/ComplementPostgres.Dockerfile index b98f4671c..4e26faa58 100644 --- a/build/scripts/ComplementPostgres.Dockerfile +++ b/build/scripts/ComplementPostgres.Dockerfile @@ -51,4 +51,4 @@ CMD /build/run_postgres.sh && ./generate-keys --server $SERVER_NAME --tls-cert s sed -i "s%connection_string:.*$%connection_string: postgresql://postgres@localhost/postgres?sslmode=disable%g" dendrite.yaml && \ sed -i 's/max_open_conns:.*$/max_open_conns: 100/g' dendrite.yaml && \ cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \ - ./dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0} \ No newline at end of file + ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0} \ No newline at end of file diff --git a/build/scripts/find-lint.sh b/build/scripts/find-lint.sh index e3564ae38..820b8cc46 100755 --- a/build/scripts/find-lint.sh +++ b/build/scripts/find-lint.sh @@ -25,7 +25,7 @@ echo "Installing golangci-lint..." # Make a backup of go.{mod,sum} first cp go.mod go.mod.bak && cp go.sum go.sum.bak -go get github.com/golangci/golangci-lint/cmd/golangci-lint@v1.41.1 +go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.45.2 # Run linting echo "Looking for lint..." @@ -33,7 +33,7 @@ echo "Looking for lint..." # Capture exit code to ensure go.{mod,sum} is restored before exiting exit_code=0 -PATH="$PATH:${GOPATH:-~/go}/bin" golangci-lint run $args || exit_code=1 +PATH="$PATH:$(go env GOPATH)/bin" golangci-lint run $args || exit_code=1 # Restore go.{mod,sum} mv go.mod.bak go.mod && mv go.sum.bak go.sum diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index e2f8d3f32..ad277056c 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -36,6 +36,7 @@ func AddPublicRoutes( process *process.ProcessContext, router *mux.Router, synapseAdminRouter *mux.Router, + dendriteAdminRouter *mux.Router, cfg *config.ClientAPI, federation *gomatrixserverlib.FederationClient, rsAPI roomserverAPI.RoomserverInternalAPI, @@ -62,7 +63,8 @@ func AddPublicRoutes( } routing.Setup( - router, synapseAdminRouter, cfg, rsAPI, asAPI, + router, synapseAdminRouter, dendriteAdminRouter, + cfg, rsAPI, asAPI, userAPI, userDirectoryProvider, federation, syncProducer, transactionsCache, fsAPI, keyAPI, extRoomsProvider, mscCfg, natsClient, diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index f370b4f8c..ec90b80db 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -48,7 +48,8 @@ import ( // applied: // nolint: gocyclo func Setup( - publicAPIMux, synapseAdminRouter *mux.Router, cfg *config.ClientAPI, + publicAPIMux, synapseAdminRouter, dendriteAdminRouter *mux.Router, + cfg *config.ClientAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, userAPI userapi.UserInternalAPI, @@ -119,6 +120,45 @@ func Setup( ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) } + dendriteAdminRouter.Handle("/admin/evacuateRoom/{roomID}", + httputil.MakeAuthAPI("admin_evacuate_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if device.AccountType != userapi.AccountTypeAdmin { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("This API can only be used by admin users."), + } + } + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + roomID, ok := vars["roomID"] + if !ok { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.MissingArgument("Expecting room ID."), + } + } + res := &roomserverAPI.PerformAdminEvacuateRoomResponse{} + rsAPI.PerformAdminEvacuateRoom( + req.Context(), + &roomserverAPI.PerformAdminEvacuateRoomRequest{ + RoomID: roomID, + }, + res, + ) + if err := res.Error; err != nil { + return err.JSONResponse() + } + return util.JSONResponse{ + Code: 200, + JSON: map[string]interface{}{ + "affected": res.Affected, + }, + } + }), + ).Methods(http.MethodGet, http.MethodOptions) + // server notifications if cfg.Matrix.ServerNotices.Enabled { logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice") diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index dd1ab3697..785e7b460 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -140,6 +140,8 @@ func main() { cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", *instanceName)) cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-appservice.db", *instanceName)) cfg.MSCs.MSCs = []string{"msc2836", "msc2946"} + cfg.ClientAPI.RegistrationDisabled = false + cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true if err := cfg.Derive(); err != nil { panic(err) } @@ -193,6 +195,7 @@ func main() { base.PublicWellKnownAPIMux, base.PublicMediaAPIMux, base.SynapseAdminMux, + base.DendriteAdminMux, ) wsUpgrader := websocket.Upgrader{ diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index b840eb2b8..f9234319a 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -89,6 +89,8 @@ func main() { cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-appservice.db", *instanceName)) cfg.MSCs.MSCs = []string{"msc2836"} cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", *instanceName)) + cfg.ClientAPI.RegistrationDisabled = false + cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true if err = cfg.Derive(); err != nil { panic(err) } @@ -150,6 +152,7 @@ func main() { base.PublicWellKnownAPIMux, base.PublicMediaAPIMux, base.SynapseAdminMux, + base.DendriteAdminMux, ) if err := mscs.Enable(base, &monolith); err != nil { logrus.WithError(err).Fatalf("Failed to enable MSCs") diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go index 1443ab5b1..5fd5c0b5d 100644 --- a/cmd/dendrite-monolith-server/main.go +++ b/cmd/dendrite-monolith-server/main.go @@ -153,6 +153,7 @@ func main() { base.PublicWellKnownAPIMux, base.PublicMediaAPIMux, base.SynapseAdminMux, + base.DendriteAdminMux, ) if len(base.Cfg.MSCs.MSCs) > 0 { diff --git a/cmd/dendrite-polylith-multi/personalities/clientapi.go b/cmd/dendrite-polylith-multi/personalities/clientapi.go index 1e509f88a..7ed2075aa 100644 --- a/cmd/dendrite-polylith-multi/personalities/clientapi.go +++ b/cmd/dendrite-polylith-multi/personalities/clientapi.go @@ -31,8 +31,10 @@ func ClientAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) { keyAPI := base.KeyServerHTTPClient() clientapi.AddPublicRoutes( - base.ProcessContext, base.PublicClientAPIMux, base.SynapseAdminMux, &base.Cfg.ClientAPI, - federation, rsAPI, asQuery, transactions.New(), fsAPI, userAPI, userAPI, + base.ProcessContext, base.PublicClientAPIMux, + base.SynapseAdminMux, base.DendriteAdminMux, + &base.Cfg.ClientAPI, federation, rsAPI, asQuery, + transactions.New(), fsAPI, userAPI, userAPI, keyAPI, nil, &cfg.MSCs, ) diff --git a/cmd/dendrite-upgrade-tests/main.go b/cmd/dendrite-upgrade-tests/main.go index 3241234ac..b7e7da07d 100644 --- a/cmd/dendrite-upgrade-tests/main.go +++ b/cmd/dendrite-upgrade-tests/main.go @@ -83,7 +83,8 @@ do \n\ done \n\ \n\ sed -i "s/server_name: localhost/server_name: ${SERVER_NAME}/g" dendrite.yaml \n\ -./dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml \n\ +PARAMS="--tls-cert server.crt --tls-key server.key --config dendrite.yaml" \n\ +./dendrite-monolith-server --really-enable-open-registration ${PARAMS} || ./dendrite-monolith-server ${PARAMS} \n\ ' > run_dendrite.sh && chmod +x run_dendrite.sh ENV SERVER_NAME=localhost diff --git a/cmd/dendritejs-pinecone/main.go b/cmd/dendritejs-pinecone/main.go index 211b3e131..5ecf1f2fb 100644 --- a/cmd/dendritejs-pinecone/main.go +++ b/cmd/dendritejs-pinecone/main.go @@ -171,6 +171,8 @@ func startup() { cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) cfg.Global.PrivateKey = sk cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk)) + cfg.ClientAPI.RegistrationDisabled = false + cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true if err := cfg.Derive(); err != nil { logrus.Fatalf("Failed to derive values from config: %s", err) @@ -220,6 +222,7 @@ func startup() { base.PublicWellKnownAPIMux, base.PublicMediaAPIMux, base.SynapseAdminMux, + base.DendriteAdminMux, ) httpRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() diff --git a/cmd/generate-config/main.go b/cmd/generate-config/main.go index 24085afaa..1c585d916 100644 --- a/cmd/generate-config/main.go +++ b/cmd/generate-config/main.go @@ -90,6 +90,8 @@ func main() { cfg.Logging[0].Type = "std" cfg.UserAPI.BCryptCost = bcrypt.MinCost cfg.Global.JetStream.InMemory = true + cfg.ClientAPI.RegistrationDisabled = false + cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true cfg.ClientAPI.RegistrationSharedSecret = "complement" cfg.Global.Presence = config.PresenceOptions{ EnableInbound: true, diff --git a/dendrite-config.yaml b/dendrite-config.yaml index 47f08c4fd..1c11ef96d 100644 --- a/dendrite-config.yaml +++ b/dendrite-config.yaml @@ -159,7 +159,7 @@ client_api: # Prevents new users from being able to register on this homeserver, except when # using the registration shared secret below. - registration_disabled: false + registration_disabled: true # Prevents new guest accounts from being created. Guest registration is also # disabled implicitly by setting 'registration_disabled' above. diff --git a/roomserver/api/api.go b/roomserver/api/api.go index fb77423f8..f0ca8a615 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -66,6 +66,12 @@ type RoomserverInternalAPI interface { res *PerformInboundPeekResponse, ) error + PerformAdminEvacuateRoom( + ctx context.Context, + req *PerformAdminEvacuateRoomRequest, + res *PerformAdminEvacuateRoomResponse, + ) + QueryPublishedRooms( ctx context.Context, req *QueryPublishedRoomsRequest, diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go index ec7211ef8..61c06e886 100644 --- a/roomserver/api/api_trace.go +++ b/roomserver/api/api_trace.go @@ -104,6 +104,15 @@ func (t *RoomserverInternalAPITrace) PerformPublish( util.GetLogger(ctx).Infof("PerformPublish req=%+v res=%+v", js(req), js(res)) } +func (t *RoomserverInternalAPITrace) PerformAdminEvacuateRoom( + ctx context.Context, + req *PerformAdminEvacuateRoomRequest, + res *PerformAdminEvacuateRoomResponse, +) { + t.Impl.PerformAdminEvacuateRoom(ctx, req, res) + util.GetLogger(ctx).Infof("PerformAdminEvacuateRoom req=%+v res=%+v", js(req), js(res)) +} + func (t *RoomserverInternalAPITrace) PerformInboundPeek( ctx context.Context, req *PerformInboundPeekRequest, diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go index cda4b3ee4..30aa2cf1b 100644 --- a/roomserver/api/perform.go +++ b/roomserver/api/perform.go @@ -214,3 +214,12 @@ type PerformRoomUpgradeResponse struct { NewRoomID string Error *PerformError } + +type PerformAdminEvacuateRoomRequest struct { + RoomID string `json:"room_id"` +} + +type PerformAdminEvacuateRoomResponse struct { + Affected []string `json:"affected"` + Error *PerformError +} diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index 59f485cf7..267cd4099 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -35,6 +35,7 @@ type RoomserverInternalAPI struct { *perform.Backfiller *perform.Forgetter *perform.Upgrader + *perform.Admin ProcessContext *process.ProcessContext DB storage.Database Cfg *config.RoomServer @@ -164,6 +165,12 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.FederationInternalA Cfg: r.Cfg, URSAPI: r, } + r.Admin = &perform.Admin{ + DB: r.DB, + Cfg: r.Cfg, + Inputer: r.Inputer, + Queryer: r.Queryer, + } if err := r.Inputer.Start(); err != nil { logrus.WithError(err).Panic("failed to start roomserver input API") diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go new file mode 100644 index 000000000..2de6477cc --- /dev/null +++ b/roomserver/internal/perform/perform_admin.go @@ -0,0 +1,162 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package perform + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/internal/input" + "github.com/matrix-org/dendrite/roomserver/internal/query" + "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" +) + +type Admin struct { + DB storage.Database + Cfg *config.RoomServer + Queryer *query.Queryer + Inputer *input.Inputer +} + +// PerformEvacuateRoom will remove all local users from the given room. +func (r *Admin) PerformAdminEvacuateRoom( + ctx context.Context, + req *api.PerformAdminEvacuateRoomRequest, + res *api.PerformAdminEvacuateRoomResponse, +) { + roomInfo, err := r.DB.RoomInfo(ctx, req.RoomID) + if err != nil { + res.Error = &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: fmt.Sprintf("r.DB.RoomInfo: %s", err), + } + return + } + if roomInfo == nil || roomInfo.IsStub { + res.Error = &api.PerformError{ + Code: api.PerformErrorNoRoom, + Msg: fmt.Sprintf("Room %s not found", req.RoomID), + } + return + } + + memberNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true) + if err != nil { + res.Error = &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: fmt.Sprintf("r.DB.GetMembershipEventNIDsForRoom: %s", err), + } + return + } + + memberEvents, err := r.DB.Events(ctx, memberNIDs) + if err != nil { + res.Error = &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: fmt.Sprintf("r.DB.Events: %s", err), + } + return + } + + inputEvents := make([]api.InputRoomEvent, 0, len(memberEvents)) + res.Affected = make([]string, 0, len(memberEvents)) + latestReq := &api.QueryLatestEventsAndStateRequest{ + RoomID: req.RoomID, + } + latestRes := &api.QueryLatestEventsAndStateResponse{} + if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil { + res.Error = &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: fmt.Sprintf("r.Queryer.QueryLatestEventsAndState: %s", err), + } + return + } + + prevEvents := latestRes.LatestEvents + for _, memberEvent := range memberEvents { + if memberEvent.StateKey() == nil { + continue + } + + var memberContent gomatrixserverlib.MemberContent + if err = json.Unmarshal(memberEvent.Content(), &memberContent); err != nil { + res.Error = &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: fmt.Sprintf("json.Unmarshal: %s", err), + } + return + } + memberContent.Membership = gomatrixserverlib.Leave + + stateKey := *memberEvent.StateKey() + fledglingEvent := &gomatrixserverlib.EventBuilder{ + RoomID: req.RoomID, + Type: gomatrixserverlib.MRoomMember, + StateKey: &stateKey, + Sender: stateKey, + PrevEvents: prevEvents, + } + + if fledglingEvent.Content, err = json.Marshal(memberContent); err != nil { + res.Error = &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: fmt.Sprintf("json.Marshal: %s", err), + } + return + } + + eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(fledglingEvent) + if err != nil { + res.Error = &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: fmt.Sprintf("gomatrixserverlib.StateNeededForEventBuilder: %s", err), + } + return + } + + event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, time.Now(), &eventsNeeded, latestRes) + if err != nil { + res.Error = &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: fmt.Sprintf("eventutil.BuildEvent: %s", err), + } + return + } + + inputEvents = append(inputEvents, api.InputRoomEvent{ + Kind: api.KindNew, + Event: event, + Origin: r.Cfg.Matrix.ServerName, + SendAsServer: string(r.Cfg.Matrix.ServerName), + }) + res.Affected = append(res.Affected, stateKey) + prevEvents = []gomatrixserverlib.EventReference{ + event.EventReference(), + } + } + + inputReq := &api.InputRoomEventsRequest{ + InputRoomEvents: inputEvents, + Asynchronous: true, + } + inputRes := &api.InputRoomEventsResponse{} + r.Inputer.InputRoomEvents(ctx, inputReq, inputRes) +} diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index d55805a91..3b29001e9 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -29,16 +29,17 @@ const ( RoomserverInputRoomEventsPath = "/roomserver/inputRoomEvents" // Perform operations - RoomserverPerformInvitePath = "/roomserver/performInvite" - RoomserverPerformPeekPath = "/roomserver/performPeek" - RoomserverPerformUnpeekPath = "/roomserver/performUnpeek" - RoomserverPerformRoomUpgradePath = "/roomserver/performRoomUpgrade" - RoomserverPerformJoinPath = "/roomserver/performJoin" - RoomserverPerformLeavePath = "/roomserver/performLeave" - RoomserverPerformBackfillPath = "/roomserver/performBackfill" - RoomserverPerformPublishPath = "/roomserver/performPublish" - RoomserverPerformInboundPeekPath = "/roomserver/performInboundPeek" - RoomserverPerformForgetPath = "/roomserver/performForget" + RoomserverPerformInvitePath = "/roomserver/performInvite" + RoomserverPerformPeekPath = "/roomserver/performPeek" + RoomserverPerformUnpeekPath = "/roomserver/performUnpeek" + RoomserverPerformRoomUpgradePath = "/roomserver/performRoomUpgrade" + RoomserverPerformJoinPath = "/roomserver/performJoin" + RoomserverPerformLeavePath = "/roomserver/performLeave" + RoomserverPerformBackfillPath = "/roomserver/performBackfill" + RoomserverPerformPublishPath = "/roomserver/performPublish" + RoomserverPerformInboundPeekPath = "/roomserver/performInboundPeek" + RoomserverPerformForgetPath = "/roomserver/performForget" + RoomserverPerformAdminEvacuateRoomPath = "/roomserver/performAdminEvacuateRoom" // Query operations RoomserverQueryLatestEventsAndStatePath = "/roomserver/queryLatestEventsAndState" @@ -299,6 +300,23 @@ func (h *httpRoomserverInternalAPI) PerformPublish( } } +func (h *httpRoomserverInternalAPI) PerformAdminEvacuateRoom( + ctx context.Context, + req *api.PerformAdminEvacuateRoomRequest, + res *api.PerformAdminEvacuateRoomResponse, +) { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformAdminEvacuateRoom") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverPerformAdminEvacuateRoomPath + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) + if err != nil { + res.Error = &api.PerformError{ + Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err), + } + } +} + // QueryLatestEventsAndState implements RoomserverQueryAPI func (h *httpRoomserverInternalAPI) QueryLatestEventsAndState( ctx context.Context, diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go index 0b27b5a8d..c5159a63c 100644 --- a/roomserver/inthttp/server.go +++ b/roomserver/inthttp/server.go @@ -118,6 +118,17 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(RoomserverPerformAdminEvacuateRoomPath, + httputil.MakeInternalAPI("performAdminEvacuateRoom", func(req *http.Request) util.JSONResponse { + var request api.PerformAdminEvacuateRoomRequest + var response api.PerformAdminEvacuateRoomResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + r.PerformAdminEvacuateRoom(req.Context(), &request, &response) + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) internalAPIMux.Handle( RoomserverQueryPublishedRoomsPath, httputil.MakeInternalAPI("queryPublishedRooms", func(req *http.Request) util.JSONResponse { diff --git a/setup/base/base.go b/setup/base/base.go index 7091c6ba5..4b771aa36 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -126,6 +126,10 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string, options ...Base logrus.Infof("Dendrite version %s", internal.VersionString()) + if !cfg.ClientAPI.RegistrationDisabled && cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled { + logrus.Warn("Open registration is enabled") + } + closer, err := cfg.SetupTracing("Dendrite" + componentName) if err != nil { logrus.WithError(err).Panicf("failed to start opentracing") diff --git a/setup/config/config_clientapi.go b/setup/config/config_clientapi.go index 4590e752b..6104ed8b9 100644 --- a/setup/config/config_clientapi.go +++ b/setup/config/config_clientapi.go @@ -15,6 +15,12 @@ type ClientAPI struct { // If set disables new users from registering (except via shared // secrets) RegistrationDisabled bool `yaml:"registration_disabled"` + + // Enable registration without captcha verification or shared secret. + // This option is populated by the -really-enable-open-registration + // command line parameter as it is not recommended. + OpenRegistrationWithoutVerificationEnabled bool `yaml:"-"` + // If set, allows registration by anyone who also has the shared // secret, even if registration is otherwise disabled. RegistrationSharedSecret string `yaml:"registration_shared_secret"` @@ -55,7 +61,8 @@ func (c *ClientAPI) Defaults(generate bool) { c.RecaptchaEnabled = false c.RecaptchaBypassSecret = "" c.RecaptchaSiteVerifyAPI = "" - c.RegistrationDisabled = false + c.RegistrationDisabled = true + c.OpenRegistrationWithoutVerificationEnabled = false c.RateLimiting.Defaults() } @@ -72,6 +79,20 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { } c.TURN.Verify(configErrs) c.RateLimiting.Verify(configErrs) + + // Ensure there is any spam counter measure when enabling registration + if !c.RegistrationDisabled && !c.OpenRegistrationWithoutVerificationEnabled { + if !c.RecaptchaEnabled { + configErrs.Add( + "You have tried to enable open registration without any secondary verification methods " + + "(such as reCAPTCHA). By enabling open registration, you are SIGNIFICANTLY " + + "increasing the risk that your server will be used to send spam or abuse, and may result in " + + "your server being banned from some rooms. If you are ABSOLUTELY CERTAIN you want to do this, " + + "start Dendrite with the -really-enable-open-registration command line flag. Otherwise, you " + + "should set the registration_disabled option in your Dendrite config.", + ) + } + } } type TURN struct { diff --git a/setup/flags.go b/setup/flags.go index 281cf3392..a9dac61a1 100644 --- a/setup/flags.go +++ b/setup/flags.go @@ -25,8 +25,9 @@ import ( ) var ( - configPath = flag.String("config", "dendrite.yaml", "The path to the config file. For more information, see the config file in this repository.") - version = flag.Bool("version", false, "Shows the current version and exits immediately.") + configPath = flag.String("config", "dendrite.yaml", "The path to the config file. For more information, see the config file in this repository.") + version = flag.Bool("version", false, "Shows the current version and exits immediately.") + enableRegistrationWithoutVerification = flag.Bool("really-enable-open-registration", false, "This allows open registration without secondary verification (reCAPTCHA). This is NOT RECOMMENDED and will SIGNIFICANTLY increase the risk that your server will be used to send spam or conduct attacks, which may result in your server being banned from rooms.") ) // ParseFlags parses the commandline flags and uses them to create a config. @@ -48,5 +49,9 @@ func ParseFlags(monolith bool) *config.Dendrite { logrus.Fatalf("Invalid config file: %s", err) } + if *enableRegistrationWithoutVerification { + cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true + } + return cfg } diff --git a/setup/monolith.go b/setup/monolith.go index 32f1a6494..c86ec7b69 100644 --- a/setup/monolith.go +++ b/setup/monolith.go @@ -54,13 +54,13 @@ type Monolith struct { } // AddAllPublicRoutes attaches all public paths to the given router -func (m *Monolith) AddAllPublicRoutes(process *process.ProcessContext, csMux, ssMux, keyMux, wkMux, mediaMux, synapseMux *mux.Router) { +func (m *Monolith) AddAllPublicRoutes(process *process.ProcessContext, csMux, ssMux, keyMux, wkMux, mediaMux, synapseMux, dendriteMux *mux.Router) { userDirectoryProvider := m.ExtUserDirectoryProvider if userDirectoryProvider == nil { userDirectoryProvider = m.UserAPI } clientapi.AddPublicRoutes( - process, csMux, synapseMux, &m.Config.ClientAPI, + process, csMux, synapseMux, dendriteMux, &m.Config.ClientAPI, m.FedClient, m.RoomserverAPI, m.AppserviceAPI, transactions.New(), m.FederationAPI, m.UserAPI, userDirectoryProvider, m.KeyAPI, diff --git a/syncapi/notifier/notifier.go b/syncapi/notifier/notifier.go index 82834239b..87f0d86d7 100644 --- a/syncapi/notifier/notifier.go +++ b/syncapi/notifier/notifier.go @@ -333,6 +333,20 @@ func (n *Notifier) Load(ctx context.Context, db storage.Database) error { return nil } +// LoadRooms loads the membership states required to notify users correctly. +func (n *Notifier) LoadRooms(ctx context.Context, db storage.Database, roomIDs []string) error { + n.lock.Lock() + defer n.lock.Unlock() + + roomToUsers, err := db.AllJoinedUsersInRoom(ctx, roomIDs) + if err != nil { + return err + } + n.setUsersJoinedToRooms(roomToUsers) + + return nil +} + // CurrentPosition returns the current sync position func (n *Notifier) CurrentPosition() types.StreamingToken { n.lock.RLock() diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 43aaa3588..5a036d889 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -52,6 +52,9 @@ type Database interface { // AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs. AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) + // AllJoinedUsersInRoom returns a map of room ID to a list of all joined user IDs for a given room. + AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error) + // AllPeekingDevicesInRooms returns a map of room ID to a list of all peeking devices. AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) // Events lookups a list of event by their event ID. @@ -159,6 +162,6 @@ type Database interface { type Presence interface { UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) - PresenceAfter(ctx context.Context, after types.StreamPosition) (map[string]*types.PresenceInternal, error) + PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) } diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index fe68788d1..8ee387b39 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -93,6 +93,9 @@ const selectCurrentStateSQL = "" + const selectJoinedUsersSQL = "" + "SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'" +const selectJoinedUsersInRoomSQL = "" + + "SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join' AND room_id = ANY($1)" + const selectStateEventSQL = "" + "SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3" @@ -112,6 +115,7 @@ type currentRoomStateStatements struct { selectRoomIDsWithAnyMembershipStmt *sql.Stmt selectCurrentStateStmt *sql.Stmt selectJoinedUsersStmt *sql.Stmt + selectJoinedUsersInRoomStmt *sql.Stmt selectEventsWithEventIDsStmt *sql.Stmt selectStateEventStmt *sql.Stmt } @@ -143,6 +147,9 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil { return nil, err } + if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil { + return nil, err + } if s.selectEventsWithEventIDsStmt, err = db.Prepare(selectEventsWithEventIDsSQL); err != nil { return nil, err } @@ -163,9 +170,32 @@ func (s *currentRoomStateStatements) SelectJoinedUsers( defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsers: rows.close() failed") result := make(map[string][]string) + var roomID string + var userID string + for rows.Next() { + if err := rows.Scan(&roomID, &userID); err != nil { + return nil, err + } + users := result[roomID] + users = append(users, userID) + result[roomID] = users + } + return result, rows.Err() +} + +// SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room. +func (s *currentRoomStateStatements) SelectJoinedUsersInRoom( + ctx context.Context, roomIDs []string, +) (map[string][]string, error) { + rows, err := s.selectJoinedUsersInRoomStmt.QueryContext(ctx, pq.StringArray(roomIDs)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsers: rows.close() failed") + + result := make(map[string][]string) + var userID, roomID string for rows.Next() { - var roomID string - var userID string if err := rows.Scan(&roomID, &userID); err != nil { return nil, err } diff --git a/syncapi/storage/postgres/presence_table.go b/syncapi/storage/postgres/presence_table.go index 9f1e37f79..7194afea6 100644 --- a/syncapi/storage/postgres/presence_table.go +++ b/syncapi/storage/postgres/presence_table.go @@ -17,6 +17,7 @@ package postgres import ( "context" "database/sql" + "time" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -72,7 +73,8 @@ const selectMaxPresenceSQL = "" + const selectPresenceAfter = "" + " SELECT id, user_id, presence, status_msg, last_active_ts" + " FROM syncapi_presence" + - " WHERE id > $1" + " WHERE id > $1 AND last_active_ts >= $2" + + " ORDER BY id ASC LIMIT $3" type presenceStatements struct { upsertPresenceStmt *sql.Stmt @@ -144,11 +146,12 @@ func (p *presenceStatements) GetMaxPresenceID(ctx context.Context, txn *sql.Tx) func (p *presenceStatements) GetPresenceAfter( ctx context.Context, txn *sql.Tx, after types.StreamPosition, + filter gomatrixserverlib.EventFilter, ) (presences map[string]*types.PresenceInternal, err error) { presences = make(map[string]*types.PresenceInternal) stmt := sqlutil.TxStmt(txn, p.selectPresenceAfterStmt) - - rows, err := stmt.QueryContext(ctx, after) + afterTS := gomatrixserverlib.AsTimestamp(time.Now().Add(time.Minute * -5)) + rows, err := stmt.QueryContext(ctx, after, afterTS, filter.Limit) if err != nil { return nil, err } diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 25aca50ae..ec5edd355 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -168,6 +168,10 @@ func (d *Database) AllJoinedUsersInRooms(ctx context.Context) (map[string][]stri return d.CurrentRoomState.SelectJoinedUsers(ctx) } +func (d *Database) AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error) { + return d.CurrentRoomState.SelectJoinedUsersInRoom(ctx, roomIDs) +} + func (d *Database) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) { return d.Peeks.SelectPeekingDevices(ctx) } @@ -1056,8 +1060,8 @@ func (s *Database) GetPresence(ctx context.Context, userID string) (*types.Prese return s.Presence.GetPresenceForUser(ctx, nil, userID) } -func (s *Database) PresenceAfter(ctx context.Context, after types.StreamPosition) (map[string]*types.PresenceInternal, error) { - return s.Presence.GetPresenceAfter(ctx, nil, after) +func (s *Database) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) { + return s.Presence.GetPresenceAfter(ctx, nil, after, filter) } func (s *Database) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) { diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index ccda005c1..f0a1c7bb7 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -77,6 +77,9 @@ const selectCurrentStateSQL = "" + const selectJoinedUsersSQL = "" + "SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'" +const selectJoinedUsersInRoomSQL = "" + + "SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join' AND room_id IN ($1)" + const selectStateEventSQL = "" + "SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3" @@ -97,7 +100,8 @@ type currentRoomStateStatements struct { selectRoomIDsWithMembershipStmt *sql.Stmt selectRoomIDsWithAnyMembershipStmt *sql.Stmt selectJoinedUsersStmt *sql.Stmt - selectStateEventStmt *sql.Stmt + //selectJoinedUsersInRoomStmt *sql.Stmt - prepared at runtime due to variadic + selectStateEventStmt *sql.Stmt } func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) { @@ -127,13 +131,16 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (t if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil { return nil, err } + //if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil { + // return nil, err + //} if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil { return nil, err } return s, nil } -// JoinedMemberLists returns a map of room ID to a list of joined user IDs. +// SelectJoinedUsers returns a map of room ID to a list of joined user IDs. func (s *currentRoomStateStatements) SelectJoinedUsers( ctx context.Context, ) (map[string][]string, error) { @@ -144,9 +151,9 @@ func (s *currentRoomStateStatements) SelectJoinedUsers( defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsers: rows.close() failed") result := make(map[string][]string) + var roomID string + var userID string for rows.Next() { - var roomID string - var userID string if err := rows.Scan(&roomID, &userID); err != nil { return nil, err } @@ -157,6 +164,40 @@ func (s *currentRoomStateStatements) SelectJoinedUsers( return result, nil } +// SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room. +func (s *currentRoomStateStatements) SelectJoinedUsersInRoom( + ctx context.Context, roomIDs []string, +) (map[string][]string, error) { + query := strings.Replace(selectJoinedUsersInRoomSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1) + params := make([]interface{}, 0, len(roomIDs)) + for _, roomID := range roomIDs { + params = append(params, roomID) + } + stmt, err := s.db.Prepare(query) + if err != nil { + return nil, fmt.Errorf("SelectJoinedUsersInRoom s.db.Prepare: %w", err) + } + defer internal.CloseAndLogIfError(ctx, stmt, "SelectJoinedUsersInRoom: stmt.close() failed") + + rows, err := stmt.QueryContext(ctx, params...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectJoinedUsersInRoom: rows.close() failed") + + result := make(map[string][]string) + var userID, roomID string + for rows.Next() { + if err := rows.Scan(&roomID, &userID); err != nil { + return nil, err + } + users := result[roomID] + users = append(users, userID) + result[roomID] = users + } + return result, rows.Err() +} + // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. func (s *currentRoomStateStatements) SelectRoomIDsWithMembership( ctx context.Context, diff --git a/syncapi/storage/sqlite3/presence_table.go b/syncapi/storage/sqlite3/presence_table.go index 177a01bf3..b61a825df 100644 --- a/syncapi/storage/sqlite3/presence_table.go +++ b/syncapi/storage/sqlite3/presence_table.go @@ -17,6 +17,7 @@ package sqlite3 import ( "context" "database/sql" + "time" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -71,7 +72,8 @@ const selectMaxPresenceSQL = "" + const selectPresenceAfter = "" + " SELECT id, user_id, presence, status_msg, last_active_ts" + " FROM syncapi_presence" + - " WHERE id > $1" + " WHERE id > $1 AND last_active_ts >= $2" + + " ORDER BY id ASC LIMIT $3" type presenceStatements struct { db *sql.DB @@ -158,12 +160,12 @@ func (p *presenceStatements) GetMaxPresenceID(ctx context.Context, txn *sql.Tx) // GetPresenceAfter returns the changes presences after a given stream id func (p *presenceStatements) GetPresenceAfter( ctx context.Context, txn *sql.Tx, - after types.StreamPosition, + after types.StreamPosition, filter gomatrixserverlib.EventFilter, ) (presences map[string]*types.PresenceInternal, err error) { presences = make(map[string]*types.PresenceInternal) stmt := sqlutil.TxStmt(txn, p.selectPresenceAfterStmt) - - rows, err := stmt.QueryContext(ctx, after) + afterTS := gomatrixserverlib.AsTimestamp(time.Now().Add(time.Minute * -5)) + rows, err := stmt.QueryContext(ctx, after, afterTS, filter.Limit) if err != nil { return nil, err } diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 4ff4689ed..ccdebfdbd 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -102,6 +102,8 @@ type CurrentRoomState interface { SelectRoomIDsWithAnyMembership(ctx context.Context, txn *sql.Tx, userID string) (map[string]string, error) // SelectJoinedUsers returns a map of room ID to a list of joined user IDs. SelectJoinedUsers(ctx context.Context) (map[string][]string, error) + // SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room. + SelectJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error) } // BackwardsExtremities keeps track of backwards extremities for a room. @@ -188,5 +190,5 @@ type Presence interface { UpsertPresence(ctx context.Context, txn *sql.Tx, userID string, statusMsg *string, presence types.Presence, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (pos types.StreamPosition, err error) GetPresenceForUser(ctx context.Context, txn *sql.Tx, userID string) (presence *types.PresenceInternal, err error) GetMaxPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) - GetPresenceAfter(ctx context.Context, txn *sql.Tx, after types.StreamPosition) (presences map[string]*types.PresenceInternal, err error) + GetPresenceAfter(ctx context.Context, txn *sql.Tx, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (presences map[string]*types.PresenceInternal, err error) } diff --git a/syncapi/streams/stream_presence.go b/syncapi/streams/stream_presence.go index 614b88d48..a84d19878 100644 --- a/syncapi/streams/stream_presence.go +++ b/syncapi/streams/stream_presence.go @@ -53,7 +53,8 @@ func (p *PresenceStreamProvider) IncrementalSync( req *types.SyncRequest, from, to types.StreamPosition, ) types.StreamPosition { - presences, err := p.DB.PresenceAfter(ctx, from) + // We pull out a larger number than the filter asks for, since we're filtering out events later + presences, err := p.DB.PresenceAfter(ctx, from, gomatrixserverlib.EventFilter{Limit: 1000}) if err != nil { req.Log.WithError(err).Error("p.DB.PresenceAfter failed") return from @@ -66,12 +67,12 @@ func (p *PresenceStreamProvider) IncrementalSync( // add newly joined rooms user presences newlyJoined := joinedRooms(req.Response, req.Device.UserID) if len(newlyJoined) > 0 { - // TODO: This refreshes all lists and is quite expensive - // The notifier should update the lists itself - if err = p.notifier.Load(ctx, p.DB); err != nil { + // TODO: Check if this is working better than before. + if err = p.notifier.LoadRooms(ctx, p.DB, newlyJoined); err != nil { req.Log.WithError(err).Error("unable to refresh notifier lists") return from } + NewlyJoinedLoop: for _, roomID := range newlyJoined { roomUsers := p.notifier.JoinedUsers(roomID) for i := range roomUsers { @@ -86,11 +87,14 @@ func (p *PresenceStreamProvider) IncrementalSync( req.Log.WithError(err).Error("unable to query presence for user") return from } + if len(presences) > req.Filter.Presence.Limit { + break NewlyJoinedLoop + } } } } - lastPos := to + lastPos := from for _, presence := range presences { if presence == nil { continue @@ -135,6 +139,9 @@ func (p *PresenceStreamProvider) IncrementalSync( if presence.StreamPos > lastPos { lastPos = presence.StreamPos } + if len(req.Response.Presence.Events) == req.Filter.Presence.Limit { + break + } p.cache.Store(cacheKey, presence) } diff --git a/syncapi/sync/requestpool_test.go b/syncapi/sync/requestpool_test.go index a80089945..5e52bc7c9 100644 --- a/syncapi/sync/requestpool_test.go +++ b/syncapi/sync/requestpool_test.go @@ -30,7 +30,7 @@ func (d dummyDB) GetPresence(ctx context.Context, userID string) (*types.Presenc return &types.PresenceInternal{}, nil } -func (d dummyDB) PresenceAfter(ctx context.Context, after types.StreamPosition) (map[string]*types.PresenceInternal, error) { +func (d dummyDB) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) { return map[string]*types.PresenceInternal{}, nil } diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go index fe8c54e04..6c777982f 100644 --- a/userapi/storage/postgres/devices_table.go +++ b/userapi/storage/postgres/devices_table.go @@ -75,7 +75,7 @@ const selectDeviceByTokenSQL = "" + "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1" const selectDeviceByIDSQL = "" + - "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" + "SELECT display_name, last_seen_ts, ip FROM device_devices WHERE localpart = $1 and device_id = $2" const selectDevicesByLocalpartSQL = "" + "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" @@ -215,15 +215,22 @@ func (s *devicesStatements) SelectDeviceByID( ctx context.Context, localpart, deviceID string, ) (*api.Device, error) { var dev api.Device - var displayName sql.NullString + var displayName, ip sql.NullString + var lastseenTS sql.NullInt64 stmt := s.selectDeviceByIDStmt - err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName) + err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip) if err == nil { dev.ID = deviceID dev.UserID = userutil.MakeUserID(localpart, s.serverName) if displayName.Valid { dev.DisplayName = displayName.String } + if lastseenTS.Valid { + dev.LastSeenTS = lastseenTS.Int64 + } + if ip.Valid { + dev.LastSeenIP = ip.String + } } return &dev, err } diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go index 7860bd6a2..b86ed1cc2 100644 --- a/userapi/storage/sqlite3/devices_table.go +++ b/userapi/storage/sqlite3/devices_table.go @@ -60,7 +60,7 @@ const selectDeviceByTokenSQL = "" + "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1" const selectDeviceByIDSQL = "" + - "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" + "SELECT display_name, last_seen_ts, ip FROM device_devices WHERE localpart = $1 and device_id = $2" const selectDevicesByLocalpartSQL = "" + "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" @@ -212,15 +212,22 @@ func (s *devicesStatements) SelectDeviceByID( ctx context.Context, localpart, deviceID string, ) (*api.Device, error) { var dev api.Device - var displayName sql.NullString + var displayName, ip sql.NullString stmt := s.selectDeviceByIDStmt - err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName) + var lastseenTS sql.NullInt64 + err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip) if err == nil { dev.ID = deviceID dev.UserID = userutil.MakeUserID(localpart, s.serverName) if displayName.Valid { dev.DisplayName = displayName.String } + if lastseenTS.Valid { + dev.LastSeenTS = lastseenTS.Int64 + } + if ip.Valid { + dev.LastSeenIP = ip.String + } } return &dev, err } diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index e6c7d35fc..2eb57d0bc 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -180,12 +180,12 @@ func Test_Devices(t *testing.T) { deviceWithID.DisplayName = newName deviceWithID.LastSeenIP = "127.0.0.1" deviceWithID.LastSeenTS = int64(gomatrixserverlib.AsTimestamp(time.Now().Truncate(time.Second))) - devices, err = db.GetDevicesByLocalpart(ctx, localpart) + gotDevice, err = db.GetDeviceByID(ctx, localpart, deviceWithID.ID) assert.NoError(t, err, "unable to get device by id") assert.Equal(t, 2, len(devices)) - assert.Equal(t, deviceWithID.DisplayName, devices[0].DisplayName) - assert.Equal(t, deviceWithID.LastSeenIP, devices[0].LastSeenIP) - truncatedTime := gomatrixserverlib.Timestamp(devices[0].LastSeenTS).Time().Truncate(time.Second) + assert.Equal(t, deviceWithID.DisplayName, gotDevice.DisplayName) + assert.Equal(t, deviceWithID.LastSeenIP, gotDevice.LastSeenIP) + truncatedTime := gomatrixserverlib.Timestamp(gotDevice.LastSeenTS).Time().Truncate(time.Second) assert.Equal(t, gomatrixserverlib.Timestamp(deviceWithID.LastSeenTS), gomatrixserverlib.AsTimestamp(truncatedTime)) // create one more device and remove the devices step by step