diff --git a/build/docker/Dockerfile.monolith b/build/docker/Dockerfile.monolith index 0d2a141ad..891a3a9e0 100644 --- a/build/docker/Dockerfile.monolith +++ b/build/docker/Dockerfile.monolith @@ -1,4 +1,4 @@ -FROM docker.io/golang:1.17-alpine AS base +FROM docker.io/golang:1.18-alpine AS base RUN apk --update --no-cache add bash build-base @@ -23,4 +23,4 @@ COPY --from=base /build/bin/* /usr/bin/ VOLUME /etc/dendrite WORKDIR /etc/dendrite -ENTRYPOINT ["/usr/bin/dendrite-monolith-server"] \ No newline at end of file +ENTRYPOINT ["/usr/bin/dendrite-monolith-server"] diff --git a/build/docker/Dockerfile.polylith b/build/docker/Dockerfile.polylith index c266fd480..ffdc35586 100644 --- a/build/docker/Dockerfile.polylith +++ b/build/docker/Dockerfile.polylith @@ -1,4 +1,4 @@ -FROM docker.io/golang:1.17-alpine AS base +FROM docker.io/golang:1.18-alpine AS base RUN apk --update --no-cache add bash build-base @@ -23,4 +23,4 @@ COPY --from=base /build/bin/* /usr/bin/ VOLUME /etc/dendrite WORKDIR /etc/dendrite -ENTRYPOINT ["/usr/bin/dendrite-polylith-multi"] \ No newline at end of file +ENTRYPOINT ["/usr/bin/dendrite-polylith-multi"] diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index 9cc94d650..6b2533491 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -314,6 +314,7 @@ func (m *DendriteMonolith) Start() { base.PublicWellKnownAPIMux, base.PublicMediaAPIMux, base.SynapseAdminMux, + base.DendriteAdminMux, ) httpRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() diff --git a/build/gobind-yggdrasil/monolith.go b/build/gobind-yggdrasil/monolith.go index 87dcad2e8..b9c6c1b78 100644 --- a/build/gobind-yggdrasil/monolith.go +++ b/build/gobind-yggdrasil/monolith.go @@ -152,6 +152,7 @@ func (m *DendriteMonolith) Start() { base.PublicWellKnownAPIMux, base.PublicMediaAPIMux, base.SynapseAdminMux, + base.DendriteAdminMux, ) httpRouter := mux.NewRouter() diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index e2f8d3f32..ad277056c 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -36,6 +36,7 @@ func AddPublicRoutes( process *process.ProcessContext, router *mux.Router, synapseAdminRouter *mux.Router, + dendriteAdminRouter *mux.Router, cfg *config.ClientAPI, federation *gomatrixserverlib.FederationClient, rsAPI roomserverAPI.RoomserverInternalAPI, @@ -62,7 +63,8 @@ func AddPublicRoutes( } routing.Setup( - router, synapseAdminRouter, cfg, rsAPI, asAPI, + router, synapseAdminRouter, dendriteAdminRouter, + cfg, rsAPI, asAPI, userAPI, userDirectoryProvider, federation, syncProducer, transactionsCache, fsAPI, keyAPI, extRoomsProvider, mscCfg, natsClient, diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index f370b4f8c..ec90b80db 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -48,7 +48,8 @@ import ( // applied: // nolint: gocyclo func Setup( - publicAPIMux, synapseAdminRouter *mux.Router, cfg *config.ClientAPI, + publicAPIMux, synapseAdminRouter, dendriteAdminRouter *mux.Router, + cfg *config.ClientAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, userAPI userapi.UserInternalAPI, @@ -119,6 +120,45 @@ func Setup( ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) } + dendriteAdminRouter.Handle("/admin/evacuateRoom/{roomID}", + httputil.MakeAuthAPI("admin_evacuate_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if device.AccountType != userapi.AccountTypeAdmin { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("This API can only be used by admin users."), + } + } + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + roomID, ok := vars["roomID"] + if !ok { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.MissingArgument("Expecting room ID."), + } + } + res := &roomserverAPI.PerformAdminEvacuateRoomResponse{} + rsAPI.PerformAdminEvacuateRoom( + req.Context(), + &roomserverAPI.PerformAdminEvacuateRoomRequest{ + RoomID: roomID, + }, + res, + ) + if err := res.Error; err != nil { + return err.JSONResponse() + } + return util.JSONResponse{ + Code: 200, + JSON: map[string]interface{}{ + "affected": res.Affected, + }, + } + }), + ).Methods(http.MethodGet, http.MethodOptions) + // server notifications if cfg.Matrix.ServerNotices.Enabled { logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice") diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index dd1ab3697..7ec810c95 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -193,6 +193,7 @@ func main() { base.PublicWellKnownAPIMux, base.PublicMediaAPIMux, base.SynapseAdminMux, + base.DendriteAdminMux, ) wsUpgrader := websocket.Upgrader{ diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index b840eb2b8..54231f30c 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -150,6 +150,7 @@ func main() { base.PublicWellKnownAPIMux, base.PublicMediaAPIMux, base.SynapseAdminMux, + base.DendriteAdminMux, ) if err := mscs.Enable(base, &monolith); err != nil { logrus.WithError(err).Fatalf("Failed to enable MSCs") diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go index 1443ab5b1..5fd5c0b5d 100644 --- a/cmd/dendrite-monolith-server/main.go +++ b/cmd/dendrite-monolith-server/main.go @@ -153,6 +153,7 @@ func main() { base.PublicWellKnownAPIMux, base.PublicMediaAPIMux, base.SynapseAdminMux, + base.DendriteAdminMux, ) if len(base.Cfg.MSCs.MSCs) > 0 { diff --git a/cmd/dendrite-polylith-multi/personalities/clientapi.go b/cmd/dendrite-polylith-multi/personalities/clientapi.go index 1e509f88a..7ed2075aa 100644 --- a/cmd/dendrite-polylith-multi/personalities/clientapi.go +++ b/cmd/dendrite-polylith-multi/personalities/clientapi.go @@ -31,8 +31,10 @@ func ClientAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) { keyAPI := base.KeyServerHTTPClient() clientapi.AddPublicRoutes( - base.ProcessContext, base.PublicClientAPIMux, base.SynapseAdminMux, &base.Cfg.ClientAPI, - federation, rsAPI, asQuery, transactions.New(), fsAPI, userAPI, userAPI, + base.ProcessContext, base.PublicClientAPIMux, + base.SynapseAdminMux, base.DendriteAdminMux, + &base.Cfg.ClientAPI, federation, rsAPI, asQuery, + transactions.New(), fsAPI, userAPI, userAPI, keyAPI, nil, &cfg.MSCs, ) diff --git a/cmd/dendritejs-pinecone/main.go b/cmd/dendritejs-pinecone/main.go index 211b3e131..0d4b2fbc5 100644 --- a/cmd/dendritejs-pinecone/main.go +++ b/cmd/dendritejs-pinecone/main.go @@ -220,6 +220,7 @@ func startup() { base.PublicWellKnownAPIMux, base.PublicMediaAPIMux, base.SynapseAdminMux, + base.DendriteAdminMux, ) httpRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() diff --git a/federationapi/routing/backfill.go b/federationapi/routing/backfill.go index 31005209f..82f6cbabf 100644 --- a/federationapi/routing/backfill.go +++ b/federationapi/routing/backfill.go @@ -51,6 +51,12 @@ func Backfill( } } + // If we don't think we belong to this room then don't waste the effort + // responding to expensive requests for it. + if err := ErrorIfLocalServerNotInRoom(httpReq.Context(), rsAPI, roomID); err != nil { + return *err + } + // Check if all of the required parameters are there. eIDs, exists = httpReq.URL.Query()["v"] if !exists { diff --git a/federationapi/routing/eventauth.go b/federationapi/routing/eventauth.go index 0a03a0cb4..e83cb8ad2 100644 --- a/federationapi/routing/eventauth.go +++ b/federationapi/routing/eventauth.go @@ -30,6 +30,12 @@ func GetEventAuth( roomID string, eventID string, ) util.JSONResponse { + // If we don't think we belong to this room then don't waste the effort + // responding to expensive requests for it. + if err := ErrorIfLocalServerNotInRoom(ctx, rsAPI, roomID); err != nil { + return *err + } + event, resErr := fetchEvent(ctx, rsAPI, eventID) if resErr != nil { return *resErr diff --git a/federationapi/routing/missingevents.go b/federationapi/routing/missingevents.go index dd3df7aa9..b826d69c4 100644 --- a/federationapi/routing/missingevents.go +++ b/federationapi/routing/missingevents.go @@ -45,6 +45,12 @@ func GetMissingEvents( } } + // If we don't think we belong to this room then don't waste the effort + // responding to expensive requests for it. + if err := ErrorIfLocalServerNotInRoom(httpReq.Context(), rsAPI, roomID); err != nil { + return *err + } + var eventsResponse api.QueryMissingEventsResponse if err := rsAPI.QueryMissingEvents( httpReq.Context(), &api.QueryMissingEventsRequest{ diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index a085ed780..6d24c8b40 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -15,6 +15,8 @@ package routing import ( + "context" + "fmt" "net/http" "github.com/gorilla/mux" @@ -24,6 +26,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/httputil" keyserverAPI "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" @@ -491,3 +494,27 @@ func Setup( }), ).Methods(http.MethodGet) } + +func ErrorIfLocalServerNotInRoom( + ctx context.Context, + rsAPI api.RoomserverInternalAPI, + roomID string, +) *util.JSONResponse { + // Check if we think we're in this room. If we aren't then + // we won't waste CPU cycles serving this request. + joinedReq := &api.QueryServerJoinedToRoomRequest{ + RoomID: roomID, + } + joinedRes := &api.QueryServerJoinedToRoomResponse{} + if err := rsAPI.QueryServerJoinedToRoom(ctx, joinedReq, joinedRes); err != nil { + res := util.ErrorResponse(err) + return &res + } + if !joinedRes.IsInRoom { + return &util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound(fmt.Sprintf("This server is not joined to room %s", roomID)), + } + } + return nil +} diff --git a/federationapi/routing/state.go b/federationapi/routing/state.go index a202c92c2..e2b67776a 100644 --- a/federationapi/routing/state.go +++ b/federationapi/routing/state.go @@ -101,6 +101,12 @@ func getState( roomID string, eventID string, ) (stateEvents, authEvents []*gomatrixserverlib.HeaderedEvent, errRes *util.JSONResponse) { + // If we don't think we belong to this room then don't waste the effort + // responding to expensive requests for it. + if err := ErrorIfLocalServerNotInRoom(ctx, rsAPI, roomID); err != nil { + return nil, nil, err + } + event, resErr := fetchEvent(ctx, rsAPI, eventID) if resErr != nil { return nil, nil, resErr diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index e571c7e56..1677cf8e3 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -319,6 +319,9 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques // JSON, add the signatures and marshal it again, for some reason? for targetUserID, masterKey := range res.MasterKeys { + if masterKey.Signatures == nil { + masterKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } for targetKeyID := range masterKey.Keys { sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID) if err != nil { diff --git a/roomserver/api/api.go b/roomserver/api/api.go index fb77423f8..f0ca8a615 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -66,6 +66,12 @@ type RoomserverInternalAPI interface { res *PerformInboundPeekResponse, ) error + PerformAdminEvacuateRoom( + ctx context.Context, + req *PerformAdminEvacuateRoomRequest, + res *PerformAdminEvacuateRoomResponse, + ) + QueryPublishedRooms( ctx context.Context, req *QueryPublishedRoomsRequest, diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go index ec7211ef8..61c06e886 100644 --- a/roomserver/api/api_trace.go +++ b/roomserver/api/api_trace.go @@ -104,6 +104,15 @@ func (t *RoomserverInternalAPITrace) PerformPublish( util.GetLogger(ctx).Infof("PerformPublish req=%+v res=%+v", js(req), js(res)) } +func (t *RoomserverInternalAPITrace) PerformAdminEvacuateRoom( + ctx context.Context, + req *PerformAdminEvacuateRoomRequest, + res *PerformAdminEvacuateRoomResponse, +) { + t.Impl.PerformAdminEvacuateRoom(ctx, req, res) + util.GetLogger(ctx).Infof("PerformAdminEvacuateRoom req=%+v res=%+v", js(req), js(res)) +} + func (t *RoomserverInternalAPITrace) PerformInboundPeek( ctx context.Context, req *PerformInboundPeekRequest, diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go index cda4b3ee4..30aa2cf1b 100644 --- a/roomserver/api/perform.go +++ b/roomserver/api/perform.go @@ -214,3 +214,12 @@ type PerformRoomUpgradeResponse struct { NewRoomID string Error *PerformError } + +type PerformAdminEvacuateRoomRequest struct { + RoomID string `json:"room_id"` +} + +type PerformAdminEvacuateRoomResponse struct { + Affected []string `json:"affected"` + Error *PerformError +} diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index 59f485cf7..267cd4099 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -35,6 +35,7 @@ type RoomserverInternalAPI struct { *perform.Backfiller *perform.Forgetter *perform.Upgrader + *perform.Admin ProcessContext *process.ProcessContext DB storage.Database Cfg *config.RoomServer @@ -164,6 +165,12 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.FederationInternalA Cfg: r.Cfg, URSAPI: r, } + r.Admin = &perform.Admin{ + DB: r.DB, + Cfg: r.Cfg, + Inputer: r.Inputer, + Queryer: r.Queryer, + } if err := r.Inputer.Start(); err != nil { logrus.WithError(err).Panic("failed to start roomserver input API") diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go new file mode 100644 index 000000000..2de6477cc --- /dev/null +++ b/roomserver/internal/perform/perform_admin.go @@ -0,0 +1,162 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package perform + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/internal/input" + "github.com/matrix-org/dendrite/roomserver/internal/query" + "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" +) + +type Admin struct { + DB storage.Database + Cfg *config.RoomServer + Queryer *query.Queryer + Inputer *input.Inputer +} + +// PerformEvacuateRoom will remove all local users from the given room. +func (r *Admin) PerformAdminEvacuateRoom( + ctx context.Context, + req *api.PerformAdminEvacuateRoomRequest, + res *api.PerformAdminEvacuateRoomResponse, +) { + roomInfo, err := r.DB.RoomInfo(ctx, req.RoomID) + if err != nil { + res.Error = &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: fmt.Sprintf("r.DB.RoomInfo: %s", err), + } + return + } + if roomInfo == nil || roomInfo.IsStub { + res.Error = &api.PerformError{ + Code: api.PerformErrorNoRoom, + Msg: fmt.Sprintf("Room %s not found", req.RoomID), + } + return + } + + memberNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true) + if err != nil { + res.Error = &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: fmt.Sprintf("r.DB.GetMembershipEventNIDsForRoom: %s", err), + } + return + } + + memberEvents, err := r.DB.Events(ctx, memberNIDs) + if err != nil { + res.Error = &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: fmt.Sprintf("r.DB.Events: %s", err), + } + return + } + + inputEvents := make([]api.InputRoomEvent, 0, len(memberEvents)) + res.Affected = make([]string, 0, len(memberEvents)) + latestReq := &api.QueryLatestEventsAndStateRequest{ + RoomID: req.RoomID, + } + latestRes := &api.QueryLatestEventsAndStateResponse{} + if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil { + res.Error = &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: fmt.Sprintf("r.Queryer.QueryLatestEventsAndState: %s", err), + } + return + } + + prevEvents := latestRes.LatestEvents + for _, memberEvent := range memberEvents { + if memberEvent.StateKey() == nil { + continue + } + + var memberContent gomatrixserverlib.MemberContent + if err = json.Unmarshal(memberEvent.Content(), &memberContent); err != nil { + res.Error = &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: fmt.Sprintf("json.Unmarshal: %s", err), + } + return + } + memberContent.Membership = gomatrixserverlib.Leave + + stateKey := *memberEvent.StateKey() + fledglingEvent := &gomatrixserverlib.EventBuilder{ + RoomID: req.RoomID, + Type: gomatrixserverlib.MRoomMember, + StateKey: &stateKey, + Sender: stateKey, + PrevEvents: prevEvents, + } + + if fledglingEvent.Content, err = json.Marshal(memberContent); err != nil { + res.Error = &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: fmt.Sprintf("json.Marshal: %s", err), + } + return + } + + eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(fledglingEvent) + if err != nil { + res.Error = &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: fmt.Sprintf("gomatrixserverlib.StateNeededForEventBuilder: %s", err), + } + return + } + + event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, time.Now(), &eventsNeeded, latestRes) + if err != nil { + res.Error = &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: fmt.Sprintf("eventutil.BuildEvent: %s", err), + } + return + } + + inputEvents = append(inputEvents, api.InputRoomEvent{ + Kind: api.KindNew, + Event: event, + Origin: r.Cfg.Matrix.ServerName, + SendAsServer: string(r.Cfg.Matrix.ServerName), + }) + res.Affected = append(res.Affected, stateKey) + prevEvents = []gomatrixserverlib.EventReference{ + event.EventReference(), + } + } + + inputReq := &api.InputRoomEventsRequest{ + InputRoomEvents: inputEvents, + Asynchronous: true, + } + inputRes := &api.InputRoomEventsResponse{} + r.Inputer.InputRoomEvents(ctx, inputReq, inputRes) +} diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index d55805a91..3b29001e9 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -29,16 +29,17 @@ const ( RoomserverInputRoomEventsPath = "/roomserver/inputRoomEvents" // Perform operations - RoomserverPerformInvitePath = "/roomserver/performInvite" - RoomserverPerformPeekPath = "/roomserver/performPeek" - RoomserverPerformUnpeekPath = "/roomserver/performUnpeek" - RoomserverPerformRoomUpgradePath = "/roomserver/performRoomUpgrade" - RoomserverPerformJoinPath = "/roomserver/performJoin" - RoomserverPerformLeavePath = "/roomserver/performLeave" - RoomserverPerformBackfillPath = "/roomserver/performBackfill" - RoomserverPerformPublishPath = "/roomserver/performPublish" - RoomserverPerformInboundPeekPath = "/roomserver/performInboundPeek" - RoomserverPerformForgetPath = "/roomserver/performForget" + RoomserverPerformInvitePath = "/roomserver/performInvite" + RoomserverPerformPeekPath = "/roomserver/performPeek" + RoomserverPerformUnpeekPath = "/roomserver/performUnpeek" + RoomserverPerformRoomUpgradePath = "/roomserver/performRoomUpgrade" + RoomserverPerformJoinPath = "/roomserver/performJoin" + RoomserverPerformLeavePath = "/roomserver/performLeave" + RoomserverPerformBackfillPath = "/roomserver/performBackfill" + RoomserverPerformPublishPath = "/roomserver/performPublish" + RoomserverPerformInboundPeekPath = "/roomserver/performInboundPeek" + RoomserverPerformForgetPath = "/roomserver/performForget" + RoomserverPerformAdminEvacuateRoomPath = "/roomserver/performAdminEvacuateRoom" // Query operations RoomserverQueryLatestEventsAndStatePath = "/roomserver/queryLatestEventsAndState" @@ -299,6 +300,23 @@ func (h *httpRoomserverInternalAPI) PerformPublish( } } +func (h *httpRoomserverInternalAPI) PerformAdminEvacuateRoom( + ctx context.Context, + req *api.PerformAdminEvacuateRoomRequest, + res *api.PerformAdminEvacuateRoomResponse, +) { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformAdminEvacuateRoom") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverPerformAdminEvacuateRoomPath + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) + if err != nil { + res.Error = &api.PerformError{ + Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err), + } + } +} + // QueryLatestEventsAndState implements RoomserverQueryAPI func (h *httpRoomserverInternalAPI) QueryLatestEventsAndState( ctx context.Context, diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go index 0b27b5a8d..c5159a63c 100644 --- a/roomserver/inthttp/server.go +++ b/roomserver/inthttp/server.go @@ -118,6 +118,17 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(RoomserverPerformAdminEvacuateRoomPath, + httputil.MakeInternalAPI("performAdminEvacuateRoom", func(req *http.Request) util.JSONResponse { + var request api.PerformAdminEvacuateRoomRequest + var response api.PerformAdminEvacuateRoomResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + r.PerformAdminEvacuateRoom(req.Context(), &request, &response) + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) internalAPIMux.Handle( RoomserverQueryPublishedRoomsPath, httputil.MakeInternalAPI("queryPublishedRooms", func(req *http.Request) util.JSONResponse { diff --git a/setup/monolith.go b/setup/monolith.go index 32f1a6494..c86ec7b69 100644 --- a/setup/monolith.go +++ b/setup/monolith.go @@ -54,13 +54,13 @@ type Monolith struct { } // AddAllPublicRoutes attaches all public paths to the given router -func (m *Monolith) AddAllPublicRoutes(process *process.ProcessContext, csMux, ssMux, keyMux, wkMux, mediaMux, synapseMux *mux.Router) { +func (m *Monolith) AddAllPublicRoutes(process *process.ProcessContext, csMux, ssMux, keyMux, wkMux, mediaMux, synapseMux, dendriteMux *mux.Router) { userDirectoryProvider := m.ExtUserDirectoryProvider if userDirectoryProvider == nil { userDirectoryProvider = m.UserAPI } clientapi.AddPublicRoutes( - process, csMux, synapseMux, &m.Config.ClientAPI, + process, csMux, synapseMux, dendriteMux, &m.Config.ClientAPI, m.FedClient, m.RoomserverAPI, m.AppserviceAPI, transactions.New(), m.FederationAPI, m.UserAPI, userDirectoryProvider, m.KeyAPI, diff --git a/syncapi/notifier/notifier.go b/syncapi/notifier/notifier.go index 82834239b..87f0d86d7 100644 --- a/syncapi/notifier/notifier.go +++ b/syncapi/notifier/notifier.go @@ -333,6 +333,20 @@ func (n *Notifier) Load(ctx context.Context, db storage.Database) error { return nil } +// LoadRooms loads the membership states required to notify users correctly. +func (n *Notifier) LoadRooms(ctx context.Context, db storage.Database, roomIDs []string) error { + n.lock.Lock() + defer n.lock.Unlock() + + roomToUsers, err := db.AllJoinedUsersInRoom(ctx, roomIDs) + if err != nil { + return err + } + n.setUsersJoinedToRooms(roomToUsers) + + return nil +} + // CurrentPosition returns the current sync position func (n *Notifier) CurrentPosition() types.StreamingToken { n.lock.RLock() diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 43aaa3588..5a036d889 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -52,6 +52,9 @@ type Database interface { // AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs. AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) + // AllJoinedUsersInRoom returns a map of room ID to a list of all joined user IDs for a given room. + AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error) + // AllPeekingDevicesInRooms returns a map of room ID to a list of all peeking devices. AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) // Events lookups a list of event by their event ID. @@ -159,6 +162,6 @@ type Database interface { type Presence interface { UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) - PresenceAfter(ctx context.Context, after types.StreamPosition) (map[string]*types.PresenceInternal, error) + PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) } diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index fe68788d1..8ee387b39 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -93,6 +93,9 @@ const selectCurrentStateSQL = "" + const selectJoinedUsersSQL = "" + "SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'" +const selectJoinedUsersInRoomSQL = "" + + "SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join' AND room_id = ANY($1)" + const selectStateEventSQL = "" + "SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3" @@ -112,6 +115,7 @@ type currentRoomStateStatements struct { selectRoomIDsWithAnyMembershipStmt *sql.Stmt selectCurrentStateStmt *sql.Stmt selectJoinedUsersStmt *sql.Stmt + selectJoinedUsersInRoomStmt *sql.Stmt selectEventsWithEventIDsStmt *sql.Stmt selectStateEventStmt *sql.Stmt } @@ -143,6 +147,9 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil { return nil, err } + if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil { + return nil, err + } if s.selectEventsWithEventIDsStmt, err = db.Prepare(selectEventsWithEventIDsSQL); err != nil { return nil, err } @@ -163,9 +170,32 @@ func (s *currentRoomStateStatements) SelectJoinedUsers( defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsers: rows.close() failed") result := make(map[string][]string) + var roomID string + var userID string + for rows.Next() { + if err := rows.Scan(&roomID, &userID); err != nil { + return nil, err + } + users := result[roomID] + users = append(users, userID) + result[roomID] = users + } + return result, rows.Err() +} + +// SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room. +func (s *currentRoomStateStatements) SelectJoinedUsersInRoom( + ctx context.Context, roomIDs []string, +) (map[string][]string, error) { + rows, err := s.selectJoinedUsersInRoomStmt.QueryContext(ctx, pq.StringArray(roomIDs)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsers: rows.close() failed") + + result := make(map[string][]string) + var userID, roomID string for rows.Next() { - var roomID string - var userID string if err := rows.Scan(&roomID, &userID); err != nil { return nil, err } diff --git a/syncapi/storage/postgres/memberships_table.go b/syncapi/storage/postgres/memberships_table.go index 8c049977f..00223c57a 100644 --- a/syncapi/storage/postgres/memberships_table.go +++ b/syncapi/storage/postgres/memberships_table.go @@ -64,7 +64,7 @@ const selectMembershipCountSQL = "" + ") t WHERE t.membership = $3" const selectHeroesSQL = "" + - "SELECT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership = ANY($3) LIMIT 5" + "SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership = ANY($3) LIMIT 5" type membershipsStatements struct { upsertMembershipStmt *sql.Stmt diff --git a/syncapi/storage/postgres/presence_table.go b/syncapi/storage/postgres/presence_table.go index 9f1e37f79..7194afea6 100644 --- a/syncapi/storage/postgres/presence_table.go +++ b/syncapi/storage/postgres/presence_table.go @@ -17,6 +17,7 @@ package postgres import ( "context" "database/sql" + "time" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -72,7 +73,8 @@ const selectMaxPresenceSQL = "" + const selectPresenceAfter = "" + " SELECT id, user_id, presence, status_msg, last_active_ts" + " FROM syncapi_presence" + - " WHERE id > $1" + " WHERE id > $1 AND last_active_ts >= $2" + + " ORDER BY id ASC LIMIT $3" type presenceStatements struct { upsertPresenceStmt *sql.Stmt @@ -144,11 +146,12 @@ func (p *presenceStatements) GetMaxPresenceID(ctx context.Context, txn *sql.Tx) func (p *presenceStatements) GetPresenceAfter( ctx context.Context, txn *sql.Tx, after types.StreamPosition, + filter gomatrixserverlib.EventFilter, ) (presences map[string]*types.PresenceInternal, err error) { presences = make(map[string]*types.PresenceInternal) stmt := sqlutil.TxStmt(txn, p.selectPresenceAfterStmt) - - rows, err := stmt.QueryContext(ctx, after) + afterTS := gomatrixserverlib.AsTimestamp(time.Now().Add(time.Minute * -5)) + rows, err := stmt.QueryContext(ctx, after, afterTS, filter.Limit) if err != nil { return nil, err } diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 25aca50ae..ec5edd355 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -168,6 +168,10 @@ func (d *Database) AllJoinedUsersInRooms(ctx context.Context) (map[string][]stri return d.CurrentRoomState.SelectJoinedUsers(ctx) } +func (d *Database) AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error) { + return d.CurrentRoomState.SelectJoinedUsersInRoom(ctx, roomIDs) +} + func (d *Database) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) { return d.Peeks.SelectPeekingDevices(ctx) } @@ -1056,8 +1060,8 @@ func (s *Database) GetPresence(ctx context.Context, userID string) (*types.Prese return s.Presence.GetPresenceForUser(ctx, nil, userID) } -func (s *Database) PresenceAfter(ctx context.Context, after types.StreamPosition) (map[string]*types.PresenceInternal, error) { - return s.Presence.GetPresenceAfter(ctx, nil, after) +func (s *Database) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) { + return s.Presence.GetPresenceAfter(ctx, nil, after, filter) } func (s *Database) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) { diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index ccda005c1..f0a1c7bb7 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -77,6 +77,9 @@ const selectCurrentStateSQL = "" + const selectJoinedUsersSQL = "" + "SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'" +const selectJoinedUsersInRoomSQL = "" + + "SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join' AND room_id IN ($1)" + const selectStateEventSQL = "" + "SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3" @@ -97,7 +100,8 @@ type currentRoomStateStatements struct { selectRoomIDsWithMembershipStmt *sql.Stmt selectRoomIDsWithAnyMembershipStmt *sql.Stmt selectJoinedUsersStmt *sql.Stmt - selectStateEventStmt *sql.Stmt + //selectJoinedUsersInRoomStmt *sql.Stmt - prepared at runtime due to variadic + selectStateEventStmt *sql.Stmt } func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) { @@ -127,13 +131,16 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (t if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil { return nil, err } + //if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil { + // return nil, err + //} if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil { return nil, err } return s, nil } -// JoinedMemberLists returns a map of room ID to a list of joined user IDs. +// SelectJoinedUsers returns a map of room ID to a list of joined user IDs. func (s *currentRoomStateStatements) SelectJoinedUsers( ctx context.Context, ) (map[string][]string, error) { @@ -144,9 +151,9 @@ func (s *currentRoomStateStatements) SelectJoinedUsers( defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsers: rows.close() failed") result := make(map[string][]string) + var roomID string + var userID string for rows.Next() { - var roomID string - var userID string if err := rows.Scan(&roomID, &userID); err != nil { return nil, err } @@ -157,6 +164,40 @@ func (s *currentRoomStateStatements) SelectJoinedUsers( return result, nil } +// SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room. +func (s *currentRoomStateStatements) SelectJoinedUsersInRoom( + ctx context.Context, roomIDs []string, +) (map[string][]string, error) { + query := strings.Replace(selectJoinedUsersInRoomSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1) + params := make([]interface{}, 0, len(roomIDs)) + for _, roomID := range roomIDs { + params = append(params, roomID) + } + stmt, err := s.db.Prepare(query) + if err != nil { + return nil, fmt.Errorf("SelectJoinedUsersInRoom s.db.Prepare: %w", err) + } + defer internal.CloseAndLogIfError(ctx, stmt, "SelectJoinedUsersInRoom: stmt.close() failed") + + rows, err := stmt.QueryContext(ctx, params...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectJoinedUsersInRoom: rows.close() failed") + + result := make(map[string][]string) + var userID, roomID string + for rows.Next() { + if err := rows.Scan(&roomID, &userID); err != nil { + return nil, err + } + users := result[roomID] + users = append(users, userID) + result[roomID] = users + } + return result, rows.Err() +} + // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. func (s *currentRoomStateStatements) SelectRoomIDsWithMembership( ctx context.Context, diff --git a/syncapi/storage/sqlite3/presence_table.go b/syncapi/storage/sqlite3/presence_table.go index 177a01bf3..b61a825df 100644 --- a/syncapi/storage/sqlite3/presence_table.go +++ b/syncapi/storage/sqlite3/presence_table.go @@ -17,6 +17,7 @@ package sqlite3 import ( "context" "database/sql" + "time" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -71,7 +72,8 @@ const selectMaxPresenceSQL = "" + const selectPresenceAfter = "" + " SELECT id, user_id, presence, status_msg, last_active_ts" + " FROM syncapi_presence" + - " WHERE id > $1" + " WHERE id > $1 AND last_active_ts >= $2" + + " ORDER BY id ASC LIMIT $3" type presenceStatements struct { db *sql.DB @@ -158,12 +160,12 @@ func (p *presenceStatements) GetMaxPresenceID(ctx context.Context, txn *sql.Tx) // GetPresenceAfter returns the changes presences after a given stream id func (p *presenceStatements) GetPresenceAfter( ctx context.Context, txn *sql.Tx, - after types.StreamPosition, + after types.StreamPosition, filter gomatrixserverlib.EventFilter, ) (presences map[string]*types.PresenceInternal, err error) { presences = make(map[string]*types.PresenceInternal) stmt := sqlutil.TxStmt(txn, p.selectPresenceAfterStmt) - - rows, err := stmt.QueryContext(ctx, after) + afterTS := gomatrixserverlib.AsTimestamp(time.Now().Add(time.Minute * -5)) + rows, err := stmt.QueryContext(ctx, after, afterTS, filter.Limit) if err != nil { return nil, err } diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 4ff4689ed..ccdebfdbd 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -102,6 +102,8 @@ type CurrentRoomState interface { SelectRoomIDsWithAnyMembership(ctx context.Context, txn *sql.Tx, userID string) (map[string]string, error) // SelectJoinedUsers returns a map of room ID to a list of joined user IDs. SelectJoinedUsers(ctx context.Context) (map[string][]string, error) + // SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room. + SelectJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error) } // BackwardsExtremities keeps track of backwards extremities for a room. @@ -188,5 +190,5 @@ type Presence interface { UpsertPresence(ctx context.Context, txn *sql.Tx, userID string, statusMsg *string, presence types.Presence, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (pos types.StreamPosition, err error) GetPresenceForUser(ctx context.Context, txn *sql.Tx, userID string) (presence *types.PresenceInternal, err error) GetMaxPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) - GetPresenceAfter(ctx context.Context, txn *sql.Tx, after types.StreamPosition) (presences map[string]*types.PresenceInternal, err error) + GetPresenceAfter(ctx context.Context, txn *sql.Tx, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (presences map[string]*types.PresenceInternal, err error) } diff --git a/syncapi/streams/stream_presence.go b/syncapi/streams/stream_presence.go index 614b88d48..a84d19878 100644 --- a/syncapi/streams/stream_presence.go +++ b/syncapi/streams/stream_presence.go @@ -53,7 +53,8 @@ func (p *PresenceStreamProvider) IncrementalSync( req *types.SyncRequest, from, to types.StreamPosition, ) types.StreamPosition { - presences, err := p.DB.PresenceAfter(ctx, from) + // We pull out a larger number than the filter asks for, since we're filtering out events later + presences, err := p.DB.PresenceAfter(ctx, from, gomatrixserverlib.EventFilter{Limit: 1000}) if err != nil { req.Log.WithError(err).Error("p.DB.PresenceAfter failed") return from @@ -66,12 +67,12 @@ func (p *PresenceStreamProvider) IncrementalSync( // add newly joined rooms user presences newlyJoined := joinedRooms(req.Response, req.Device.UserID) if len(newlyJoined) > 0 { - // TODO: This refreshes all lists and is quite expensive - // The notifier should update the lists itself - if err = p.notifier.Load(ctx, p.DB); err != nil { + // TODO: Check if this is working better than before. + if err = p.notifier.LoadRooms(ctx, p.DB, newlyJoined); err != nil { req.Log.WithError(err).Error("unable to refresh notifier lists") return from } + NewlyJoinedLoop: for _, roomID := range newlyJoined { roomUsers := p.notifier.JoinedUsers(roomID) for i := range roomUsers { @@ -86,11 +87,14 @@ func (p *PresenceStreamProvider) IncrementalSync( req.Log.WithError(err).Error("unable to query presence for user") return from } + if len(presences) > req.Filter.Presence.Limit { + break NewlyJoinedLoop + } } } } - lastPos := to + lastPos := from for _, presence := range presences { if presence == nil { continue @@ -135,6 +139,9 @@ func (p *PresenceStreamProvider) IncrementalSync( if presence.StreamPos > lastPos { lastPos = presence.StreamPos } + if len(req.Response.Presence.Events) == req.Filter.Presence.Limit { + break + } p.cache.Store(cacheKey, presence) } diff --git a/syncapi/sync/requestpool_test.go b/syncapi/sync/requestpool_test.go index a80089945..5e52bc7c9 100644 --- a/syncapi/sync/requestpool_test.go +++ b/syncapi/sync/requestpool_test.go @@ -30,7 +30,7 @@ func (d dummyDB) GetPresence(ctx context.Context, userID string) (*types.Presenc return &types.PresenceInternal{}, nil } -func (d dummyDB) PresenceAfter(ctx context.Context, after types.StreamPosition) (map[string]*types.PresenceInternal, error) { +func (d dummyDB) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) { return map[string]*types.PresenceInternal{}, nil } diff --git a/sytest-blacklist b/sytest-blacklist index 713a5b631..be0826eee 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -48,4 +48,3 @@ Notifications can be viewed with GET /notifications # More flakey If remote user leaves room we no longer receive device updates -Local device key changes get to remote servers diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go index fe8c54e04..6c777982f 100644 --- a/userapi/storage/postgres/devices_table.go +++ b/userapi/storage/postgres/devices_table.go @@ -75,7 +75,7 @@ const selectDeviceByTokenSQL = "" + "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1" const selectDeviceByIDSQL = "" + - "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" + "SELECT display_name, last_seen_ts, ip FROM device_devices WHERE localpart = $1 and device_id = $2" const selectDevicesByLocalpartSQL = "" + "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" @@ -215,15 +215,22 @@ func (s *devicesStatements) SelectDeviceByID( ctx context.Context, localpart, deviceID string, ) (*api.Device, error) { var dev api.Device - var displayName sql.NullString + var displayName, ip sql.NullString + var lastseenTS sql.NullInt64 stmt := s.selectDeviceByIDStmt - err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName) + err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip) if err == nil { dev.ID = deviceID dev.UserID = userutil.MakeUserID(localpart, s.serverName) if displayName.Valid { dev.DisplayName = displayName.String } + if lastseenTS.Valid { + dev.LastSeenTS = lastseenTS.Int64 + } + if ip.Valid { + dev.LastSeenIP = ip.String + } } return &dev, err } diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go index 7860bd6a2..b86ed1cc2 100644 --- a/userapi/storage/sqlite3/devices_table.go +++ b/userapi/storage/sqlite3/devices_table.go @@ -60,7 +60,7 @@ const selectDeviceByTokenSQL = "" + "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1" const selectDeviceByIDSQL = "" + - "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" + "SELECT display_name, last_seen_ts, ip FROM device_devices WHERE localpart = $1 and device_id = $2" const selectDevicesByLocalpartSQL = "" + "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" @@ -212,15 +212,22 @@ func (s *devicesStatements) SelectDeviceByID( ctx context.Context, localpart, deviceID string, ) (*api.Device, error) { var dev api.Device - var displayName sql.NullString + var displayName, ip sql.NullString stmt := s.selectDeviceByIDStmt - err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName) + var lastseenTS sql.NullInt64 + err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip) if err == nil { dev.ID = deviceID dev.UserID = userutil.MakeUserID(localpart, s.serverName) if displayName.Valid { dev.DisplayName = displayName.String } + if lastseenTS.Valid { + dev.LastSeenTS = lastseenTS.Int64 + } + if ip.Valid { + dev.LastSeenIP = ip.String + } } return &dev, err } diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index e6c7d35fc..2eb57d0bc 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -180,12 +180,12 @@ func Test_Devices(t *testing.T) { deviceWithID.DisplayName = newName deviceWithID.LastSeenIP = "127.0.0.1" deviceWithID.LastSeenTS = int64(gomatrixserverlib.AsTimestamp(time.Now().Truncate(time.Second))) - devices, err = db.GetDevicesByLocalpart(ctx, localpart) + gotDevice, err = db.GetDeviceByID(ctx, localpart, deviceWithID.ID) assert.NoError(t, err, "unable to get device by id") assert.Equal(t, 2, len(devices)) - assert.Equal(t, deviceWithID.DisplayName, devices[0].DisplayName) - assert.Equal(t, deviceWithID.LastSeenIP, devices[0].LastSeenIP) - truncatedTime := gomatrixserverlib.Timestamp(devices[0].LastSeenTS).Time().Truncate(time.Second) + assert.Equal(t, deviceWithID.DisplayName, gotDevice.DisplayName) + assert.Equal(t, deviceWithID.LastSeenIP, gotDevice.LastSeenIP) + truncatedTime := gomatrixserverlib.Timestamp(gotDevice.LastSeenTS).Time().Truncate(time.Second) assert.Equal(t, gomatrixserverlib.Timestamp(deviceWithID.LastSeenTS), gomatrixserverlib.AsTimestamp(truncatedTime)) // create one more device and remove the devices step by step