Merge branch 'matrix-org:main' into main

This commit is contained in:
Emanuele Aliberti 2022-04-29 08:58:33 +02:00 committed by GitHub
commit f6009f568a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
40 changed files with 500 additions and 55 deletions

View file

@ -1,4 +1,4 @@
FROM docker.io/golang:1.17-alpine AS base FROM docker.io/golang:1.18-alpine AS base
RUN apk --update --no-cache add bash build-base RUN apk --update --no-cache add bash build-base
@ -23,4 +23,4 @@ COPY --from=base /build/bin/* /usr/bin/
VOLUME /etc/dendrite VOLUME /etc/dendrite
WORKDIR /etc/dendrite WORKDIR /etc/dendrite
ENTRYPOINT ["/usr/bin/dendrite-monolith-server"] ENTRYPOINT ["/usr/bin/dendrite-monolith-server"]

View file

@ -1,4 +1,4 @@
FROM docker.io/golang:1.17-alpine AS base FROM docker.io/golang:1.18-alpine AS base
RUN apk --update --no-cache add bash build-base RUN apk --update --no-cache add bash build-base
@ -23,4 +23,4 @@ COPY --from=base /build/bin/* /usr/bin/
VOLUME /etc/dendrite VOLUME /etc/dendrite
WORKDIR /etc/dendrite WORKDIR /etc/dendrite
ENTRYPOINT ["/usr/bin/dendrite-polylith-multi"] ENTRYPOINT ["/usr/bin/dendrite-polylith-multi"]

View file

@ -314,6 +314,7 @@ func (m *DendriteMonolith) Start() {
base.PublicWellKnownAPIMux, base.PublicWellKnownAPIMux,
base.PublicMediaAPIMux, base.PublicMediaAPIMux,
base.SynapseAdminMux, base.SynapseAdminMux,
base.DendriteAdminMux,
) )
httpRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() httpRouter := mux.NewRouter().SkipClean(true).UseEncodedPath()

View file

@ -152,6 +152,7 @@ func (m *DendriteMonolith) Start() {
base.PublicWellKnownAPIMux, base.PublicWellKnownAPIMux,
base.PublicMediaAPIMux, base.PublicMediaAPIMux,
base.SynapseAdminMux, base.SynapseAdminMux,
base.DendriteAdminMux,
) )
httpRouter := mux.NewRouter() httpRouter := mux.NewRouter()

View file

@ -36,6 +36,7 @@ func AddPublicRoutes(
process *process.ProcessContext, process *process.ProcessContext,
router *mux.Router, router *mux.Router,
synapseAdminRouter *mux.Router, synapseAdminRouter *mux.Router,
dendriteAdminRouter *mux.Router,
cfg *config.ClientAPI, cfg *config.ClientAPI,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
@ -62,7 +63,8 @@ func AddPublicRoutes(
} }
routing.Setup( routing.Setup(
router, synapseAdminRouter, cfg, rsAPI, asAPI, router, synapseAdminRouter, dendriteAdminRouter,
cfg, rsAPI, asAPI,
userAPI, userDirectoryProvider, federation, userAPI, userDirectoryProvider, federation,
syncProducer, transactionsCache, fsAPI, keyAPI, syncProducer, transactionsCache, fsAPI, keyAPI,
extRoomsProvider, mscCfg, natsClient, extRoomsProvider, mscCfg, natsClient,

View file

@ -48,7 +48,8 @@ import (
// applied: // applied:
// nolint: gocyclo // nolint: gocyclo
func Setup( func Setup(
publicAPIMux, synapseAdminRouter *mux.Router, cfg *config.ClientAPI, publicAPIMux, synapseAdminRouter, dendriteAdminRouter *mux.Router,
cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
userAPI userapi.UserInternalAPI, userAPI userapi.UserInternalAPI,
@ -119,6 +120,45 @@ func Setup(
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
} }
dendriteAdminRouter.Handle("/admin/evacuateRoom/{roomID}",
httputil.MakeAuthAPI("admin_evacuate_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if device.AccountType != userapi.AccountTypeAdmin {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("This API can only be used by admin users."),
}
}
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
roomID, ok := vars["roomID"]
if !ok {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.MissingArgument("Expecting room ID."),
}
}
res := &roomserverAPI.PerformAdminEvacuateRoomResponse{}
rsAPI.PerformAdminEvacuateRoom(
req.Context(),
&roomserverAPI.PerformAdminEvacuateRoomRequest{
RoomID: roomID,
},
res,
)
if err := res.Error; err != nil {
return err.JSONResponse()
}
return util.JSONResponse{
Code: 200,
JSON: map[string]interface{}{
"affected": res.Affected,
},
}
}),
).Methods(http.MethodGet, http.MethodOptions)
// server notifications // server notifications
if cfg.Matrix.ServerNotices.Enabled { if cfg.Matrix.ServerNotices.Enabled {
logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice") logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice")

View file

@ -193,6 +193,7 @@ func main() {
base.PublicWellKnownAPIMux, base.PublicWellKnownAPIMux,
base.PublicMediaAPIMux, base.PublicMediaAPIMux,
base.SynapseAdminMux, base.SynapseAdminMux,
base.DendriteAdminMux,
) )
wsUpgrader := websocket.Upgrader{ wsUpgrader := websocket.Upgrader{

View file

@ -150,6 +150,7 @@ func main() {
base.PublicWellKnownAPIMux, base.PublicWellKnownAPIMux,
base.PublicMediaAPIMux, base.PublicMediaAPIMux,
base.SynapseAdminMux, base.SynapseAdminMux,
base.DendriteAdminMux,
) )
if err := mscs.Enable(base, &monolith); err != nil { if err := mscs.Enable(base, &monolith); err != nil {
logrus.WithError(err).Fatalf("Failed to enable MSCs") logrus.WithError(err).Fatalf("Failed to enable MSCs")

View file

@ -153,6 +153,7 @@ func main() {
base.PublicWellKnownAPIMux, base.PublicWellKnownAPIMux,
base.PublicMediaAPIMux, base.PublicMediaAPIMux,
base.SynapseAdminMux, base.SynapseAdminMux,
base.DendriteAdminMux,
) )
if len(base.Cfg.MSCs.MSCs) > 0 { if len(base.Cfg.MSCs.MSCs) > 0 {

View file

@ -31,8 +31,10 @@ func ClientAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) {
keyAPI := base.KeyServerHTTPClient() keyAPI := base.KeyServerHTTPClient()
clientapi.AddPublicRoutes( clientapi.AddPublicRoutes(
base.ProcessContext, base.PublicClientAPIMux, base.SynapseAdminMux, &base.Cfg.ClientAPI, base.ProcessContext, base.PublicClientAPIMux,
federation, rsAPI, asQuery, transactions.New(), fsAPI, userAPI, userAPI, base.SynapseAdminMux, base.DendriteAdminMux,
&base.Cfg.ClientAPI, federation, rsAPI, asQuery,
transactions.New(), fsAPI, userAPI, userAPI,
keyAPI, nil, &cfg.MSCs, keyAPI, nil, &cfg.MSCs,
) )

View file

@ -220,6 +220,7 @@ func startup() {
base.PublicWellKnownAPIMux, base.PublicWellKnownAPIMux,
base.PublicMediaAPIMux, base.PublicMediaAPIMux,
base.SynapseAdminMux, base.SynapseAdminMux,
base.DendriteAdminMux,
) )
httpRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() httpRouter := mux.NewRouter().SkipClean(true).UseEncodedPath()

View file

@ -51,6 +51,12 @@ func Backfill(
} }
} }
// If we don't think we belong to this room then don't waste the effort
// responding to expensive requests for it.
if err := ErrorIfLocalServerNotInRoom(httpReq.Context(), rsAPI, roomID); err != nil {
return *err
}
// Check if all of the required parameters are there. // Check if all of the required parameters are there.
eIDs, exists = httpReq.URL.Query()["v"] eIDs, exists = httpReq.URL.Query()["v"]
if !exists { if !exists {

View file

@ -30,6 +30,12 @@ func GetEventAuth(
roomID string, roomID string,
eventID string, eventID string,
) util.JSONResponse { ) util.JSONResponse {
// If we don't think we belong to this room then don't waste the effort
// responding to expensive requests for it.
if err := ErrorIfLocalServerNotInRoom(ctx, rsAPI, roomID); err != nil {
return *err
}
event, resErr := fetchEvent(ctx, rsAPI, eventID) event, resErr := fetchEvent(ctx, rsAPI, eventID)
if resErr != nil { if resErr != nil {
return *resErr return *resErr

View file

@ -45,6 +45,12 @@ func GetMissingEvents(
} }
} }
// If we don't think we belong to this room then don't waste the effort
// responding to expensive requests for it.
if err := ErrorIfLocalServerNotInRoom(httpReq.Context(), rsAPI, roomID); err != nil {
return *err
}
var eventsResponse api.QueryMissingEventsResponse var eventsResponse api.QueryMissingEventsResponse
if err := rsAPI.QueryMissingEvents( if err := rsAPI.QueryMissingEvents(
httpReq.Context(), &api.QueryMissingEventsRequest{ httpReq.Context(), &api.QueryMissingEventsRequest{

View file

@ -15,6 +15,8 @@
package routing package routing
import ( import (
"context"
"fmt"
"net/http" "net/http"
"github.com/gorilla/mux" "github.com/gorilla/mux"
@ -24,6 +26,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
keyserverAPI "github.com/matrix-org/dendrite/keyserver/api" keyserverAPI "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/roomserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
@ -491,3 +494,27 @@ func Setup(
}), }),
).Methods(http.MethodGet) ).Methods(http.MethodGet)
} }
func ErrorIfLocalServerNotInRoom(
ctx context.Context,
rsAPI api.RoomserverInternalAPI,
roomID string,
) *util.JSONResponse {
// Check if we think we're in this room. If we aren't then
// we won't waste CPU cycles serving this request.
joinedReq := &api.QueryServerJoinedToRoomRequest{
RoomID: roomID,
}
joinedRes := &api.QueryServerJoinedToRoomResponse{}
if err := rsAPI.QueryServerJoinedToRoom(ctx, joinedReq, joinedRes); err != nil {
res := util.ErrorResponse(err)
return &res
}
if !joinedRes.IsInRoom {
return &util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound(fmt.Sprintf("This server is not joined to room %s", roomID)),
}
}
return nil
}

View file

@ -101,6 +101,12 @@ func getState(
roomID string, roomID string,
eventID string, eventID string,
) (stateEvents, authEvents []*gomatrixserverlib.HeaderedEvent, errRes *util.JSONResponse) { ) (stateEvents, authEvents []*gomatrixserverlib.HeaderedEvent, errRes *util.JSONResponse) {
// If we don't think we belong to this room then don't waste the effort
// responding to expensive requests for it.
if err := ErrorIfLocalServerNotInRoom(ctx, rsAPI, roomID); err != nil {
return nil, nil, err
}
event, resErr := fetchEvent(ctx, rsAPI, eventID) event, resErr := fetchEvent(ctx, rsAPI, eventID)
if resErr != nil { if resErr != nil {
return nil, nil, resErr return nil, nil, resErr

View file

@ -319,6 +319,9 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
// JSON, add the signatures and marshal it again, for some reason? // JSON, add the signatures and marshal it again, for some reason?
for targetUserID, masterKey := range res.MasterKeys { for targetUserID, masterKey := range res.MasterKeys {
if masterKey.Signatures == nil {
masterKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
}
for targetKeyID := range masterKey.Keys { for targetKeyID := range masterKey.Keys {
sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID) sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID)
if err != nil { if err != nil {

View file

@ -66,6 +66,12 @@ type RoomserverInternalAPI interface {
res *PerformInboundPeekResponse, res *PerformInboundPeekResponse,
) error ) error
PerformAdminEvacuateRoom(
ctx context.Context,
req *PerformAdminEvacuateRoomRequest,
res *PerformAdminEvacuateRoomResponse,
)
QueryPublishedRooms( QueryPublishedRooms(
ctx context.Context, ctx context.Context,
req *QueryPublishedRoomsRequest, req *QueryPublishedRoomsRequest,

View file

@ -104,6 +104,15 @@ func (t *RoomserverInternalAPITrace) PerformPublish(
util.GetLogger(ctx).Infof("PerformPublish req=%+v res=%+v", js(req), js(res)) util.GetLogger(ctx).Infof("PerformPublish req=%+v res=%+v", js(req), js(res))
} }
func (t *RoomserverInternalAPITrace) PerformAdminEvacuateRoom(
ctx context.Context,
req *PerformAdminEvacuateRoomRequest,
res *PerformAdminEvacuateRoomResponse,
) {
t.Impl.PerformAdminEvacuateRoom(ctx, req, res)
util.GetLogger(ctx).Infof("PerformAdminEvacuateRoom req=%+v res=%+v", js(req), js(res))
}
func (t *RoomserverInternalAPITrace) PerformInboundPeek( func (t *RoomserverInternalAPITrace) PerformInboundPeek(
ctx context.Context, ctx context.Context,
req *PerformInboundPeekRequest, req *PerformInboundPeekRequest,

View file

@ -214,3 +214,12 @@ type PerformRoomUpgradeResponse struct {
NewRoomID string NewRoomID string
Error *PerformError Error *PerformError
} }
type PerformAdminEvacuateRoomRequest struct {
RoomID string `json:"room_id"`
}
type PerformAdminEvacuateRoomResponse struct {
Affected []string `json:"affected"`
Error *PerformError
}

View file

@ -35,6 +35,7 @@ type RoomserverInternalAPI struct {
*perform.Backfiller *perform.Backfiller
*perform.Forgetter *perform.Forgetter
*perform.Upgrader *perform.Upgrader
*perform.Admin
ProcessContext *process.ProcessContext ProcessContext *process.ProcessContext
DB storage.Database DB storage.Database
Cfg *config.RoomServer Cfg *config.RoomServer
@ -164,6 +165,12 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.FederationInternalA
Cfg: r.Cfg, Cfg: r.Cfg,
URSAPI: r, URSAPI: r,
} }
r.Admin = &perform.Admin{
DB: r.DB,
Cfg: r.Cfg,
Inputer: r.Inputer,
Queryer: r.Queryer,
}
if err := r.Inputer.Start(); err != nil { if err := r.Inputer.Start(); err != nil {
logrus.WithError(err).Panic("failed to start roomserver input API") logrus.WithError(err).Panic("failed to start roomserver input API")

View file

@ -0,0 +1,162 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package perform
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/internal/input"
"github.com/matrix-org/dendrite/roomserver/internal/query"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
type Admin struct {
DB storage.Database
Cfg *config.RoomServer
Queryer *query.Queryer
Inputer *input.Inputer
}
// PerformEvacuateRoom will remove all local users from the given room.
func (r *Admin) PerformAdminEvacuateRoom(
ctx context.Context,
req *api.PerformAdminEvacuateRoomRequest,
res *api.PerformAdminEvacuateRoomResponse,
) {
roomInfo, err := r.DB.RoomInfo(ctx, req.RoomID)
if err != nil {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.RoomInfo: %s", err),
}
return
}
if roomInfo == nil || roomInfo.IsStub {
res.Error = &api.PerformError{
Code: api.PerformErrorNoRoom,
Msg: fmt.Sprintf("Room %s not found", req.RoomID),
}
return
}
memberNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true)
if err != nil {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.GetMembershipEventNIDsForRoom: %s", err),
}
return
}
memberEvents, err := r.DB.Events(ctx, memberNIDs)
if err != nil {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.DB.Events: %s", err),
}
return
}
inputEvents := make([]api.InputRoomEvent, 0, len(memberEvents))
res.Affected = make([]string, 0, len(memberEvents))
latestReq := &api.QueryLatestEventsAndStateRequest{
RoomID: req.RoomID,
}
latestRes := &api.QueryLatestEventsAndStateResponse{}
if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("r.Queryer.QueryLatestEventsAndState: %s", err),
}
return
}
prevEvents := latestRes.LatestEvents
for _, memberEvent := range memberEvents {
if memberEvent.StateKey() == nil {
continue
}
var memberContent gomatrixserverlib.MemberContent
if err = json.Unmarshal(memberEvent.Content(), &memberContent); err != nil {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("json.Unmarshal: %s", err),
}
return
}
memberContent.Membership = gomatrixserverlib.Leave
stateKey := *memberEvent.StateKey()
fledglingEvent := &gomatrixserverlib.EventBuilder{
RoomID: req.RoomID,
Type: gomatrixserverlib.MRoomMember,
StateKey: &stateKey,
Sender: stateKey,
PrevEvents: prevEvents,
}
if fledglingEvent.Content, err = json.Marshal(memberContent); err != nil {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("json.Marshal: %s", err),
}
return
}
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(fledglingEvent)
if err != nil {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("gomatrixserverlib.StateNeededForEventBuilder: %s", err),
}
return
}
event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, time.Now(), &eventsNeeded, latestRes)
if err != nil {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("eventutil.BuildEvent: %s", err),
}
return
}
inputEvents = append(inputEvents, api.InputRoomEvent{
Kind: api.KindNew,
Event: event,
Origin: r.Cfg.Matrix.ServerName,
SendAsServer: string(r.Cfg.Matrix.ServerName),
})
res.Affected = append(res.Affected, stateKey)
prevEvents = []gomatrixserverlib.EventReference{
event.EventReference(),
}
}
inputReq := &api.InputRoomEventsRequest{
InputRoomEvents: inputEvents,
Asynchronous: true,
}
inputRes := &api.InputRoomEventsResponse{}
r.Inputer.InputRoomEvents(ctx, inputReq, inputRes)
}

View file

@ -29,16 +29,17 @@ const (
RoomserverInputRoomEventsPath = "/roomserver/inputRoomEvents" RoomserverInputRoomEventsPath = "/roomserver/inputRoomEvents"
// Perform operations // Perform operations
RoomserverPerformInvitePath = "/roomserver/performInvite" RoomserverPerformInvitePath = "/roomserver/performInvite"
RoomserverPerformPeekPath = "/roomserver/performPeek" RoomserverPerformPeekPath = "/roomserver/performPeek"
RoomserverPerformUnpeekPath = "/roomserver/performUnpeek" RoomserverPerformUnpeekPath = "/roomserver/performUnpeek"
RoomserverPerformRoomUpgradePath = "/roomserver/performRoomUpgrade" RoomserverPerformRoomUpgradePath = "/roomserver/performRoomUpgrade"
RoomserverPerformJoinPath = "/roomserver/performJoin" RoomserverPerformJoinPath = "/roomserver/performJoin"
RoomserverPerformLeavePath = "/roomserver/performLeave" RoomserverPerformLeavePath = "/roomserver/performLeave"
RoomserverPerformBackfillPath = "/roomserver/performBackfill" RoomserverPerformBackfillPath = "/roomserver/performBackfill"
RoomserverPerformPublishPath = "/roomserver/performPublish" RoomserverPerformPublishPath = "/roomserver/performPublish"
RoomserverPerformInboundPeekPath = "/roomserver/performInboundPeek" RoomserverPerformInboundPeekPath = "/roomserver/performInboundPeek"
RoomserverPerformForgetPath = "/roomserver/performForget" RoomserverPerformForgetPath = "/roomserver/performForget"
RoomserverPerformAdminEvacuateRoomPath = "/roomserver/performAdminEvacuateRoom"
// Query operations // Query operations
RoomserverQueryLatestEventsAndStatePath = "/roomserver/queryLatestEventsAndState" RoomserverQueryLatestEventsAndStatePath = "/roomserver/queryLatestEventsAndState"
@ -299,6 +300,23 @@ func (h *httpRoomserverInternalAPI) PerformPublish(
} }
} }
func (h *httpRoomserverInternalAPI) PerformAdminEvacuateRoom(
ctx context.Context,
req *api.PerformAdminEvacuateRoomRequest,
res *api.PerformAdminEvacuateRoomResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformAdminEvacuateRoom")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverPerformAdminEvacuateRoomPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
if err != nil {
res.Error = &api.PerformError{
Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err),
}
}
}
// QueryLatestEventsAndState implements RoomserverQueryAPI // QueryLatestEventsAndState implements RoomserverQueryAPI
func (h *httpRoomserverInternalAPI) QueryLatestEventsAndState( func (h *httpRoomserverInternalAPI) QueryLatestEventsAndState(
ctx context.Context, ctx context.Context,

View file

@ -118,6 +118,17 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )
internalAPIMux.Handle(RoomserverPerformAdminEvacuateRoomPath,
httputil.MakeInternalAPI("performAdminEvacuateRoom", func(req *http.Request) util.JSONResponse {
var request api.PerformAdminEvacuateRoomRequest
var response api.PerformAdminEvacuateRoomResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
r.PerformAdminEvacuateRoom(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle( internalAPIMux.Handle(
RoomserverQueryPublishedRoomsPath, RoomserverQueryPublishedRoomsPath,
httputil.MakeInternalAPI("queryPublishedRooms", func(req *http.Request) util.JSONResponse { httputil.MakeInternalAPI("queryPublishedRooms", func(req *http.Request) util.JSONResponse {

View file

@ -54,13 +54,13 @@ type Monolith struct {
} }
// AddAllPublicRoutes attaches all public paths to the given router // AddAllPublicRoutes attaches all public paths to the given router
func (m *Monolith) AddAllPublicRoutes(process *process.ProcessContext, csMux, ssMux, keyMux, wkMux, mediaMux, synapseMux *mux.Router) { func (m *Monolith) AddAllPublicRoutes(process *process.ProcessContext, csMux, ssMux, keyMux, wkMux, mediaMux, synapseMux, dendriteMux *mux.Router) {
userDirectoryProvider := m.ExtUserDirectoryProvider userDirectoryProvider := m.ExtUserDirectoryProvider
if userDirectoryProvider == nil { if userDirectoryProvider == nil {
userDirectoryProvider = m.UserAPI userDirectoryProvider = m.UserAPI
} }
clientapi.AddPublicRoutes( clientapi.AddPublicRoutes(
process, csMux, synapseMux, &m.Config.ClientAPI, process, csMux, synapseMux, dendriteMux, &m.Config.ClientAPI,
m.FedClient, m.RoomserverAPI, m.FedClient, m.RoomserverAPI,
m.AppserviceAPI, transactions.New(), m.AppserviceAPI, transactions.New(),
m.FederationAPI, m.UserAPI, userDirectoryProvider, m.KeyAPI, m.FederationAPI, m.UserAPI, userDirectoryProvider, m.KeyAPI,

View file

@ -333,6 +333,20 @@ func (n *Notifier) Load(ctx context.Context, db storage.Database) error {
return nil return nil
} }
// LoadRooms loads the membership states required to notify users correctly.
func (n *Notifier) LoadRooms(ctx context.Context, db storage.Database, roomIDs []string) error {
n.lock.Lock()
defer n.lock.Unlock()
roomToUsers, err := db.AllJoinedUsersInRoom(ctx, roomIDs)
if err != nil {
return err
}
n.setUsersJoinedToRooms(roomToUsers)
return nil
}
// CurrentPosition returns the current sync position // CurrentPosition returns the current sync position
func (n *Notifier) CurrentPosition() types.StreamingToken { func (n *Notifier) CurrentPosition() types.StreamingToken {
n.lock.RLock() n.lock.RLock()

View file

@ -52,6 +52,9 @@ type Database interface {
// AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs. // AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs.
AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error)
// AllJoinedUsersInRoom returns a map of room ID to a list of all joined user IDs for a given room.
AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error)
// AllPeekingDevicesInRooms returns a map of room ID to a list of all peeking devices. // AllPeekingDevicesInRooms returns a map of room ID to a list of all peeking devices.
AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error)
// Events lookups a list of event by their event ID. // Events lookups a list of event by their event ID.
@ -159,6 +162,6 @@ type Database interface {
type Presence interface { type Presence interface {
UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error)
GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error)
PresenceAfter(ctx context.Context, after types.StreamPosition) (map[string]*types.PresenceInternal, error) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error)
MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error)
} }

View file

@ -93,6 +93,9 @@ const selectCurrentStateSQL = "" +
const selectJoinedUsersSQL = "" + const selectJoinedUsersSQL = "" +
"SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'" "SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'"
const selectJoinedUsersInRoomSQL = "" +
"SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join' AND room_id = ANY($1)"
const selectStateEventSQL = "" + const selectStateEventSQL = "" +
"SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3" "SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3"
@ -112,6 +115,7 @@ type currentRoomStateStatements struct {
selectRoomIDsWithAnyMembershipStmt *sql.Stmt selectRoomIDsWithAnyMembershipStmt *sql.Stmt
selectCurrentStateStmt *sql.Stmt selectCurrentStateStmt *sql.Stmt
selectJoinedUsersStmt *sql.Stmt selectJoinedUsersStmt *sql.Stmt
selectJoinedUsersInRoomStmt *sql.Stmt
selectEventsWithEventIDsStmt *sql.Stmt selectEventsWithEventIDsStmt *sql.Stmt
selectStateEventStmt *sql.Stmt selectStateEventStmt *sql.Stmt
} }
@ -143,6 +147,9 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro
if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil { if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil {
return nil, err return nil, err
} }
if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil {
return nil, err
}
if s.selectEventsWithEventIDsStmt, err = db.Prepare(selectEventsWithEventIDsSQL); err != nil { if s.selectEventsWithEventIDsStmt, err = db.Prepare(selectEventsWithEventIDsSQL); err != nil {
return nil, err return nil, err
} }
@ -163,9 +170,32 @@ func (s *currentRoomStateStatements) SelectJoinedUsers(
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsers: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsers: rows.close() failed")
result := make(map[string][]string) result := make(map[string][]string)
var roomID string
var userID string
for rows.Next() {
if err := rows.Scan(&roomID, &userID); err != nil {
return nil, err
}
users := result[roomID]
users = append(users, userID)
result[roomID] = users
}
return result, rows.Err()
}
// SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room.
func (s *currentRoomStateStatements) SelectJoinedUsersInRoom(
ctx context.Context, roomIDs []string,
) (map[string][]string, error) {
rows, err := s.selectJoinedUsersInRoomStmt.QueryContext(ctx, pq.StringArray(roomIDs))
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsers: rows.close() failed")
result := make(map[string][]string)
var userID, roomID string
for rows.Next() { for rows.Next() {
var roomID string
var userID string
if err := rows.Scan(&roomID, &userID); err != nil { if err := rows.Scan(&roomID, &userID); err != nil {
return nil, err return nil, err
} }

View file

@ -64,7 +64,7 @@ const selectMembershipCountSQL = "" +
") t WHERE t.membership = $3" ") t WHERE t.membership = $3"
const selectHeroesSQL = "" + const selectHeroesSQL = "" +
"SELECT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership = ANY($3) LIMIT 5" "SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership = ANY($3) LIMIT 5"
type membershipsStatements struct { type membershipsStatements struct {
upsertMembershipStmt *sql.Stmt upsertMembershipStmt *sql.Stmt

View file

@ -17,6 +17,7 @@ package postgres
import ( import (
"context" "context"
"database/sql" "database/sql"
"time"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
@ -72,7 +73,8 @@ const selectMaxPresenceSQL = "" +
const selectPresenceAfter = "" + const selectPresenceAfter = "" +
" SELECT id, user_id, presence, status_msg, last_active_ts" + " SELECT id, user_id, presence, status_msg, last_active_ts" +
" FROM syncapi_presence" + " FROM syncapi_presence" +
" WHERE id > $1" " WHERE id > $1 AND last_active_ts >= $2" +
" ORDER BY id ASC LIMIT $3"
type presenceStatements struct { type presenceStatements struct {
upsertPresenceStmt *sql.Stmt upsertPresenceStmt *sql.Stmt
@ -144,11 +146,12 @@ func (p *presenceStatements) GetMaxPresenceID(ctx context.Context, txn *sql.Tx)
func (p *presenceStatements) GetPresenceAfter( func (p *presenceStatements) GetPresenceAfter(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
after types.StreamPosition, after types.StreamPosition,
filter gomatrixserverlib.EventFilter,
) (presences map[string]*types.PresenceInternal, err error) { ) (presences map[string]*types.PresenceInternal, err error) {
presences = make(map[string]*types.PresenceInternal) presences = make(map[string]*types.PresenceInternal)
stmt := sqlutil.TxStmt(txn, p.selectPresenceAfterStmt) stmt := sqlutil.TxStmt(txn, p.selectPresenceAfterStmt)
afterTS := gomatrixserverlib.AsTimestamp(time.Now().Add(time.Minute * -5))
rows, err := stmt.QueryContext(ctx, after) rows, err := stmt.QueryContext(ctx, after, afterTS, filter.Limit)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -168,6 +168,10 @@ func (d *Database) AllJoinedUsersInRooms(ctx context.Context) (map[string][]stri
return d.CurrentRoomState.SelectJoinedUsers(ctx) return d.CurrentRoomState.SelectJoinedUsers(ctx)
} }
func (d *Database) AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error) {
return d.CurrentRoomState.SelectJoinedUsersInRoom(ctx, roomIDs)
}
func (d *Database) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) { func (d *Database) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) {
return d.Peeks.SelectPeekingDevices(ctx) return d.Peeks.SelectPeekingDevices(ctx)
} }
@ -1056,8 +1060,8 @@ func (s *Database) GetPresence(ctx context.Context, userID string) (*types.Prese
return s.Presence.GetPresenceForUser(ctx, nil, userID) return s.Presence.GetPresenceForUser(ctx, nil, userID)
} }
func (s *Database) PresenceAfter(ctx context.Context, after types.StreamPosition) (map[string]*types.PresenceInternal, error) { func (s *Database) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) {
return s.Presence.GetPresenceAfter(ctx, nil, after) return s.Presence.GetPresenceAfter(ctx, nil, after, filter)
} }
func (s *Database) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) { func (s *Database) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) {

View file

@ -77,6 +77,9 @@ const selectCurrentStateSQL = "" +
const selectJoinedUsersSQL = "" + const selectJoinedUsersSQL = "" +
"SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'" "SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'"
const selectJoinedUsersInRoomSQL = "" +
"SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join' AND room_id IN ($1)"
const selectStateEventSQL = "" + const selectStateEventSQL = "" +
"SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3" "SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3"
@ -97,7 +100,8 @@ type currentRoomStateStatements struct {
selectRoomIDsWithMembershipStmt *sql.Stmt selectRoomIDsWithMembershipStmt *sql.Stmt
selectRoomIDsWithAnyMembershipStmt *sql.Stmt selectRoomIDsWithAnyMembershipStmt *sql.Stmt
selectJoinedUsersStmt *sql.Stmt selectJoinedUsersStmt *sql.Stmt
selectStateEventStmt *sql.Stmt //selectJoinedUsersInRoomStmt *sql.Stmt - prepared at runtime due to variadic
selectStateEventStmt *sql.Stmt
} }
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) { func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) {
@ -127,13 +131,16 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (t
if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil { if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil {
return nil, err return nil, err
} }
//if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil {
// return nil, err
//}
if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil { if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil {
return nil, err return nil, err
} }
return s, nil return s, nil
} }
// JoinedMemberLists returns a map of room ID to a list of joined user IDs. // SelectJoinedUsers returns a map of room ID to a list of joined user IDs.
func (s *currentRoomStateStatements) SelectJoinedUsers( func (s *currentRoomStateStatements) SelectJoinedUsers(
ctx context.Context, ctx context.Context,
) (map[string][]string, error) { ) (map[string][]string, error) {
@ -144,9 +151,9 @@ func (s *currentRoomStateStatements) SelectJoinedUsers(
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsers: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsers: rows.close() failed")
result := make(map[string][]string) result := make(map[string][]string)
var roomID string
var userID string
for rows.Next() { for rows.Next() {
var roomID string
var userID string
if err := rows.Scan(&roomID, &userID); err != nil { if err := rows.Scan(&roomID, &userID); err != nil {
return nil, err return nil, err
} }
@ -157,6 +164,40 @@ func (s *currentRoomStateStatements) SelectJoinedUsers(
return result, nil return result, nil
} }
// SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room.
func (s *currentRoomStateStatements) SelectJoinedUsersInRoom(
ctx context.Context, roomIDs []string,
) (map[string][]string, error) {
query := strings.Replace(selectJoinedUsersInRoomSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1)
params := make([]interface{}, 0, len(roomIDs))
for _, roomID := range roomIDs {
params = append(params, roomID)
}
stmt, err := s.db.Prepare(query)
if err != nil {
return nil, fmt.Errorf("SelectJoinedUsersInRoom s.db.Prepare: %w", err)
}
defer internal.CloseAndLogIfError(ctx, stmt, "SelectJoinedUsersInRoom: stmt.close() failed")
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectJoinedUsersInRoom: rows.close() failed")
result := make(map[string][]string)
var userID, roomID string
for rows.Next() {
if err := rows.Scan(&roomID, &userID); err != nil {
return nil, err
}
users := result[roomID]
users = append(users, userID)
result[roomID] = users
}
return result, rows.Err()
}
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
func (s *currentRoomStateStatements) SelectRoomIDsWithMembership( func (s *currentRoomStateStatements) SelectRoomIDsWithMembership(
ctx context.Context, ctx context.Context,

View file

@ -17,6 +17,7 @@ package sqlite3
import ( import (
"context" "context"
"database/sql" "database/sql"
"time"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
@ -71,7 +72,8 @@ const selectMaxPresenceSQL = "" +
const selectPresenceAfter = "" + const selectPresenceAfter = "" +
" SELECT id, user_id, presence, status_msg, last_active_ts" + " SELECT id, user_id, presence, status_msg, last_active_ts" +
" FROM syncapi_presence" + " FROM syncapi_presence" +
" WHERE id > $1" " WHERE id > $1 AND last_active_ts >= $2" +
" ORDER BY id ASC LIMIT $3"
type presenceStatements struct { type presenceStatements struct {
db *sql.DB db *sql.DB
@ -158,12 +160,12 @@ func (p *presenceStatements) GetMaxPresenceID(ctx context.Context, txn *sql.Tx)
// GetPresenceAfter returns the changes presences after a given stream id // GetPresenceAfter returns the changes presences after a given stream id
func (p *presenceStatements) GetPresenceAfter( func (p *presenceStatements) GetPresenceAfter(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
after types.StreamPosition, after types.StreamPosition, filter gomatrixserverlib.EventFilter,
) (presences map[string]*types.PresenceInternal, err error) { ) (presences map[string]*types.PresenceInternal, err error) {
presences = make(map[string]*types.PresenceInternal) presences = make(map[string]*types.PresenceInternal)
stmt := sqlutil.TxStmt(txn, p.selectPresenceAfterStmt) stmt := sqlutil.TxStmt(txn, p.selectPresenceAfterStmt)
afterTS := gomatrixserverlib.AsTimestamp(time.Now().Add(time.Minute * -5))
rows, err := stmt.QueryContext(ctx, after) rows, err := stmt.QueryContext(ctx, after, afterTS, filter.Limit)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -102,6 +102,8 @@ type CurrentRoomState interface {
SelectRoomIDsWithAnyMembership(ctx context.Context, txn *sql.Tx, userID string) (map[string]string, error) SelectRoomIDsWithAnyMembership(ctx context.Context, txn *sql.Tx, userID string) (map[string]string, error)
// SelectJoinedUsers returns a map of room ID to a list of joined user IDs. // SelectJoinedUsers returns a map of room ID to a list of joined user IDs.
SelectJoinedUsers(ctx context.Context) (map[string][]string, error) SelectJoinedUsers(ctx context.Context) (map[string][]string, error)
// SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room.
SelectJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error)
} }
// BackwardsExtremities keeps track of backwards extremities for a room. // BackwardsExtremities keeps track of backwards extremities for a room.
@ -188,5 +190,5 @@ type Presence interface {
UpsertPresence(ctx context.Context, txn *sql.Tx, userID string, statusMsg *string, presence types.Presence, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (pos types.StreamPosition, err error) UpsertPresence(ctx context.Context, txn *sql.Tx, userID string, statusMsg *string, presence types.Presence, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (pos types.StreamPosition, err error)
GetPresenceForUser(ctx context.Context, txn *sql.Tx, userID string) (presence *types.PresenceInternal, err error) GetPresenceForUser(ctx context.Context, txn *sql.Tx, userID string) (presence *types.PresenceInternal, err error)
GetMaxPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) GetMaxPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error)
GetPresenceAfter(ctx context.Context, txn *sql.Tx, after types.StreamPosition) (presences map[string]*types.PresenceInternal, err error) GetPresenceAfter(ctx context.Context, txn *sql.Tx, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (presences map[string]*types.PresenceInternal, err error)
} }

View file

@ -53,7 +53,8 @@ func (p *PresenceStreamProvider) IncrementalSync(
req *types.SyncRequest, req *types.SyncRequest,
from, to types.StreamPosition, from, to types.StreamPosition,
) types.StreamPosition { ) types.StreamPosition {
presences, err := p.DB.PresenceAfter(ctx, from) // We pull out a larger number than the filter asks for, since we're filtering out events later
presences, err := p.DB.PresenceAfter(ctx, from, gomatrixserverlib.EventFilter{Limit: 1000})
if err != nil { if err != nil {
req.Log.WithError(err).Error("p.DB.PresenceAfter failed") req.Log.WithError(err).Error("p.DB.PresenceAfter failed")
return from return from
@ -66,12 +67,12 @@ func (p *PresenceStreamProvider) IncrementalSync(
// add newly joined rooms user presences // add newly joined rooms user presences
newlyJoined := joinedRooms(req.Response, req.Device.UserID) newlyJoined := joinedRooms(req.Response, req.Device.UserID)
if len(newlyJoined) > 0 { if len(newlyJoined) > 0 {
// TODO: This refreshes all lists and is quite expensive // TODO: Check if this is working better than before.
// The notifier should update the lists itself if err = p.notifier.LoadRooms(ctx, p.DB, newlyJoined); err != nil {
if err = p.notifier.Load(ctx, p.DB); err != nil {
req.Log.WithError(err).Error("unable to refresh notifier lists") req.Log.WithError(err).Error("unable to refresh notifier lists")
return from return from
} }
NewlyJoinedLoop:
for _, roomID := range newlyJoined { for _, roomID := range newlyJoined {
roomUsers := p.notifier.JoinedUsers(roomID) roomUsers := p.notifier.JoinedUsers(roomID)
for i := range roomUsers { for i := range roomUsers {
@ -86,11 +87,14 @@ func (p *PresenceStreamProvider) IncrementalSync(
req.Log.WithError(err).Error("unable to query presence for user") req.Log.WithError(err).Error("unable to query presence for user")
return from return from
} }
if len(presences) > req.Filter.Presence.Limit {
break NewlyJoinedLoop
}
} }
} }
} }
lastPos := to lastPos := from
for _, presence := range presences { for _, presence := range presences {
if presence == nil { if presence == nil {
continue continue
@ -135,6 +139,9 @@ func (p *PresenceStreamProvider) IncrementalSync(
if presence.StreamPos > lastPos { if presence.StreamPos > lastPos {
lastPos = presence.StreamPos lastPos = presence.StreamPos
} }
if len(req.Response.Presence.Events) == req.Filter.Presence.Limit {
break
}
p.cache.Store(cacheKey, presence) p.cache.Store(cacheKey, presence)
} }

View file

@ -30,7 +30,7 @@ func (d dummyDB) GetPresence(ctx context.Context, userID string) (*types.Presenc
return &types.PresenceInternal{}, nil return &types.PresenceInternal{}, nil
} }
func (d dummyDB) PresenceAfter(ctx context.Context, after types.StreamPosition) (map[string]*types.PresenceInternal, error) { func (d dummyDB) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) {
return map[string]*types.PresenceInternal{}, nil return map[string]*types.PresenceInternal{}, nil
} }

View file

@ -48,4 +48,3 @@ Notifications can be viewed with GET /notifications
# More flakey # More flakey
If remote user leaves room we no longer receive device updates If remote user leaves room we no longer receive device updates
Local device key changes get to remote servers

View file

@ -75,7 +75,7 @@ const selectDeviceByTokenSQL = "" +
"SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1" "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
const selectDeviceByIDSQL = "" + const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" "SELECT display_name, last_seen_ts, ip FROM device_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" + const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
@ -215,15 +215,22 @@ func (s *devicesStatements) SelectDeviceByID(
ctx context.Context, localpart, deviceID string, ctx context.Context, localpart, deviceID string,
) (*api.Device, error) { ) (*api.Device, error) {
var dev api.Device var dev api.Device
var displayName sql.NullString var displayName, ip sql.NullString
var lastseenTS sql.NullInt64
stmt := s.selectDeviceByIDStmt stmt := s.selectDeviceByIDStmt
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName) err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip)
if err == nil { if err == nil {
dev.ID = deviceID dev.ID = deviceID
dev.UserID = userutil.MakeUserID(localpart, s.serverName) dev.UserID = userutil.MakeUserID(localpart, s.serverName)
if displayName.Valid { if displayName.Valid {
dev.DisplayName = displayName.String dev.DisplayName = displayName.String
} }
if lastseenTS.Valid {
dev.LastSeenTS = lastseenTS.Int64
}
if ip.Valid {
dev.LastSeenIP = ip.String
}
} }
return &dev, err return &dev, err
} }

View file

@ -60,7 +60,7 @@ const selectDeviceByTokenSQL = "" +
"SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1" "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
const selectDeviceByIDSQL = "" + const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" "SELECT display_name, last_seen_ts, ip FROM device_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" + const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
@ -212,15 +212,22 @@ func (s *devicesStatements) SelectDeviceByID(
ctx context.Context, localpart, deviceID string, ctx context.Context, localpart, deviceID string,
) (*api.Device, error) { ) (*api.Device, error) {
var dev api.Device var dev api.Device
var displayName sql.NullString var displayName, ip sql.NullString
stmt := s.selectDeviceByIDStmt stmt := s.selectDeviceByIDStmt
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName) var lastseenTS sql.NullInt64
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip)
if err == nil { if err == nil {
dev.ID = deviceID dev.ID = deviceID
dev.UserID = userutil.MakeUserID(localpart, s.serverName) dev.UserID = userutil.MakeUserID(localpart, s.serverName)
if displayName.Valid { if displayName.Valid {
dev.DisplayName = displayName.String dev.DisplayName = displayName.String
} }
if lastseenTS.Valid {
dev.LastSeenTS = lastseenTS.Int64
}
if ip.Valid {
dev.LastSeenIP = ip.String
}
} }
return &dev, err return &dev, err
} }

View file

@ -180,12 +180,12 @@ func Test_Devices(t *testing.T) {
deviceWithID.DisplayName = newName deviceWithID.DisplayName = newName
deviceWithID.LastSeenIP = "127.0.0.1" deviceWithID.LastSeenIP = "127.0.0.1"
deviceWithID.LastSeenTS = int64(gomatrixserverlib.AsTimestamp(time.Now().Truncate(time.Second))) deviceWithID.LastSeenTS = int64(gomatrixserverlib.AsTimestamp(time.Now().Truncate(time.Second)))
devices, err = db.GetDevicesByLocalpart(ctx, localpart) gotDevice, err = db.GetDeviceByID(ctx, localpart, deviceWithID.ID)
assert.NoError(t, err, "unable to get device by id") assert.NoError(t, err, "unable to get device by id")
assert.Equal(t, 2, len(devices)) assert.Equal(t, 2, len(devices))
assert.Equal(t, deviceWithID.DisplayName, devices[0].DisplayName) assert.Equal(t, deviceWithID.DisplayName, gotDevice.DisplayName)
assert.Equal(t, deviceWithID.LastSeenIP, devices[0].LastSeenIP) assert.Equal(t, deviceWithID.LastSeenIP, gotDevice.LastSeenIP)
truncatedTime := gomatrixserverlib.Timestamp(devices[0].LastSeenTS).Time().Truncate(time.Second) truncatedTime := gomatrixserverlib.Timestamp(gotDevice.LastSeenTS).Time().Truncate(time.Second)
assert.Equal(t, gomatrixserverlib.Timestamp(deviceWithID.LastSeenTS), gomatrixserverlib.AsTimestamp(truncatedTime)) assert.Equal(t, gomatrixserverlib.Timestamp(deviceWithID.LastSeenTS), gomatrixserverlib.AsTimestamp(truncatedTime))
// create one more device and remove the devices step by step // create one more device and remove the devices step by step