Merge branch 'main' into s7evink/expireedus

This commit is contained in:
Neil Alexander 2022-04-29 08:56:24 +01:00 committed by GitHub
commit e0c5940d6c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
41 changed files with 495 additions and 60 deletions

View file

@ -140,7 +140,7 @@ client_api:
# Prevents new users from being able to register on this homeserver, except when # Prevents new users from being able to register on this homeserver, except when
# using the registration shared secret below. # using the registration shared secret below.
registration_disabled: false registration_disabled: true
# If set, allows registration by anyone who knows the shared secret, regardless of # If set, allows registration by anyone who knows the shared secret, regardless of
# whether registration is otherwise disabled. # whether registration is otherwise disabled.

View file

@ -259,6 +259,8 @@ func (m *DendriteMonolith) Start() {
cfg.MediaAPI.BasePath = config.Path(fmt.Sprintf("%s/media", m.CacheDirectory)) cfg.MediaAPI.BasePath = config.Path(fmt.Sprintf("%s/media", m.CacheDirectory))
cfg.MediaAPI.AbsBasePath = config.Path(fmt.Sprintf("%s/media", m.CacheDirectory)) cfg.MediaAPI.AbsBasePath = config.Path(fmt.Sprintf("%s/media", m.CacheDirectory))
cfg.MSCs.MSCs = []string{"msc2836", "msc2946"} cfg.MSCs.MSCs = []string{"msc2836", "msc2946"}
cfg.ClientAPI.RegistrationDisabled = false
cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true
if err := cfg.Derive(); err != nil { if err := cfg.Derive(); err != nil {
panic(err) panic(err)
} }
@ -314,6 +316,7 @@ func (m *DendriteMonolith) Start() {
base.PublicWellKnownAPIMux, base.PublicWellKnownAPIMux,
base.PublicMediaAPIMux, base.PublicMediaAPIMux,
base.SynapseAdminMux, base.SynapseAdminMux,
base.DendriteAdminMux,
) )
httpRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() httpRouter := mux.NewRouter().SkipClean(true).UseEncodedPath()

View file

@ -97,6 +97,8 @@ func (m *DendriteMonolith) Start() {
cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-appservice.db", m.StorageDirectory)) cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-appservice.db", m.StorageDirectory))
cfg.MediaAPI.BasePath = config.Path(fmt.Sprintf("%s/tmp", m.StorageDirectory)) cfg.MediaAPI.BasePath = config.Path(fmt.Sprintf("%s/tmp", m.StorageDirectory))
cfg.MediaAPI.AbsBasePath = config.Path(fmt.Sprintf("%s/tmp", m.StorageDirectory)) cfg.MediaAPI.AbsBasePath = config.Path(fmt.Sprintf("%s/tmp", m.StorageDirectory))
cfg.ClientAPI.RegistrationDisabled = false
cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true
if err = cfg.Derive(); err != nil { if err = cfg.Derive(); err != nil {
panic(err) panic(err)
} }
@ -152,6 +154,7 @@ func (m *DendriteMonolith) Start() {
base.PublicWellKnownAPIMux, base.PublicWellKnownAPIMux,
base.PublicMediaAPIMux, base.PublicMediaAPIMux,
base.SynapseAdminMux, base.SynapseAdminMux,
base.DendriteAdminMux,
) )
httpRouter := mux.NewRouter() httpRouter := mux.NewRouter()

View file

@ -29,4 +29,4 @@ EXPOSE 8008 8448
CMD ./generate-keys --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key && \ CMD ./generate-keys --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key && \
./generate-config -server $SERVER_NAME --ci > dendrite.yaml && \ ./generate-config -server $SERVER_NAME --ci > dendrite.yaml && \
cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \ cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \
./dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0} ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0}

View file

@ -32,7 +32,7 @@ RUN echo '\
./generate-keys --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key \n\ ./generate-keys --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key \n\
./generate-config -server $SERVER_NAME --ci > dendrite.yaml \n\ ./generate-config -server $SERVER_NAME --ci > dendrite.yaml \n\
cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates \n\ cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates \n\
./dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml \n\ ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml \n\
' > run.sh && chmod +x run.sh ' > run.sh && chmod +x run.sh

View file

@ -51,4 +51,4 @@ CMD /build/run_postgres.sh && ./generate-keys --server $SERVER_NAME --tls-cert s
sed -i "s%connection_string:.*$%connection_string: postgresql://postgres@localhost/postgres?sslmode=disable%g" dendrite.yaml && \ sed -i "s%connection_string:.*$%connection_string: postgresql://postgres@localhost/postgres?sslmode=disable%g" dendrite.yaml && \
sed -i 's/max_open_conns:.*$/max_open_conns: 100/g' dendrite.yaml && \ sed -i 's/max_open_conns:.*$/max_open_conns: 100/g' dendrite.yaml && \
cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \ cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \
./dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0} ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0}

View file

@ -25,7 +25,7 @@ echo "Installing golangci-lint..."
# Make a backup of go.{mod,sum} first # Make a backup of go.{mod,sum} first
cp go.mod go.mod.bak && cp go.sum go.sum.bak cp go.mod go.mod.bak && cp go.sum go.sum.bak
go get github.com/golangci/golangci-lint/cmd/golangci-lint@v1.41.1 go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.45.2
# Run linting # Run linting
echo "Looking for lint..." echo "Looking for lint..."
@ -33,7 +33,7 @@ echo "Looking for lint..."
# Capture exit code to ensure go.{mod,sum} is restored before exiting # Capture exit code to ensure go.{mod,sum} is restored before exiting
exit_code=0 exit_code=0
PATH="$PATH:${GOPATH:-~/go}/bin" golangci-lint run $args || exit_code=1 PATH="$PATH:$(go env GOPATH)/bin" golangci-lint run $args || exit_code=1
# Restore go.{mod,sum} # Restore go.{mod,sum}
mv go.mod.bak go.mod && mv go.sum.bak go.sum mv go.mod.bak go.mod && mv go.sum.bak go.sum

View file

@ -36,6 +36,7 @@ func AddPublicRoutes(
process *process.ProcessContext, process *process.ProcessContext,
router *mux.Router, router *mux.Router,
synapseAdminRouter *mux.Router, synapseAdminRouter *mux.Router,
dendriteAdminRouter *mux.Router,
cfg *config.ClientAPI, cfg *config.ClientAPI,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
@ -62,7 +63,8 @@ func AddPublicRoutes(
} }
routing.Setup( routing.Setup(
router, synapseAdminRouter, cfg, rsAPI, asAPI, router, synapseAdminRouter, dendriteAdminRouter,
cfg, rsAPI, asAPI,
userAPI, userDirectoryProvider, federation, userAPI, userDirectoryProvider, federation,
syncProducer, transactionsCache, fsAPI, keyAPI, syncProducer, transactionsCache, fsAPI, keyAPI,
extRoomsProvider, mscCfg, natsClient, extRoomsProvider, mscCfg, natsClient,

View file

@ -48,7 +48,8 @@ import (
// applied: // applied:
// nolint: gocyclo // nolint: gocyclo
func Setup( func Setup(
publicAPIMux, synapseAdminRouter *mux.Router, cfg *config.ClientAPI, publicAPIMux, synapseAdminRouter, dendriteAdminRouter *mux.Router,
cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
userAPI userapi.UserInternalAPI, userAPI userapi.UserInternalAPI,
@ -119,6 +120,45 @@ func Setup(
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
} }
dendriteAdminRouter.Handle("/admin/evacuateRoom/{roomID}",
httputil.MakeAuthAPI("admin_evacuate_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if device.AccountType != userapi.AccountTypeAdmin {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("This API can only be used by admin users."),
}
}
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
roomID, ok := vars["roomID"]
if !ok {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.MissingArgument("Expecting room ID."),
}
}
res := &roomserverAPI.PerformAdminEvacuateRoomResponse{}
rsAPI.PerformAdminEvacuateRoom(
req.Context(),
&roomserverAPI.PerformAdminEvacuateRoomRequest{
RoomID: roomID,
},
res,
)
if err := res.Error; err != nil {
return err.JSONResponse()
}
return util.JSONResponse{
Code: 200,
JSON: map[string]interface{}{
"affected": res.Affected,
},
}
}),
).Methods(http.MethodGet, http.MethodOptions)
// server notifications // server notifications
if cfg.Matrix.ServerNotices.Enabled { if cfg.Matrix.ServerNotices.Enabled {
logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice") logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice")

View file

@ -140,6 +140,8 @@ func main() {
cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", *instanceName)) cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", *instanceName))
cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-appservice.db", *instanceName)) cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-appservice.db", *instanceName))
cfg.MSCs.MSCs = []string{"msc2836", "msc2946"} cfg.MSCs.MSCs = []string{"msc2836", "msc2946"}
cfg.ClientAPI.RegistrationDisabled = false
cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true
if err := cfg.Derive(); err != nil { if err := cfg.Derive(); err != nil {
panic(err) panic(err)
} }
@ -193,6 +195,7 @@ func main() {
base.PublicWellKnownAPIMux, base.PublicWellKnownAPIMux,
base.PublicMediaAPIMux, base.PublicMediaAPIMux,
base.SynapseAdminMux, base.SynapseAdminMux,
base.DendriteAdminMux,
) )
wsUpgrader := websocket.Upgrader{ wsUpgrader := websocket.Upgrader{

View file

@ -89,6 +89,8 @@ func main() {
cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-appservice.db", *instanceName)) cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-appservice.db", *instanceName))
cfg.MSCs.MSCs = []string{"msc2836"} cfg.MSCs.MSCs = []string{"msc2836"}
cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", *instanceName)) cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", *instanceName))
cfg.ClientAPI.RegistrationDisabled = false
cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true
if err = cfg.Derive(); err != nil { if err = cfg.Derive(); err != nil {
panic(err) panic(err)
} }
@ -150,6 +152,7 @@ func main() {
base.PublicWellKnownAPIMux, base.PublicWellKnownAPIMux,
base.PublicMediaAPIMux, base.PublicMediaAPIMux,
base.SynapseAdminMux, base.SynapseAdminMux,
base.DendriteAdminMux,
) )
if err := mscs.Enable(base, &monolith); err != nil { if err := mscs.Enable(base, &monolith); err != nil {
logrus.WithError(err).Fatalf("Failed to enable MSCs") logrus.WithError(err).Fatalf("Failed to enable MSCs")

View file

@ -153,6 +153,7 @@ func main() {
base.PublicWellKnownAPIMux, base.PublicWellKnownAPIMux,
base.PublicMediaAPIMux, base.PublicMediaAPIMux,
base.SynapseAdminMux, base.SynapseAdminMux,
base.DendriteAdminMux,
) )
if len(base.Cfg.MSCs.MSCs) > 0 { if len(base.Cfg.MSCs.MSCs) > 0 {

View file

@ -31,8 +31,10 @@ func ClientAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) {
keyAPI := base.KeyServerHTTPClient() keyAPI := base.KeyServerHTTPClient()
clientapi.AddPublicRoutes( clientapi.AddPublicRoutes(
base.ProcessContext, base.PublicClientAPIMux, base.SynapseAdminMux, &base.Cfg.ClientAPI, base.ProcessContext, base.PublicClientAPIMux,
federation, rsAPI, asQuery, transactions.New(), fsAPI, userAPI, userAPI, base.SynapseAdminMux, base.DendriteAdminMux,
&base.Cfg.ClientAPI, federation, rsAPI, asQuery,
transactions.New(), fsAPI, userAPI, userAPI,
keyAPI, nil, &cfg.MSCs, keyAPI, nil, &cfg.MSCs,
) )

View file

@ -83,7 +83,8 @@ do \n\
done \n\ done \n\
\n\ \n\
sed -i "s/server_name: localhost/server_name: ${SERVER_NAME}/g" dendrite.yaml \n\ sed -i "s/server_name: localhost/server_name: ${SERVER_NAME}/g" dendrite.yaml \n\
./dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml \n\ PARAMS="--tls-cert server.crt --tls-key server.key --config dendrite.yaml" \n\
./dendrite-monolith-server --really-enable-open-registration ${PARAMS} || ./dendrite-monolith-server ${PARAMS} \n\
' > run_dendrite.sh && chmod +x run_dendrite.sh ' > run_dendrite.sh && chmod +x run_dendrite.sh
ENV SERVER_NAME=localhost ENV SERVER_NAME=localhost

View file

@ -171,6 +171,8 @@ func startup() {
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
cfg.Global.PrivateKey = sk cfg.Global.PrivateKey = sk
cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk)) cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk))
cfg.ClientAPI.RegistrationDisabled = false
cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true
if err := cfg.Derive(); err != nil { if err := cfg.Derive(); err != nil {
logrus.Fatalf("Failed to derive values from config: %s", err) logrus.Fatalf("Failed to derive values from config: %s", err)
@ -220,6 +222,7 @@ func startup() {
base.PublicWellKnownAPIMux, base.PublicWellKnownAPIMux,
base.PublicMediaAPIMux, base.PublicMediaAPIMux,
base.SynapseAdminMux, base.SynapseAdminMux,
base.DendriteAdminMux,
) )
httpRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() httpRouter := mux.NewRouter().SkipClean(true).UseEncodedPath()

View file

@ -90,6 +90,8 @@ func main() {
cfg.Logging[0].Type = "std" cfg.Logging[0].Type = "std"
cfg.UserAPI.BCryptCost = bcrypt.MinCost cfg.UserAPI.BCryptCost = bcrypt.MinCost
cfg.Global.JetStream.InMemory = true cfg.Global.JetStream.InMemory = true
cfg.ClientAPI.RegistrationDisabled = false
cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true
cfg.ClientAPI.RegistrationSharedSecret = "complement" cfg.ClientAPI.RegistrationSharedSecret = "complement"
cfg.Global.Presence = config.PresenceOptions{ cfg.Global.Presence = config.PresenceOptions{
EnableInbound: true, EnableInbound: true,

View file

@ -159,7 +159,7 @@ client_api:
# Prevents new users from being able to register on this homeserver, except when # Prevents new users from being able to register on this homeserver, except when
# using the registration shared secret below. # using the registration shared secret below.
registration_disabled: false registration_disabled: true
# Prevents new guest accounts from being created. Guest registration is also # Prevents new guest accounts from being created. Guest registration is also
# disabled implicitly by setting 'registration_disabled' above. # disabled implicitly by setting 'registration_disabled' above.

View file

@ -66,6 +66,12 @@ type RoomserverInternalAPI interface {
res *PerformInboundPeekResponse, res *PerformInboundPeekResponse,
) error ) error
PerformAdminEvacuateRoom(
ctx context.Context,
req *PerformAdminEvacuateRoomRequest,
res *PerformAdminEvacuateRoomResponse,
)
QueryPublishedRooms( QueryPublishedRooms(
ctx context.Context, ctx context.Context,
req *QueryPublishedRoomsRequest, req *QueryPublishedRoomsRequest,

View file

@ -104,6 +104,15 @@ func (t *RoomserverInternalAPITrace) PerformPublish(
util.GetLogger(ctx).Infof("PerformPublish req=%+v res=%+v", js(req), js(res)) util.GetLogger(ctx).Infof("PerformPublish req=%+v res=%+v", js(req), js(res))
} }
func (t *RoomserverInternalAPITrace) PerformAdminEvacuateRoom(
ctx context.Context,
req *PerformAdminEvacuateRoomRequest,
res *PerformAdminEvacuateRoomResponse,
) {
t.Impl.PerformAdminEvacuateRoom(ctx, req, res)
util.GetLogger(ctx).Infof("PerformAdminEvacuateRoom req=%+v res=%+v", js(req), js(res))
}
func (t *RoomserverInternalAPITrace) PerformInboundPeek( func (t *RoomserverInternalAPITrace) PerformInboundPeek(
ctx context.Context, ctx context.Context,
req *PerformInboundPeekRequest, req *PerformInboundPeekRequest,

View file

@ -214,3 +214,12 @@ type PerformRoomUpgradeResponse struct {
NewRoomID string NewRoomID string
Error *PerformError Error *PerformError
} }
type PerformAdminEvacuateRoomRequest struct {
RoomID string `json:"room_id"`
}
type PerformAdminEvacuateRoomResponse struct {
Affected []string `json:"affected"`
Error *PerformError
}

View file

@ -35,6 +35,7 @@ type RoomserverInternalAPI struct {
*perform.Backfiller *perform.Backfiller
*perform.Forgetter *perform.Forgetter
*perform.Upgrader *perform.Upgrader
*perform.Admin
ProcessContext *process.ProcessContext ProcessContext *process.ProcessContext
DB storage.Database DB storage.Database
Cfg *config.RoomServer Cfg *config.RoomServer
@ -164,6 +165,12 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.FederationInternalA
Cfg: r.Cfg, Cfg: r.Cfg,
URSAPI: r, URSAPI: r,
} }
r.Admin = &perform.Admin{
DB: r.DB,
Cfg: r.Cfg,
Inputer: r.Inputer,
Queryer: r.Queryer,
}
if err := r.Inputer.Start(); err != nil { if err := r.Inputer.Start(); err != nil {
logrus.WithError(err).Panic("failed to start roomserver input API") logrus.WithError(err).Panic("failed to start roomserver input API")

View file

@ -0,0 +1,162 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package perform
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/internal/input"
"github.com/matrix-org/dendrite/roomserver/internal/query"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
type Admin struct {
DB storage.Database
Cfg *config.RoomServer
Queryer *query.Queryer
Inputer *input.Inputer
}
// PerformEvacuateRoom will remove all local users from the given room.
func (r *Admin) PerformAdminEvacuateRoom(
ctx context.Context,
req *api.PerformAdminEvacuateRoomRequest,
res *api.PerformAdminEvacuateRoomResponse,
) {
roomInfo, err := r.DB.RoomInfo(ctx, req.RoomID)
if err != nil {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.RoomInfo: %s", err),
}
return
}
if roomInfo == nil || roomInfo.IsStub {
res.Error = &api.PerformError{
Code: api.PerformErrorNoRoom,
Msg: fmt.Sprintf("Room %s not found", req.RoomID),
}
return
}
memberNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true)
if err != nil {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.GetMembershipEventNIDsForRoom: %s", err),
}
return
}
memberEvents, err := r.DB.Events(ctx, memberNIDs)
if err != nil {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.Events: %s", err),
}
return
}
inputEvents := make([]api.InputRoomEvent, 0, len(memberEvents))
res.Affected = make([]string, 0, len(memberEvents))
latestReq := &api.QueryLatestEventsAndStateRequest{
RoomID: req.RoomID,
}
latestRes := &api.QueryLatestEventsAndStateResponse{}
if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.Queryer.QueryLatestEventsAndState: %s", err),
}
return
}
prevEvents := latestRes.LatestEvents
for _, memberEvent := range memberEvents {
if memberEvent.StateKey() == nil {
continue
}
var memberContent gomatrixserverlib.MemberContent
if err = json.Unmarshal(memberEvent.Content(), &memberContent); err != nil {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("json.Unmarshal: %s", err),
}
return
}
memberContent.Membership = gomatrixserverlib.Leave
stateKey := *memberEvent.StateKey()
fledglingEvent := &gomatrixserverlib.EventBuilder{
RoomID: req.RoomID,
Type: gomatrixserverlib.MRoomMember,
StateKey: &stateKey,
Sender: stateKey,
PrevEvents: prevEvents,
}
if fledglingEvent.Content, err = json.Marshal(memberContent); err != nil {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("json.Marshal: %s", err),
}
return
}
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(fledglingEvent)
if err != nil {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("gomatrixserverlib.StateNeededForEventBuilder: %s", err),
}
return
}
event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, time.Now(), &eventsNeeded, latestRes)
if err != nil {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("eventutil.BuildEvent: %s", err),
}
return
}
inputEvents = append(inputEvents, api.InputRoomEvent{
Kind: api.KindNew,
Event: event,
Origin: r.Cfg.Matrix.ServerName,
SendAsServer: string(r.Cfg.Matrix.ServerName),
})
res.Affected = append(res.Affected, stateKey)
prevEvents = []gomatrixserverlib.EventReference{
event.EventReference(),
}
}
inputReq := &api.InputRoomEventsRequest{
InputRoomEvents: inputEvents,
Asynchronous: true,
}
inputRes := &api.InputRoomEventsResponse{}
r.Inputer.InputRoomEvents(ctx, inputReq, inputRes)
}

View file

@ -29,16 +29,17 @@ const (
RoomserverInputRoomEventsPath = "/roomserver/inputRoomEvents" RoomserverInputRoomEventsPath = "/roomserver/inputRoomEvents"
// Perform operations // Perform operations
RoomserverPerformInvitePath = "/roomserver/performInvite" RoomserverPerformInvitePath = "/roomserver/performInvite"
RoomserverPerformPeekPath = "/roomserver/performPeek" RoomserverPerformPeekPath = "/roomserver/performPeek"
RoomserverPerformUnpeekPath = "/roomserver/performUnpeek" RoomserverPerformUnpeekPath = "/roomserver/performUnpeek"
RoomserverPerformRoomUpgradePath = "/roomserver/performRoomUpgrade" RoomserverPerformRoomUpgradePath = "/roomserver/performRoomUpgrade"
RoomserverPerformJoinPath = "/roomserver/performJoin" RoomserverPerformJoinPath = "/roomserver/performJoin"
RoomserverPerformLeavePath = "/roomserver/performLeave" RoomserverPerformLeavePath = "/roomserver/performLeave"
RoomserverPerformBackfillPath = "/roomserver/performBackfill" RoomserverPerformBackfillPath = "/roomserver/performBackfill"
RoomserverPerformPublishPath = "/roomserver/performPublish" RoomserverPerformPublishPath = "/roomserver/performPublish"
RoomserverPerformInboundPeekPath = "/roomserver/performInboundPeek" RoomserverPerformInboundPeekPath = "/roomserver/performInboundPeek"
RoomserverPerformForgetPath = "/roomserver/performForget" RoomserverPerformForgetPath = "/roomserver/performForget"
RoomserverPerformAdminEvacuateRoomPath = "/roomserver/performAdminEvacuateRoom"
// Query operations // Query operations
RoomserverQueryLatestEventsAndStatePath = "/roomserver/queryLatestEventsAndState" RoomserverQueryLatestEventsAndStatePath = "/roomserver/queryLatestEventsAndState"
@ -299,6 +300,23 @@ func (h *httpRoomserverInternalAPI) PerformPublish(
} }
} }
func (h *httpRoomserverInternalAPI) PerformAdminEvacuateRoom(
ctx context.Context,
req *api.PerformAdminEvacuateRoomRequest,
res *api.PerformAdminEvacuateRoomResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformAdminEvacuateRoom")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverPerformAdminEvacuateRoomPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
if err != nil {
res.Error = &api.PerformError{
Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err),
}
}
}
// QueryLatestEventsAndState implements RoomserverQueryAPI // QueryLatestEventsAndState implements RoomserverQueryAPI
func (h *httpRoomserverInternalAPI) QueryLatestEventsAndState( func (h *httpRoomserverInternalAPI) QueryLatestEventsAndState(
ctx context.Context, ctx context.Context,

View file

@ -118,6 +118,17 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )
internalAPIMux.Handle(RoomserverPerformAdminEvacuateRoomPath,
httputil.MakeInternalAPI("performAdminEvacuateRoom", func(req *http.Request) util.JSONResponse {
var request api.PerformAdminEvacuateRoomRequest
var response api.PerformAdminEvacuateRoomResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
r.PerformAdminEvacuateRoom(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle( internalAPIMux.Handle(
RoomserverQueryPublishedRoomsPath, RoomserverQueryPublishedRoomsPath,
httputil.MakeInternalAPI("queryPublishedRooms", func(req *http.Request) util.JSONResponse { httputil.MakeInternalAPI("queryPublishedRooms", func(req *http.Request) util.JSONResponse {

View file

@ -126,6 +126,10 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string, options ...Base
logrus.Infof("Dendrite version %s", internal.VersionString()) logrus.Infof("Dendrite version %s", internal.VersionString())
if !cfg.ClientAPI.RegistrationDisabled && cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled {
logrus.Warn("Open registration is enabled")
}
closer, err := cfg.SetupTracing("Dendrite" + componentName) closer, err := cfg.SetupTracing("Dendrite" + componentName)
if err != nil { if err != nil {
logrus.WithError(err).Panicf("failed to start opentracing") logrus.WithError(err).Panicf("failed to start opentracing")

View file

@ -15,6 +15,12 @@ type ClientAPI struct {
// If set disables new users from registering (except via shared // If set disables new users from registering (except via shared
// secrets) // secrets)
RegistrationDisabled bool `yaml:"registration_disabled"` RegistrationDisabled bool `yaml:"registration_disabled"`
// Enable registration without captcha verification or shared secret.
// This option is populated by the -really-enable-open-registration
// command line parameter as it is not recommended.
OpenRegistrationWithoutVerificationEnabled bool `yaml:"-"`
// If set, allows registration by anyone who also has the shared // If set, allows registration by anyone who also has the shared
// secret, even if registration is otherwise disabled. // secret, even if registration is otherwise disabled.
RegistrationSharedSecret string `yaml:"registration_shared_secret"` RegistrationSharedSecret string `yaml:"registration_shared_secret"`
@ -55,7 +61,8 @@ func (c *ClientAPI) Defaults(generate bool) {
c.RecaptchaEnabled = false c.RecaptchaEnabled = false
c.RecaptchaBypassSecret = "" c.RecaptchaBypassSecret = ""
c.RecaptchaSiteVerifyAPI = "" c.RecaptchaSiteVerifyAPI = ""
c.RegistrationDisabled = false c.RegistrationDisabled = true
c.OpenRegistrationWithoutVerificationEnabled = false
c.RateLimiting.Defaults() c.RateLimiting.Defaults()
} }
@ -72,6 +79,20 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
} }
c.TURN.Verify(configErrs) c.TURN.Verify(configErrs)
c.RateLimiting.Verify(configErrs) c.RateLimiting.Verify(configErrs)
// Ensure there is any spam counter measure when enabling registration
if !c.RegistrationDisabled && !c.OpenRegistrationWithoutVerificationEnabled {
if !c.RecaptchaEnabled {
configErrs.Add(
"You have tried to enable open registration without any secondary verification methods " +
"(such as reCAPTCHA). By enabling open registration, you are SIGNIFICANTLY " +
"increasing the risk that your server will be used to send spam or abuse, and may result in " +
"your server being banned from some rooms. If you are ABSOLUTELY CERTAIN you want to do this, " +
"start Dendrite with the -really-enable-open-registration command line flag. Otherwise, you " +
"should set the registration_disabled option in your Dendrite config.",
)
}
}
} }
type TURN struct { type TURN struct {

View file

@ -25,8 +25,9 @@ import (
) )
var ( var (
configPath = flag.String("config", "dendrite.yaml", "The path to the config file. For more information, see the config file in this repository.") configPath = flag.String("config", "dendrite.yaml", "The path to the config file. For more information, see the config file in this repository.")
version = flag.Bool("version", false, "Shows the current version and exits immediately.") version = flag.Bool("version", false, "Shows the current version and exits immediately.")
enableRegistrationWithoutVerification = flag.Bool("really-enable-open-registration", false, "This allows open registration without secondary verification (reCAPTCHA). This is NOT RECOMMENDED and will SIGNIFICANTLY increase the risk that your server will be used to send spam or conduct attacks, which may result in your server being banned from rooms.")
) )
// ParseFlags parses the commandline flags and uses them to create a config. // ParseFlags parses the commandline flags and uses them to create a config.
@ -48,5 +49,9 @@ func ParseFlags(monolith bool) *config.Dendrite {
logrus.Fatalf("Invalid config file: %s", err) logrus.Fatalf("Invalid config file: %s", err)
} }
if *enableRegistrationWithoutVerification {
cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true
}
return cfg return cfg
} }

View file

@ -54,13 +54,13 @@ type Monolith struct {
} }
// AddAllPublicRoutes attaches all public paths to the given router // AddAllPublicRoutes attaches all public paths to the given router
func (m *Monolith) AddAllPublicRoutes(process *process.ProcessContext, csMux, ssMux, keyMux, wkMux, mediaMux, synapseMux *mux.Router) { func (m *Monolith) AddAllPublicRoutes(process *process.ProcessContext, csMux, ssMux, keyMux, wkMux, mediaMux, synapseMux, dendriteMux *mux.Router) {
userDirectoryProvider := m.ExtUserDirectoryProvider userDirectoryProvider := m.ExtUserDirectoryProvider
if userDirectoryProvider == nil { if userDirectoryProvider == nil {
userDirectoryProvider = m.UserAPI userDirectoryProvider = m.UserAPI
} }
clientapi.AddPublicRoutes( clientapi.AddPublicRoutes(
process, csMux, synapseMux, &m.Config.ClientAPI, process, csMux, synapseMux, dendriteMux, &m.Config.ClientAPI,
m.FedClient, m.RoomserverAPI, m.FedClient, m.RoomserverAPI,
m.AppserviceAPI, transactions.New(), m.AppserviceAPI, transactions.New(),
m.FederationAPI, m.UserAPI, userDirectoryProvider, m.KeyAPI, m.FederationAPI, m.UserAPI, userDirectoryProvider, m.KeyAPI,

View file

@ -333,6 +333,20 @@ func (n *Notifier) Load(ctx context.Context, db storage.Database) error {
return nil return nil
} }
// LoadRooms loads the membership states required to notify users correctly.
func (n *Notifier) LoadRooms(ctx context.Context, db storage.Database, roomIDs []string) error {
n.lock.Lock()
defer n.lock.Unlock()
roomToUsers, err := db.AllJoinedUsersInRoom(ctx, roomIDs)
if err != nil {
return err
}
n.setUsersJoinedToRooms(roomToUsers)
return nil
}
// CurrentPosition returns the current sync position // CurrentPosition returns the current sync position
func (n *Notifier) CurrentPosition() types.StreamingToken { func (n *Notifier) CurrentPosition() types.StreamingToken {
n.lock.RLock() n.lock.RLock()

View file

@ -52,6 +52,9 @@ type Database interface {
// AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs. // AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs.
AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error)
// AllJoinedUsersInRoom returns a map of room ID to a list of all joined user IDs for a given room.
AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error)
// AllPeekingDevicesInRooms returns a map of room ID to a list of all peeking devices. // AllPeekingDevicesInRooms returns a map of room ID to a list of all peeking devices.
AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error)
// Events lookups a list of event by their event ID. // Events lookups a list of event by their event ID.
@ -159,6 +162,6 @@ type Database interface {
type Presence interface { type Presence interface {
UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error)
GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error)
PresenceAfter(ctx context.Context, after types.StreamPosition) (map[string]*types.PresenceInternal, error) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error)
MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error)
} }

View file

@ -93,6 +93,9 @@ const selectCurrentStateSQL = "" +
const selectJoinedUsersSQL = "" + const selectJoinedUsersSQL = "" +
"SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'" "SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'"
const selectJoinedUsersInRoomSQL = "" +
"SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join' AND room_id = ANY($1)"
const selectStateEventSQL = "" + const selectStateEventSQL = "" +
"SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3" "SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3"
@ -112,6 +115,7 @@ type currentRoomStateStatements struct {
selectRoomIDsWithAnyMembershipStmt *sql.Stmt selectRoomIDsWithAnyMembershipStmt *sql.Stmt
selectCurrentStateStmt *sql.Stmt selectCurrentStateStmt *sql.Stmt
selectJoinedUsersStmt *sql.Stmt selectJoinedUsersStmt *sql.Stmt
selectJoinedUsersInRoomStmt *sql.Stmt
selectEventsWithEventIDsStmt *sql.Stmt selectEventsWithEventIDsStmt *sql.Stmt
selectStateEventStmt *sql.Stmt selectStateEventStmt *sql.Stmt
} }
@ -143,6 +147,9 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro
if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil { if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil {
return nil, err return nil, err
} }
if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil {
return nil, err
}
if s.selectEventsWithEventIDsStmt, err = db.Prepare(selectEventsWithEventIDsSQL); err != nil { if s.selectEventsWithEventIDsStmt, err = db.Prepare(selectEventsWithEventIDsSQL); err != nil {
return nil, err return nil, err
} }
@ -163,9 +170,32 @@ func (s *currentRoomStateStatements) SelectJoinedUsers(
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsers: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsers: rows.close() failed")
result := make(map[string][]string) result := make(map[string][]string)
var roomID string
var userID string
for rows.Next() {
if err := rows.Scan(&roomID, &userID); err != nil {
return nil, err
}
users := result[roomID]
users = append(users, userID)
result[roomID] = users
}
return result, rows.Err()
}
// SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room.
func (s *currentRoomStateStatements) SelectJoinedUsersInRoom(
ctx context.Context, roomIDs []string,
) (map[string][]string, error) {
rows, err := s.selectJoinedUsersInRoomStmt.QueryContext(ctx, pq.StringArray(roomIDs))
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsers: rows.close() failed")
result := make(map[string][]string)
var userID, roomID string
for rows.Next() { for rows.Next() {
var roomID string
var userID string
if err := rows.Scan(&roomID, &userID); err != nil { if err := rows.Scan(&roomID, &userID); err != nil {
return nil, err return nil, err
} }

View file

@ -17,6 +17,7 @@ package postgres
import ( import (
"context" "context"
"database/sql" "database/sql"
"time"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
@ -72,7 +73,8 @@ const selectMaxPresenceSQL = "" +
const selectPresenceAfter = "" + const selectPresenceAfter = "" +
" SELECT id, user_id, presence, status_msg, last_active_ts" + " SELECT id, user_id, presence, status_msg, last_active_ts" +
" FROM syncapi_presence" + " FROM syncapi_presence" +
" WHERE id > $1" " WHERE id > $1 AND last_active_ts >= $2" +
" ORDER BY id ASC LIMIT $3"
type presenceStatements struct { type presenceStatements struct {
upsertPresenceStmt *sql.Stmt upsertPresenceStmt *sql.Stmt
@ -144,11 +146,12 @@ func (p *presenceStatements) GetMaxPresenceID(ctx context.Context, txn *sql.Tx)
func (p *presenceStatements) GetPresenceAfter( func (p *presenceStatements) GetPresenceAfter(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
after types.StreamPosition, after types.StreamPosition,
filter gomatrixserverlib.EventFilter,
) (presences map[string]*types.PresenceInternal, err error) { ) (presences map[string]*types.PresenceInternal, err error) {
presences = make(map[string]*types.PresenceInternal) presences = make(map[string]*types.PresenceInternal)
stmt := sqlutil.TxStmt(txn, p.selectPresenceAfterStmt) stmt := sqlutil.TxStmt(txn, p.selectPresenceAfterStmt)
afterTS := gomatrixserverlib.AsTimestamp(time.Now().Add(time.Minute * -5))
rows, err := stmt.QueryContext(ctx, after) rows, err := stmt.QueryContext(ctx, after, afterTS, filter.Limit)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -168,6 +168,10 @@ func (d *Database) AllJoinedUsersInRooms(ctx context.Context) (map[string][]stri
return d.CurrentRoomState.SelectJoinedUsers(ctx) return d.CurrentRoomState.SelectJoinedUsers(ctx)
} }
func (d *Database) AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error) {
return d.CurrentRoomState.SelectJoinedUsersInRoom(ctx, roomIDs)
}
func (d *Database) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) { func (d *Database) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) {
return d.Peeks.SelectPeekingDevices(ctx) return d.Peeks.SelectPeekingDevices(ctx)
} }
@ -1056,8 +1060,8 @@ func (s *Database) GetPresence(ctx context.Context, userID string) (*types.Prese
return s.Presence.GetPresenceForUser(ctx, nil, userID) return s.Presence.GetPresenceForUser(ctx, nil, userID)
} }
func (s *Database) PresenceAfter(ctx context.Context, after types.StreamPosition) (map[string]*types.PresenceInternal, error) { func (s *Database) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) {
return s.Presence.GetPresenceAfter(ctx, nil, after) return s.Presence.GetPresenceAfter(ctx, nil, after, filter)
} }
func (s *Database) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) { func (s *Database) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) {

View file

@ -77,6 +77,9 @@ const selectCurrentStateSQL = "" +
const selectJoinedUsersSQL = "" + const selectJoinedUsersSQL = "" +
"SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'" "SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'"
const selectJoinedUsersInRoomSQL = "" +
"SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join' AND room_id IN ($1)"
const selectStateEventSQL = "" + const selectStateEventSQL = "" +
"SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3" "SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3"
@ -97,7 +100,8 @@ type currentRoomStateStatements struct {
selectRoomIDsWithMembershipStmt *sql.Stmt selectRoomIDsWithMembershipStmt *sql.Stmt
selectRoomIDsWithAnyMembershipStmt *sql.Stmt selectRoomIDsWithAnyMembershipStmt *sql.Stmt
selectJoinedUsersStmt *sql.Stmt selectJoinedUsersStmt *sql.Stmt
selectStateEventStmt *sql.Stmt //selectJoinedUsersInRoomStmt *sql.Stmt - prepared at runtime due to variadic
selectStateEventStmt *sql.Stmt
} }
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) { func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) {
@ -127,13 +131,16 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (t
if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil { if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil {
return nil, err return nil, err
} }
//if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil {
// return nil, err
//}
if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil { if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil {
return nil, err return nil, err
} }
return s, nil return s, nil
} }
// JoinedMemberLists returns a map of room ID to a list of joined user IDs. // SelectJoinedUsers returns a map of room ID to a list of joined user IDs.
func (s *currentRoomStateStatements) SelectJoinedUsers( func (s *currentRoomStateStatements) SelectJoinedUsers(
ctx context.Context, ctx context.Context,
) (map[string][]string, error) { ) (map[string][]string, error) {
@ -144,9 +151,9 @@ func (s *currentRoomStateStatements) SelectJoinedUsers(
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsers: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsers: rows.close() failed")
result := make(map[string][]string) result := make(map[string][]string)
var roomID string
var userID string
for rows.Next() { for rows.Next() {
var roomID string
var userID string
if err := rows.Scan(&roomID, &userID); err != nil { if err := rows.Scan(&roomID, &userID); err != nil {
return nil, err return nil, err
} }
@ -157,6 +164,40 @@ func (s *currentRoomStateStatements) SelectJoinedUsers(
return result, nil return result, nil
} }
// SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room.
func (s *currentRoomStateStatements) SelectJoinedUsersInRoom(
ctx context.Context, roomIDs []string,
) (map[string][]string, error) {
query := strings.Replace(selectJoinedUsersInRoomSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1)
params := make([]interface{}, 0, len(roomIDs))
for _, roomID := range roomIDs {
params = append(params, roomID)
}
stmt, err := s.db.Prepare(query)
if err != nil {
return nil, fmt.Errorf("SelectJoinedUsersInRoom s.db.Prepare: %w", err)
}
defer internal.CloseAndLogIfError(ctx, stmt, "SelectJoinedUsersInRoom: stmt.close() failed")
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectJoinedUsersInRoom: rows.close() failed")
result := make(map[string][]string)
var userID, roomID string
for rows.Next() {
if err := rows.Scan(&roomID, &userID); err != nil {
return nil, err
}
users := result[roomID]
users = append(users, userID)
result[roomID] = users
}
return result, rows.Err()
}
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
func (s *currentRoomStateStatements) SelectRoomIDsWithMembership( func (s *currentRoomStateStatements) SelectRoomIDsWithMembership(
ctx context.Context, ctx context.Context,

View file

@ -17,6 +17,7 @@ package sqlite3
import ( import (
"context" "context"
"database/sql" "database/sql"
"time"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
@ -71,7 +72,8 @@ const selectMaxPresenceSQL = "" +
const selectPresenceAfter = "" + const selectPresenceAfter = "" +
" SELECT id, user_id, presence, status_msg, last_active_ts" + " SELECT id, user_id, presence, status_msg, last_active_ts" +
" FROM syncapi_presence" + " FROM syncapi_presence" +
" WHERE id > $1" " WHERE id > $1 AND last_active_ts >= $2" +
" ORDER BY id ASC LIMIT $3"
type presenceStatements struct { type presenceStatements struct {
db *sql.DB db *sql.DB
@ -158,12 +160,12 @@ func (p *presenceStatements) GetMaxPresenceID(ctx context.Context, txn *sql.Tx)
// GetPresenceAfter returns the changes presences after a given stream id // GetPresenceAfter returns the changes presences after a given stream id
func (p *presenceStatements) GetPresenceAfter( func (p *presenceStatements) GetPresenceAfter(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
after types.StreamPosition, after types.StreamPosition, filter gomatrixserverlib.EventFilter,
) (presences map[string]*types.PresenceInternal, err error) { ) (presences map[string]*types.PresenceInternal, err error) {
presences = make(map[string]*types.PresenceInternal) presences = make(map[string]*types.PresenceInternal)
stmt := sqlutil.TxStmt(txn, p.selectPresenceAfterStmt) stmt := sqlutil.TxStmt(txn, p.selectPresenceAfterStmt)
afterTS := gomatrixserverlib.AsTimestamp(time.Now().Add(time.Minute * -5))
rows, err := stmt.QueryContext(ctx, after) rows, err := stmt.QueryContext(ctx, after, afterTS, filter.Limit)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -102,6 +102,8 @@ type CurrentRoomState interface {
SelectRoomIDsWithAnyMembership(ctx context.Context, txn *sql.Tx, userID string) (map[string]string, error) SelectRoomIDsWithAnyMembership(ctx context.Context, txn *sql.Tx, userID string) (map[string]string, error)
// SelectJoinedUsers returns a map of room ID to a list of joined user IDs. // SelectJoinedUsers returns a map of room ID to a list of joined user IDs.
SelectJoinedUsers(ctx context.Context) (map[string][]string, error) SelectJoinedUsers(ctx context.Context) (map[string][]string, error)
// SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room.
SelectJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error)
} }
// BackwardsExtremities keeps track of backwards extremities for a room. // BackwardsExtremities keeps track of backwards extremities for a room.
@ -188,5 +190,5 @@ type Presence interface {
UpsertPresence(ctx context.Context, txn *sql.Tx, userID string, statusMsg *string, presence types.Presence, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (pos types.StreamPosition, err error) UpsertPresence(ctx context.Context, txn *sql.Tx, userID string, statusMsg *string, presence types.Presence, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (pos types.StreamPosition, err error)
GetPresenceForUser(ctx context.Context, txn *sql.Tx, userID string) (presence *types.PresenceInternal, err error) GetPresenceForUser(ctx context.Context, txn *sql.Tx, userID string) (presence *types.PresenceInternal, err error)
GetMaxPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) GetMaxPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error)
GetPresenceAfter(ctx context.Context, txn *sql.Tx, after types.StreamPosition) (presences map[string]*types.PresenceInternal, err error) GetPresenceAfter(ctx context.Context, txn *sql.Tx, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (presences map[string]*types.PresenceInternal, err error)
} }

View file

@ -53,7 +53,8 @@ func (p *PresenceStreamProvider) IncrementalSync(
req *types.SyncRequest, req *types.SyncRequest,
from, to types.StreamPosition, from, to types.StreamPosition,
) types.StreamPosition { ) types.StreamPosition {
presences, err := p.DB.PresenceAfter(ctx, from) // We pull out a larger number than the filter asks for, since we're filtering out events later
presences, err := p.DB.PresenceAfter(ctx, from, gomatrixserverlib.EventFilter{Limit: 1000})
if err != nil { if err != nil {
req.Log.WithError(err).Error("p.DB.PresenceAfter failed") req.Log.WithError(err).Error("p.DB.PresenceAfter failed")
return from return from
@ -66,12 +67,12 @@ func (p *PresenceStreamProvider) IncrementalSync(
// add newly joined rooms user presences // add newly joined rooms user presences
newlyJoined := joinedRooms(req.Response, req.Device.UserID) newlyJoined := joinedRooms(req.Response, req.Device.UserID)
if len(newlyJoined) > 0 { if len(newlyJoined) > 0 {
// TODO: This refreshes all lists and is quite expensive // TODO: Check if this is working better than before.
// The notifier should update the lists itself if err = p.notifier.LoadRooms(ctx, p.DB, newlyJoined); err != nil {
if err = p.notifier.Load(ctx, p.DB); err != nil {
req.Log.WithError(err).Error("unable to refresh notifier lists") req.Log.WithError(err).Error("unable to refresh notifier lists")
return from return from
} }
NewlyJoinedLoop:
for _, roomID := range newlyJoined { for _, roomID := range newlyJoined {
roomUsers := p.notifier.JoinedUsers(roomID) roomUsers := p.notifier.JoinedUsers(roomID)
for i := range roomUsers { for i := range roomUsers {
@ -86,11 +87,14 @@ func (p *PresenceStreamProvider) IncrementalSync(
req.Log.WithError(err).Error("unable to query presence for user") req.Log.WithError(err).Error("unable to query presence for user")
return from return from
} }
if len(presences) > req.Filter.Presence.Limit {
break NewlyJoinedLoop
}
} }
} }
} }
lastPos := to lastPos := from
for _, presence := range presences { for _, presence := range presences {
if presence == nil { if presence == nil {
continue continue
@ -135,6 +139,9 @@ func (p *PresenceStreamProvider) IncrementalSync(
if presence.StreamPos > lastPos { if presence.StreamPos > lastPos {
lastPos = presence.StreamPos lastPos = presence.StreamPos
} }
if len(req.Response.Presence.Events) == req.Filter.Presence.Limit {
break
}
p.cache.Store(cacheKey, presence) p.cache.Store(cacheKey, presence)
} }

View file

@ -30,7 +30,7 @@ func (d dummyDB) GetPresence(ctx context.Context, userID string) (*types.Presenc
return &types.PresenceInternal{}, nil return &types.PresenceInternal{}, nil
} }
func (d dummyDB) PresenceAfter(ctx context.Context, after types.StreamPosition) (map[string]*types.PresenceInternal, error) { func (d dummyDB) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) {
return map[string]*types.PresenceInternal{}, nil return map[string]*types.PresenceInternal{}, nil
} }

View file

@ -75,7 +75,7 @@ const selectDeviceByTokenSQL = "" +
"SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1" "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
const selectDeviceByIDSQL = "" + const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" "SELECT display_name, last_seen_ts, ip FROM device_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" + const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
@ -215,15 +215,22 @@ func (s *devicesStatements) SelectDeviceByID(
ctx context.Context, localpart, deviceID string, ctx context.Context, localpart, deviceID string,
) (*api.Device, error) { ) (*api.Device, error) {
var dev api.Device var dev api.Device
var displayName sql.NullString var displayName, ip sql.NullString
var lastseenTS sql.NullInt64
stmt := s.selectDeviceByIDStmt stmt := s.selectDeviceByIDStmt
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName) err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip)
if err == nil { if err == nil {
dev.ID = deviceID dev.ID = deviceID
dev.UserID = userutil.MakeUserID(localpart, s.serverName) dev.UserID = userutil.MakeUserID(localpart, s.serverName)
if displayName.Valid { if displayName.Valid {
dev.DisplayName = displayName.String dev.DisplayName = displayName.String
} }
if lastseenTS.Valid {
dev.LastSeenTS = lastseenTS.Int64
}
if ip.Valid {
dev.LastSeenIP = ip.String
}
} }
return &dev, err return &dev, err
} }

View file

@ -60,7 +60,7 @@ const selectDeviceByTokenSQL = "" +
"SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1" "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
const selectDeviceByIDSQL = "" + const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" "SELECT display_name, last_seen_ts, ip FROM device_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" + const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
@ -212,15 +212,22 @@ func (s *devicesStatements) SelectDeviceByID(
ctx context.Context, localpart, deviceID string, ctx context.Context, localpart, deviceID string,
) (*api.Device, error) { ) (*api.Device, error) {
var dev api.Device var dev api.Device
var displayName sql.NullString var displayName, ip sql.NullString
stmt := s.selectDeviceByIDStmt stmt := s.selectDeviceByIDStmt
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName) var lastseenTS sql.NullInt64
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip)
if err == nil { if err == nil {
dev.ID = deviceID dev.ID = deviceID
dev.UserID = userutil.MakeUserID(localpart, s.serverName) dev.UserID = userutil.MakeUserID(localpart, s.serverName)
if displayName.Valid { if displayName.Valid {
dev.DisplayName = displayName.String dev.DisplayName = displayName.String
} }
if lastseenTS.Valid {
dev.LastSeenTS = lastseenTS.Int64
}
if ip.Valid {
dev.LastSeenIP = ip.String
}
} }
return &dev, err return &dev, err
} }

View file

@ -180,12 +180,12 @@ func Test_Devices(t *testing.T) {
deviceWithID.DisplayName = newName deviceWithID.DisplayName = newName
deviceWithID.LastSeenIP = "127.0.0.1" deviceWithID.LastSeenIP = "127.0.0.1"
deviceWithID.LastSeenTS = int64(gomatrixserverlib.AsTimestamp(time.Now().Truncate(time.Second))) deviceWithID.LastSeenTS = int64(gomatrixserverlib.AsTimestamp(time.Now().Truncate(time.Second)))
devices, err = db.GetDevicesByLocalpart(ctx, localpart) gotDevice, err = db.GetDeviceByID(ctx, localpart, deviceWithID.ID)
assert.NoError(t, err, "unable to get device by id") assert.NoError(t, err, "unable to get device by id")
assert.Equal(t, 2, len(devices)) assert.Equal(t, 2, len(devices))
assert.Equal(t, deviceWithID.DisplayName, devices[0].DisplayName) assert.Equal(t, deviceWithID.DisplayName, gotDevice.DisplayName)
assert.Equal(t, deviceWithID.LastSeenIP, devices[0].LastSeenIP) assert.Equal(t, deviceWithID.LastSeenIP, gotDevice.LastSeenIP)
truncatedTime := gomatrixserverlib.Timestamp(devices[0].LastSeenTS).Time().Truncate(time.Second) truncatedTime := gomatrixserverlib.Timestamp(gotDevice.LastSeenTS).Time().Truncate(time.Second)
assert.Equal(t, gomatrixserverlib.Timestamp(deviceWithID.LastSeenTS), gomatrixserverlib.AsTimestamp(truncatedTime)) assert.Equal(t, gomatrixserverlib.Timestamp(deviceWithID.LastSeenTS), gomatrixserverlib.AsTimestamp(truncatedTime))
// create one more device and remove the devices step by step // create one more device and remove the devices step by step