diff --git a/README.md b/README.md index 864a24e99..ecb3cdb9b 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,18 @@ look for [Good First Issues](https://github.com/matrix-org/dendrite/labels/good% familiar with the project, look for [Help Wanted](https://github.com/matrix-org/dendrite/labels/help-wanted) issues. +# Hardware requirements + +Dendrite in Monolith + SQLite works in a range of environments including iOS and in-browser via WASM. + +For small homeserver installations joined on ~10s rooms on matrix.org with ~100s of users in those rooms, including some +encrypted rooms: + - Memory: uses around 100MB of RAM, with peaks at around 200MB. + - Disk space: After a few months of usage, the database grew to around 2GB (in Monolith mode). + - CPU: Brief spikes when processing events, typically idles at 1% CPU. + +This means Dendrite should comfortably work on things like Raspberry Pis. + # Discussion For questions about Dendrite we have a dedicated room on Matrix diff --git a/build/docker/config/dendrite-config.yaml b/build/docker/config/dendrite-config.yaml index 8cc9934d4..7ebeeb6e7 100644 --- a/build/docker/config/dendrite-config.yaml +++ b/build/docker/config/dendrite-config.yaml @@ -133,17 +133,6 @@ client_api: turn_username: "" turn_password: "" -# Configuration for the Current State Server. -current_state_server: - internal_api: - listen: http://0.0.0.0:7782 - connect: http://current_state_server:7782 - database: - connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_currentstate?sslmode=disable - max_open_conns: 100 - max_idle_conns: 2 - conn_max_lifetime: -1 - # Configuration for the EDU server. edu_server: internal_api: diff --git a/build/docker/docker-compose.polylith.yml b/build/docker/docker-compose.polylith.yml index 1f84e58d1..6dd743141 100644 --- a/build/docker/docker-compose.polylith.yml +++ b/build/docker/docker-compose.polylith.yml @@ -43,17 +43,6 @@ services: networks: - internal - current_state_server: - hostname: current_state_server - image: matrixdotorg/dendrite:currentstateserver - command: [ - "--config=dendrite.yaml" - ] - volumes: - - ./config:/etc/dendrite - networks: - - internal - sync_api: hostname: sync_api image: matrixdotorg/dendrite:syncapi diff --git a/build/docker/images-build.sh b/build/docker/images-build.sh index 443f30920..fdff51320 100755 --- a/build/docker/images-build.sh +++ b/build/docker/images-build.sh @@ -15,7 +15,6 @@ docker build -t matrixdotorg/dendrite:federationsender --build-arg component=de docker build -t matrixdotorg/dendrite:federationproxy --build-arg component=federation-api-proxy -f build/docker/Dockerfile.component . docker build -t matrixdotorg/dendrite:keyserver --build-arg component=dendrite-key-server -f build/docker/Dockerfile.component . docker build -t matrixdotorg/dendrite:mediaapi --build-arg component=dendrite-media-api-server -f build/docker/Dockerfile.component . -docker build -t matrixdotorg/dendrite:currentstateserver --build-arg component=dendrite-current-state-server -f build/docker/Dockerfile.component . docker build -t matrixdotorg/dendrite:roomserver --build-arg component=dendrite-room-server -f build/docker/Dockerfile.component . docker build -t matrixdotorg/dendrite:syncapi --build-arg component=dendrite-sync-api-server -f build/docker/Dockerfile.component . docker build -t matrixdotorg/dendrite:serverkeyapi --build-arg component=dendrite-server-key-api-server -f build/docker/Dockerfile.component . diff --git a/build/docker/images-pull.sh b/build/docker/images-pull.sh index b4a4b2fce..c6b09b6a4 100755 --- a/build/docker/images-pull.sh +++ b/build/docker/images-pull.sh @@ -11,7 +11,6 @@ docker pull matrixdotorg/dendrite:federationsender docker pull matrixdotorg/dendrite:federationproxy docker pull matrixdotorg/dendrite:keyserver docker pull matrixdotorg/dendrite:mediaapi -docker pull matrixdotorg/dendrite:currentstateserver docker pull matrixdotorg/dendrite:roomserver docker pull matrixdotorg/dendrite:syncapi docker pull matrixdotorg/dendrite:userapi diff --git a/build/docker/images-push.sh b/build/docker/images-push.sh index ec1e860f0..4838c76f6 100755 --- a/build/docker/images-push.sh +++ b/build/docker/images-push.sh @@ -11,7 +11,6 @@ docker push matrixdotorg/dendrite:federationsender docker push matrixdotorg/dendrite:federationproxy docker push matrixdotorg/dendrite:keyserver docker push matrixdotorg/dendrite:mediaapi -docker push matrixdotorg/dendrite:currentstateserver docker push matrixdotorg/dendrite:roomserver docker push matrixdotorg/dendrite:syncapi docker push matrixdotorg/dendrite:serverkeyapi diff --git a/build/docker/postgres/create_db.sh b/build/docker/postgres/create_db.sh index 222675f6e..70d6743e4 100644 --- a/build/docker/postgres/create_db.sh +++ b/build/docker/postgres/create_db.sh @@ -1,5 +1,5 @@ #!/bin/bash -for db in account device mediaapi syncapi roomserver serverkey keyserver federationsender currentstate appservice e2ekey naffka; do +for db in account device mediaapi syncapi roomserver serverkey keyserver federationsender appservice e2ekey naffka; do createdb -U dendrite -O dendrite dendrite_$db done diff --git a/build/gobind/monolith.go b/build/gobind/monolith.go index 725c9c074..27b116487 100644 --- a/build/gobind/monolith.go +++ b/build/gobind/monolith.go @@ -13,7 +13,6 @@ import ( "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/yggconn" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/yggrooms" - "github.com/matrix-org/dendrite/currentstateserver" "github.com/matrix-org/dendrite/eduserver" "github.com/matrix-org/dendrite/eduserver/cache" "github.com/matrix-org/dendrite/federationsender" @@ -99,7 +98,6 @@ func (m *DendriteMonolith) Start() { cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-keyserver.db", m.StorageDirectory)) cfg.FederationSender.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-federationsender.db", m.StorageDirectory)) cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-appservice.db", m.StorageDirectory)) - cfg.CurrentStateServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-currentstate.db", m.StorageDirectory)) cfg.MediaAPI.BasePath = config.Path(fmt.Sprintf("%s/tmp", m.StorageDirectory)) cfg.MediaAPI.AbsBasePath = config.Path(fmt.Sprintf("%s/tmp", m.StorageDirectory)) cfg.FederationSender.FederationMaxRetries = 8 @@ -128,9 +126,8 @@ func (m *DendriteMonolith) Start() { ) asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) - stateAPI := currentstateserver.NewInternalAPI(&base.Cfg.CurrentStateServer, base.KafkaConsumer) fsAPI := federationsender.NewInternalAPI( - base, federation, rsAPI, stateAPI, keyRing, + base, federation, rsAPI, keyRing, ) ygg.SetSessionFunc(func(address string) { @@ -163,7 +160,6 @@ func (m *DendriteMonolith) Start() { FederationSenderAPI: fsAPI, RoomserverAPI: rsAPI, UserAPI: userAPI, - StateAPI: stateAPI, KeyAPI: keyAPI, ExtPublicRoomsProvider: yggrooms.NewYggdrasilRoomProvider( ygg, fsAPI, federation, diff --git a/clientapi/auth/authtypes/logintypes.go b/clientapi/auth/authtypes/logintypes.go index 087e45043..da0324251 100644 --- a/clientapi/auth/authtypes/logintypes.go +++ b/clientapi/auth/authtypes/logintypes.go @@ -5,6 +5,7 @@ type LoginType string // The relevant login types implemented in Dendrite const ( + LoginTypePassword = "m.login.password" LoginTypeDummy = "m.login.dummy" LoginTypeSharedSecret = "org.matrix.login.shared_secret" LoginTypeRecaptcha = "m.login.recaptcha" diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index fe6789fcb..2ab92ed4e 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -21,7 +21,6 @@ import ( "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/clientapi/routing" - currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api" eduServerAPI "github.com/matrix-org/dendrite/eduserver/api" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/internal/config" @@ -43,7 +42,6 @@ func AddPublicRoutes( rsAPI roomserverAPI.RoomserverInternalAPI, eduInputAPI eduServerAPI.EDUServerInputAPI, asAPI appserviceAPI.AppServiceQueryAPI, - stateAPI currentstateAPI.CurrentStateInternalAPI, transactionsCache *transactions.Cache, fsAPI federationSenderAPI.FederationSenderInternalAPI, userAPI userapi.UserInternalAPI, @@ -58,6 +56,6 @@ func AddPublicRoutes( routing.Setup( router, cfg, eduInputAPI, rsAPI, asAPI, accountsDB, userAPI, federation, - syncProducer, transactionsCache, fsAPI, stateAPI, keyAPI, extRoomsProvider, + syncProducer, transactionsCache, fsAPI, keyAPI, extRoomsProvider, ) } diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index 62f295fef..e64d6b233 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -20,7 +20,6 @@ import ( "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" - currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/internal/config" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" @@ -270,10 +269,10 @@ func GetVisibility( // SetVisibility implements PUT /directory/list/room/{roomID} // TODO: Allow admin users to edit the room visibility func SetVisibility( - req *http.Request, stateAPI currentstateAPI.CurrentStateInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI, dev *userapi.Device, + req *http.Request, rsAPI roomserverAPI.RoomserverInternalAPI, dev *userapi.Device, roomID string, ) util.JSONResponse { - resErr := checkMemberInRoom(req.Context(), stateAPI, dev.UserID, roomID) + resErr := checkMemberInRoom(req.Context(), rsAPI, dev.UserID, roomID) if resErr != nil { return *resErr } diff --git a/clientapi/routing/directory_public.go b/clientapi/routing/directory_public.go index bae7e49bb..fd7bc1e86 100644 --- a/clientapi/routing/directory_public.go +++ b/clientapi/routing/directory_public.go @@ -26,7 +26,6 @@ import ( "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" - currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api" "github.com/matrix-org/dendrite/internal/config" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" @@ -51,7 +50,7 @@ type filter struct { // GetPostPublicRooms implements GET and POST /publicRooms func GetPostPublicRooms( - req *http.Request, rsAPI roomserverAPI.RoomserverInternalAPI, stateAPI currentstateAPI.CurrentStateInternalAPI, + req *http.Request, rsAPI roomserverAPI.RoomserverInternalAPI, extRoomsProvider api.ExtraPublicRoomsProvider, federation *gomatrixserverlib.FederationClient, cfg *config.ClientAPI, @@ -75,7 +74,7 @@ func GetPostPublicRooms( } } - response, err := publicRooms(req.Context(), request, rsAPI, stateAPI, extRoomsProvider) + response, err := publicRooms(req.Context(), request, rsAPI, extRoomsProvider) if err != nil { util.GetLogger(req.Context()).WithError(err).Errorf("failed to work out public rooms") return jsonerror.InternalServerError() @@ -86,8 +85,8 @@ func GetPostPublicRooms( } } -func publicRooms(ctx context.Context, request PublicRoomReq, rsAPI roomserverAPI.RoomserverInternalAPI, - stateAPI currentstateAPI.CurrentStateInternalAPI, extRoomsProvider api.ExtraPublicRoomsProvider, +func publicRooms( + ctx context.Context, request PublicRoomReq, rsAPI roomserverAPI.RoomserverInternalAPI, extRoomsProvider api.ExtraPublicRoomsProvider, ) (*gomatrixserverlib.RespPublicRooms, error) { response := gomatrixserverlib.RespPublicRooms{ @@ -110,7 +109,7 @@ func publicRooms(ctx context.Context, request PublicRoomReq, rsAPI roomserverAPI var rooms []gomatrixserverlib.PublicRoom if request.Since == "" { - rooms = refreshPublicRoomCache(ctx, rsAPI, extRoomsProvider, stateAPI) + rooms = refreshPublicRoomCache(ctx, rsAPI, extRoomsProvider) } else { rooms = getPublicRoomsFromCache() } @@ -226,7 +225,6 @@ func sliceInto(slice []gomatrixserverlib.PublicRoom, since int64, limit int16) ( func refreshPublicRoomCache( ctx context.Context, rsAPI roomserverAPI.RoomserverInternalAPI, extRoomsProvider api.ExtraPublicRoomsProvider, - stateAPI currentstateAPI.CurrentStateInternalAPI, ) []gomatrixserverlib.PublicRoom { cacheMu.Lock() defer cacheMu.Unlock() @@ -241,7 +239,7 @@ func refreshPublicRoomCache( util.GetLogger(ctx).WithError(err).Error("QueryPublishedRooms failed") return publicRoomsCache } - pubRooms, err := currentstateAPI.PopulatePublicRooms(ctx, queryRes.RoomIDs, stateAPI) + pubRooms, err := roomserverAPI.PopulatePublicRooms(ctx, queryRes.RoomIDs, rsAPI) if err != nil { util.GetLogger(ctx).WithError(err).Error("PopulatePublicRooms failed") return publicRoomsCache diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 202662ab6..88cb23647 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -25,7 +25,6 @@ import ( "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/threepid" - currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api" "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" @@ -95,7 +94,6 @@ func SendKick( req *http.Request, accountDB accounts.Database, device *userapi.Device, roomID string, cfg *config.ClientAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, - stateAPI currentstateAPI.CurrentStateInternalAPI, ) util.JSONResponse { body, evTime, roomVer, reqErr := extractRequestData(req, roomID, rsAPI) if reqErr != nil { @@ -108,7 +106,7 @@ func SendKick( } } - errRes := checkMemberInRoom(req.Context(), stateAPI, device.UserID, roomID) + errRes := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) if errRes != nil { return *errRes } @@ -372,13 +370,13 @@ func checkAndProcessThreepid( return } -func checkMemberInRoom(ctx context.Context, stateAPI currentstateAPI.CurrentStateInternalAPI, userID, roomID string) *util.JSONResponse { +func checkMemberInRoom(ctx context.Context, rsAPI api.RoomserverInternalAPI, userID, roomID string) *util.JSONResponse { tuple := gomatrixserverlib.StateKeyTuple{ EventType: gomatrixserverlib.MRoomMember, StateKey: userID, } - var membershipRes currentstateAPI.QueryCurrentStateResponse - err := stateAPI.QueryCurrentState(ctx, ¤tstateAPI.QueryCurrentStateRequest{ + var membershipRes api.QueryCurrentStateResponse + err := rsAPI.QueryCurrentState(ctx, &api.QueryCurrentStateRequest{ RoomID: roomID, StateTuples: []gomatrixserverlib.StateKeyTuple{tuple}, }, &membershipRes) diff --git a/clientapi/routing/memberships.go b/clientapi/routing/memberships.go index 560593500..613484875 100644 --- a/clientapi/routing/memberships.go +++ b/clientapi/routing/memberships.go @@ -19,7 +19,6 @@ import ( "net/http" "github.com/matrix-org/dendrite/clientapi/jsonerror" - currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api" "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/roomserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" @@ -94,10 +93,10 @@ func GetMemberships( func GetJoinedRooms( req *http.Request, device *userapi.Device, - stateAPI currentstateAPI.CurrentStateInternalAPI, + rsAPI api.RoomserverInternalAPI, ) util.JSONResponse { - var res currentstateAPI.QueryRoomsForUserResponse - err := stateAPI.QueryRoomsForUser(req.Context(), ¤tstateAPI.QueryRoomsForUserRequest{ + var res api.QueryRoomsForUserResponse + err := rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{ UserID: device.UserID, WantMembership: "join", }, &res) diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go new file mode 100644 index 000000000..8b81b9f02 --- /dev/null +++ b/clientapi/routing/password.go @@ -0,0 +1,127 @@ +package routing + +import ( + "net/http" + + "github.com/matrix-org/dendrite/clientapi/auth" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/userapi/api" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +type newPasswordRequest struct { + NewPassword string `json:"new_password"` + LogoutDevices bool `json:"logout_devices"` + Auth newPasswordAuth `json:"auth"` +} + +type newPasswordAuth struct { + Type string `json:"type"` + Session string `json:"session"` + auth.PasswordRequest +} + +func Password( + req *http.Request, + userAPI userapi.UserInternalAPI, + accountDB accounts.Database, + device *api.Device, + cfg *config.ClientAPI, +) util.JSONResponse { + // Check that the existing password is right. + var r newPasswordRequest + r.LogoutDevices = true + + // Unmarshal the request. + resErr := httputil.UnmarshalJSONRequest(req, &r) + if resErr != nil { + return *resErr + } + + // Retrieve or generate the sessionID + sessionID := r.Auth.Session + if sessionID == "" { + // Generate a new, random session ID + sessionID = util.RandomString(sessionIDLength) + } + + // Require password auth to change the password. + if r.Auth.Type != authtypes.LoginTypePassword { + return util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: newUserInteractiveResponse( + sessionID, + []authtypes.Flow{ + { + Stages: []authtypes.LoginType{authtypes.LoginTypePassword}, + }, + }, + nil, + ), + } + } + + // Check if the existing password is correct. + typePassword := auth.LoginTypePassword{ + GetAccountByPassword: accountDB.GetAccountByPassword, + Config: cfg, + } + if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest); authErr != nil { + return *authErr + } + AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword) + + // Check the new password strength. + if resErr = validatePassword(r.NewPassword); resErr != nil { + return *resErr + } + + // Get the local part. + localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") + return jsonerror.InternalServerError() + } + + // Ask the user API to perform the password change. + passwordReq := &userapi.PerformPasswordUpdateRequest{ + Localpart: localpart, + Password: r.NewPassword, + } + passwordRes := &userapi.PerformPasswordUpdateResponse{} + if err := userAPI.PerformPasswordUpdate(req.Context(), passwordReq, passwordRes); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("PerformPasswordUpdate failed") + return jsonerror.InternalServerError() + } + if !passwordRes.PasswordUpdated { + util.GetLogger(req.Context()).Error("Expected password to have been updated but wasn't") + return jsonerror.InternalServerError() + } + + // If the request asks us to log out all other devices then + // ask the user API to do that. + if r.LogoutDevices { + logoutReq := &userapi.PerformDeviceDeletionRequest{ + UserID: device.UserID, + DeviceIDs: nil, + ExceptDeviceID: device.ID, + } + logoutRes := &userapi.PerformDeviceDeletionResponse{} + if err := userAPI.PerformDeviceDeletion(req.Context(), logoutReq, logoutRes); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed") + return jsonerror.InternalServerError() + } + } + + // Return a success code. + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } +} diff --git a/clientapi/routing/peekroom.go b/clientapi/routing/peekroom.go index d86ccd6b5..d96f91d0f 100644 --- a/clientapi/routing/peekroom.go +++ b/clientapi/routing/peekroom.go @@ -40,7 +40,7 @@ func PeekRoomByIDOrAlias( peekReq := roomserverAPI.PerformPeekRequest{ RoomIDOrAlias: roomIDOrAlias, UserID: device.UserID, - DeviceID: device.ID, + DeviceID: device.ID, } peekRes := roomserverAPI.PerformPeekResponse{} @@ -64,7 +64,6 @@ func PeekRoomByIDOrAlias( // if this user is already joined to the room, we let them peek anyway // (given they might be about to part the room, and it makes things less fiddly) - // Peeking stops if none of the devices who started peeking have been // /syncing for a while, or if everyone who was peeking calls /leave // (or /unpeek with a server_name param? or DELETE /peek?) diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index bc51b0b51..60669a0c8 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -23,7 +23,6 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" - currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api" "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" @@ -94,7 +93,7 @@ func GetAvatarURL( // SetAvatarURL implements PUT /profile/{userID}/avatar_url // nolint:gocyclo func SetAvatarURL( - req *http.Request, accountDB accounts.Database, stateAPI currentstateAPI.CurrentStateInternalAPI, + req *http.Request, accountDB accounts.Database, device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI, ) util.JSONResponse { if userID != device.UserID { @@ -140,8 +139,8 @@ func SetAvatarURL( return jsonerror.InternalServerError() } - var res currentstateAPI.QueryRoomsForUserResponse - err = stateAPI.QueryRoomsForUser(req.Context(), ¤tstateAPI.QueryRoomsForUserRequest{ + var res api.QueryRoomsForUserResponse + err = rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{ UserID: device.UserID, WantMembership: "join", }, &res) @@ -212,7 +211,7 @@ func GetDisplayName( // SetDisplayName implements PUT /profile/{userID}/displayname // nolint:gocyclo func SetDisplayName( - req *http.Request, accountDB accounts.Database, stateAPI currentstateAPI.CurrentStateInternalAPI, + req *http.Request, accountDB accounts.Database, device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI, ) util.JSONResponse { if userID != device.UserID { @@ -258,8 +257,8 @@ func SetDisplayName( return jsonerror.InternalServerError() } - var res currentstateAPI.QueryRoomsForUserResponse - err = stateAPI.QueryRoomsForUser(req.Context(), ¤tstateAPI.QueryRoomsForUserRequest{ + var res api.QueryRoomsForUserResponse + err = rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{ UserID: device.UserID, WantMembership: "join", }, &res) diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index 178bfafc9..9701685e0 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -21,7 +21,6 @@ import ( "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" - currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api" "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" @@ -41,9 +40,9 @@ type redactionResponse struct { func SendRedaction( req *http.Request, device *userapi.Device, roomID, eventID string, cfg *config.ClientAPI, - rsAPI roomserverAPI.RoomserverInternalAPI, stateAPI currentstateAPI.CurrentStateInternalAPI, + rsAPI roomserverAPI.RoomserverInternalAPI, ) util.JSONResponse { - resErr := checkMemberInRoom(req.Context(), stateAPI, device.UserID, roomID) + resErr := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) if resErr != nil { return *resErr } @@ -67,7 +66,7 @@ func SendRedaction( // https://matrix.org/docs/spec/client_server/r0.6.1#put-matrix-client-r0-rooms-roomid-redact-eventid-txnid allowedToRedact := ev.Sender() == device.UserID if !allowedToRedact { - plEvent := currentstateAPI.GetEvent(req.Context(), stateAPI, roomID, gomatrixserverlib.StateKeyTuple{ + plEvent := roomserverAPI.GetStateEvent(req.Context(), rsAPI, roomID, gomatrixserverlib.StateKeyTuple{ EventType: gomatrixserverlib.MRoomPowerLevels, StateKey: "", }) diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index cd717e2b1..999946a65 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -25,7 +25,6 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" - currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api" eduServerAPI "github.com/matrix-org/dendrite/eduserver/api" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/internal/config" @@ -56,7 +55,6 @@ func Setup( syncProducer *producers.SyncAPIProducer, transactionsCache *transactions.Cache, federationSender federationSenderAPI.FederationSenderInternalAPI, - stateAPI currentstateAPI.CurrentStateInternalAPI, keyAPI keyserverAPI.KeyInternalAPI, extRoomsProvider api.ExtraPublicRoomsProvider, ) { @@ -118,7 +116,7 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/joined_rooms", httputil.MakeAuthAPI("joined_rooms", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return GetJoinedRooms(req, device, stateAPI) + return GetJoinedRooms(req, device, rsAPI) }), ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/join", @@ -176,7 +174,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return SendKick(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI, stateAPI) + return SendKick(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI) }), ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/unban", @@ -342,12 +340,12 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return SetVisibility(req, stateAPI, rsAPI, device, vars["roomID"]) + return SetVisibility(req, rsAPI, device, vars["roomID"]) }), ).Methods(http.MethodPut, http.MethodOptions) r0mux.Handle("/publicRooms", httputil.MakeExternalAPI("public_rooms", func(req *http.Request) util.JSONResponse { - return GetPostPublicRooms(req, rsAPI, stateAPI, extRoomsProvider, federation, cfg) + return GetPostPublicRooms(req, rsAPI, extRoomsProvider, federation, cfg) }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) @@ -372,7 +370,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return SendTyping(req, device, vars["roomID"], vars["userID"], accountDB, eduAPI, stateAPI) + return SendTyping(req, device, vars["roomID"], vars["userID"], accountDB, eduAPI, rsAPI) }), ).Methods(http.MethodPut, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/redact/{eventID}", @@ -381,7 +379,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI, stateAPI) + return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI) }), ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnId}", @@ -390,7 +388,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI, stateAPI) + return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI) }), ).Methods(http.MethodPut, http.MethodOptions) @@ -428,6 +426,15 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) + r0mux.Handle("/account/password", + httputil.MakeAuthAPI("password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } + return Password(req, userAPI, accountDB, device, cfg) + }), + ).Methods(http.MethodPost, http.MethodOptions) + // Stub endpoints required by Riot r0mux.Handle("/login", @@ -496,7 +503,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return SetAvatarURL(req, accountDB, stateAPI, device, vars["userID"], cfg, rsAPI) + return SetAvatarURL(req, accountDB, device, vars["userID"], cfg, rsAPI) }), ).Methods(http.MethodPut, http.MethodOptions) // Browsers use the OPTIONS HTTP method to check if the CORS policy allows @@ -521,7 +528,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return SetDisplayName(req, accountDB, stateAPI, device, vars["userID"], cfg, rsAPI) + return SetDisplayName(req, accountDB, device, vars["userID"], cfg, rsAPI) }), ).Methods(http.MethodPut, http.MethodOptions) // Browsers use the OPTIONS HTTP method to check if the CORS policy allows @@ -650,7 +657,7 @@ func Setup( req.Context(), device, userAPI, - stateAPI, + rsAPI, cfg.Matrix.ServerName, postContent.SearchString, postContent.Limit, diff --git a/clientapi/routing/sendtyping.go b/clientapi/routing/sendtyping.go index e4b5b7a3a..3abf3db27 100644 --- a/clientapi/routing/sendtyping.go +++ b/clientapi/routing/sendtyping.go @@ -17,8 +17,8 @@ import ( "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" - currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api" "github.com/matrix-org/dendrite/eduserver/api" + roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/util" @@ -35,7 +35,7 @@ func SendTyping( req *http.Request, device *userapi.Device, roomID string, userID string, accountDB accounts.Database, eduAPI api.EDUServerInputAPI, - stateAPI currentstateAPI.CurrentStateInternalAPI, + rsAPI roomserverAPI.RoomserverInternalAPI, ) util.JSONResponse { if device.UserID != userID { return util.JSONResponse{ @@ -45,7 +45,7 @@ func SendTyping( } // Verify that the user is a member of this room - resErr := checkMemberInRoom(req.Context(), stateAPI, userID, roomID) + resErr := checkMemberInRoom(req.Context(), rsAPI, userID, roomID) if resErr != nil { return *resErr } diff --git a/clientapi/routing/userdirectory.go b/clientapi/routing/userdirectory.go index db81ffeae..2659bc9cc 100644 --- a/clientapi/routing/userdirectory.go +++ b/clientapi/routing/userdirectory.go @@ -19,7 +19,7 @@ import ( "fmt" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api" + "github.com/matrix-org/dendrite/roomserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -34,7 +34,7 @@ func SearchUserDirectory( ctx context.Context, device *userapi.Device, userAPI userapi.UserInternalAPI, - stateAPI currentstateAPI.CurrentStateInternalAPI, + rsAPI api.RoomserverInternalAPI, serverName gomatrixserverlib.ServerName, searchString string, limit int, @@ -81,14 +81,14 @@ func SearchUserDirectory( // start searching for known users from joined rooms. if len(results) <= limit { - stateReq := ¤tstateAPI.QueryKnownUsersRequest{ + stateReq := &api.QueryKnownUsersRequest{ UserID: device.UserID, SearchString: searchString, Limit: limit - len(results), } - stateRes := ¤tstateAPI.QueryKnownUsersResponse{} - if err := stateAPI.QueryKnownUsers(ctx, stateReq, stateRes); err != nil { - errRes := util.ErrorResponse(fmt.Errorf("stateAPI.QueryKnownUsers: %w", err)) + stateRes := &api.QueryKnownUsersResponse{} + if err := rsAPI.QueryKnownUsers(ctx, stateReq, stateRes); err != nil { + errRes := util.ErrorResponse(fmt.Errorf("rsAPI.QueryKnownUsers: %w", err)) return &errRes } diff --git a/cmd/dendrite-client-api-server/main.go b/cmd/dendrite-client-api-server/main.go index 35dbb7745..0fdc6679f 100644 --- a/cmd/dendrite-client-api-server/main.go +++ b/cmd/dendrite-client-api-server/main.go @@ -34,12 +34,11 @@ func main() { fsAPI := base.FederationSenderHTTPClient() eduInputAPI := base.EDUServerClient() userAPI := base.UserAPIClient() - stateAPI := base.CurrentStateAPIClient() keyAPI := base.KeyServerHTTPClient() clientapi.AddPublicRoutes( base.PublicClientAPIMux, &base.Cfg.ClientAPI, base.KafkaProducer, accountDB, federation, - rsAPI, eduInputAPI, asQuery, stateAPI, transactions.New(), fsAPI, userAPI, keyAPI, nil, + rsAPI, eduInputAPI, asQuery, transactions.New(), fsAPI, userAPI, keyAPI, nil, ) base.SetupAndServeHTTP( diff --git a/cmd/dendrite-current-state-server/main.go b/cmd/dendrite-current-state-server/main.go deleted file mode 100644 index 594bfcf9d..000000000 --- a/cmd/dendrite-current-state-server/main.go +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "github.com/matrix-org/dendrite/currentstateserver" - "github.com/matrix-org/dendrite/internal/setup" -) - -func main() { - cfg := setup.ParseFlags(false) - base := setup.NewBaseDendrite(cfg, "CurrentStateServer", true) - defer base.Close() // nolint: errcheck - - stateAPI := currentstateserver.NewInternalAPI(&cfg.CurrentStateServer, base.KafkaConsumer) - - currentstateserver.AddInternalRoutes(base.InternalAPIMux, stateAPI) - - base.SetupAndServeHTTP( - base.Cfg.CurrentStateServer.InternalAPI.Listen, - setup.NoExternalListener, - nil, nil, - ) -} diff --git a/cmd/dendrite-demo-libp2p/main.go b/cmd/dendrite-demo-libp2p/main.go index d4f0cee04..1f6748865 100644 --- a/cmd/dendrite-demo-libp2p/main.go +++ b/cmd/dendrite-demo-libp2p/main.go @@ -29,7 +29,6 @@ import ( p2pdisc "github.com/libp2p/go-libp2p/p2p/discovery" "github.com/matrix-org/dendrite/appservice" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/embed" - "github.com/matrix-org/dendrite/currentstateserver" "github.com/matrix-org/dendrite/eduserver" "github.com/matrix-org/dendrite/federationsender" "github.com/matrix-org/dendrite/internal/config" @@ -129,7 +128,6 @@ func main() { cfg.ServerKeyAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-serverkey.db", *instanceName)) cfg.FederationSender.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationsender.db", *instanceName)) cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-appservice.db", *instanceName)) - cfg.CurrentStateServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-currentstate.db", *instanceName)) cfg.Global.Kafka.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-naffka.db", *instanceName)) cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-e2ekey.db", *instanceName)) if err = cfg.Derive(); err != nil { @@ -153,7 +151,6 @@ func main() { base, serverKeyAPI, ) - stateAPI := currentstateserver.NewInternalAPI(&base.Base.Cfg.CurrentStateServer, base.Base.KafkaConsumer) rsAPI := roomserver.NewInternalAPI( &base.Base, keyRing, ) @@ -162,10 +159,10 @@ func main() { ) asAPI := appservice.NewInternalAPI(&base.Base, userAPI, rsAPI) fsAPI := federationsender.NewInternalAPI( - &base.Base, federation, rsAPI, stateAPI, keyRing, + &base.Base, federation, rsAPI, keyRing, ) rsAPI.SetFederationSenderAPI(fsAPI) - provider := newPublicRoomsProvider(base.LibP2PPubsub, rsAPI, stateAPI) + provider := newPublicRoomsProvider(base.LibP2PPubsub, rsAPI) err = provider.Start() if err != nil { panic("failed to create new public rooms provider: " + err.Error()) @@ -185,7 +182,6 @@ func main() { FederationSenderAPI: fsAPI, RoomserverAPI: rsAPI, ServerKeyAPI: serverKeyAPI, - StateAPI: stateAPI, UserAPI: userAPI, KeyAPI: keyAPI, ExtPublicRoomsProvider: provider, diff --git a/cmd/dendrite-demo-libp2p/publicrooms.go b/cmd/dendrite-demo-libp2p/publicrooms.go index 2160ddefd..96e8ab5e1 100644 --- a/cmd/dendrite-demo-libp2p/publicrooms.go +++ b/cmd/dendrite-demo-libp2p/publicrooms.go @@ -22,7 +22,6 @@ import ( "sync/atomic" "time" - currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" pubsub "github.com/libp2p/go-libp2p-pubsub" @@ -46,15 +45,13 @@ type publicRoomsProvider struct { maintenanceTimer *time.Timer // roomsAdvertised atomic.Value // stores int rsAPI roomserverAPI.RoomserverInternalAPI - stateAPI currentstateAPI.CurrentStateInternalAPI } -func newPublicRoomsProvider(ps *pubsub.PubSub, rsAPI roomserverAPI.RoomserverInternalAPI, stateAPI currentstateAPI.CurrentStateInternalAPI) *publicRoomsProvider { +func newPublicRoomsProvider(ps *pubsub.PubSub, rsAPI roomserverAPI.RoomserverInternalAPI) *publicRoomsProvider { return &publicRoomsProvider{ foundRooms: make(map[string]discoveredRoom), pubsub: ps, rsAPI: rsAPI, - stateAPI: stateAPI, } } @@ -106,7 +103,7 @@ func (p *publicRoomsProvider) AdvertiseRooms() error { util.GetLogger(ctx).WithError(err).Error("QueryPublishedRooms failed") return err } - ourRooms, err := currentstateAPI.PopulatePublicRooms(ctx, queryRes.RoomIDs, p.stateAPI) + ourRooms, err := roomserverAPI.PopulatePublicRooms(ctx, queryRes.RoomIDs, p.rsAPI) if err != nil { util.GetLogger(ctx).WithError(err).Error("PopulatePublicRooms failed") return err diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index fcf3d4c56..7a370bda5 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -29,7 +29,6 @@ import ( "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/yggconn" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/yggrooms" - "github.com/matrix-org/dendrite/currentstateserver" "github.com/matrix-org/dendrite/eduserver" "github.com/matrix-org/dendrite/eduserver/cache" "github.com/matrix-org/dendrite/federationsender" @@ -84,7 +83,6 @@ func main() { cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", *instanceName)) cfg.FederationSender.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationsender.db", *instanceName)) cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-appservice.db", *instanceName)) - cfg.CurrentStateServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-currentstate.db", *instanceName)) cfg.Global.Kafka.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-naffka.db", *instanceName)) if err = cfg.Derive(); err != nil { panic(err) @@ -113,9 +111,8 @@ func main() { ) asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) - stateAPI := currentstateserver.NewInternalAPI(&base.Cfg.CurrentStateServer, base.KafkaConsumer) fsAPI := federationsender.NewInternalAPI( - base, federation, rsAPI, stateAPI, keyRing, + base, federation, rsAPI, keyRing, ) ygg.SetSessionFunc(func(address string) { @@ -146,7 +143,6 @@ func main() { FederationSenderAPI: fsAPI, RoomserverAPI: rsAPI, UserAPI: userAPI, - StateAPI: stateAPI, KeyAPI: keyAPI, ExtPublicRoomsProvider: yggrooms.NewYggdrasilRoomProvider( ygg, fsAPI, federation, diff --git a/cmd/dendrite-federation-api-server/main.go b/cmd/dendrite-federation-api-server/main.go index 4181ee0c2..cab304e6b 100644 --- a/cmd/dendrite-federation-api-server/main.go +++ b/cmd/dendrite-federation-api-server/main.go @@ -35,7 +35,7 @@ func main() { federationapi.AddPublicRoutes( base.PublicFederationAPIMux, base.PublicKeyAPIMux, &base.Cfg.FederationAPI, userAPI, federation, keyRing, - rsAPI, fsAPI, base.EDUServerClient(), base.CurrentStateAPIClient(), keyAPI, + rsAPI, fsAPI, base.EDUServerClient(), keyAPI, ) base.SetupAndServeHTTP( diff --git a/cmd/dendrite-federation-sender-server/main.go b/cmd/dendrite-federation-sender-server/main.go index 369060196..4d918f6b1 100644 --- a/cmd/dendrite-federation-sender-server/main.go +++ b/cmd/dendrite-federation-sender-server/main.go @@ -31,7 +31,7 @@ func main() { rsAPI := base.RoomserverHTTPClient() fsAPI := federationsender.NewInternalAPI( - base, federation, rsAPI, base.CurrentStateAPIClient(), keyRing, + base, federation, rsAPI, keyRing, ) federationsender.AddInternalRoutes(base.InternalAPIMux, fsAPI) diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go index 717b21a9f..759f1c9ff 100644 --- a/cmd/dendrite-monolith-server/main.go +++ b/cmd/dendrite-monolith-server/main.go @@ -19,7 +19,6 @@ import ( "os" "github.com/matrix-org/dendrite/appservice" - "github.com/matrix-org/dendrite/currentstateserver" "github.com/matrix-org/dendrite/eduserver" "github.com/matrix-org/dendrite/eduserver/cache" "github.com/matrix-org/dendrite/federationsender" @@ -54,7 +53,6 @@ func main() { // itself. cfg.AppServiceAPI.InternalAPI.Connect = httpAddr cfg.ClientAPI.InternalAPI.Connect = httpAddr - cfg.CurrentStateServer.InternalAPI.Connect = httpAddr cfg.EDUServer.InternalAPI.Connect = httpAddr cfg.FederationAPI.InternalAPI.Connect = httpAddr cfg.FederationSender.InternalAPI.Connect = httpAddr @@ -95,10 +93,8 @@ func main() { } } - stateAPI := currentstateserver.NewInternalAPI(&base.Cfg.CurrentStateServer, base.KafkaConsumer) - fsAPI := federationsender.NewInternalAPI( - base, federation, rsAPI, stateAPI, keyRing, + base, federation, rsAPI, keyRing, ) if base.UseHTTPAPIs { federationsender.AddInternalRoutes(base.InternalAPIMux, fsAPI) @@ -140,7 +136,6 @@ func main() { FederationSenderAPI: fsAPI, RoomserverAPI: rsAPI, ServerKeyAPI: serverKeyAPI, - StateAPI: stateAPI, UserAPI: userAPI, KeyAPI: keyAPI, } diff --git a/cmd/dendrite-sync-api-server/main.go b/cmd/dendrite-sync-api-server/main.go index 8a73cd371..b879f842f 100644 --- a/cmd/dendrite-sync-api-server/main.go +++ b/cmd/dendrite-sync-api-server/main.go @@ -31,7 +31,7 @@ func main() { syncapi.AddPublicRoutes( base.PublicClientAPIMux, base.KafkaConsumer, userAPI, rsAPI, - base.KeyServerHTTPClient(), base.CurrentStateAPIClient(), + base.KeyServerHTTPClient(), federation, &cfg.SyncAPI, ) diff --git a/cmd/dendritejs/main.go b/cmd/dendritejs/main.go index aeca70946..12dc2d7cc 100644 --- a/cmd/dendritejs/main.go +++ b/cmd/dendritejs/main.go @@ -23,7 +23,6 @@ import ( "github.com/gorilla/mux" "github.com/matrix-org/dendrite/appservice" - "github.com/matrix-org/dendrite/currentstateserver" "github.com/matrix-org/dendrite/eduserver" "github.com/matrix-org/dendrite/eduserver/cache" "github.com/matrix-org/dendrite/federationsender" @@ -171,7 +170,6 @@ func main() { cfg.RoomServer.Database.ConnectionString = "file:/idb/dendritejs_roomserver.db" cfg.ServerKeyAPI.Database.ConnectionString = "file:/idb/dendritejs_serverkey.db" cfg.SyncAPI.Database.ConnectionString = "file:/idb/dendritejs_syncapi.db" - cfg.CurrentStateServer.Database.ConnectionString = "file:/idb/dendritejs_currentstate.db" cfg.KeyServer.Database.ConnectionString = "file:/idb/dendritejs_e2ekey.db" cfg.Global.Kafka.UseNaffka = true cfg.Global.Kafka.Database.ConnectionString = "file:/idb/dendritejs_naffka.db" @@ -204,13 +202,12 @@ func main() { KeyDatabase: fetcher, } - stateAPI := currentstateserver.NewInternalAPI(&base.Cfg.CurrentStateServer, base.KafkaConsumer) rsAPI := roomserver.NewInternalAPI(base, keyRing) eduInputAPI := eduserver.NewInternalAPI(base, cache.New(), userAPI) asQuery := appservice.NewInternalAPI( base, userAPI, rsAPI, ) - fedSenderAPI := federationsender.NewInternalAPI(base, federation, rsAPI, stateAPI, &keyRing) + fedSenderAPI := federationsender.NewInternalAPI(base, federation, rsAPI, &keyRing) rsAPI.SetFederationSenderAPI(fedSenderAPI) p2pPublicRoomProvider := NewLibP2PPublicRoomsProvider(node, fedSenderAPI, federation) @@ -227,7 +224,6 @@ func main() { EDUInternalAPI: eduInputAPI, FederationSenderAPI: fedSenderAPI, RoomserverAPI: rsAPI, - StateAPI: stateAPI, UserAPI: userAPI, KeyAPI: keyAPI, //ServerKeyAPI: serverKeyAPI, diff --git a/cmd/goose/README.md b/cmd/goose/README.md new file mode 100644 index 000000000..c7f085d80 --- /dev/null +++ b/cmd/goose/README.md @@ -0,0 +1,107 @@ +## Database migrations + +We use [goose](https://github.com/pressly/goose) to handle database migrations. This allows us to execute +both SQL deltas (e.g `ALTER TABLE ...`) as well as manipulate data in the database in Go using Go functions. + +To run a migration, the `goose` binary in this directory needs to be built: +``` +$ go build ./cmd/goose +``` + +This binary allows Dendrite databases to be upgraded and downgraded. Sample usage for upgrading the roomserver database: + +``` +# for sqlite +$ ./goose -dir roomserver/storage/sqlite3/deltas sqlite3 ./roomserver.db up + +# for postgres +$ ./goose -dir roomserver/storage/postgres/deltas postgres "user=dendrite dbname=dendrite sslmode=disable" up +``` + +For a full list of options, including rollbacks, see https://github.com/pressly/goose or use `goose` with no args. + + +### Rationale + +Dendrite creates tables on startup using `CREATE TABLE IF NOT EXISTS`, so you might think that we should also +apply version upgrades on startup as well. This is convenient and doesn't involve an additional binary to run +which complicates upgrades. However, combining the upgrade mechanism and the server binary makes it difficult +to handle rollbacks. Firstly, how do you specify you wish to rollback? We would have to add additional flags +to the main server binary to say "rollback to version X". Secondly, if you roll back the server binary from +version 5 to version 4, the version 4 binary doesn't know how to rollback the database from version 5 to +version 4! For these reasons, we prefer to have a separate "upgrade" binary which is run for database upgrades. +Rather than roll-our-own migration tool, we decided to use [goose](https://github.com/pressly/goose) as it supports +complex migrations in Go code in addition to just executing SQL deltas. Other alternatives like +`github.com/golang-migrate/migrate` [do not support](https://github.com/golang-migrate/migrate/issues/15) these +kinds of complex migrations. + +### Adding new deltas + +You can add `.sql` or `.go` files manually or you can use goose to create them for you. + +If you only want to add a SQL delta then run: + +``` +$ ./goose -dir serverkeyapi/storage/sqlite3/deltas sqlite3 ./foo.db create new_col sql +2020/09/09 14:37:43 Created new file: serverkeyapi/storage/sqlite3/deltas/20200909143743_new_col.sql +``` + +In this case, the version number is `20200909143743`. The important thing is that it is always increasing. + +Then add up/downgrade SQL commands to the created file which looks like: +```sql +-- +goose Up +-- +goose StatementBegin +SELECT 'up SQL query'; +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +SELECT 'down SQL query'; +-- +goose StatementEnd + +``` +You __must__ keep the `+goose` annotations. You'll need to repeat this process for Postgres. + +For complex Go migrations: + +``` +$ ./goose -dir serverkeyapi/storage/sqlite3/deltas sqlite3 ./foo.db create complex_update go +2020/09/09 14:40:38 Created new file: serverkeyapi/storage/sqlite3/deltas/20200909144038_complex_update.go +``` + +Then modify the created `.go` file which looks like: + +```go +package migrations + +import ( + "database/sql" + "fmt" + + "github.com/pressly/goose" +) + +func init() { + goose.AddMigration(upComplexUpdate, downComplexUpdate) +} + +func upComplexUpdate(tx *sql.Tx) error { + // This code is executed when the migration is applied. + return nil +} + +func downComplexUpdate(tx *sql.Tx) error { + // This code is executed when the migration is rolled back. + return nil +} + +``` + +You __must__ import the package in `/cmd/goose/main.go` so `func init()` gets called. + + +#### Database limitations + +- SQLite3 does NOT support `ALTER TABLE table_name DROP COLUMN` - you would have to rename the column or drop the table + entirely and recreate it. diff --git a/cmd/goose/main.go b/cmd/goose/main.go new file mode 100644 index 000000000..ef3942d90 --- /dev/null +++ b/cmd/goose/main.go @@ -0,0 +1,98 @@ +// This is custom goose binary + +package main + +import ( + "flag" + "fmt" + "log" + "os" + + // Example complex Go migration import: + // _ "github.com/matrix-org/dendrite/serverkeyapi/storage/postgres/deltas" + "github.com/pressly/goose" + + _ "github.com/lib/pq" + _ "github.com/mattn/go-sqlite3" +) + +var ( + flags = flag.NewFlagSet("goose", flag.ExitOnError) + dir = flags.String("dir", ".", "directory with migration files") +) + +func main() { + err := flags.Parse(os.Args[1:]) + if err != nil { + panic(err.Error()) + } + args := flags.Args() + + if len(args) < 3 { + fmt.Println( + `Usage: goose [OPTIONS] DRIVER DBSTRING COMMAND + +Drivers: + postgres + sqlite3 + +Examples: + goose -d roomserver/storage/sqlite3/deltas sqlite3 ./roomserver.db status + goose -d roomserver/storage/sqlite3/deltas sqlite3 ./roomserver.db up + + goose -d roomserver/storage/postgres/deltas postgres "user=dendrite dbname=dendrite sslmode=disable" status + +Options: + + -dir string + directory with migration files (default ".") + -table string + migrations table name (default "goose_db_version") + -h print help + -v enable verbose mode + -version + print version + +Commands: + up Migrate the DB to the most recent version available + up-by-one Migrate the DB up by 1 + up-to VERSION Migrate the DB to a specific VERSION + down Roll back the version by 1 + down-to VERSION Roll back to a specific VERSION + redo Re-run the latest migration + reset Roll back all migrations + status Dump the migration status for the current DB + version Print the current version of the database + create NAME [sql|go] Creates new migration file with the current timestamp + fix Apply sequential ordering to migrations`, + ) + return + } + + engine := args[0] + if engine != "sqlite3" && engine != "postgres" { + fmt.Println("engine must be one of 'sqlite3' or 'postgres'") + return + } + dbstring, command := args[1], args[2] + + db, err := goose.OpenDBWithDriver(engine, dbstring) + if err != nil { + log.Fatalf("goose: failed to open DB: %v\n", err) + } + + defer func() { + if err := db.Close(); err != nil { + log.Fatalf("goose: failed to close DB: %v\n", err) + } + }() + + arguments := []string{} + if len(args) > 3 { + arguments = append(arguments, args[3:]...) + } + + if err := goose.Run(command, db, *dir, arguments...); err != nil { + log.Fatalf("goose %v: %v", command, err) + } +} diff --git a/currentstateserver/acls/acls.go b/currentstateserver/acls/acls.go deleted file mode 100644 index 775b6c73a..000000000 --- a/currentstateserver/acls/acls.go +++ /dev/null @@ -1,164 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package acls - -import ( - "context" - "encoding/json" - "fmt" - "net" - "regexp" - "strings" - "sync" - - "github.com/matrix-org/gomatrixserverlib" - "github.com/sirupsen/logrus" -) - -type ServerACLDatabase interface { - // GetKnownRooms returns a list of all rooms we know about. - GetKnownRooms(ctx context.Context) ([]string, error) - // GetStateEvent returns the state event of a given type for a given room with a given state key - // If no event could be found, returns nil - // If there was an issue during the retrieval, returns an error - GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) -} - -type ServerACLs struct { - acls map[string]*serverACL // room ID -> ACL - aclsMutex sync.RWMutex // protects the above -} - -func NewServerACLs(db ServerACLDatabase) *ServerACLs { - ctx := context.TODO() - acls := &ServerACLs{ - acls: make(map[string]*serverACL), - } - // Look up all of the rooms that the current state server knows about. - rooms, err := db.GetKnownRooms(ctx) - if err != nil { - logrus.WithError(err).Fatalf("Failed to get known rooms") - } - // For each room, let's see if we have a server ACL state event. If we - // do then we'll process it into memory so that we have the regexes to - // hand. - for _, room := range rooms { - state, err := db.GetStateEvent(ctx, room, "m.room.server_acl", "") - if err != nil { - logrus.WithError(err).Errorf("Failed to get server ACLs for room %q", room) - continue - } - if state != nil { - acls.OnServerACLUpdate(&state.Event) - } - } - return acls -} - -type ServerACL struct { - Allowed []string `json:"allow"` - Denied []string `json:"deny"` - AllowIPLiterals bool `json:"allow_ip_literals"` -} - -type serverACL struct { - ServerACL - allowedRegexes []*regexp.Regexp - deniedRegexes []*regexp.Regexp -} - -func compileACLRegex(orig string) (*regexp.Regexp, error) { - escaped := regexp.QuoteMeta(orig) - escaped = strings.Replace(escaped, "\\?", ".", -1) - escaped = strings.Replace(escaped, "\\*", ".*", -1) - return regexp.Compile(escaped) -} - -func (s *ServerACLs) OnServerACLUpdate(state *gomatrixserverlib.Event) { - acls := &serverACL{} - if err := json.Unmarshal(state.Content(), &acls.ServerACL); err != nil { - logrus.WithError(err).Errorf("Failed to unmarshal state content for server ACLs") - return - } - // The spec calls only for * (zero or more chars) and ? (exactly one char) - // to be supported as wildcard components, so we will escape all of the regex - // special characters and then replace * and ? with their regex counterparts. - // https://matrix.org/docs/spec/client_server/r0.6.1#m-room-server-acl - for _, orig := range acls.Allowed { - if expr, err := compileACLRegex(orig); err != nil { - logrus.WithError(err).Errorf("Failed to compile allowed regex") - } else { - acls.allowedRegexes = append(acls.allowedRegexes, expr) - } - } - for _, orig := range acls.Denied { - if expr, err := compileACLRegex(orig); err != nil { - logrus.WithError(err).Errorf("Failed to compile denied regex") - } else { - acls.deniedRegexes = append(acls.deniedRegexes, expr) - } - } - logrus.WithFields(logrus.Fields{ - "allow_ip_literals": acls.AllowIPLiterals, - "num_allowed": len(acls.allowedRegexes), - "num_denied": len(acls.deniedRegexes), - }).Debugf("Updating server ACLs for %q", state.RoomID()) - s.aclsMutex.Lock() - defer s.aclsMutex.Unlock() - s.acls[state.RoomID()] = acls -} - -func (s *ServerACLs) IsServerBannedFromRoom(serverName gomatrixserverlib.ServerName, roomID string) bool { - s.aclsMutex.RLock() - // First of all check if we have an ACL for this room. If we don't then - // no servers are banned from the room. - acls, ok := s.acls[roomID] - if !ok { - s.aclsMutex.RUnlock() - return false - } - s.aclsMutex.RUnlock() - // Split the host and port apart. This is because the spec calls on us to - // validate the hostname only in cases where the port is also present. - if serverNameOnly, _, err := net.SplitHostPort(string(serverName)); err == nil { - serverName = gomatrixserverlib.ServerName(serverNameOnly) - } - // Check if the hostname is an IPv4 or IPv6 literal. We cheat here by adding - // a /0 prefix length just to trick ParseCIDR into working. If we find that - // the server is an IP literal and we don't allow those then stop straight - // away. - if _, _, err := net.ParseCIDR(fmt.Sprintf("%s/0", serverName)); err == nil { - if !acls.AllowIPLiterals { - return true - } - } - // Check if the hostname matches one of the denied regexes. If it does then - // the server is banned from the room. - for _, expr := range acls.deniedRegexes { - if expr.MatchString(string(serverName)) { - return true - } - } - // Check if the hostname matches one of the allowed regexes. If it does then - // the server is NOT banned from the room. - for _, expr := range acls.allowedRegexes { - if expr.MatchString(string(serverName)) { - return false - } - } - // If we've got to this point then we haven't matched any regexes or an IP - // hostname if disallowed. The spec calls for default-deny here. - return true -} diff --git a/currentstateserver/acls/acls_test.go b/currentstateserver/acls/acls_test.go deleted file mode 100644 index 9fb6a5581..000000000 --- a/currentstateserver/acls/acls_test.go +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package acls - -import ( - "regexp" - "testing" -) - -func TestOpenACLsWithBlacklist(t *testing.T) { - roomID := "!test:test.com" - allowRegex, err := compileACLRegex("*") - if err != nil { - t.Fatalf(err.Error()) - } - denyRegex, err := compileACLRegex("foo.com") - if err != nil { - t.Fatalf(err.Error()) - } - - acls := ServerACLs{ - acls: make(map[string]*serverACL), - } - - acls.acls[roomID] = &serverACL{ - ServerACL: ServerACL{ - AllowIPLiterals: true, - }, - allowedRegexes: []*regexp.Regexp{allowRegex}, - deniedRegexes: []*regexp.Regexp{denyRegex}, - } - - if acls.IsServerBannedFromRoom("1.2.3.4", roomID) { - t.Fatal("Expected 1.2.3.4 to be allowed but wasn't") - } - if acls.IsServerBannedFromRoom("1.2.3.4:2345", roomID) { - t.Fatal("Expected 1.2.3.4:2345 to be allowed but wasn't") - } - if !acls.IsServerBannedFromRoom("foo.com", roomID) { - t.Fatal("Expected foo.com to be banned but wasn't") - } - if !acls.IsServerBannedFromRoom("foo.com:3456", roomID) { - t.Fatal("Expected foo.com:3456 to be banned but wasn't") - } - if acls.IsServerBannedFromRoom("bar.com", roomID) { - t.Fatal("Expected bar.com to be allowed but wasn't") - } - if acls.IsServerBannedFromRoom("bar.com:4567", roomID) { - t.Fatal("Expected bar.com:4567 to be allowed but wasn't") - } -} - -func TestDefaultACLsWithWhitelist(t *testing.T) { - roomID := "!test:test.com" - allowRegex, err := compileACLRegex("foo.com") - if err != nil { - t.Fatalf(err.Error()) - } - - acls := ServerACLs{ - acls: make(map[string]*serverACL), - } - - acls.acls[roomID] = &serverACL{ - ServerACL: ServerACL{ - AllowIPLiterals: false, - }, - allowedRegexes: []*regexp.Regexp{allowRegex}, - deniedRegexes: []*regexp.Regexp{}, - } - - if !acls.IsServerBannedFromRoom("1.2.3.4", roomID) { - t.Fatal("Expected 1.2.3.4 to be banned but wasn't") - } - if !acls.IsServerBannedFromRoom("1.2.3.4:2345", roomID) { - t.Fatal("Expected 1.2.3.4:2345 to be banned but wasn't") - } - if acls.IsServerBannedFromRoom("foo.com", roomID) { - t.Fatal("Expected foo.com to be allowed but wasn't") - } - if acls.IsServerBannedFromRoom("foo.com:3456", roomID) { - t.Fatal("Expected foo.com:3456 to be allowed but wasn't") - } - if !acls.IsServerBannedFromRoom("bar.com", roomID) { - t.Fatal("Expected bar.com to be allowed but wasn't") - } - if !acls.IsServerBannedFromRoom("baz.com", roomID) { - t.Fatal("Expected baz.com to be allowed but wasn't") - } - if !acls.IsServerBannedFromRoom("qux.com:4567", roomID) { - t.Fatal("Expected qux.com:4567 to be allowed but wasn't") - } -} diff --git a/currentstateserver/api/api.go b/currentstateserver/api/api.go deleted file mode 100644 index 5ae57bb9a..000000000 --- a/currentstateserver/api/api.go +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package api - -import ( - "context" - "encoding/json" - "fmt" - "strings" - - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/gomatrixserverlib" -) - -type CurrentStateInternalAPI interface { - // QueryCurrentState retrieves the requested state events. If state events are not found, they will be missing from - // the response. - QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error - // QueryRoomsForUser retrieves a list of room IDs matching the given query. - QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error - // QueryBulkStateContent does a bulk query for state event content in the given rooms. - QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error - // QuerySharedUsers returns a list of users who share at least 1 room in common with the given user. - QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error - // QueryKnownUsers returns a list of users that we know about from our joined rooms. - QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error - // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. - QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error -} - -type QuerySharedUsersRequest struct { - UserID string - ExcludeRoomIDs []string - IncludeRoomIDs []string -} - -type QuerySharedUsersResponse struct { - UserIDsToCount map[string]int -} - -type QueryRoomsForUserRequest struct { - UserID string - // The desired membership of the user. If this is the empty string then no rooms are returned. - WantMembership string -} - -type QueryRoomsForUserResponse struct { - RoomIDs []string -} - -type QueryBulkStateContentRequest struct { - // Returns state events in these rooms - RoomIDs []string - // If true, treats the '*' StateKey as "all state events of this type" rather than a literal value of '*' - AllowWildcards bool - // The state events to return. Only a small subset of tuples are allowed in this request as only certain events - // have their content fields extracted. Specifically, the tuple Type must be one of: - // m.room.avatar - // m.room.create - // m.room.canonical_alias - // m.room.guest_access - // m.room.history_visibility - // m.room.join_rules - // m.room.member - // m.room.name - // m.room.topic - // Any other tuple type will result in the query failing. - StateTuples []gomatrixserverlib.StateKeyTuple -} -type QueryBulkStateContentResponse struct { - // map of room ID -> tuple -> content_value - Rooms map[string]map[gomatrixserverlib.StateKeyTuple]string -} - -type QueryCurrentStateRequest struct { - RoomID string - StateTuples []gomatrixserverlib.StateKeyTuple -} - -type QueryCurrentStateResponse struct { - StateEvents map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent -} - -type QueryKnownUsersRequest struct { - UserID string `json:"user_id"` - SearchString string `json:"search_string"` - Limit int `json:"limit"` -} - -type QueryKnownUsersResponse struct { - Users []authtypes.FullyQualifiedProfile `json:"profiles"` -} - -type QueryServerBannedFromRoomRequest struct { - ServerName gomatrixserverlib.ServerName `json:"server_name"` - RoomID string `json:"room_id"` -} - -type QueryServerBannedFromRoomResponse struct { - Banned bool `json:"banned"` -} - -// MarshalJSON stringifies the StateKeyTuple keys so they can be sent over the wire in HTTP API mode. -func (r *QueryCurrentStateResponse) MarshalJSON() ([]byte, error) { - se := make(map[string]*gomatrixserverlib.HeaderedEvent, len(r.StateEvents)) - for k, v := range r.StateEvents { - // use 0x1F (unit separator) as the delimiter between type/state key, - se[fmt.Sprintf("%s\x1F%s", k.EventType, k.StateKey)] = v - } - return json.Marshal(se) -} - -func (r *QueryCurrentStateResponse) UnmarshalJSON(data []byte) error { - res := make(map[string]*gomatrixserverlib.HeaderedEvent) - err := json.Unmarshal(data, &res) - if err != nil { - return err - } - r.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent, len(res)) - for k, v := range res { - fields := strings.Split(k, "\x1F") - r.StateEvents[gomatrixserverlib.StateKeyTuple{ - EventType: fields[0], - StateKey: fields[1], - }] = v - } - return nil -} diff --git a/currentstateserver/api/wrapper.go b/currentstateserver/api/wrapper.go deleted file mode 100644 index 317fea431..000000000 --- a/currentstateserver/api/wrapper.go +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package api - -import ( - "context" - - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" -) - -// GetEvent returns the current state event in the room or nil. -func GetEvent(ctx context.Context, stateAPI CurrentStateInternalAPI, roomID string, tuple gomatrixserverlib.StateKeyTuple) *gomatrixserverlib.HeaderedEvent { - var res QueryCurrentStateResponse - err := stateAPI.QueryCurrentState(ctx, &QueryCurrentStateRequest{ - RoomID: roomID, - StateTuples: []gomatrixserverlib.StateKeyTuple{tuple}, - }, &res) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("Failed to QueryCurrentState") - return nil - } - ev, ok := res.StateEvents[tuple] - if ok { - return ev - } - return nil -} - -// IsServerBannedFromRoom returns whether the server is banned from a room by server ACLs. -func IsServerBannedFromRoom(ctx context.Context, stateAPI CurrentStateInternalAPI, roomID string, serverName gomatrixserverlib.ServerName) bool { - req := &QueryServerBannedFromRoomRequest{ - ServerName: serverName, - RoomID: roomID, - } - res := &QueryServerBannedFromRoomResponse{} - if err := stateAPI.QueryServerBannedFromRoom(ctx, req, res); err != nil { - util.GetLogger(ctx).WithError(err).Error("Failed to QueryServerBannedFromRoom") - return true - } - return res.Banned -} - -// PopulatePublicRooms extracts PublicRoom information for all the provided room IDs. The IDs are not checked to see if they are visible in the -// published room directory. -// due to lots of switches -// nolint:gocyclo -func PopulatePublicRooms(ctx context.Context, roomIDs []string, stateAPI CurrentStateInternalAPI) ([]gomatrixserverlib.PublicRoom, error) { - avatarTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.avatar", StateKey: ""} - nameTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.name", StateKey: ""} - canonicalTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomCanonicalAlias, StateKey: ""} - topicTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.topic", StateKey: ""} - guestTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.guest_access", StateKey: ""} - visibilityTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomHistoryVisibility, StateKey: ""} - joinRuleTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomJoinRules, StateKey: ""} - - var stateRes QueryBulkStateContentResponse - err := stateAPI.QueryBulkStateContent(ctx, &QueryBulkStateContentRequest{ - RoomIDs: roomIDs, - AllowWildcards: true, - StateTuples: []gomatrixserverlib.StateKeyTuple{ - nameTuple, canonicalTuple, topicTuple, guestTuple, visibilityTuple, joinRuleTuple, avatarTuple, - {EventType: gomatrixserverlib.MRoomMember, StateKey: "*"}, - }, - }, &stateRes) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("QueryBulkStateContent failed") - return nil, err - } - chunk := make([]gomatrixserverlib.PublicRoom, len(roomIDs)) - i := 0 - for roomID, data := range stateRes.Rooms { - pub := gomatrixserverlib.PublicRoom{ - RoomID: roomID, - } - joinCount := 0 - var joinRule, guestAccess string - for tuple, contentVal := range data { - if tuple.EventType == gomatrixserverlib.MRoomMember && contentVal == "join" { - joinCount++ - continue - } - switch tuple { - case avatarTuple: - pub.AvatarURL = contentVal - case nameTuple: - pub.Name = contentVal - case topicTuple: - pub.Topic = contentVal - case canonicalTuple: - pub.CanonicalAlias = contentVal - case visibilityTuple: - pub.WorldReadable = contentVal == "world_readable" - // need both of these to determine whether guests can join - case joinRuleTuple: - joinRule = contentVal - case guestTuple: - guestAccess = contentVal - } - } - if joinRule == gomatrixserverlib.Public && guestAccess == "can_join" { - pub.GuestCanJoin = true - } - pub.JoinedMembersCount = joinCount - chunk[i] = pub - i++ - } - return chunk, nil -} diff --git a/currentstateserver/consumers/roomserver.go b/currentstateserver/consumers/roomserver.go deleted file mode 100644 index cb054481c..000000000 --- a/currentstateserver/consumers/roomserver.go +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package consumers - -import ( - "context" - "encoding/json" - - "github.com/Shopify/sarama" - "github.com/matrix-org/dendrite/currentstateserver/acls" - "github.com/matrix-org/dendrite/currentstateserver/storage" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" - log "github.com/sirupsen/logrus" -) - -type OutputRoomEventConsumer struct { - rsConsumer *internal.ContinualConsumer - db storage.Database - acls *acls.ServerACLs -} - -func NewOutputRoomEventConsumer(topicName string, kafkaConsumer sarama.Consumer, store storage.Database, acls *acls.ServerACLs) *OutputRoomEventConsumer { - consumer := &internal.ContinualConsumer{ - ComponentName: "currentstateserver/roomserver", - Topic: topicName, - Consumer: kafkaConsumer, - PartitionStore: store, - } - s := &OutputRoomEventConsumer{ - rsConsumer: consumer, - db: store, - acls: acls, - } - consumer.ProcessMessage = s.onMessage - - return s -} - -func (c *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error { - // Parse out the event JSON - var output api.OutputEvent - if err := json.Unmarshal(msg.Value, &output); err != nil { - // If the message was invalid, log it and move on to the next message in the stream - log.WithError(err).Errorf("roomserver output log: message parse failure") - return nil - } - - switch output.Type { - case api.OutputTypeNewRoomEvent: - return c.onNewRoomEvent(context.TODO(), *output.NewRoomEvent) - case api.OutputTypeNewInviteEvent: - case api.OutputTypeRetireInviteEvent: - case api.OutputTypeRedactedEvent: - return c.onRedactEvent(context.Background(), *output.RedactedEvent) - default: - log.WithField("type", output.Type).Debug( - "roomserver output log: ignoring unknown output type", - ) - } - return nil -} - -func (c *OutputRoomEventConsumer) onNewRoomEvent( - ctx context.Context, msg api.OutputNewRoomEvent, -) error { - ev := msg.Event - - if ev.Type() == "m.room.server_acl" && ev.StateKeyEquals("") { - defer c.acls.OnServerACLUpdate(&ev.Event) - } - - addsStateEvents := msg.AddsState() - - ev, err := c.updateStateEvent(ev) - if err != nil { - return err - } - - for i := range addsStateEvents { - addsStateEvents[i], err = c.updateStateEvent(addsStateEvents[i]) - if err != nil { - return err - } - } - - err = c.db.StoreStateEvents( - ctx, - addsStateEvents, - msg.RemovesStateEventIDs, - ) - if err != nil { - // panic rather than continue with an inconsistent database - log.WithFields(log.Fields{ - "event": string(ev.JSON()), - log.ErrorKey: err, - "add": msg.AddsStateEventIDs, - "del": msg.RemovesStateEventIDs, - }).Panicf("roomserver output log: write event failure") - } - return nil -} - -func (c *OutputRoomEventConsumer) onRedactEvent( - ctx context.Context, msg api.OutputRedactedEvent, -) error { - return c.db.RedactEvent(ctx, msg.RedactedEventID, msg.RedactedBecause) -} - -// Start consuming from room servers -func (c *OutputRoomEventConsumer) Start() error { - return c.rsConsumer.Start() -} - -func (c *OutputRoomEventConsumer) updateStateEvent(event gomatrixserverlib.HeaderedEvent) (gomatrixserverlib.HeaderedEvent, error) { - stateKey := "" - if event.StateKey() != nil { - stateKey = *event.StateKey() - } - - prevEvent, err := c.db.GetStateEvent( - context.TODO(), event.RoomID(), event.Type(), stateKey, - ) - if err != nil { - return event, err - } - - if prevEvent == nil { - return event, nil - } - - prev := types.PrevEventRef{ - PrevContent: prevEvent.Content(), - ReplacesState: prevEvent.EventID(), - PrevSender: prevEvent.Sender(), - } - - event.Event, err = event.SetUnsigned(prev) - return event, err -} diff --git a/currentstateserver/currentstateserver.go b/currentstateserver/currentstateserver.go deleted file mode 100644 index 196434eb8..000000000 --- a/currentstateserver/currentstateserver.go +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package currentstateserver - -import ( - "github.com/Shopify/sarama" - "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/currentstateserver/acls" - "github.com/matrix-org/dendrite/currentstateserver/api" - "github.com/matrix-org/dendrite/currentstateserver/consumers" - "github.com/matrix-org/dendrite/currentstateserver/internal" - "github.com/matrix-org/dendrite/currentstateserver/inthttp" - "github.com/matrix-org/dendrite/currentstateserver/storage" - "github.com/matrix-org/dendrite/internal/config" - "github.com/sirupsen/logrus" -) - -// AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions -// on the given input API. -func AddInternalRoutes(router *mux.Router, intAPI api.CurrentStateInternalAPI) { - inthttp.AddRoutes(router, intAPI) -} - -// NewInternalAPI returns a concrete implementation of the internal API. Callers -// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. -func NewInternalAPI(cfg *config.CurrentStateServer, consumer sarama.Consumer) api.CurrentStateInternalAPI { - csDB, err := storage.NewDatabase(&cfg.Database) - if err != nil { - logrus.WithError(err).Panicf("failed to open database") - } - serverACLs := acls.NewServerACLs(csDB) - roomConsumer := consumers.NewOutputRoomEventConsumer( - cfg.Matrix.Kafka.TopicFor(config.TopicOutputRoomEvent), consumer, csDB, serverACLs, - ) - if err = roomConsumer.Start(); err != nil { - logrus.WithError(err).Panicf("failed to start room server consumer") - } - return &internal.CurrentStateInternalAPI{ - DB: csDB, - ServerACLs: serverACLs, - } -} diff --git a/currentstateserver/currentstateserver_test.go b/currentstateserver/currentstateserver_test.go deleted file mode 100644 index b83103f13..000000000 --- a/currentstateserver/currentstateserver_test.go +++ /dev/null @@ -1,370 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package currentstateserver - -import ( - "bytes" - "context" - "crypto/ed25519" - "encoding/json" - "fmt" - "net/http" - "os" - "reflect" - "testing" - "time" - - "github.com/Shopify/sarama" - "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/currentstateserver/api" - "github.com/matrix-org/dendrite/currentstateserver/internal" - "github.com/matrix-org/dendrite/currentstateserver/inthttp" - "github.com/matrix-org/dendrite/currentstateserver/storage" - "github.com/matrix-org/dendrite/internal/config" - "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/internal/test" - roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/naffka" - naffkaStorage "github.com/matrix-org/naffka/storage" -) - -var ( - testRoomVersion = gomatrixserverlib.RoomVersionV1 - testData = []json.RawMessage{ - []byte(`{"auth_events":[],"content":{"creator":"@userid:kaer.morhen"},"depth":0,"event_id":"$0ok8ynDp7kjc95e3:kaer.morhen","hashes":{"sha256":"17kPoH+h0Dk4Omn7Sus0qMb6+oGcf+CZFEgDhv7UKWs"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"jP4a04f5/F10Pw95FPpdCyKAO44JOwUQ/MZOOeA/RTU1Dn+AHPMzGSaZnuGjRr/xQuADt+I3ctb5ZQfLKNzHDw"}},"state_key":"","type":"m.room.create"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}]],"content":{"membership":"join"},"depth":1,"event_id":"$LEwEu0kxrtu5fOiS:kaer.morhen","hashes":{"sha256":"B7M88PhXf3vd1LaFtjQutFu4x/w7fHD28XKZ4sAsJTo"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"p2vqmuJn7ZBRImctSaKbXCAxCcBlIjPH9JHte1ouIUGy84gpu4eLipOvSBCLL26hXfC0Zrm4WUto6Hr+ohdrCg"}},"state_key":"@userid:kaer.morhen","type":"m.room.member"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"join_rule":"public"},"depth":2,"event_id":"$SMHlqUrNhhBBRLeN:kaer.morhen","hashes":{"sha256":"vIuJQvmMjrGxshAkj1SXe0C4RqvMbv4ZADDw9pFCWqQ"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"hBMsb3Qppo3RaqqAl4JyTgaiWEbW5hlckATky6PrHun+F3YM203TzG7w9clwuQU5F5pZoB1a6nw+to0hN90FAw"}},"state_key":"","type":"m.room.join_rules"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"history_visibility":"shared"},"depth":3,"event_id":"$6F1yGIbO0J7TM93h:kaer.morhen","hashes":{"sha256":"Mr23GKSlZW7UCCYLgOWawI2Sg6KIoMjUWO2TDenuOgw"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$SMHlqUrNhhBBRLeN:kaer.morhen",{"sha256":"SylzE8U02I+6eyEHgL+FlU0L5YdqrVp8OOlxKS9VQW0"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sHLKrFI3hKGrEJfpMVZSDS3LvLasQsy50CTsOwru9XTVxgRsPo6wozNtRVjxo1J3Rk18RC9JppovmQ5VR5EcDw"}},"state_key":"","type":"m.room.history_visibility"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"ban":50,"events":null,"events_default":0,"invite":0,"kick":50,"redact":50,"state_default":50,"users":null,"users_default":0},"depth":4,"event_id":"$UKNe10XzYzG0TeA9:kaer.morhen","hashes":{"sha256":"ngbP3yja9U5dlckKerUs/fSOhtKxZMCVvsfhPURSS28"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$6F1yGIbO0J7TM93h:kaer.morhen",{"sha256":"A4CucrKSoWX4IaJXhq02mBg1sxIyZEftbC+5p3fZAvk"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"zOmwlP01QL3yFchzuR9WHvogOoBZA3oVtNIF3lM0ZfDnqlSYZB9sns27G/4HVq0k7alaK7ZE3oGoCrVnMkPNCw"}},"state_key":"","type":"m.room.power_levels"}`), - // messages - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":6,"event_id":"$MYSbs8m4rEbsCWXD:kaer.morhen","hashes":{"sha256":"kgbYM7v4Ud2YaBsjBTolM4ySg6rHcJNYI6nWhMSdFUA"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$gl2T9l3qm0kUbiIJ:kaer.morhen",{"sha256":"C/rD04h9wGxRdN2G/IBfrgoE1UovzLZ+uskwaKZ37/Q"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"x0UoKh968jj/F5l1/R7Ew0T6CTKuew3PLNHASNxqck/bkNe8yYQiDHXRr+kZxObeqPZZTpaF1+EI+bLU9W8GDQ"}},"type":"m.room.message"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":7,"event_id":"$N5x9WJkl9ClPrAEg:kaer.morhen","hashes":{"sha256":"FWM8oz4yquTunRZ67qlW2gzPDzdWfBP6RPHXhK1I/x8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$MYSbs8m4rEbsCWXD:kaer.morhen",{"sha256":"fatqgW+SE8mb2wFn3UN+drmluoD4UJ/EcSrL6Ur9q1M"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"Y+LX/xcyufoXMOIoqQBNOzy6lZfUGB1ffgXIrSugk6obMiyAsiRejHQN/pciZXsHKxMJLYRFAz4zSJoS/LGPAA"}},"type":"m.room.message"}`), - } - testEvents = []gomatrixserverlib.HeaderedEvent{} - testStateEvents = make(map[gomatrixserverlib.StateKeyTuple]gomatrixserverlib.HeaderedEvent) - - kafkaPrefix = "Dendrite" - kafkaTopic = fmt.Sprintf("%s%s", kafkaPrefix, "OutputRoomEvent") -) - -func init() { - for _, j := range testData { - e, err := gomatrixserverlib.NewEventFromTrustedJSON(j, false, testRoomVersion) - if err != nil { - panic("cannot load test data: " + err.Error()) - } - h := e.Headered(testRoomVersion) - testEvents = append(testEvents, h) - if e.StateKey() != nil { - testStateEvents[gomatrixserverlib.StateKeyTuple{ - EventType: e.Type(), - StateKey: *e.StateKey(), - }] = h - } - } -} - -func waitForOffsetProcessed(t *testing.T, db storage.Database, offset int64) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - for { - poffsets, err := db.PartitionOffsets(ctx, kafkaTopic) - if err != nil { - t.Fatalf("failed to PartitionOffsets: %s", err) - } - for _, partition := range poffsets { - if partition.Offset >= offset { - return - } - } - time.Sleep(50 * time.Millisecond) - } -} - -func MustWriteOutputEvent(t *testing.T, producer sarama.SyncProducer, out *roomserverAPI.OutputNewRoomEvent) int64 { - value, err := json.Marshal(roomserverAPI.OutputEvent{ - Type: roomserverAPI.OutputTypeNewRoomEvent, - NewRoomEvent: out, - }) - if err != nil { - t.Fatalf("failed to marshal output event: %s", err) - } - _, offset, err := producer.SendMessage(&sarama.ProducerMessage{ - Topic: kafkaTopic, - Key: sarama.StringEncoder(out.Event.RoomID()), - Value: sarama.ByteEncoder(value), - }) - if err != nil { - t.Fatalf("failed to send message: %s", err) - } - return offset -} - -func MustMakeInternalAPI(t *testing.T) (api.CurrentStateInternalAPI, storage.Database, sarama.SyncProducer, func()) { - cfg := &config.Dendrite{} - cfg.Defaults() - stateDBName := "test_state.db" - naffkaDBName := "test_naffka.db" - cfg.Global.ServerName = "kaer.morhen" - cfg.CurrentStateServer.Database.ConnectionString = config.DataSource("file:" + stateDBName) - cfg.Global.Kafka.TopicPrefix = kafkaPrefix - naffkaDB, err := naffkaStorage.NewDatabase("file:" + naffkaDBName) - if err != nil { - t.Fatalf("Failed to setup naffka database: %s", err) - } - naff, err := naffka.New(naffkaDB) - if err != nil { - t.Fatalf("Failed to create naffka consumer: %s", err) - } - stateAPI := NewInternalAPI(&cfg.CurrentStateServer, naff) - // type-cast to pull out the DB - stateAPIVal := stateAPI.(*internal.CurrentStateInternalAPI) - return stateAPI, stateAPIVal.DB, naff, func() { - os.Remove(naffkaDBName) - os.Remove(stateDBName) - } -} - -func TestQueryCurrentState(t *testing.T) { - currStateAPI, db, producer, cancel := MustMakeInternalAPI(t) - defer cancel() - plTuple := gomatrixserverlib.StateKeyTuple{ - EventType: "m.room.power_levels", - StateKey: "", - } - plEvent := testEvents[4] - offset := MustWriteOutputEvent(t, producer, &roomserverAPI.OutputNewRoomEvent{ - Event: plEvent, - AddsStateEventIDs: []string{plEvent.EventID()}, - }) - waitForOffsetProcessed(t, db, offset) - - testCases := []struct { - req api.QueryCurrentStateRequest - wantRes api.QueryCurrentStateResponse - wantErr error - }{ - { - req: api.QueryCurrentStateRequest{ - RoomID: plEvent.RoomID(), - StateTuples: []gomatrixserverlib.StateKeyTuple{ - plTuple, - }, - }, - wantRes: api.QueryCurrentStateResponse{ - StateEvents: map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent{ - plTuple: &plEvent, - }, - }, - }, - } - - runCases := func(testAPI api.CurrentStateInternalAPI) { - for _, tc := range testCases { - var gotRes api.QueryCurrentStateResponse - gotErr := testAPI.QueryCurrentState(context.TODO(), &tc.req, &gotRes) - if tc.wantErr == nil && gotErr != nil || tc.wantErr != nil && gotErr == nil { - t.Errorf("QueryCurrentState error, got %s want %s", gotErr, tc.wantErr) - continue - } - for tuple, wantEvent := range tc.wantRes.StateEvents { - gotEvent, ok := gotRes.StateEvents[tuple] - if !ok { - t.Errorf("QueryCurrentState want tuple %+v but it is missing from the response", tuple) - continue - } - gotCanon, err := gomatrixserverlib.CanonicalJSON(gotEvent.JSON()) - if err != nil { - t.Errorf("CanonicalJSON failed: %w", err) - continue - } - if !bytes.Equal(gotCanon, wantEvent.JSON()) { - t.Errorf("QueryCurrentState tuple %+v got event JSON %s want %s", tuple, string(gotCanon), string(wantEvent.JSON())) - } - } - } - } - t.Run("HTTP API", func(t *testing.T) { - router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() - AddInternalRoutes(router, currStateAPI) - apiURL, cancel := test.ListenAndServe(t, router, false) - defer cancel() - httpAPI, err := inthttp.NewCurrentStateAPIClient(apiURL, &http.Client{}) - if err != nil { - t.Fatalf("failed to create HTTP client") - } - runCases(httpAPI) - }) - t.Run("Monolith", func(t *testing.T) { - runCases(currStateAPI) - }) -} - -func mustMakeMembershipEvent(t *testing.T, roomID, userID, membership string) *roomserverAPI.OutputNewRoomEvent { - eb := gomatrixserverlib.EventBuilder{ - RoomID: roomID, - Sender: userID, - StateKey: &userID, - Type: "m.room.member", - Content: []byte(`{"membership":"` + membership + `"}`), - } - _, pkey, err := ed25519.GenerateKey(nil) - if err != nil { - t.Fatalf("failed to make ed25519 key: %s", err) - } - roomVer := gomatrixserverlib.RoomVersionV5 - ev, err := eb.Build( - time.Now(), gomatrixserverlib.ServerName("localhost"), gomatrixserverlib.KeyID("ed25519:test"), - pkey, roomVer, - ) - if err != nil { - t.Fatalf("mustMakeMembershipEvent failed: %s", err) - } - - return &roomserverAPI.OutputNewRoomEvent{ - Event: ev.Headered(roomVer), - AddsStateEventIDs: []string{ev.EventID()}, - } -} - -// This test makes sure that QuerySharedUsers is returning the correct users for a range of sets. -func TestQuerySharedUsers(t *testing.T) { - currStateAPI, db, producer, cancel := MustMakeInternalAPI(t) - defer cancel() - MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@alice:localhost", "join")) - MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@bob:localhost", "join")) - - MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo2:bar", "@alice:localhost", "join")) - MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo2:bar", "@charlie:localhost", "join")) - - MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo3:bar", "@alice:localhost", "join")) - MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo3:bar", "@bob:localhost", "join")) - MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo3:bar", "@dave:localhost", "leave")) - - offset := MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo4:bar", "@alice:localhost", "join")) - waitForOffsetProcessed(t, db, offset) - - testCases := []struct { - req api.QuerySharedUsersRequest - wantRes api.QuerySharedUsersResponse - }{ - // Simple case: sharing (A,B) (A,C) (A,B) (A) produces (A:4,B:2,C:1) - { - req: api.QuerySharedUsersRequest{ - UserID: "@alice:localhost", - }, - wantRes: api.QuerySharedUsersResponse{ - UserIDsToCount: map[string]int{ - "@alice:localhost": 4, - "@bob:localhost": 2, - "@charlie:localhost": 1, - }, - }, - }, - - // Exclude (A,C): sharing (A,B) (A,B) (A) produces (A:3,B:2) - { - req: api.QuerySharedUsersRequest{ - UserID: "@alice:localhost", - ExcludeRoomIDs: []string{"!foo2:bar"}, - }, - wantRes: api.QuerySharedUsersResponse{ - UserIDsToCount: map[string]int{ - "@alice:localhost": 3, - "@bob:localhost": 2, - }, - }, - }, - - // Unknown user has no shared users - { - req: api.QuerySharedUsersRequest{ - UserID: "@unknownuser:localhost", - }, - wantRes: api.QuerySharedUsersResponse{ - UserIDsToCount: map[string]int{}, - }, - }, - - // left real user produces no shared users - { - req: api.QuerySharedUsersRequest{ - UserID: "@dave:localhost", - }, - wantRes: api.QuerySharedUsersResponse{ - UserIDsToCount: map[string]int{}, - }, - }, - - // left real user but with included room returns the included room member - { - req: api.QuerySharedUsersRequest{ - UserID: "@dave:localhost", - IncludeRoomIDs: []string{"!foo:bar"}, - }, - wantRes: api.QuerySharedUsersResponse{ - UserIDsToCount: map[string]int{ - "@alice:localhost": 1, - "@bob:localhost": 1, - }, - }, - }, - - // including a room more than once doesn't double counts - { - req: api.QuerySharedUsersRequest{ - UserID: "@dave:localhost", - IncludeRoomIDs: []string{"!foo:bar", "!foo:bar", "!foo:bar"}, - }, - wantRes: api.QuerySharedUsersResponse{ - UserIDsToCount: map[string]int{ - "@alice:localhost": 1, - "@bob:localhost": 1, - }, - }, - }, - } - - runCases := func(testAPI api.CurrentStateInternalAPI) { - for _, tc := range testCases { - var res api.QuerySharedUsersResponse - err := testAPI.QuerySharedUsers(context.Background(), &tc.req, &res) - if err != nil { - t.Errorf("QuerySharedUsers returned error: %s", err) - continue - } - if !reflect.DeepEqual(res.UserIDsToCount, tc.wantRes.UserIDsToCount) { - t.Errorf("QuerySharedUsers got users %+v want %+v", res.UserIDsToCount, tc.wantRes.UserIDsToCount) - } - } - } - - t.Run("HTTP API", func(t *testing.T) { - router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() - AddInternalRoutes(router, currStateAPI) - apiURL, cancel := test.ListenAndServe(t, router, false) - defer cancel() - httpAPI, err := inthttp.NewCurrentStateAPIClient(apiURL, &http.Client{}) - if err != nil { - t.Fatalf("failed to create HTTP client") - } - runCases(httpAPI) - }) - t.Run("Monolith", func(t *testing.T) { - runCases(currStateAPI) - }) -} diff --git a/currentstateserver/internal/api.go b/currentstateserver/internal/api.go deleted file mode 100644 index 0a7e025e7..000000000 --- a/currentstateserver/internal/api.go +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package internal - -import ( - "context" - "errors" - - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/currentstateserver/acls" - "github.com/matrix-org/dendrite/currentstateserver/api" - "github.com/matrix-org/dendrite/currentstateserver/storage" - "github.com/matrix-org/gomatrixserverlib" -) - -type CurrentStateInternalAPI struct { - DB storage.Database - ServerACLs *acls.ServerACLs -} - -func (a *CurrentStateInternalAPI) QueryCurrentState(ctx context.Context, req *api.QueryCurrentStateRequest, res *api.QueryCurrentStateResponse) error { - res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent) - for _, tuple := range req.StateTuples { - ev, err := a.DB.GetStateEvent(ctx, req.RoomID, tuple.EventType, tuple.StateKey) - if err != nil { - return err - } - if ev != nil { - res.StateEvents[tuple] = ev - } - } - return nil -} - -func (a *CurrentStateInternalAPI) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error { - roomIDs, err := a.DB.GetRoomsByMembership(ctx, req.UserID, req.WantMembership) - if err != nil { - return err - } - res.RoomIDs = roomIDs - return nil -} - -func (a *CurrentStateInternalAPI) QueryKnownUsers(ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse) error { - users, err := a.DB.GetKnownUsers(ctx, req.UserID, req.SearchString, req.Limit) - if err != nil { - return err - } - for _, user := range users { - res.Users = append(res.Users, authtypes.FullyQualifiedProfile{ - UserID: user, - }) - } - return nil -} - -func (a *CurrentStateInternalAPI) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkStateContentRequest, res *api.QueryBulkStateContentResponse) error { - events, err := a.DB.GetBulkStateContent(ctx, req.RoomIDs, req.StateTuples, req.AllowWildcards) - if err != nil { - return err - } - res.Rooms = make(map[string]map[gomatrixserverlib.StateKeyTuple]string) - for _, ev := range events { - if res.Rooms[ev.RoomID] == nil { - res.Rooms[ev.RoomID] = make(map[gomatrixserverlib.StateKeyTuple]string) - } - room := res.Rooms[ev.RoomID] - room[gomatrixserverlib.StateKeyTuple{ - EventType: ev.EventType, - StateKey: ev.StateKey, - }] = ev.ContentValue - res.Rooms[ev.RoomID] = room - } - return nil -} - -func (a *CurrentStateInternalAPI) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error { - roomIDs, err := a.DB.GetRoomsByMembership(ctx, req.UserID, "join") - if err != nil { - return err - } - roomIDs = append(roomIDs, req.IncludeRoomIDs...) - excludeMap := make(map[string]bool) - for _, roomID := range req.ExcludeRoomIDs { - excludeMap[roomID] = true - } - // filter out excluded rooms - j := 0 - for i := range roomIDs { - // move elements to include to the beginning of the slice - // then trim elements on the right - if !excludeMap[roomIDs[i]] { - roomIDs[j] = roomIDs[i] - j++ - } - } - roomIDs = roomIDs[:j] - - users, err := a.DB.JoinedUsersSetInRooms(ctx, roomIDs) - if err != nil { - return err - } - res.UserIDsToCount = users - return nil -} - -func (a *CurrentStateInternalAPI) QueryServerBannedFromRoom(ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse) error { - if a.ServerACLs == nil { - return errors.New("no server ACL tracking") - } - res.Banned = a.ServerACLs.IsServerBannedFromRoom(req.ServerName, req.RoomID) - return nil -} diff --git a/currentstateserver/inthttp/client.go b/currentstateserver/inthttp/client.go deleted file mode 100644 index 6d54f5484..000000000 --- a/currentstateserver/inthttp/client.go +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package inthttp - -import ( - "context" - "errors" - "net/http" - - "github.com/matrix-org/dendrite/currentstateserver/api" - "github.com/matrix-org/dendrite/internal/httputil" - "github.com/opentracing/opentracing-go" -) - -// HTTP paths for the internal HTTP APIs -const ( - QueryCurrentStatePath = "/currentstateserver/queryCurrentState" - QueryRoomsForUserPath = "/currentstateserver/queryRoomsForUser" - QueryBulkStateContentPath = "/currentstateserver/queryBulkStateContent" - QuerySharedUsersPath = "/currentstateserver/querySharedUsers" - QueryKnownUsersPath = "/currentstateserver/queryKnownUsers" - QueryServerBannedFromRoomPath = "/currentstateserver/queryServerBannedFromRoom" -) - -// NewCurrentStateAPIClient creates a CurrentStateInternalAPI implemented by talking to a HTTP POST API. -// If httpClient is nil an error is returned -func NewCurrentStateAPIClient( - apiURL string, - httpClient *http.Client, -) (api.CurrentStateInternalAPI, error) { - if httpClient == nil { - return nil, errors.New("NewCurrentStateAPIClient: httpClient is ") - } - return &httpCurrentStateInternalAPI{ - apiURL: apiURL, - httpClient: httpClient, - }, nil -} - -type httpCurrentStateInternalAPI struct { - apiURL string - httpClient *http.Client -} - -func (h *httpCurrentStateInternalAPI) QueryCurrentState( - ctx context.Context, - request *api.QueryCurrentStateRequest, - response *api.QueryCurrentStateResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryCurrentState") - defer span.Finish() - - apiURL := h.apiURL + QueryCurrentStatePath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} - -func (h *httpCurrentStateInternalAPI) QueryRoomsForUser( - ctx context.Context, - request *api.QueryRoomsForUserRequest, - response *api.QueryRoomsForUserResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRoomsForUser") - defer span.Finish() - - apiURL := h.apiURL + QueryRoomsForUserPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} - -func (h *httpCurrentStateInternalAPI) QueryBulkStateContent( - ctx context.Context, - request *api.QueryBulkStateContentRequest, - response *api.QueryBulkStateContentResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryBulkStateContent") - defer span.Finish() - - apiURL := h.apiURL + QueryBulkStateContentPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} - -func (h *httpCurrentStateInternalAPI) QuerySharedUsers( - ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySharedUsers") - defer span.Finish() - - apiURL := h.apiURL + QuerySharedUsersPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) -} - -func (h *httpCurrentStateInternalAPI) QueryKnownUsers( - ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKnownUsers") - defer span.Finish() - - apiURL := h.apiURL + QueryKnownUsersPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) -} - -func (h *httpCurrentStateInternalAPI) QueryServerBannedFromRoom( - ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerBannedFromRoom") - defer span.Finish() - - apiURL := h.apiURL + QueryServerBannedFromRoomPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) -} diff --git a/currentstateserver/inthttp/server.go b/currentstateserver/inthttp/server.go deleted file mode 100644 index 1cf8cd2ac..000000000 --- a/currentstateserver/inthttp/server.go +++ /dev/null @@ -1,106 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package inthttp - -import ( - "encoding/json" - "net/http" - - "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/currentstateserver/api" - "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/util" -) - -func AddRoutes(internalAPIMux *mux.Router, intAPI api.CurrentStateInternalAPI) { - internalAPIMux.Handle(QueryCurrentStatePath, - httputil.MakeInternalAPI("queryCurrentState", func(req *http.Request) util.JSONResponse { - request := api.QueryCurrentStateRequest{} - response := api.QueryCurrentStateResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := intAPI.QueryCurrentState(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(QueryRoomsForUserPath, - httputil.MakeInternalAPI("queryRoomsForUser", func(req *http.Request) util.JSONResponse { - request := api.QueryRoomsForUserRequest{} - response := api.QueryRoomsForUserResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := intAPI.QueryRoomsForUser(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(QueryBulkStateContentPath, - httputil.MakeInternalAPI("queryBulkStateContent", func(req *http.Request) util.JSONResponse { - request := api.QueryBulkStateContentRequest{} - response := api.QueryBulkStateContentResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := intAPI.QueryBulkStateContent(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(QuerySharedUsersPath, - httputil.MakeInternalAPI("querySharedUsers", func(req *http.Request) util.JSONResponse { - request := api.QuerySharedUsersRequest{} - response := api.QuerySharedUsersResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := intAPI.QuerySharedUsers(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(QuerySharedUsersPath, - httputil.MakeInternalAPI("queryKnownUsers", func(req *http.Request) util.JSONResponse { - request := api.QueryKnownUsersRequest{} - response := api.QueryKnownUsersResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := intAPI.QueryKnownUsers(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(QueryServerBannedFromRoomPath, - httputil.MakeInternalAPI("queryServerBannedFromRoom", func(req *http.Request) util.JSONResponse { - request := api.QueryServerBannedFromRoomRequest{} - response := api.QueryServerBannedFromRoomResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := intAPI.QueryServerBannedFromRoom(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) -} diff --git a/currentstateserver/storage/interface.go b/currentstateserver/storage/interface.go deleted file mode 100644 index 7c1c83b79..000000000 --- a/currentstateserver/storage/interface.go +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package storage - -import ( - "context" - - "github.com/matrix-org/dendrite/currentstateserver/storage/tables" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/gomatrixserverlib" -) - -type Database interface { - internal.PartitionStorer - // StoreStateEvents updates the database with new events from the roomserver. - StoreStateEvents(ctx context.Context, addStateEvents []gomatrixserverlib.HeaderedEvent, removeStateEventIDs []string) error - // GetStateEvent returns the state event of a given type for a given room with a given state key - // If no event could be found, returns nil - // If there was an issue during the retrieval, returns an error - GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) - // GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key). - GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) - // GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match. - // If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned. - GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) - // Redact a state event - RedactEvent(ctx context.Context, redactedEventID string, redactedBecause gomatrixserverlib.HeaderedEvent) error - // JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. - JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) - // GetKnownUsers searches all users that userID knows about. - GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) - // GetKnownRooms returns a list of all rooms we know about. - GetKnownRooms(ctx context.Context) ([]string, error) -} diff --git a/currentstateserver/storage/postgres/current_room_state_table.go b/currentstateserver/storage/postgres/current_room_state_table.go deleted file mode 100644 index 7d2056711..000000000 --- a/currentstateserver/storage/postgres/current_room_state_table.go +++ /dev/null @@ -1,351 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "context" - "database/sql" - "encoding/json" - "fmt" - - "github.com/lib/pq" - "github.com/matrix-org/dendrite/currentstateserver/storage/tables" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" -) - -var currentRoomStateSchema = ` --- Stores the current room state for every room. -CREATE TABLE IF NOT EXISTS currentstate_current_room_state ( - -- The 'room_id' key for the state event. - room_id TEXT NOT NULL, - -- The state event ID - event_id TEXT NOT NULL, - -- The state event type e.g 'm.room.member' - type TEXT NOT NULL, - -- The 'sender' property of the event. - sender TEXT NOT NULL, - -- The state_key value for this state event e.g '' - state_key TEXT NOT NULL, - -- The JSON for the event. Stored as TEXT because this should be valid UTF-8. - headered_event_json TEXT NOT NULL, - -- A piece of extracted content e.g membership for m.room.member events - content_value TEXT NOT NULL DEFAULT '', - -- Clobber based on 3-uple of room_id, type and state_key - CONSTRAINT currentstate_current_room_state_unique UNIQUE (room_id, type, state_key) -); --- for event deletion -CREATE UNIQUE INDEX IF NOT EXISTS currentstate_event_id_idx ON currentstate_current_room_state(event_id, room_id, type, sender); --- for querying membership states of users -CREATE INDEX IF NOT EXISTS currentstate_membership_idx ON currentstate_current_room_state(type, state_key, content_value) -WHERE type='m.room.member' AND content_value IS NOT NULL AND content_value != 'leave'; -` - -const upsertRoomStateSQL = "" + - "INSERT INTO currentstate_current_room_state (room_id, event_id, type, sender, state_key, headered_event_json, content_value)" + - " VALUES ($1, $2, $3, $4, $5, $6, $7)" + - " ON CONFLICT ON CONSTRAINT currentstate_current_room_state_unique" + - " DO UPDATE SET event_id = $2, sender=$4, headered_event_json = $6, content_value = $7" - -const deleteRoomStateByEventIDSQL = "" + - "DELETE FROM currentstate_current_room_state WHERE event_id = $1" - -const selectRoomIDsWithMembershipSQL = "" + - "SELECT room_id FROM currentstate_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND content_value = $2" - -const selectStateEventSQL = "" + - "SELECT headered_event_json FROM currentstate_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3" - -const selectEventsWithEventIDsSQL = "" + - "SELECT headered_event_json FROM currentstate_current_room_state WHERE event_id = ANY($1)" - -const selectBulkStateContentSQL = "" + - "SELECT room_id, type, state_key, content_value FROM currentstate_current_room_state WHERE room_id = ANY($1) AND type = ANY($2) AND state_key = ANY($3)" - -const selectBulkStateContentWildSQL = "" + - "SELECT room_id, type, state_key, content_value FROM currentstate_current_room_state WHERE room_id = ANY($1) AND type = ANY($2)" - -const selectJoinedUsersSetForRoomsSQL = "" + - "SELECT state_key, COUNT(room_id) FROM currentstate_current_room_state WHERE room_id = ANY($1) AND" + - " type = 'm.room.member' and content_value = 'join' GROUP BY state_key" - -const selectKnownRoomsSQL = "" + - "SELECT DISTINCT room_id FROM currentstate_current_room_state" - -// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is -// joined to. Since this information is used to populate the user directory, we will -// only return users that the user would ordinarily be able to see anyway. -const selectKnownUsersSQL = "" + - "SELECT DISTINCT state_key FROM currentstate_current_room_state WHERE room_id = ANY(" + - " SELECT DISTINCT room_id FROM currentstate_current_room_state WHERE state_key=$1 AND TYPE='m.room.member' AND content_value='join'" + - ") AND TYPE='m.room.member' AND content_value='join' AND state_key LIKE $2 LIMIT $3" - -type currentRoomStateStatements struct { - upsertRoomStateStmt *sql.Stmt - deleteRoomStateByEventIDStmt *sql.Stmt - selectRoomIDsWithMembershipStmt *sql.Stmt - selectEventsWithEventIDsStmt *sql.Stmt - selectStateEventStmt *sql.Stmt - selectBulkStateContentStmt *sql.Stmt - selectBulkStateContentWildStmt *sql.Stmt - selectJoinedUsersSetForRoomsStmt *sql.Stmt - selectKnownRoomsStmt *sql.Stmt - selectKnownUsersStmt *sql.Stmt -} - -func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) { - s := ¤tRoomStateStatements{} - _, err := db.Exec(currentRoomStateSchema) - if err != nil { - return nil, err - } - if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil { - return nil, err - } - if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil { - return nil, err - } - if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil { - return nil, err - } - if s.selectEventsWithEventIDsStmt, err = db.Prepare(selectEventsWithEventIDsSQL); err != nil { - return nil, err - } - if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil { - return nil, err - } - if s.selectBulkStateContentStmt, err = db.Prepare(selectBulkStateContentSQL); err != nil { - return nil, err - } - if s.selectBulkStateContentWildStmt, err = db.Prepare(selectBulkStateContentWildSQL); err != nil { - return nil, err - } - if s.selectJoinedUsersSetForRoomsStmt, err = db.Prepare(selectJoinedUsersSetForRoomsSQL); err != nil { - return nil, err - } - if s.selectKnownRoomsStmt, err = db.Prepare(selectKnownRoomsSQL); err != nil { - return nil, err - } - if s.selectKnownUsersStmt, err = db.Prepare(selectKnownUsersSQL); err != nil { - return nil, err - } - return s, nil -} - -func (s *currentRoomStateStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) (map[string]int, error) { - rows, err := s.selectJoinedUsersSetForRoomsStmt.QueryContext(ctx, pq.StringArray(roomIDs)) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed") - result := make(map[string]int) - for rows.Next() { - var userID string - var count int - if err := rows.Scan(&userID, &count); err != nil { - return nil, err - } - result[userID] = count - } - return result, rows.Err() -} - -// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. -func (s *currentRoomStateStatements) SelectRoomIDsWithMembership( - ctx context.Context, - txn *sql.Tx, - userID string, - contentVal string, -) ([]string, error) { - stmt := sqlutil.TxStmt(txn, s.selectRoomIDsWithMembershipStmt) - rows, err := stmt.QueryContext(ctx, userID, contentVal) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsWithMembership: rows.close() failed") - - var result []string - for rows.Next() { - var roomID string - if err := rows.Scan(&roomID); err != nil { - return nil, err - } - result = append(result, roomID) - } - return result, rows.Err() -} - -func (s *currentRoomStateStatements) DeleteRoomStateByEventID( - ctx context.Context, txn *sql.Tx, eventID string, -) error { - stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt) - _, err := stmt.ExecContext(ctx, eventID) - return err -} - -func (s *currentRoomStateStatements) UpsertRoomState( - ctx context.Context, txn *sql.Tx, - event gomatrixserverlib.HeaderedEvent, contentVal string, -) error { - headeredJSON, err := json.Marshal(event) - if err != nil { - return err - } - - // upsert state event - stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt) - _, err = stmt.ExecContext( - ctx, - event.RoomID(), - event.EventID(), - event.Type(), - event.Sender(), - *event.StateKey(), - headeredJSON, - contentVal, - ) - return err -} - -func (s *currentRoomStateStatements) SelectEventsWithEventIDs( - ctx context.Context, txn *sql.Tx, eventIDs []string, -) ([]gomatrixserverlib.HeaderedEvent, error) { - stmt := sqlutil.TxStmt(txn, s.selectEventsWithEventIDsStmt) - rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectEventsWithEventIDs: rows.close() failed") - result := []gomatrixserverlib.HeaderedEvent{} - for rows.Next() { - var eventBytes []byte - if err := rows.Scan(&eventBytes); err != nil { - return nil, err - } - var ev gomatrixserverlib.HeaderedEvent - if err := json.Unmarshal(eventBytes, &ev); err != nil { - return nil, err - } - result = append(result, ev) - } - return result, rows.Err() -} - -func (s *currentRoomStateStatements) SelectStateEvent( - ctx context.Context, roomID, evType, stateKey string, -) (*gomatrixserverlib.HeaderedEvent, error) { - stmt := s.selectStateEventStmt - var res []byte - err := stmt.QueryRowContext(ctx, roomID, evType, stateKey).Scan(&res) - if err == sql.ErrNoRows { - return nil, nil - } - if err != nil { - return nil, err - } - var ev gomatrixserverlib.HeaderedEvent - if err = json.Unmarshal(res, &ev); err != nil { - return nil, err - } - return &ev, err -} - -func (s *currentRoomStateStatements) SelectBulkStateContent( - ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool, -) ([]tables.StrippedEvent, error) { - hasWildcards := false - eventTypeSet := make(map[string]bool) - stateKeySet := make(map[string]bool) - var eventTypes []string - var stateKeys []string - for _, tuple := range tuples { - if !eventTypeSet[tuple.EventType] { - eventTypeSet[tuple.EventType] = true - eventTypes = append(eventTypes, tuple.EventType) - } - if !stateKeySet[tuple.StateKey] { - stateKeySet[tuple.StateKey] = true - stateKeys = append(stateKeys, tuple.StateKey) - } - if tuple.StateKey == "*" { - hasWildcards = true - } - } - var rows *sql.Rows - var err error - if hasWildcards && allowWildcards { - rows, err = s.selectBulkStateContentWildStmt.QueryContext(ctx, pq.StringArray(roomIDs), pq.StringArray(eventTypes)) - } else { - rows, err = s.selectBulkStateContentStmt.QueryContext( - ctx, pq.StringArray(roomIDs), pq.StringArray(eventTypes), pq.StringArray(stateKeys), - ) - } - if err != nil { - return nil, err - } - strippedEvents := []tables.StrippedEvent{} - defer internal.CloseAndLogIfError(ctx, rows, "SelectBulkStateContent: rows.close() failed") - for rows.Next() { - var roomID string - var eventType string - var stateKey string - var contentVal string - if err = rows.Scan(&roomID, &eventType, &stateKey, &contentVal); err != nil { - return nil, err - } - strippedEvents = append(strippedEvents, tables.StrippedEvent{ - RoomID: roomID, - ContentValue: contentVal, - EventType: eventType, - StateKey: stateKey, - }) - } - return strippedEvents, rows.Err() -} - -func (s *currentRoomStateStatements) SelectKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) { - rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) - if err != nil { - return nil, err - } - result := []string{} - defer internal.CloseAndLogIfError(ctx, rows, "SelectKnownUsers: rows.close() failed") - for rows.Next() { - var userID string - if err := rows.Scan(&userID); err != nil { - return nil, err - } - result = append(result, userID) - } - return result, rows.Err() -} - -func (s *currentRoomStateStatements) SelectKnownRooms(ctx context.Context) ([]string, error) { - rows, err := s.selectKnownRoomsStmt.QueryContext(ctx) - if err != nil { - return nil, err - } - result := []string{} - defer internal.CloseAndLogIfError(ctx, rows, "SelectKnownRooms: rows.close() failed") - for rows.Next() { - var roomID string - if err := rows.Scan(&roomID); err != nil { - return nil, err - } - result = append(result, roomID) - } - return result, rows.Err() -} diff --git a/currentstateserver/storage/postgres/storage.go b/currentstateserver/storage/postgres/storage.go deleted file mode 100644 index cb5ebff0d..000000000 --- a/currentstateserver/storage/postgres/storage.go +++ /dev/null @@ -1,39 +0,0 @@ -package postgres - -import ( - "database/sql" - - "github.com/matrix-org/dendrite/currentstateserver/storage/shared" - "github.com/matrix-org/dendrite/internal/config" - "github.com/matrix-org/dendrite/internal/sqlutil" -) - -type Database struct { - shared.Database - db *sql.DB - writer sqlutil.Writer - sqlutil.PartitionOffsetStatements -} - -// NewDatabase creates a new sync server database -func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) { - var d Database - var err error - if d.db, err = sqlutil.Open(dbProperties); err != nil { - return nil, err - } - d.writer = sqlutil.NewDummyWriter() - if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "currentstate"); err != nil { - return nil, err - } - currRoomState, err := NewPostgresCurrentRoomStateTable(d.db) - if err != nil { - return nil, err - } - d.Database = shared.Database{ - DB: d.db, - Writer: d.writer, - CurrentRoomState: currRoomState, - } - return &d, nil -} diff --git a/currentstateserver/storage/shared/storage.go b/currentstateserver/storage/shared/storage.go deleted file mode 100644 index 2cf40ccc4..000000000 --- a/currentstateserver/storage/shared/storage.go +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package shared - -import ( - "context" - "database/sql" - "fmt" - - "github.com/matrix-org/dendrite/currentstateserver/storage/tables" - "github.com/matrix-org/dendrite/internal/eventutil" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" -) - -type Database struct { - DB *sql.DB - Writer sqlutil.Writer - CurrentRoomState tables.CurrentRoomState -} - -func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) { - return d.CurrentRoomState.SelectStateEvent(ctx, roomID, evType, stateKey) -} - -func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) { - return d.CurrentRoomState.SelectBulkStateContent(ctx, roomIDs, tuples, allowWildcards) -} - -func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, redactedBecause gomatrixserverlib.HeaderedEvent) error { - events, err := d.CurrentRoomState.SelectEventsWithEventIDs(ctx, nil, []string{redactedEventID}) - if err != nil { - return err - } - if len(events) != 1 { - // this will happen for all non-state events - return nil - } - redactionEvent := redactedBecause.Unwrap() - eventBeingRedacted := events[0].Unwrap() - redactedEvent, err := eventutil.RedactEvent(&redactionEvent, &eventBeingRedacted) - if err != nil { - return fmt.Errorf("RedactEvent failed: %w", err) - } - // replace the state event with a redacted version of itself - return d.StoreStateEvents(ctx, []gomatrixserverlib.HeaderedEvent{redactedEvent.Headered(redactedBecause.RoomVersion)}, []string{redactedEventID}) -} - -func (d *Database) StoreStateEvents(ctx context.Context, addStateEvents []gomatrixserverlib.HeaderedEvent, - removeStateEventIDs []string) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - // remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add. - for _, eventID := range removeStateEventIDs { - if err := d.CurrentRoomState.DeleteRoomStateByEventID(ctx, txn, eventID); err != nil { - return err - } - } - - for _, event := range addStateEvents { - if event.StateKey() == nil { - // ignore non state events - continue - } - contentVal := tables.ExtractContentValue(&event) - - if err := d.CurrentRoomState.UpsertRoomState(ctx, txn, event, contentVal); err != nil { - return err - } - } - return nil - }) -} - -func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) { - return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, nil, userID, membership) -} - -func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) { - return d.CurrentRoomState.SelectJoinedUsersSetForRooms(ctx, roomIDs) -} - -func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) { - return d.CurrentRoomState.SelectKnownUsers(ctx, userID, searchString, limit) -} - -func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) { - return d.CurrentRoomState.SelectKnownRooms(ctx) -} diff --git a/currentstateserver/storage/sqlite3/current_room_state_table.go b/currentstateserver/storage/sqlite3/current_room_state_table.go deleted file mode 100644 index 1985bf422..000000000 --- a/currentstateserver/storage/sqlite3/current_room_state_table.go +++ /dev/null @@ -1,365 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "context" - "database/sql" - "encoding/json" - "fmt" - "strings" - - "github.com/matrix-org/dendrite/currentstateserver/storage/tables" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" -) - -const currentRoomStateSchema = ` --- Stores the current room state for every room. -CREATE TABLE IF NOT EXISTS currentstate_current_room_state ( - room_id TEXT NOT NULL, - event_id TEXT NOT NULL, - type TEXT NOT NULL, - sender TEXT NOT NULL, - state_key TEXT NOT NULL, - headered_event_json TEXT NOT NULL, - content_value TEXT NOT NULL DEFAULT '', - UNIQUE (room_id, type, state_key) -); --- for event deletion -CREATE UNIQUE INDEX IF NOT EXISTS currentstate_event_id_idx ON currentstate_current_room_state(event_id, room_id, type, sender); -` - -const upsertRoomStateSQL = "" + - "INSERT INTO currentstate_current_room_state (room_id, event_id, type, sender, state_key, headered_event_json, content_value)" + - " VALUES ($1, $2, $3, $4, $5, $6, $7)" + - " ON CONFLICT (event_id, room_id, type, sender)" + - " DO UPDATE SET event_id = $2, sender=$4, headered_event_json = $6, content_value = $7" - -const deleteRoomStateByEventIDSQL = "" + - "DELETE FROM currentstate_current_room_state WHERE event_id = $1" - -const selectRoomIDsWithMembershipSQL = "" + - "SELECT room_id FROM currentstate_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND content_value = $2" - -const selectStateEventSQL = "" + - "SELECT headered_event_json FROM currentstate_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3" - -const selectEventsWithEventIDsSQL = "" + - "SELECT headered_event_json FROM currentstate_current_room_state WHERE event_id IN ($1)" - -const selectBulkStateContentSQL = "" + - "SELECT room_id, type, state_key, content_value FROM currentstate_current_room_state WHERE room_id IN ($1) AND type IN ($2) AND state_key IN ($3)" - -const selectBulkStateContentWildSQL = "" + - "SELECT room_id, type, state_key, content_value FROM currentstate_current_room_state WHERE room_id IN ($1) AND type IN ($2)" - -const selectJoinedUsersSetForRoomsSQL = "" + - "SELECT state_key, COUNT(room_id) FROM currentstate_current_room_state WHERE room_id IN ($1) AND type = 'm.room.member' and content_value = 'join' GROUP BY state_key" - -const selectKnownRoomsSQL = "" + - "SELECT DISTINCT room_id FROM currentstate_current_room_state" - -// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is -// joined to. Since this information is used to populate the user directory, we will -// only return users that the user would ordinarily be able to see anyway. -const selectKnownUsersSQL = "" + - "SELECT DISTINCT state_key FROM currentstate_current_room_state WHERE room_id IN (" + - " SELECT DISTINCT room_id FROM currentstate_current_room_state WHERE state_key=$1 AND TYPE='m.room.member' AND content_value='join'" + - ") AND TYPE='m.room.member' AND content_value='join' AND state_key LIKE $2 LIMIT $3" - -type currentRoomStateStatements struct { - db *sql.DB - upsertRoomStateStmt *sql.Stmt - deleteRoomStateByEventIDStmt *sql.Stmt - selectRoomIDsWithMembershipStmt *sql.Stmt - selectStateEventStmt *sql.Stmt - selectJoinedUsersSetForRoomsStmt *sql.Stmt - selectKnownRoomsStmt *sql.Stmt - selectKnownUsersStmt *sql.Stmt -} - -func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) { - s := ¤tRoomStateStatements{ - db: db, - } - _, err := db.Exec(currentRoomStateSchema) - if err != nil { - return nil, err - } - if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil { - return nil, err - } - if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil { - return nil, err - } - if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil { - return nil, err - } - if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil { - return nil, err - } - if s.selectJoinedUsersSetForRoomsStmt, err = db.Prepare(selectJoinedUsersSetForRoomsSQL); err != nil { - return nil, err - } - if s.selectKnownRoomsStmt, err = db.Prepare(selectKnownRoomsSQL); err != nil { - return nil, err - } - if s.selectKnownUsersStmt, err = db.Prepare(selectKnownUsersSQL); err != nil { - return nil, err - } - return s, nil -} - -func (s *currentRoomStateStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) (map[string]int, error) { - iRoomIDs := make([]interface{}, len(roomIDs)) - for i, v := range roomIDs { - iRoomIDs[i] = v - } - query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1) - rows, err := s.db.QueryContext(ctx, query, iRoomIDs...) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed") - result := make(map[string]int) - for rows.Next() { - var userID string - var count int - if err := rows.Scan(&userID, &count); err != nil { - return nil, err - } - result[userID] = count - } - return result, rows.Err() -} - -// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. -func (s *currentRoomStateStatements) SelectRoomIDsWithMembership( - ctx context.Context, - txn *sql.Tx, - userID string, - membership string, -) ([]string, error) { - stmt := sqlutil.TxStmt(txn, s.selectRoomIDsWithMembershipStmt) - rows, err := stmt.QueryContext(ctx, userID, membership) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsWithMembership: rows.close() failed") - - var result []string - for rows.Next() { - var roomID string - if err := rows.Scan(&roomID); err != nil { - return nil, err - } - result = append(result, roomID) - } - return result, nil -} - -func (s *currentRoomStateStatements) DeleteRoomStateByEventID( - ctx context.Context, txn *sql.Tx, eventID string, -) error { - stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt) - _, err := stmt.ExecContext(ctx, eventID) - return err -} - -func (s *currentRoomStateStatements) UpsertRoomState( - ctx context.Context, txn *sql.Tx, - event gomatrixserverlib.HeaderedEvent, contentVal string, -) error { - headeredJSON, err := json.Marshal(event) - if err != nil { - return err - } - - // upsert state event - stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt) - _, err = stmt.ExecContext( - ctx, - event.RoomID(), - event.EventID(), - event.Type(), - event.Sender(), - *event.StateKey(), - headeredJSON, - contentVal, - ) - return err -} - -func (s *currentRoomStateStatements) SelectEventsWithEventIDs( - ctx context.Context, txn *sql.Tx, eventIDs []string, -) ([]gomatrixserverlib.HeaderedEvent, error) { - iEventIDs := make([]interface{}, len(eventIDs)) - for k, v := range eventIDs { - iEventIDs[k] = v - } - query := strings.Replace(selectEventsWithEventIDsSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) - var rows *sql.Rows - var err error - if txn != nil { - rows, err = txn.QueryContext(ctx, query, iEventIDs...) - } else { - rows, err = s.db.QueryContext(ctx, query, iEventIDs...) - } - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectEventsWithEventIDs: rows.close() failed") - result := []gomatrixserverlib.HeaderedEvent{} - for rows.Next() { - var eventBytes []byte - if err := rows.Scan(&eventBytes); err != nil { - return nil, err - } - var ev gomatrixserverlib.HeaderedEvent - if err := json.Unmarshal(eventBytes, &ev); err != nil { - return nil, err - } - result = append(result, ev) - } - return result, nil -} - -func (s *currentRoomStateStatements) SelectStateEvent( - ctx context.Context, roomID, evType, stateKey string, -) (*gomatrixserverlib.HeaderedEvent, error) { - stmt := s.selectStateEventStmt - var res []byte - err := stmt.QueryRowContext(ctx, roomID, evType, stateKey).Scan(&res) - if err == sql.ErrNoRows { - return nil, nil - } - if err != nil { - return nil, err - } - var ev gomatrixserverlib.HeaderedEvent - if err = json.Unmarshal(res, &ev); err != nil { - return nil, err - } - return &ev, err -} - -func (s *currentRoomStateStatements) SelectBulkStateContent( - ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool, -) ([]tables.StrippedEvent, error) { - hasWildcards := false - eventTypeSet := make(map[string]bool) - stateKeySet := make(map[string]bool) - var eventTypes []string - var stateKeys []string - for _, tuple := range tuples { - if !eventTypeSet[tuple.EventType] { - eventTypeSet[tuple.EventType] = true - eventTypes = append(eventTypes, tuple.EventType) - } - if !stateKeySet[tuple.StateKey] { - stateKeySet[tuple.StateKey] = true - stateKeys = append(stateKeys, tuple.StateKey) - } - if tuple.StateKey == "*" { - hasWildcards = true - } - } - - iRoomIDs := make([]interface{}, len(roomIDs)) - for i, v := range roomIDs { - iRoomIDs[i] = v - } - iEventTypes := make([]interface{}, len(eventTypes)) - for i, v := range eventTypes { - iEventTypes[i] = v - } - iStateKeys := make([]interface{}, len(stateKeys)) - for i, v := range stateKeys { - iStateKeys[i] = v - } - - var query string - var args []interface{} - if hasWildcards && allowWildcards { - query = strings.Replace(selectBulkStateContentWildSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1) - query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(iEventTypes), len(iRoomIDs)), 1) - args = append(iRoomIDs, iEventTypes...) - } else { - query = strings.Replace(selectBulkStateContentSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1) - query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(iEventTypes), len(iRoomIDs)), 1) - query = strings.Replace(query, "($3)", sqlutil.QueryVariadicOffset(len(iStateKeys), len(iEventTypes)+len(iRoomIDs)), 1) - args = append(iRoomIDs, iEventTypes...) - args = append(args, iStateKeys...) - } - rows, err := s.db.QueryContext(ctx, query, args...) - - if err != nil { - return nil, err - } - strippedEvents := []tables.StrippedEvent{} - defer internal.CloseAndLogIfError(ctx, rows, "SelectBulkStateContent: rows.close() failed") - for rows.Next() { - var roomID string - var eventType string - var stateKey string - var contentVal string - if err = rows.Scan(&roomID, &eventType, &stateKey, &contentVal); err != nil { - return nil, err - } - strippedEvents = append(strippedEvents, tables.StrippedEvent{ - RoomID: roomID, - ContentValue: contentVal, - EventType: eventType, - StateKey: stateKey, - }) - } - return strippedEvents, rows.Err() -} - -func (s *currentRoomStateStatements) SelectKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) { - rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) - if err != nil { - return nil, err - } - result := []string{} - defer internal.CloseAndLogIfError(ctx, rows, "SelectKnownUsers: rows.close() failed") - for rows.Next() { - var userID string - if err := rows.Scan(&userID); err != nil { - return nil, err - } - result = append(result, userID) - } - return result, rows.Err() -} - -func (s *currentRoomStateStatements) SelectKnownRooms(ctx context.Context) ([]string, error) { - rows, err := s.selectKnownRoomsStmt.QueryContext(ctx) - if err != nil { - return nil, err - } - result := []string{} - defer internal.CloseAndLogIfError(ctx, rows, "SelectKnownRooms: rows.close() failed") - for rows.Next() { - var roomID string - if err := rows.Scan(&roomID); err != nil { - return nil, err - } - result = append(result, roomID) - } - return result, rows.Err() -} diff --git a/currentstateserver/storage/sqlite3/storage.go b/currentstateserver/storage/sqlite3/storage.go deleted file mode 100644 index e79afd70a..000000000 --- a/currentstateserver/storage/sqlite3/storage.go +++ /dev/null @@ -1,40 +0,0 @@ -package sqlite3 - -import ( - "database/sql" - - "github.com/matrix-org/dendrite/currentstateserver/storage/shared" - "github.com/matrix-org/dendrite/internal/config" - "github.com/matrix-org/dendrite/internal/sqlutil" -) - -type Database struct { - shared.Database - db *sql.DB - writer sqlutil.Writer - sqlutil.PartitionOffsetStatements -} - -// NewDatabase creates a new sync server database -// nolint: gocyclo -func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) { - var d Database - var err error - if d.db, err = sqlutil.Open(dbProperties); err != nil { - return nil, err - } - d.writer = sqlutil.NewExclusiveWriter() - if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "currentstate"); err != nil { - return nil, err - } - currRoomState, err := NewSqliteCurrentRoomStateTable(d.db) - if err != nil { - return nil, err - } - d.Database = shared.Database{ - DB: d.db, - Writer: d.writer, - CurrentRoomState: currRoomState, - } - return &d, nil -} diff --git a/currentstateserver/storage/storage.go b/currentstateserver/storage/storage.go deleted file mode 100644 index e0707def4..000000000 --- a/currentstateserver/storage/storage.go +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build !wasm - -package storage - -import ( - "fmt" - - "github.com/matrix-org/dendrite/currentstateserver/storage/postgres" - "github.com/matrix-org/dendrite/currentstateserver/storage/sqlite3" - "github.com/matrix-org/dendrite/internal/config" -) - -// NewDatabase opens a database connection. -func NewDatabase(dbProperties *config.DatabaseOptions) (Database, error) { - switch { - case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties) - case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(dbProperties) - default: - return nil, fmt.Errorf("unexpected database type") - } -} diff --git a/currentstateserver/storage/storage_wasm.go b/currentstateserver/storage/storage_wasm.go deleted file mode 100644 index 46a5abd64..000000000 --- a/currentstateserver/storage/storage_wasm.go +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package storage - -import ( - "fmt" - - "github.com/matrix-org/dendrite/currentstateserver/storage/sqlite3" - "github.com/matrix-org/dendrite/internal/config" -) - -// NewDatabase opens a database connection. -func NewDatabase(dbProperties *config.DatabaseOptions) (Database, error) { - switch { - case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties) - case dbProperties.ConnectionString.IsPostgres(): - return nil, fmt.Errorf("can't use Postgres implementation") - default: - return nil, fmt.Errorf("unexpected database type") - } -} diff --git a/currentstateserver/storage/tables/interface.go b/currentstateserver/storage/tables/interface.go deleted file mode 100644 index cc5c6e808..000000000 --- a/currentstateserver/storage/tables/interface.go +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tables - -import ( - "context" - "database/sql" - - "github.com/matrix-org/gomatrixserverlib" - "github.com/tidwall/gjson" -) - -type CurrentRoomState interface { - SelectStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) - // SelectEventsWithEventIDs returns the events for the given event IDs. If the event(s) are missing, they are not returned - // and no error is returned. - SelectEventsWithEventIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]gomatrixserverlib.HeaderedEvent, error) - // UpsertRoomState stores the given event in the database, along with an extracted piece of content. - // The piece of content will vary depending on the event type, and table implementations may use this information to optimise - // lookups e.g membership lookups. The mapped value of `contentVal` is outlined in ExtractContentValue. An empty `contentVal` - // means there is nothing to store for this field. - UpsertRoomState(ctx context.Context, txn *sql.Tx, event gomatrixserverlib.HeaderedEvent, contentVal string) error - DeleteRoomStateByEventID(ctx context.Context, txn *sql.Tx, eventID string) error - // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. - SelectRoomIDsWithMembership(ctx context.Context, txn *sql.Tx, userID string, membership string) ([]string, error) - SelectBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]StrippedEvent, error) - // SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the - // counts of how many rooms they are joined. - SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) (map[string]int, error) - // SelectKnownUsers searches all users that userID knows about. - SelectKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) - // SelectKnownRooms returns all rooms that we know about. - SelectKnownRooms(ctx context.Context) ([]string, error) -} - -// StrippedEvent represents a stripped event for returning extracted content values. -type StrippedEvent struct { - RoomID string - EventType string - StateKey string - ContentValue string -} - -// ExtractContentValue from the given state event. For example, given an m.room.name event with: -// content: { name: "Foo" } -// this returns "Foo". -func ExtractContentValue(ev *gomatrixserverlib.HeaderedEvent) string { - content := ev.Content() - key := "" - switch ev.Type() { - case gomatrixserverlib.MRoomCreate: - key = "creator" - case gomatrixserverlib.MRoomCanonicalAlias: - key = "alias" - case gomatrixserverlib.MRoomHistoryVisibility: - key = "history_visibility" - case gomatrixserverlib.MRoomJoinRules: - key = "join_rule" - case gomatrixserverlib.MRoomMember: - key = "membership" - case gomatrixserverlib.MRoomName: - key = "name" - case "m.room.avatar": - key = "url" - case "m.room.topic": - key = "topic" - case "m.room.guest_access": - key = "guest_access" - } - result := gjson.GetBytes(content, key) - if !result.Exists() { - return "" - } - // this returns the empty string if this is not a string type - return result.Str -} diff --git a/dendrite-config.yaml b/dendrite-config.yaml index 570669c1a..be0972e4a 100644 --- a/dendrite-config.yaml +++ b/dendrite-config.yaml @@ -141,17 +141,6 @@ client_api: threshold: 5 cooloff_ms: 500 -# Configuration for the Current State Server. -current_state_server: - internal_api: - listen: http://localhost:7782 - connect: http://localhost:7782 - database: - connection_string: file:currentstate.db - max_open_conns: 100 - max_idle_conns: 2 - conn_max_lifetime: -1 - # Configuration for the EDU server. edu_server: internal_api: diff --git a/docs/INSTALL.md b/docs/INSTALL.md index dedcf1517..7a7fb03ee 100644 --- a/docs/INSTALL.md +++ b/docs/INSTALL.md @@ -109,7 +109,7 @@ Assuming that Postgres 9.5 (or later) is installed: * Create the component databases: ```bash - for i in account device mediaapi syncapi roomserver serverkey federationsender currentstate appservice e2ekey naffka; do + for i in account device mediaapi syncapi roomserver serverkey federationsender appservice e2ekey naffka; do sudo -u postgres createdb -O dendrite dendrite_$i done ``` @@ -239,16 +239,6 @@ This is what implements the room DAG. Clients do not talk to this. ./bin/dendrite-room-server --config=dendrite.yaml ``` -#### Current state server - -This tracks the current state of rooms which various components need to know. For example, -`/publicRooms` implemented by client API asks this server for the room names, joined member -counts, etc. - -```bash -./bin/dendrite-current-state-server --config=dendrite.yaml -``` - #### Federation sender This sends events from our users to other servers. This is only required if diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index 9193685a8..944e2797c 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -16,7 +16,6 @@ package federationapi import ( "github.com/gorilla/mux" - currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api" eduserverAPI "github.com/matrix-org/dendrite/eduserver/api" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/internal/config" @@ -38,12 +37,11 @@ func AddPublicRoutes( rsAPI roomserverAPI.RoomserverInternalAPI, federationSenderAPI federationSenderAPI.FederationSenderInternalAPI, eduAPI eduserverAPI.EDUServerInputAPI, - stateAPI currentstateAPI.CurrentStateInternalAPI, keyAPI keyserverAPI.KeyInternalAPI, ) { routing.Setup( fedRouter, keyRouter, cfg, rsAPI, eduAPI, federationSenderAPI, keyRing, - federation, userAPI, stateAPI, keyAPI, + federation, userAPI, keyAPI, ) } diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index 45346bc0f..3c2e5bbb0 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -31,7 +31,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) { fsAPI := base.FederationSenderHTTPClient() // TODO: This is pretty fragile, as if anything calls anything on these nils this test will break. // Unfortunately, it makes little sense to instantiate these dependencies when we just want to test routing. - federationapi.AddPublicRoutes(base.PublicFederationAPIMux, base.PublicKeyAPIMux, &cfg.FederationAPI, nil, nil, keyRing, nil, fsAPI, nil, nil, nil) + federationapi.AddPublicRoutes(base.PublicFederationAPIMux, base.PublicKeyAPIMux, &cfg.FederationAPI, nil, nil, keyRing, nil, fsAPI, nil, nil) baseURL, cancel := test.ListenAndServe(t, base.PublicFederationAPIMux, true) defer cancel() serverName := gomatrixserverlib.ServerName(strings.TrimPrefix(baseURL, "https://")) diff --git a/federationapi/routing/publicrooms.go b/federationapi/routing/publicrooms.go index 3807a5183..d923a236a 100644 --- a/federationapi/routing/publicrooms.go +++ b/federationapi/routing/publicrooms.go @@ -7,7 +7,6 @@ import ( "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" - currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -24,7 +23,7 @@ type filter struct { } // GetPostPublicRooms implements GET and POST /publicRooms -func GetPostPublicRooms(req *http.Request, rsAPI roomserverAPI.RoomserverInternalAPI, stateAPI currentstateAPI.CurrentStateInternalAPI) util.JSONResponse { +func GetPostPublicRooms(req *http.Request, rsAPI roomserverAPI.RoomserverInternalAPI) util.JSONResponse { var request PublicRoomReq if fillErr := fillPublicRoomsReq(req, &request); fillErr != nil { return *fillErr @@ -32,7 +31,7 @@ func GetPostPublicRooms(req *http.Request, rsAPI roomserverAPI.RoomserverInterna if request.Limit == 0 { request.Limit = 50 } - response, err := publicRooms(req.Context(), request, rsAPI, stateAPI) + response, err := publicRooms(req.Context(), request, rsAPI) if err != nil { return jsonerror.InternalServerError() } @@ -42,8 +41,9 @@ func GetPostPublicRooms(req *http.Request, rsAPI roomserverAPI.RoomserverInterna } } -func publicRooms(ctx context.Context, request PublicRoomReq, rsAPI roomserverAPI.RoomserverInternalAPI, - stateAPI currentstateAPI.CurrentStateInternalAPI) (*gomatrixserverlib.RespPublicRooms, error) { +func publicRooms( + ctx context.Context, request PublicRoomReq, rsAPI roomserverAPI.RoomserverInternalAPI, +) (*gomatrixserverlib.RespPublicRooms, error) { var response gomatrixserverlib.RespPublicRooms var limit int16 @@ -80,7 +80,7 @@ func publicRooms(ctx context.Context, request PublicRoomReq, rsAPI roomserverAPI nextIndex = len(queryRes.RoomIDs) } roomIDs := queryRes.RoomIDs[offset:nextIndex] - response.Chunk, err = fillInRooms(ctx, roomIDs, stateAPI) + response.Chunk, err = fillInRooms(ctx, roomIDs, rsAPI) return &response, err } @@ -112,7 +112,7 @@ func fillPublicRoomsReq(httpReq *http.Request, request *PublicRoomReq) *util.JSO // due to lots of switches // nolint:gocyclo -func fillInRooms(ctx context.Context, roomIDs []string, stateAPI currentstateAPI.CurrentStateInternalAPI) ([]gomatrixserverlib.PublicRoom, error) { +func fillInRooms(ctx context.Context, roomIDs []string, rsAPI roomserverAPI.RoomserverInternalAPI) ([]gomatrixserverlib.PublicRoom, error) { avatarTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.avatar", StateKey: ""} nameTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.name", StateKey: ""} canonicalTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomCanonicalAlias, StateKey: ""} @@ -121,8 +121,8 @@ func fillInRooms(ctx context.Context, roomIDs []string, stateAPI currentstateAPI visibilityTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomHistoryVisibility, StateKey: ""} joinRuleTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomJoinRules, StateKey: ""} - var stateRes currentstateAPI.QueryBulkStateContentResponse - err := stateAPI.QueryBulkStateContent(ctx, ¤tstateAPI.QueryBulkStateContentRequest{ + var stateRes roomserverAPI.QueryBulkStateContentResponse + err := rsAPI.QueryBulkStateContent(ctx, &roomserverAPI.QueryBulkStateContentRequest{ RoomIDs: roomIDs, AllowWildcards: true, StateTuples: []gomatrixserverlib.StateKeyTuple{ diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index 5ea190a17..71a09d421 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -19,7 +19,6 @@ import ( "github.com/gorilla/mux" "github.com/matrix-org/dendrite/clientapi/jsonerror" - currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api" eduserverAPI "github.com/matrix-org/dendrite/eduserver/api" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/internal/config" @@ -48,7 +47,6 @@ func Setup( keys gomatrixserverlib.JSONVerifier, federation *gomatrixserverlib.FederationClient, userAPI userapi.UserInternalAPI, - stateAPI currentstateAPI.CurrentStateInternalAPI, keyAPI keyserverAPI.KeyInternalAPI, ) { v2keysmux := keyMux.PathPrefix("/v2").Subrouter() @@ -76,7 +74,7 @@ func Setup( func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { return Send( httpReq, request, gomatrixserverlib.TransactionID(vars["txnID"]), - cfg, rsAPI, eduAPI, keyAPI, stateAPI, keys, federation, + cfg, rsAPI, eduAPI, keyAPI, keys, federation, ) }, )).Methods(http.MethodPut, http.MethodOptions) @@ -84,7 +82,7 @@ func Setup( v1fedmux.Handle("/invite/{roomID}/{eventID}", httputil.MakeFedAPI( "federation_invite", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { - if currentstateAPI.IsServerBannedFromRoom(httpReq.Context(), stateAPI, vars["roomID"], request.Origin()) { + if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, JSON: jsonerror.Forbidden("Forbidden by server ACLs"), @@ -100,7 +98,7 @@ func Setup( v2fedmux.Handle("/invite/{roomID}/{eventID}", httputil.MakeFedAPI( "federation_invite", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { - if currentstateAPI.IsServerBannedFromRoom(httpReq.Context(), stateAPI, vars["roomID"], request.Origin()) { + if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, JSON: jsonerror.Forbidden("Forbidden by server ACLs"), @@ -140,7 +138,7 @@ func Setup( v1fedmux.Handle("/state/{roomID}", httputil.MakeFedAPI( "federation_get_state", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { - if currentstateAPI.IsServerBannedFromRoom(httpReq.Context(), stateAPI, vars["roomID"], request.Origin()) { + if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, JSON: jsonerror.Forbidden("Forbidden by server ACLs"), @@ -155,7 +153,7 @@ func Setup( v1fedmux.Handle("/state_ids/{roomID}", httputil.MakeFedAPI( "federation_get_state_ids", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { - if currentstateAPI.IsServerBannedFromRoom(httpReq.Context(), stateAPI, vars["roomID"], request.Origin()) { + if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, JSON: jsonerror.Forbidden("Forbidden by server ACLs"), @@ -170,7 +168,7 @@ func Setup( v1fedmux.Handle("/event_auth/{roomID}/{eventID}", httputil.MakeFedAPI( "federation_get_event_auth", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { - if currentstateAPI.IsServerBannedFromRoom(httpReq.Context(), stateAPI, vars["roomID"], request.Origin()) { + if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, JSON: jsonerror.Forbidden("Forbidden by server ACLs"), @@ -212,7 +210,7 @@ func Setup( v1fedmux.Handle("/make_join/{roomID}/{eventID}", httputil.MakeFedAPI( "federation_make_join", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { - if currentstateAPI.IsServerBannedFromRoom(httpReq.Context(), stateAPI, vars["roomID"], request.Origin()) { + if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, JSON: jsonerror.Forbidden("Forbidden by server ACLs"), @@ -243,7 +241,7 @@ func Setup( v1fedmux.Handle("/send_join/{roomID}/{eventID}", httputil.MakeFedAPI( "federation_send_join", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { - if currentstateAPI.IsServerBannedFromRoom(httpReq.Context(), stateAPI, vars["roomID"], request.Origin()) { + if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, JSON: jsonerror.Forbidden("Forbidden by server ACLs"), @@ -275,7 +273,7 @@ func Setup( v2fedmux.Handle("/send_join/{roomID}/{eventID}", httputil.MakeFedAPI( "federation_send_join", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { - if currentstateAPI.IsServerBannedFromRoom(httpReq.Context(), stateAPI, vars["roomID"], request.Origin()) { + if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, JSON: jsonerror.Forbidden("Forbidden by server ACLs"), @@ -292,7 +290,7 @@ func Setup( v1fedmux.Handle("/make_leave/{roomID}/{eventID}", httputil.MakeFedAPI( "federation_make_leave", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { - if currentstateAPI.IsServerBannedFromRoom(httpReq.Context(), stateAPI, vars["roomID"], request.Origin()) { + if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, JSON: jsonerror.Forbidden("Forbidden by server ACLs"), @@ -309,7 +307,7 @@ func Setup( v1fedmux.Handle("/send_leave/{roomID}/{eventID}", httputil.MakeFedAPI( "federation_send_leave", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { - if currentstateAPI.IsServerBannedFromRoom(httpReq.Context(), stateAPI, vars["roomID"], request.Origin()) { + if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, JSON: jsonerror.Forbidden("Forbidden by server ACLs"), @@ -341,7 +339,7 @@ func Setup( v2fedmux.Handle("/send_leave/{roomID}/{eventID}", httputil.MakeFedAPI( "federation_send_leave", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { - if currentstateAPI.IsServerBannedFromRoom(httpReq.Context(), stateAPI, vars["roomID"], request.Origin()) { + if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, JSON: jsonerror.Forbidden("Forbidden by server ACLs"), @@ -365,7 +363,7 @@ func Setup( v1fedmux.Handle("/get_missing_events/{roomID}", httputil.MakeFedAPI( "federation_get_missing_events", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { - if currentstateAPI.IsServerBannedFromRoom(httpReq.Context(), stateAPI, vars["roomID"], request.Origin()) { + if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, JSON: jsonerror.Forbidden("Forbidden by server ACLs"), @@ -378,7 +376,7 @@ func Setup( v1fedmux.Handle("/backfill/{roomID}", httputil.MakeFedAPI( "federation_backfill", cfg.Matrix.ServerName, keys, wakeup, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { - if currentstateAPI.IsServerBannedFromRoom(httpReq.Context(), stateAPI, vars["roomID"], request.Origin()) { + if roomserverAPI.IsServerBannedFromRoom(httpReq.Context(), rsAPI, vars["roomID"], request.Origin()) { return util.JSONResponse{ Code: http.StatusForbidden, JSON: jsonerror.Forbidden("Forbidden by server ACLs"), @@ -390,7 +388,7 @@ func Setup( v1fedmux.Handle("/publicRooms", httputil.MakeExternalAPI("federation_public_rooms", func(req *http.Request) util.JSONResponse { - return GetPostPublicRooms(req, rsAPI, stateAPI) + return GetPostPublicRooms(req, rsAPI) }), ).Methods(http.MethodGet) diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 570062adc..9def7c3c3 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -19,9 +19,9 @@ import ( "encoding/json" "fmt" "net/http" + "time" "github.com/matrix-org/dendrite/clientapi/jsonerror" - currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api" eduserverAPI "github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/internal/config" keyapi "github.com/matrix-org/dendrite/keyserver/api" @@ -40,15 +40,12 @@ func Send( rsAPI api.RoomserverInternalAPI, eduAPI eduserverAPI.EDUServerInputAPI, keyAPI keyapi.KeyInternalAPI, - stateAPI currentstateAPI.CurrentStateInternalAPI, keys gomatrixserverlib.JSONVerifier, federation *gomatrixserverlib.FederationClient, ) util.JSONResponse { t := txnReq{ - context: httpReq.Context(), rsAPI: rsAPI, eduAPI: eduAPI, - stateAPI: stateAPI, keys: keys, federation: federation, haveEvents: make(map[string]*gomatrixserverlib.HeaderedEvent), @@ -85,7 +82,7 @@ func Send( util.GetLogger(httpReq.Context()).Infof("Received transaction %q containing %d PDUs, %d EDUs", txnID, len(t.PDUs), len(t.EDUs)) - resp, jsonErr := t.processTransaction() + resp, jsonErr := t.processTransaction(httpReq.Context()) if jsonErr != nil { util.GetLogger(httpReq.Context()).WithField("jsonErr", jsonErr).Error("t.processTransaction failed") return *jsonErr @@ -103,11 +100,9 @@ func Send( type txnReq struct { gomatrixserverlib.Transaction - context context.Context rsAPI api.RoomserverInternalAPI eduAPI eduserverAPI.EDUServerInputAPI keyAPI keyapi.KeyInternalAPI - stateAPI currentstateAPI.CurrentStateInternalAPI keys gomatrixserverlib.JSONVerifier federation txnFederationClient // local cache of events for auth checks, etc - this may include events @@ -128,7 +123,7 @@ type txnFederationClient interface { roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) } -func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, *util.JSONResponse) { +func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.RespSend, *util.JSONResponse) { results := make(map[string]gomatrixserverlib.PDUResult) pdus := []gomatrixserverlib.HeaderedEvent{} @@ -137,15 +132,15 @@ func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, *util.JSONRe RoomID string `json:"room_id"` } if err := json.Unmarshal(pdu, &header); err != nil { - util.GetLogger(t.context).WithError(err).Warn("Transaction: Failed to extract room ID from event") + util.GetLogger(ctx).WithError(err).Warn("Transaction: Failed to extract room ID from event") // We don't know the event ID at this point so we can't return the // failure in the PDU results continue } verReq := api.QueryRoomVersionForRoomRequest{RoomID: header.RoomID} verRes := api.QueryRoomVersionForRoomResponse{} - if err := t.rsAPI.QueryRoomVersionForRoom(t.context, &verReq, &verRes); err != nil { - util.GetLogger(t.context).WithError(err).Warn("Transaction: Failed to query room version for room", verReq.RoomID) + if err := t.rsAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil { + util.GetLogger(ctx).WithError(err).Warn("Transaction: Failed to query room version for room", verReq.RoomID) // We don't know the event ID at this point so we can't return the // failure in the PDU results continue @@ -165,17 +160,17 @@ func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, *util.JSONRe JSON: jsonerror.BadJSON("PDU contains bad JSON"), } } - util.GetLogger(t.context).WithError(err).Warnf("Transaction: Failed to parse event JSON of event %s", string(pdu)) + util.GetLogger(ctx).WithError(err).Warnf("Transaction: Failed to parse event JSON of event %s", string(pdu)) continue } - if currentstateAPI.IsServerBannedFromRoom(t.context, t.stateAPI, event.RoomID(), t.Origin) { + if api.IsServerBannedFromRoom(ctx, t.rsAPI, event.RoomID(), t.Origin) { results[event.EventID()] = gomatrixserverlib.PDUResult{ Error: "Forbidden by server ACLs", } continue } - if err = gomatrixserverlib.VerifyAllEventSignatures(t.context, []gomatrixserverlib.Event{event}, t.keys); err != nil { - util.GetLogger(t.context).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID()) + if err = gomatrixserverlib.VerifyAllEventSignatures(ctx, []gomatrixserverlib.Event{event}, t.keys); err != nil { + util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID()) results[event.EventID()] = gomatrixserverlib.PDUResult{ Error: err.Error(), } @@ -186,7 +181,7 @@ func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, *util.JSONRe // Process the events. for _, e := range pdus { - if err := t.processEvent(e.Unwrap(), true); err != nil { + if err := t.processEvent(ctx, e.Unwrap(), true); err != nil { // If the error is due to the event itself being bad then we skip // it and move onto the next event. We report an error so that the // sender knows that we have skipped processing it. @@ -205,7 +200,7 @@ func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, *util.JSONRe if isProcessingErrorFatal(err) { // Any other error should be the result of a temporary error in // our server so we should bail processing the transaction entirely. - util.GetLogger(t.context).Warnf("Processing %s failed fatally: %s", e.EventID(), err) + util.GetLogger(ctx).Warnf("Processing %s failed fatally: %s", e.EventID(), err) jsonErr := util.ErrorResponse(err) return nil, &jsonErr } else { @@ -215,7 +210,7 @@ func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, *util.JSONRe if rejected { errMsg = "" } - util.GetLogger(t.context).WithError(err).WithField("event_id", e.EventID()).WithField("rejected", rejected).Warn( + util.GetLogger(ctx).WithError(err).WithField("event_id", e.EventID()).WithField("rejected", rejected).Warn( "Failed to process incoming federation event, skipping", ) results[e.EventID()] = gomatrixserverlib.PDUResult{ @@ -227,9 +222,9 @@ func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, *util.JSONRe } } - t.processEDUs(t.EDUs) + t.processEDUs(ctx) if c := len(results); c > 0 { - util.GetLogger(t.context).Infof("Processed %d PDUs from transaction %q", c, t.TransactionID) + util.GetLogger(ctx).Infof("Processed %d PDUs from transaction %q", c, t.TransactionID) } return &gomatrixserverlib.RespSend{PDUs: results}, nil } @@ -288,8 +283,9 @@ func (t *txnReq) haveEventIDs() map[string]bool { return result } -func (t *txnReq) processEDUs(edus []gomatrixserverlib.EDU) { - for _, e := range edus { +// nolint:gocyclo +func (t *txnReq) processEDUs(ctx context.Context) { + for _, e := range t.EDUs { switch e.Type { case gomatrixserverlib.MTyping: // https://matrix.org/docs/spec/server_server/latest#typing-notifications @@ -299,24 +295,24 @@ func (t *txnReq) processEDUs(edus []gomatrixserverlib.EDU) { Typing bool `json:"typing"` } if err := json.Unmarshal(e.Content, &typingPayload); err != nil { - util.GetLogger(t.context).WithError(err).Error("Failed to unmarshal typing event") + util.GetLogger(ctx).WithError(err).Error("Failed to unmarshal typing event") continue } - if err := eduserverAPI.SendTyping(t.context, t.eduAPI, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil { - util.GetLogger(t.context).WithError(err).Error("Failed to send typing event to edu server") + if err := eduserverAPI.SendTyping(ctx, t.eduAPI, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil { + util.GetLogger(ctx).WithError(err).Error("Failed to send typing event to edu server") } case gomatrixserverlib.MDirectToDevice: // https://matrix.org/docs/spec/server_server/r0.1.3#m-direct-to-device-schema var directPayload gomatrixserverlib.ToDeviceMessage if err := json.Unmarshal(e.Content, &directPayload); err != nil { - util.GetLogger(t.context).WithError(err).Error("Failed to unmarshal send-to-device events") + util.GetLogger(ctx).WithError(err).Error("Failed to unmarshal send-to-device events") continue } for userID, byUser := range directPayload.Messages { for deviceID, message := range byUser { // TODO: check that the user and the device actually exist here - if err := eduserverAPI.SendToDevice(t.context, t.eduAPI, directPayload.Sender, userID, deviceID, directPayload.Type, message); err != nil { - util.GetLogger(t.context).WithError(err).WithFields(logrus.Fields{ + if err := eduserverAPI.SendToDevice(ctx, t.eduAPI, directPayload.Sender, userID, deviceID, directPayload.Type, message); err != nil { + util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{ "sender": directPayload.Sender, "user_id": userID, "device_id": deviceID, @@ -325,17 +321,17 @@ func (t *txnReq) processEDUs(edus []gomatrixserverlib.EDU) { } } case gomatrixserverlib.MDeviceListUpdate: - t.processDeviceListUpdate(e) + t.processDeviceListUpdate(ctx, e) default: - util.GetLogger(t.context).WithField("type", e.Type).Debug("Unhandled EDU") + util.GetLogger(ctx).WithField("type", e.Type).Debug("Unhandled EDU") } } } -func (t *txnReq) processDeviceListUpdate(e gomatrixserverlib.EDU) { +func (t *txnReq) processDeviceListUpdate(ctx context.Context, e gomatrixserverlib.EDU) { var payload gomatrixserverlib.DeviceListUpdateEvent if err := json.Unmarshal(e.Content, &payload); err != nil { - util.GetLogger(t.context).WithError(err).Error("Failed to unmarshal device list update event") + util.GetLogger(ctx).WithError(err).Error("Failed to unmarshal device list update event") return } var inputRes keyapi.InputDeviceListUpdateResponse @@ -343,11 +339,11 @@ func (t *txnReq) processDeviceListUpdate(e gomatrixserverlib.EDU) { Event: payload, }, &inputRes) if inputRes.Error != nil { - util.GetLogger(t.context).WithError(inputRes.Error).WithField("user_id", payload.UserID).Error("failed to InputDeviceListUpdate") + util.GetLogger(ctx).WithError(inputRes.Error).WithField("user_id", payload.UserID).Error("failed to InputDeviceListUpdate") } } -func (t *txnReq) processEvent(e gomatrixserverlib.Event, isInboundTxn bool) error { +func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event, isInboundTxn bool) error { prevEventIDs := e.PrevEventIDs() // Fetch the state needed to authenticate the event. @@ -358,7 +354,7 @@ func (t *txnReq) processEvent(e gomatrixserverlib.Event, isInboundTxn bool) erro StateToFetch: needed.Tuples(), } var stateResp api.QueryStateAfterEventsResponse - if err := t.rsAPI.QueryStateAfterEvents(t.context, &stateReq, &stateResp); err != nil { + if err := t.rsAPI.QueryStateAfterEvents(ctx, &stateReq, &stateResp); err != nil { return err } @@ -373,7 +369,7 @@ func (t *txnReq) processEvent(e gomatrixserverlib.Event, isInboundTxn bool) erro } if !stateResp.PrevEventsExist { - return t.processEventWithMissingState(e, stateResp.RoomVersion, isInboundTxn) + return t.processEventWithMissingState(ctx, e, stateResp.RoomVersion, isInboundTxn) } // Check that the event is allowed by the state at the event. @@ -383,7 +379,8 @@ func (t *txnReq) processEvent(e gomatrixserverlib.Event, isInboundTxn bool) erro // pass the event to the roomserver return api.SendEvents( - t.context, t.rsAPI, + context.Background(), + t.rsAPI, []gomatrixserverlib.HeaderedEvent{ e.Headered(stateResp.RoomVersion), }, @@ -403,7 +400,12 @@ func checkAllowedByState(e gomatrixserverlib.Event, stateEvents []gomatrixserver return gomatrixserverlib.Allowed(e, &authUsingState) } -func (t *txnReq) processEventWithMissingState(e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, isInboundTxn bool) error { +func (t *txnReq) processEventWithMissingState(ctx context.Context, e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, isInboundTxn bool) error { + // Do this with a fresh context, so that we keep working even if the + // original request times out. With any luck, by the time the remote + // side retries, we'll have fetched the missing state. + gmectx, cancel := context.WithTimeout(context.Background(), time.Minute*5) + defer cancel() // We are missing the previous events for this events. // This means that there is a gap in our view of the history of the // room. There two ways that we can handle such a gap: @@ -424,7 +426,7 @@ func (t *txnReq) processEventWithMissingState(e gomatrixserverlib.Event, roomVer // - fill in the gap completely then process event `e` returning no backwards extremity // - fail to fill in the gap and tell us to terminate the transaction err=not nil // - fail to fill in the gap and tell us to fetch state at the new backwards extremity, and to not terminate the transaction - backwardsExtremity, err := t.getMissingEvents(e, roomVersion, isInboundTxn) + backwardsExtremity, err := t.getMissingEvents(gmectx, e, roomVersion, isInboundTxn) if err != nil { return err } @@ -441,16 +443,16 @@ func (t *txnReq) processEventWithMissingState(e gomatrixserverlib.Event, roomVer needed := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{*backwardsExtremity}).Tuples() for _, prevEventID := range backwardsExtremity.PrevEventIDs() { var prevState *gomatrixserverlib.RespState - prevState, err = t.lookupStateAfterEvent(roomVersion, backwardsExtremity.RoomID(), prevEventID, needed) + prevState, err = t.lookupStateAfterEvent(gmectx, roomVersion, backwardsExtremity.RoomID(), prevEventID, needed) if err != nil { - util.GetLogger(t.context).WithError(err).Errorf("Failed to lookup state after prev_event: %s", prevEventID) + util.GetLogger(ctx).WithError(err).Errorf("Failed to lookup state after prev_event: %s", prevEventID) return err } states = append(states, prevState) } - resolvedState, err := t.resolveStatesAndCheck(roomVersion, states, backwardsExtremity) + resolvedState, err := t.resolveStatesAndCheck(gmectx, roomVersion, states, backwardsExtremity) if err != nil { - util.GetLogger(t.context).WithError(err).Errorf("Failed to resolve state conflicts for event %s", backwardsExtremity.EventID()) + util.GetLogger(ctx).WithError(err).Errorf("Failed to resolve state conflicts for event %s", backwardsExtremity.EventID()) return err } @@ -461,21 +463,26 @@ func (t *txnReq) processEventWithMissingState(e gomatrixserverlib.Event, roomVer // lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event) // added into the mix. -func (t *txnReq) lookupStateAfterEvent(roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, needed []gomatrixserverlib.StateKeyTuple) (*gomatrixserverlib.RespState, error) { +func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, needed []gomatrixserverlib.StateKeyTuple) (*gomatrixserverlib.RespState, error) { // try doing all this locally before we resort to querying federation - respState := t.lookupStateAfterEventLocally(roomID, eventID, needed) + respState := t.lookupStateAfterEventLocally(ctx, roomID, eventID, needed) if respState != nil { return respState, nil } - respState, err := t.lookupStateBeforeEvent(roomVersion, roomID, eventID) + respState, err := t.lookupStateBeforeEvent(ctx, roomVersion, roomID, eventID) if err != nil { return nil, err } // fetch the event we're missing and add it to the pile - h, err := t.lookupEvent(roomVersion, eventID, false) - if err != nil { + h, err := t.lookupEvent(ctx, roomVersion, eventID, false) + switch err.(type) { + case verifySigError: + return respState, nil + case nil: + // do nothing + default: return nil, err } t.haveEvents[h.EventID()] = h @@ -497,15 +504,15 @@ func (t *txnReq) lookupStateAfterEvent(roomVersion gomatrixserverlib.RoomVersion return respState, nil } -func (t *txnReq) lookupStateAfterEventLocally(roomID, eventID string, needed []gomatrixserverlib.StateKeyTuple) *gomatrixserverlib.RespState { +func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, eventID string, needed []gomatrixserverlib.StateKeyTuple) *gomatrixserverlib.RespState { var res api.QueryStateAfterEventsResponse - err := t.rsAPI.QueryStateAfterEvents(t.context, &api.QueryStateAfterEventsRequest{ + err := t.rsAPI.QueryStateAfterEvents(ctx, &api.QueryStateAfterEventsRequest{ RoomID: roomID, PrevEventIDs: []string{eventID}, StateToFetch: needed, }, &res) if err != nil || !res.PrevEventsExist { - util.GetLogger(t.context).WithError(err).Warnf("failed to query state after %s locally", eventID) + util.GetLogger(ctx).WithError(err).Warnf("failed to query state after %s locally", eventID) return nil } for i, ev := range res.StateEvents { @@ -532,9 +539,9 @@ func (t *txnReq) lookupStateAfterEventLocally(roomID, eventID string, needed []g queryReq := api.QueryEventsByIDRequest{ EventIDs: missingEventList, } - util.GetLogger(t.context).Infof("Fetching missing auth events: %v", missingEventList) + util.GetLogger(ctx).Infof("Fetching missing auth events: %v", missingEventList) var queryRes api.QueryEventsByIDResponse - if err = t.rsAPI.QueryEventsByID(t.context, &queryReq, &queryRes); err != nil { + if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { return nil } for i := range queryRes.Events { @@ -552,22 +559,22 @@ func (t *txnReq) lookupStateAfterEventLocally(roomID, eventID string, needed []g // lookuptStateBeforeEvent returns the room state before the event e, which is just /state_ids and/or /state depending on what // the server supports. -func (t *txnReq) lookupStateBeforeEvent(roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) ( +func (t *txnReq) lookupStateBeforeEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) ( respState *gomatrixserverlib.RespState, err error) { - util.GetLogger(t.context).Infof("lookupStateBeforeEvent %s", eventID) + util.GetLogger(ctx).Infof("lookupStateBeforeEvent %s", eventID) // Attempt to fetch the missing state using /state_ids and /events - respState, err = t.lookupMissingStateViaStateIDs(roomID, eventID, roomVersion) + respState, err = t.lookupMissingStateViaStateIDs(ctx, roomID, eventID, roomVersion) if err != nil { // Fallback to /state - util.GetLogger(t.context).WithError(err).Warn("lookupStateBeforeEvent failed to /state_ids, falling back to /state") - respState, err = t.lookupMissingStateViaState(roomID, eventID, roomVersion) + util.GetLogger(ctx).WithError(err).Warn("lookupStateBeforeEvent failed to /state_ids, falling back to /state") + respState, err = t.lookupMissingStateViaState(ctx, roomID, eventID, roomVersion) } return } -func (t *txnReq) resolveStatesAndCheck(roomVersion gomatrixserverlib.RoomVersion, states []*gomatrixserverlib.RespState, backwardsExtremity *gomatrixserverlib.Event) (*gomatrixserverlib.RespState, error) { +func (t *txnReq) resolveStatesAndCheck(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, states []*gomatrixserverlib.RespState, backwardsExtremity *gomatrixserverlib.Event) (*gomatrixserverlib.RespState, error) { var authEventList []gomatrixserverlib.Event var stateEventList []gomatrixserverlib.Event for _, state := range states { @@ -583,11 +590,19 @@ retryAllowedState: if err = checkAllowedByState(*backwardsExtremity, resolvedStateEvents); err != nil { switch missing := err.(type) { case gomatrixserverlib.MissingAuthEventError: - h, err2 := t.lookupEvent(roomVersion, missing.AuthEventID, true) - if err2 != nil { + h, err2 := t.lookupEvent(ctx, roomVersion, missing.AuthEventID, true) + switch err2.(type) { + case verifySigError: + return &gomatrixserverlib.RespState{ + AuthEvents: authEventList, + StateEvents: resolvedStateEvents, + }, nil + case nil: + // do nothing + default: return nil, fmt.Errorf("missing auth event %s and failed to look it up: %w", missing.AuthEventID, err2) } - util.GetLogger(t.context).Infof("fetched event %s", missing.AuthEventID) + util.GetLogger(ctx).Infof("fetched event %s", missing.AuthEventID) resolvedStateEvents = append(resolvedStateEvents, h.Unwrap()) goto retryAllowedState default: @@ -604,12 +619,12 @@ retryAllowedState: // begin from. Returns an error only if we should terminate the transaction which initiated /get_missing_events // This function recursively calls txnReq.processEvent with the missing events, which will be processed before this function returns. // This means that we may recursively call this function, as we spider back up prev_events to the min depth. -func (t *txnReq) getMissingEvents(e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, isInboundTxn bool) (backwardsExtremity *gomatrixserverlib.Event, err error) { +func (t *txnReq) getMissingEvents(ctx context.Context, e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, isInboundTxn bool) (backwardsExtremity *gomatrixserverlib.Event, err error) { if !isInboundTxn { // we've recursed here, so just take a state snapshot please! return &e, nil } - logger := util.GetLogger(t.context).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) + logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) needed := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{e}) // query latest events (our trusted forward extremities) req := api.QueryLatestEventsAndStateRequest{ @@ -617,7 +632,7 @@ func (t *txnReq) getMissingEvents(e gomatrixserverlib.Event, roomVersion gomatri StateToFetch: needed.Tuples(), } var res api.QueryLatestEventsAndStateResponse - if err = t.rsAPI.QueryLatestEventsAndState(t.context, &req, &res); err != nil { + if err = t.rsAPI.QueryLatestEventsAndState(ctx, &req, &res); err != nil { logger.WithError(err).Warn("Failed to query latest events") return &e, nil } @@ -630,7 +645,7 @@ func (t *txnReq) getMissingEvents(e gomatrixserverlib.Event, roomVersion gomatri if minDepth < 0 { minDepth = 0 } - missingResp, err := t.federation.LookupMissingEvents(t.context, t.Origin, e.RoomID(), gomatrixserverlib.MissingEvents{ + missingResp, err := t.federation.LookupMissingEvents(ctx, t.Origin, e.RoomID(), gomatrixserverlib.MissingEvents{ Limit: 20, // synapse uses the min depth they've ever seen in that room MinDepth: minDepth, @@ -689,7 +704,7 @@ Event: } // process the missing events then the event which started this whole thing for _, ev := range append(newEvents, e) { - err := t.processEvent(ev, false) + err := t.processEvent(ctx, ev, false) if err != nil { return nil, err } @@ -699,24 +714,24 @@ Event: return nil, nil } -func (t *txnReq) lookupMissingStateViaState(roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( +func (t *txnReq) lookupMissingStateViaState(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( respState *gomatrixserverlib.RespState, err error) { - state, err := t.federation.LookupState(t.context, t.Origin, roomID, eventID, roomVersion) + state, err := t.federation.LookupState(ctx, t.Origin, roomID, eventID, roomVersion) if err != nil { return nil, err } // Check that the returned state is valid. - if err := state.Check(t.context, t.keys, nil); err != nil { + if err := state.Check(ctx, t.keys, nil); err != nil { return nil, err } return &state, nil } -func (t *txnReq) lookupMissingStateViaStateIDs(roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( +func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( *gomatrixserverlib.RespState, error) { - util.GetLogger(t.context).Infof("lookupMissingStateViaStateIDs %s", eventID) + util.GetLogger(ctx).Infof("lookupMissingStateViaStateIDs %s", eventID) // fetch the state event IDs at the time of the event - stateIDs, err := t.federation.LookupStateIDs(t.context, t.Origin, roomID, eventID) + stateIDs, err := t.federation.LookupStateIDs(ctx, t.Origin, roomID, eventID) if err != nil { return nil, err } @@ -738,7 +753,7 @@ func (t *txnReq) lookupMissingStateViaStateIDs(roomID, eventID string, roomVersi EventIDs: missingEventList, } var queryRes api.QueryEventsByIDResponse - if err = t.rsAPI.QueryEventsByID(t.context, &queryReq, &queryRes); err != nil { + if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { return nil, err } for i := range queryRes.Events { @@ -749,7 +764,7 @@ func (t *txnReq) lookupMissingStateViaStateIDs(roomID, eventID string, roomVersi } } - util.GetLogger(t.context).WithFields(logrus.Fields{ + util.GetLogger(ctx).WithFields(logrus.Fields{ "missing": len(missing), "event_id": eventID, "room_id": roomID, @@ -759,8 +774,13 @@ func (t *txnReq) lookupMissingStateViaStateIDs(roomID, eventID string, roomVersi for missingEventID := range missing { var h *gomatrixserverlib.HeaderedEvent - h, err = t.lookupEvent(roomVersion, missingEventID, false) - if err != nil { + h, err = t.lookupEvent(ctx, roomVersion, missingEventID, false) + switch err.(type) { + case verifySigError: + continue + case nil: + // do nothing + default: return nil, err } t.haveEvents[h.EventID()] = h @@ -797,33 +817,33 @@ func (t *txnReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib.RespStat return &respState, nil } -func (t *txnReq) lookupEvent(roomVersion gomatrixserverlib.RoomVersion, missingEventID string, localFirst bool) (*gomatrixserverlib.HeaderedEvent, error) { +func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, missingEventID string, localFirst bool) (*gomatrixserverlib.HeaderedEvent, error) { if localFirst { // fetch from the roomserver queryReq := api.QueryEventsByIDRequest{ EventIDs: []string{missingEventID}, } var queryRes api.QueryEventsByIDResponse - if err := t.rsAPI.QueryEventsByID(t.context, &queryReq, &queryRes); err != nil { - util.GetLogger(t.context).Warnf("Failed to query roomserver for missing event %s: %s - falling back to remote", missingEventID, err) + if err := t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { + util.GetLogger(ctx).Warnf("Failed to query roomserver for missing event %s: %s - falling back to remote", missingEventID, err) } else if len(queryRes.Events) == 1 { return &queryRes.Events[0], nil } } - txn, err := t.federation.GetEvent(t.context, t.Origin, missingEventID) + txn, err := t.federation.GetEvent(ctx, t.Origin, missingEventID) if err != nil || len(txn.PDUs) == 0 { - util.GetLogger(t.context).WithError(err).WithField("event_id", missingEventID).Warn("failed to get missing /event for event ID") + util.GetLogger(ctx).WithError(err).WithField("event_id", missingEventID).Warn("failed to get missing /event for event ID") return nil, err } pdu := txn.PDUs[0] var event gomatrixserverlib.Event event, err = gomatrixserverlib.NewEventFromUntrustedJSON(pdu, roomVersion) if err != nil { - util.GetLogger(t.context).WithError(err).Warnf("Transaction: Failed to parse event JSON of event %q", event.EventID()) + util.GetLogger(ctx).WithError(err).Warnf("Transaction: Failed to parse event JSON of event %q", event.EventID()) return nil, unmarshalError{err} } - if err = gomatrixserverlib.VerifyAllEventSignatures(t.context, []gomatrixserverlib.Event{event}, t.keys); err != nil { - util.GetLogger(t.context).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID()) + if err = gomatrixserverlib.VerifyAllEventSignatures(ctx, []gomatrixserverlib.Event{event}, t.keys); err != nil { + util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID()) return nil, verifySigError{event.EventID(), err} } h := event.Headered(roomVersion) diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index 6dc8621b2..a714d07ee 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -8,7 +8,6 @@ import ( "testing" "time" - currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api" eduAPI "github.com/matrix-org/dendrite/eduserver/api" fsAPI "github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/internal/test" @@ -113,6 +112,13 @@ func (t *testRoomserverAPI) PerformJoin( ) { } +func (t *testRoomserverAPI) PerformPeek( + ctx context.Context, + req *api.PerformPeekRequest, + res *api.PerformPeekResponse, +) { +} + func (t *testRoomserverAPI) PerformPublish( ctx context.Context, req *api.PerformPublishRequest, @@ -320,33 +326,6 @@ func (t *testRoomserverAPI) QueryServerBannedFromRoom(ctx context.Context, req * return nil } -type testStateAPI struct { -} - -func (t *testStateAPI) QueryCurrentState(ctx context.Context, req *currentstateAPI.QueryCurrentStateRequest, res *currentstateAPI.QueryCurrentStateResponse) error { - return nil -} - -func (t *testStateAPI) QueryRoomsForUser(ctx context.Context, req *currentstateAPI.QueryRoomsForUserRequest, res *currentstateAPI.QueryRoomsForUserResponse) error { - return fmt.Errorf("not implemented") -} - -func (t *testStateAPI) QueryBulkStateContent(ctx context.Context, req *currentstateAPI.QueryBulkStateContentRequest, res *currentstateAPI.QueryBulkStateContentResponse) error { - return fmt.Errorf("not implemented") -} - -func (t *testStateAPI) QuerySharedUsers(ctx context.Context, req *currentstateAPI.QuerySharedUsersRequest, res *currentstateAPI.QuerySharedUsersResponse) error { - return fmt.Errorf("not implemented") -} - -func (t *testStateAPI) QueryKnownUsers(ctx context.Context, req *currentstateAPI.QueryKnownUsersRequest, res *currentstateAPI.QueryKnownUsersResponse) error { - return fmt.Errorf("not implemented") -} - -func (t *testStateAPI) QueryServerBannedFromRoom(ctx context.Context, req *currentstateAPI.QueryServerBannedFromRoomRequest, res *currentstateAPI.QueryServerBannedFromRoomResponse) error { - return nil -} - type txnFedClient struct { state map[string]gomatrixserverlib.RespState // event_id to response stateIDs map[string]gomatrixserverlib.RespStateIDs // event_id to response @@ -391,12 +370,10 @@ func (c *txnFedClient) LookupMissingEvents(ctx context.Context, s gomatrixserver return c.getMissingEvents(missing) } -func mustCreateTransaction(rsAPI api.RoomserverInternalAPI, stateAPI currentstateAPI.CurrentStateInternalAPI, fedClient txnFederationClient, pdus []json.RawMessage) *txnReq { +func mustCreateTransaction(rsAPI api.RoomserverInternalAPI, fedClient txnFederationClient, pdus []json.RawMessage) *txnReq { t := &txnReq{ - context: context.Background(), rsAPI: rsAPI, eduAPI: &testEDUProducer{}, - stateAPI: stateAPI, keys: &test.NopJSONVerifier{}, federation: fedClient, haveEvents: make(map[string]*gomatrixserverlib.HeaderedEvent), @@ -410,7 +387,7 @@ func mustCreateTransaction(rsAPI api.RoomserverInternalAPI, stateAPI currentstat } func mustProcessTransaction(t *testing.T, txn *txnReq, pdusWithErrors []string) { - res, err := txn.processTransaction() + res, err := txn.processTransaction(context.Background()) if err != nil { t.Errorf("txn.processTransaction returned an error: %v", err) return @@ -476,11 +453,10 @@ func TestBasicTransaction(t *testing.T) { } }, } - stateAPI := &testStateAPI{} pdus := []json.RawMessage{ testData[len(testData)-1], // a message event } - txn := mustCreateTransaction(rsAPI, stateAPI, &txnFedClient{}, pdus) + txn := mustCreateTransaction(rsAPI, &txnFedClient{}, pdus) mustProcessTransaction(t, txn, nil) assertInputRoomEvents(t, rsAPI.inputRoomEvents, []gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) } @@ -499,11 +475,10 @@ func TestTransactionFailAuthChecks(t *testing.T) { } }, } - stateAPI := &testStateAPI{} pdus := []json.RawMessage{ testData[len(testData)-1], // a message event } - txn := mustCreateTransaction(rsAPI, stateAPI, &txnFedClient{}, pdus) + txn := mustCreateTransaction(rsAPI, &txnFedClient{}, pdus) mustProcessTransaction(t, txn, []string{ // expect the event to have an error testEvents[len(testEvents)-1].EventID(), @@ -558,8 +533,6 @@ func TestTransactionFetchMissingPrevEvents(t *testing.T) { }, } - stateAPI := &testStateAPI{} - cli := &txnFedClient{ getMissingEvents: func(missing gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error) { if !reflect.DeepEqual(missing.EarliestEvents, []string{haveEvent.EventID()}) { @@ -579,7 +552,7 @@ func TestTransactionFetchMissingPrevEvents(t *testing.T) { pdus := []json.RawMessage{ inputEvent.JSON(), } - txn := mustCreateTransaction(rsAPI, stateAPI, cli, pdus) + txn := mustCreateTransaction(rsAPI, cli, pdus) mustProcessTransaction(t, txn, nil) assertInputRoomEvents(t, rsAPI.inputRoomEvents, []gomatrixserverlib.HeaderedEvent{prevEvent, inputEvent}) } @@ -729,12 +702,10 @@ func TestTransactionFetchMissingStateByStateIDs(t *testing.T) { }, } - stateAPI := &testStateAPI{} - pdus := []json.RawMessage{ eventD.JSON(), } - txn := mustCreateTransaction(rsAPI, stateAPI, cli, pdus) + txn := mustCreateTransaction(rsAPI, cli, pdus) mustProcessTransaction(t, txn, nil) assertInputRoomEvents(t, rsAPI.inputRoomEvents, []gomatrixserverlib.HeaderedEvent{eventB, eventC, eventD}) } diff --git a/federationsender/api/api.go b/federationsender/api/api.go index 1dec0aa86..343d8301b 100644 --- a/federationsender/api/api.go +++ b/federationsender/api/api.go @@ -30,7 +30,7 @@ type FederationClientError struct { } func (e *FederationClientError) Error() string { - return fmt.Sprintf("%s - (retry_after=%d, blacklisted=%v)", e.Err, e.RetryAfter, e.Blacklisted) + return fmt.Sprintf("%s - (retry_after=%s, blacklisted=%v)", e.Err, e.RetryAfter.String(), e.Blacklisted) } // FederationSenderInternalAPI is used to query information from the federation sender. diff --git a/federationsender/consumers/keychange.go b/federationsender/consumers/keychange.go index 4f206f5fb..28244e923 100644 --- a/federationsender/consumers/keychange.go +++ b/federationsender/consumers/keychange.go @@ -20,12 +20,12 @@ import ( "fmt" "github.com/Shopify/sarama" - stateapi "github.com/matrix-org/dendrite/currentstateserver/api" "github.com/matrix-org/dendrite/federationsender/queue" "github.com/matrix-org/dendrite/federationsender/storage" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/keyserver/api" + roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" ) @@ -36,7 +36,7 @@ type KeyChangeConsumer struct { db storage.Database queues *queue.OutgoingQueues serverName gomatrixserverlib.ServerName - stateAPI stateapi.CurrentStateInternalAPI + rsAPI roomserverAPI.RoomserverInternalAPI } // NewKeyChangeConsumer creates a new KeyChangeConsumer. Call Start() to begin consuming from key servers. @@ -45,7 +45,7 @@ func NewKeyChangeConsumer( kafkaConsumer sarama.Consumer, queues *queue.OutgoingQueues, store storage.Database, - stateAPI stateapi.CurrentStateInternalAPI, + rsAPI roomserverAPI.RoomserverInternalAPI, ) *KeyChangeConsumer { c := &KeyChangeConsumer{ consumer: &internal.ContinualConsumer{ @@ -57,7 +57,7 @@ func NewKeyChangeConsumer( queues: queues, db: store, serverName: cfg.Matrix.ServerName, - stateAPI: stateAPI, + rsAPI: rsAPI, } c.consumer.ProcessMessage = c.onMessage @@ -92,8 +92,8 @@ func (t *KeyChangeConsumer) onMessage(msg *sarama.ConsumerMessage) error { return nil } - var queryRes stateapi.QueryRoomsForUserResponse - err = t.stateAPI.QueryRoomsForUser(context.Background(), &stateapi.QueryRoomsForUserRequest{ + var queryRes roomserverAPI.QueryRoomsForUserResponse + err = t.rsAPI.QueryRoomsForUser(context.Background(), &roomserverAPI.QueryRoomsForUserRequest{ UserID: m.UserID, WantMembership: "join", }, &queryRes) diff --git a/federationsender/federationsender.go b/federationsender/federationsender.go index 5794d40a2..2f1223284 100644 --- a/federationsender/federationsender.go +++ b/federationsender/federationsender.go @@ -16,7 +16,6 @@ package federationsender import ( "github.com/gorilla/mux" - stateapi "github.com/matrix-org/dendrite/currentstateserver/api" "github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/federationsender/consumers" "github.com/matrix-org/dendrite/federationsender/internal" @@ -42,7 +41,6 @@ func NewInternalAPI( base *setup.BaseDendrite, federation *gomatrixserverlib.FederationClient, rsAPI roomserverAPI.RoomserverInternalAPI, - stateAPI stateapi.CurrentStateInternalAPI, keyRing *gomatrixserverlib.KeyRing, ) api.FederationSenderInternalAPI { cfg := &base.Cfg.FederationSender @@ -59,7 +57,7 @@ func NewInternalAPI( queues := queue.NewOutgoingQueues( federationSenderDB, cfg.Matrix.ServerName, federation, - rsAPI, stateAPI, stats, + rsAPI, stats, &queue.SigningInfo{ KeyID: cfg.Matrix.KeyID, PrivateKey: cfg.Matrix.PrivateKey, @@ -82,7 +80,7 @@ func NewInternalAPI( logrus.WithError(err).Panic("failed to start typing server consumer") } keyConsumer := consumers.NewKeyChangeConsumer( - &base.Cfg.KeyServer, base.KafkaConsumer, queues, federationSenderDB, stateAPI, + &base.Cfg.KeyServer, base.KafkaConsumer, queues, federationSenderDB, rsAPI, ) if err := keyConsumer.Start(); err != nil { logrus.WithError(err).Panic("failed to start key server consumer") diff --git a/federationsender/internal/api.go b/federationsender/internal/api.go index 61663be31..2a70f7ed3 100644 --- a/federationsender/internal/api.go +++ b/federationsender/internal/api.go @@ -70,7 +70,10 @@ func failBlacklistableError(err error, stats *statistics.ServerStatistics) (unti if !ok { return stats.Failure() } - if mxerr.Code >= 500 && mxerr.Code < 600 { + if mxerr.Code == 401 { // invalid signature in X-Matrix header + return stats.Failure() + } + if mxerr.Code >= 500 && mxerr.Code < 600 { // internal server errors return stats.Failure() } return diff --git a/federationsender/internal/perform.go b/federationsender/internal/perform.go index 0a2ff824a..5d887beee 100644 --- a/federationsender/internal/perform.go +++ b/federationsender/internal/perform.go @@ -185,20 +185,21 @@ func (r *FederationSenderInternalAPI) performJoinUsingServer( // Check that the send_join response was valid. joinCtx := perform.JoinContext(r.federation, r.keyRing) - if err = joinCtx.CheckSendJoinResponse( - ctx, event, serverName, respSendJoin, - ); err != nil { + respState, err := joinCtx.CheckSendJoinResponse( + ctx, event, serverName, respMakeJoin, respSendJoin, + ) + if err != nil { return fmt.Errorf("joinCtx.CheckSendJoinResponse: %w", err) } // If we successfully performed a send_join above then the other // server now thinks we're a part of the room. Send the newly // returned state to the roomserver to update our local view. - respState := respSendJoin.ToRespState() if err = roomserverAPI.SendEventWithState( ctx, r.rsAPI, - &respState, - event.Headered(respMakeJoin.RoomVersion), nil, + respState, + event.Headered(respMakeJoin.RoomVersion), + nil, ); err != nil { return fmt.Errorf("r.producer.SendEventWithState: %w", err) } diff --git a/federationsender/internal/perform/join.go b/federationsender/internal/perform/join.go index 22840c90d..d03a57df6 100644 --- a/federationsender/internal/perform/join.go +++ b/federationsender/internal/perform/join.go @@ -43,7 +43,7 @@ func (r joinContext) CheckSendJoinResponse( event gomatrixserverlib.Event, server gomatrixserverlib.ServerName, respSendJoin gomatrixserverlib.RespSendJoin, -) error { +) (*gomatrixserverlib.RespState, error) { // A list of events that we have retried, if they were not included in // the auth events supplied in the send_join. retries := map[string][]gomatrixserverlib.Event{} @@ -110,8 +110,9 @@ func (r joinContext) CheckSendJoinResponse( // TODO: Can we expand Check here to return a list of missing auth // events rather than failing one at a time? - if err := respSendJoin.Check(ctx, r.keyRing, event, missingAuth); err != nil { - return fmt.Errorf("respSendJoin: %w", err) + rs, err := respSendJoin.Check(ctx, r.keyRing, event, missingAuth) + if err != nil { + return nil, fmt.Errorf("respSendJoin: %w", err) } - return nil + return rs, nil } diff --git a/federationsender/queue/queue.go b/federationsender/queue/queue.go index 6561251df..04cb57e70 100644 --- a/federationsender/queue/queue.go +++ b/federationsender/queue/queue.go @@ -20,8 +20,8 @@ import ( "encoding/json" "fmt" "sync" + "time" - stateapi "github.com/matrix-org/dendrite/currentstateserver/api" "github.com/matrix-org/dendrite/federationsender/statistics" "github.com/matrix-org/dendrite/federationsender/storage" "github.com/matrix-org/dendrite/roomserver/api" @@ -35,7 +35,6 @@ import ( type OutgoingQueues struct { db storage.Database rsAPI api.RoomserverInternalAPI - stateAPI stateapi.CurrentStateInternalAPI origin gomatrixserverlib.ServerName client *gomatrixserverlib.FederationClient statistics *statistics.Statistics @@ -50,14 +49,12 @@ func NewOutgoingQueues( origin gomatrixserverlib.ServerName, client *gomatrixserverlib.FederationClient, rsAPI api.RoomserverInternalAPI, - stateAPI stateapi.CurrentStateInternalAPI, statistics *statistics.Statistics, signing *SigningInfo, ) *OutgoingQueues { queues := &OutgoingQueues{ db: db, rsAPI: rsAPI, - stateAPI: stateAPI, origin: origin, client: client, statistics: statistics, @@ -65,26 +62,28 @@ func NewOutgoingQueues( queues: map[gomatrixserverlib.ServerName]*destinationQueue{}, } // Look up which servers we have pending items for and then rehydrate those queues. - serverNames := map[gomatrixserverlib.ServerName]struct{}{} - if names, err := db.GetPendingPDUServerNames(context.Background()); err == nil { - for _, serverName := range names { - serverNames[serverName] = struct{}{} + time.AfterFunc(time.Second*5, func() { + serverNames := map[gomatrixserverlib.ServerName]struct{}{} + if names, err := db.GetPendingPDUServerNames(context.Background()); err == nil { + for _, serverName := range names { + serverNames[serverName] = struct{}{} + } + } else { + log.WithError(err).Error("Failed to get PDU server names for destination queue hydration") } - } else { - log.WithError(err).Error("Failed to get PDU server names for destination queue hydration") - } - if names, err := db.GetPendingEDUServerNames(context.Background()); err == nil { - for _, serverName := range names { - serverNames[serverName] = struct{}{} + if names, err := db.GetPendingEDUServerNames(context.Background()); err == nil { + for _, serverName := range names { + serverNames[serverName] = struct{}{} + } + } else { + log.WithError(err).Error("Failed to get EDU server names for destination queue hydration") } - } else { - log.WithError(err).Error("Failed to get EDU server names for destination queue hydration") - } - for serverName := range serverNames { - if !queues.getQueue(serverName).statistics.Blacklisted() { - queues.getQueue(serverName).wakeQueueIfNeeded() + for serverName := range serverNames { + if !queues.getQueue(serverName).statistics.Blacklisted() { + queues.getQueue(serverName).wakeQueueIfNeeded() + } } - } + }) return queues } @@ -141,9 +140,9 @@ func (oqs *OutgoingQueues) SendEvent( // Check if any of the destinations are prohibited by server ACLs. for destination := range destmap { - if stateapi.IsServerBannedFromRoom( + if api.IsServerBannedFromRoom( context.TODO(), - oqs.stateAPI, + oqs.rsAPI, ev.RoomID(), destination, ) { @@ -205,9 +204,9 @@ func (oqs *OutgoingQueues) SendEDU( // ACLs. if result := gjson.GetBytes(e.Content, "room_id"); result.Exists() { for destination := range destmap { - if stateapi.IsServerBannedFromRoom( + if api.IsServerBannedFromRoom( context.TODO(), - oqs.stateAPI, + oqs.rsAPI, result.Str, destination, ) { diff --git a/federationsender/storage/shared/storage.go b/federationsender/storage/shared/storage.go index f2c4ab098..8de66cc71 100644 --- a/federationsender/storage/shared/storage.go +++ b/federationsender/storage/shared/storage.go @@ -138,7 +138,7 @@ func (d *Database) StoreJSON( var err error _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { nid, err = d.FederationSenderQueueJSON.InsertQueueJSON(ctx, txn, js) - return nil + return err }) if err != nil { return nil, fmt.Errorf("d.insertQueueJSON: %w", err) diff --git a/go.mod b/go.mod index 3a9fef9f5..f1cb3c9be 100644 --- a/go.mod +++ b/go.mod @@ -21,7 +21,7 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3 github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd - github.com/matrix-org/gomatrixserverlib v0.0.0-20200902135805-f7a5b5e89750 + github.com/matrix-org/gomatrixserverlib v0.0.0-20200907151926-38f437f2b2a6 github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.2 @@ -29,6 +29,7 @@ require ( github.com/ngrok/sqlmw v0.0.0-20200129213757-d5c93a81bec6 github.com/opentracing/opentracing-go v1.2.0 github.com/pkg/errors v0.9.1 + github.com/pressly/goose v2.7.0-rc5+incompatible github.com/prometheus/client_golang v1.7.1 github.com/sirupsen/logrus v1.6.0 github.com/tidwall/gjson v1.6.1 diff --git a/go.sum b/go.sum index 33b4f591a..ac7827d9f 100644 --- a/go.sum +++ b/go.sum @@ -567,8 +567,8 @@ github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bh github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd h1:xVrqJK3xHREMNjwjljkAUaadalWc0rRbmVuQatzmgwg= github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20200902135805-f7a5b5e89750 h1:k5vsLfpylXHOXgN51N0QNbak9i+4bT33Puk/ZJgcdDw= -github.com/matrix-org/gomatrixserverlib v0.0.0-20200902135805-f7a5b5e89750/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20200907151926-38f437f2b2a6 h1:43gla6bLt4opWY1mQkAasF/LUCipZl7x2d44TY0wf40= +github.com/matrix-org/gomatrixserverlib v0.0.0-20200907151926-38f437f2b2a6/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 h1:HJ6U3S3ljJqNffYMcIeAncp5qT/i+ZMiJ2JC2F0aXP4= github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91/go.mod h1:sjyPyRxKM5uw1nD2cJ6O2OxI6GOqyVBfNXqKjBZTBZE= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo= @@ -725,6 +725,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pressly/goose v2.7.0-rc5+incompatible h1:txvo810iG1P/rafOx31LYDlOyikBK8A/8prKP4j066w= +github.com/pressly/goose v2.7.0-rc5+incompatible/go.mod h1:m+QHWCqxR3k8D9l7qfzuC/djtlfzxr34mozWDYEu1z8= github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= diff --git a/internal/config/config.go b/internal/config/config.go index 4f8128853..d7470f873 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -51,19 +51,18 @@ type Dendrite struct { // been a breaking change to the config file format. Version int `yaml:"version"` - Global Global `yaml:"global"` - AppServiceAPI AppServiceAPI `yaml:"app_service_api"` - ClientAPI ClientAPI `yaml:"client_api"` - CurrentStateServer CurrentStateServer `yaml:"current_state_server"` - EDUServer EDUServer `yaml:"edu_server"` - FederationAPI FederationAPI `yaml:"federation_api"` - FederationSender FederationSender `yaml:"federation_sender"` - KeyServer KeyServer `yaml:"key_server"` - MediaAPI MediaAPI `yaml:"media_api"` - RoomServer RoomServer `yaml:"room_server"` - ServerKeyAPI ServerKeyAPI `yaml:"server_key_api"` - SyncAPI SyncAPI `yaml:"sync_api"` - UserAPI UserAPI `yaml:"user_api"` + Global Global `yaml:"global"` + AppServiceAPI AppServiceAPI `yaml:"app_service_api"` + ClientAPI ClientAPI `yaml:"client_api"` + EDUServer EDUServer `yaml:"edu_server"` + FederationAPI FederationAPI `yaml:"federation_api"` + FederationSender FederationSender `yaml:"federation_sender"` + KeyServer KeyServer `yaml:"key_server"` + MediaAPI MediaAPI `yaml:"media_api"` + RoomServer RoomServer `yaml:"room_server"` + ServerKeyAPI ServerKeyAPI `yaml:"server_key_api"` + SyncAPI SyncAPI `yaml:"sync_api"` + UserAPI UserAPI `yaml:"user_api"` // The config for tracing the dendrite servers. Tracing struct { @@ -289,7 +288,6 @@ func (c *Dendrite) Defaults() { c.Global.Defaults() c.ClientAPI.Defaults() - c.CurrentStateServer.Defaults() c.EDUServer.Defaults() c.FederationAPI.Defaults() c.FederationSender.Defaults() @@ -309,7 +307,7 @@ func (c *Dendrite) Verify(configErrs *ConfigErrors, isMonolith bool) { Verify(configErrs *ConfigErrors, isMonolith bool) } for _, c := range []verifiable{ - &c.Global, &c.ClientAPI, &c.CurrentStateServer, + &c.Global, &c.ClientAPI, &c.EDUServer, &c.FederationAPI, &c.FederationSender, &c.KeyServer, &c.MediaAPI, &c.RoomServer, &c.ServerKeyAPI, &c.SyncAPI, &c.UserAPI, @@ -321,7 +319,6 @@ func (c *Dendrite) Verify(configErrs *ConfigErrors, isMonolith bool) { func (c *Dendrite) Wiring() { c.ClientAPI.Matrix = &c.Global - c.CurrentStateServer.Matrix = &c.Global c.EDUServer.Matrix = &c.Global c.FederationAPI.Matrix = &c.Global c.FederationSender.Matrix = &c.Global @@ -512,15 +509,6 @@ func (config *Dendrite) UserAPIURL() string { return string(config.UserAPI.InternalAPI.Connect) } -// CurrentStateAPIURL returns an HTTP URL for where the currentstateserver is listening. -func (config *Dendrite) CurrentStateAPIURL() string { - // Hard code the currentstateserver to talk HTTP for now. - // If we support HTTPS we need to think of a practical way to do certificate validation. - // People setting up servers shouldn't need to get a certificate valid for the public - // internet for an internal API. - return string(config.CurrentStateServer.InternalAPI.Connect) -} - // EDUServerURL returns an HTTP URL for where the EDU server is listening. func (config *Dendrite) EDUServerURL() string { // Hard code the EDU server to talk HTTP for now. diff --git a/internal/config/config_currentstate.go b/internal/config/config_currentstate.go deleted file mode 100644 index c07ebe158..000000000 --- a/internal/config/config_currentstate.go +++ /dev/null @@ -1,24 +0,0 @@ -package config - -type CurrentStateServer struct { - Matrix *Global `yaml:"-"` - - InternalAPI InternalAPIOptions `yaml:"internal_api"` - - // The CurrentState database stores the current state of all rooms. - // It is accessed by the CurrentStateServer. - Database DatabaseOptions `yaml:"database"` -} - -func (c *CurrentStateServer) Defaults() { - c.InternalAPI.Listen = "http://localhost:7782" - c.InternalAPI.Connect = "http://localhost:7782" - c.Database.Defaults() - c.Database.ConnectionString = "file:currentstate.db" -} - -func (c *CurrentStateServer) Verify(configErrs *ConfigErrors, isMonolith bool) { - checkURL(configErrs, "current_state_server.internal_api.listen", string(c.InternalAPI.Listen)) - checkURL(configErrs, "current_state_server.internal_api.connect", string(c.InternalAPI.Connect)) - checkNotEmpty(configErrs, "current_state_server.database.connection_string", string(c.Database.ConnectionString)) -} diff --git a/internal/setup/base.go b/internal/setup/base.go index ec2bbc4cf..ef956dd2a 100644 --- a/internal/setup/base.go +++ b/internal/setup/base.go @@ -21,7 +21,6 @@ import ( "net/url" "time" - currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/gomatrixserverlib" @@ -38,7 +37,6 @@ import ( appserviceAPI "github.com/matrix-org/dendrite/appservice/api" asinthttp "github.com/matrix-org/dendrite/appservice/inthttp" - currentstateinthttp "github.com/matrix-org/dendrite/currentstateserver/inthttp" eduServerAPI "github.com/matrix-org/dendrite/eduserver/api" eduinthttp "github.com/matrix-org/dendrite/eduserver/inthttp" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" @@ -188,15 +186,6 @@ func (b *BaseDendrite) UserAPIClient() userapi.UserInternalAPI { return userAPI } -// CurrentStateAPIClient returns CurrentStateInternalAPI for hitting the currentstateserver over HTTP. -func (b *BaseDendrite) CurrentStateAPIClient() currentstateAPI.CurrentStateInternalAPI { - stateAPI, err := currentstateinthttp.NewCurrentStateAPIClient(b.Cfg.CurrentStateAPIURL(), b.httpClient) - if err != nil { - logrus.WithError(err).Panic("UserAPIClient failed", b.httpClient) - } - return stateAPI -} - // EDUServerClient returns EDUServerInputAPI for hitting the EDU server over HTTP func (b *BaseDendrite) EDUServerClient() eduServerAPI.EDUServerInputAPI { e, err := eduinthttp.NewEDUServerClient(b.Cfg.EDUServerURL(), b.httpClient) diff --git a/internal/setup/monolith.go b/internal/setup/monolith.go index f79ebae45..2274283e6 100644 --- a/internal/setup/monolith.go +++ b/internal/setup/monolith.go @@ -20,7 +20,6 @@ import ( appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi" "github.com/matrix-org/dendrite/clientapi/api" - currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api" eduServerAPI "github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/federationapi" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" @@ -53,7 +52,6 @@ type Monolith struct { RoomserverAPI roomserverAPI.RoomserverInternalAPI ServerKeyAPI serverKeyAPI.ServerKeyInternalAPI UserAPI userapi.UserInternalAPI - StateAPI currentstateAPI.CurrentStateInternalAPI KeyAPI keyAPI.KeyInternalAPI // Optional @@ -65,17 +63,17 @@ func (m *Monolith) AddAllPublicRoutes(csMux, ssMux, keyMux, mediaMux *mux.Router clientapi.AddPublicRoutes( csMux, &m.Config.ClientAPI, m.KafkaProducer, m.AccountDB, m.FedClient, m.RoomserverAPI, - m.EDUInternalAPI, m.AppserviceAPI, m.StateAPI, transactions.New(), + m.EDUInternalAPI, m.AppserviceAPI, transactions.New(), m.FederationSenderAPI, m.UserAPI, m.KeyAPI, m.ExtPublicRoomsProvider, ) federationapi.AddPublicRoutes( ssMux, keyMux, &m.Config.FederationAPI, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.FederationSenderAPI, - m.EDUInternalAPI, m.StateAPI, m.KeyAPI, + m.EDUInternalAPI, m.KeyAPI, ) mediaapi.AddPublicRoutes(mediaMux, &m.Config.MediaAPI, m.UserAPI, m.Client) syncapi.AddPublicRoutes( csMux, m.KafkaConsumer, m.UserAPI, m.RoomserverAPI, - m.KeyAPI, m.StateAPI, m.FedClient, &m.Config.SyncAPI, + m.KeyAPI, m.FedClient, &m.Config.SyncAPI, ) } diff --git a/internal/sqlutil/trace.go b/internal/sqlutil/trace.go index 235296388..23359b500 100644 --- a/internal/sqlutil/trace.go +++ b/internal/sqlutil/trace.go @@ -22,7 +22,10 @@ import ( "io" "os" "regexp" + "runtime" + "strconv" "strings" + "sync" "time" "github.com/matrix-org/dendrite/internal/config" @@ -31,6 +34,7 @@ import ( ) var tracingEnabled = os.Getenv("DENDRITE_TRACE_SQL") == "1" +var goidToWriter sync.Map type traceInterceptor struct { sqlmw.NullInterceptor @@ -40,6 +44,8 @@ func (in *traceInterceptor) StmtQueryContext(ctx context.Context, stmt driver.St startedAt := time.Now() rows, err := stmt.QueryContext(ctx, args) + trackGoID(query) + logrus.WithField("duration", time.Since(startedAt)).WithField(logrus.ErrorKey, err).Debug("executed sql query ", query, " args: ", args) return rows, err @@ -49,6 +55,8 @@ func (in *traceInterceptor) StmtExecContext(ctx context.Context, stmt driver.Stm startedAt := time.Now() result, err := stmt.ExecContext(ctx, args) + trackGoID(query) + logrus.WithField("duration", time.Since(startedAt)).WithField(logrus.ErrorKey, err).Debug("executed sql query ", query, " args: ", args) return result, err @@ -75,6 +83,19 @@ func (in *traceInterceptor) RowsNext(c context.Context, rows driver.Rows, dest [ return err } +func trackGoID(query string) { + thisGoID := goid() + if _, ok := goidToWriter.Load(thisGoID); ok { + return // we're on a writer goroutine + } + + q := strings.TrimSpace(query) + if strings.HasPrefix(q, "SELECT") { + return // SELECTs can go on other goroutines + } + logrus.Warnf("unsafe goid: SQL executed not on an ExclusiveWriter: %s", q) +} + // Open opens a database specified by its database driver name and a driver-specific data source name, // usually consisting of at least a database name and connection information. Includes tracing driver // if DENDRITE_TRACE_SQL=1 @@ -119,3 +140,14 @@ func Open(dbProperties *config.DatabaseOptions) (*sql.DB, error) { func init() { registerDrivers() } + +func goid() int { + var buf [64]byte + n := runtime.Stack(buf[:], false) + idField := strings.Fields(strings.TrimPrefix(string(buf[:n]), "goroutine "))[0] + id, err := strconv.Atoi(idField) + if err != nil { + panic(fmt.Sprintf("cannot get goroutine id: %v", err)) + } + return id +} diff --git a/internal/sqlutil/writer_exclusive.go b/internal/sqlutil/writer_exclusive.go index 002bc32cf..91dd77e4d 100644 --- a/internal/sqlutil/writer_exclusive.go +++ b/internal/sqlutil/writer_exclusive.go @@ -60,6 +60,12 @@ func (w *ExclusiveWriter) run() { if !w.running.CAS(false, true) { return } + if tracingEnabled { + gid := goid() + goidToWriter.Store(gid, w) + defer goidToWriter.Delete(gid) + } + defer w.running.Store(false) for task := range w.todo { if task.db != nil && task.txn != nil { diff --git a/internal/test/config.go b/internal/test/config.go index e2106de40..72cd0e6e4 100644 --- a/internal/test/config.go +++ b/internal/test/config.go @@ -87,7 +87,6 @@ func MakeConfig(configDir, kafkaURI, database, host string, startPort int) (*con // the table names are globally unique. But we might not want to // rely on that in the future. cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(database) - cfg.CurrentStateServer.Database.ConnectionString = config.DataSource(database) cfg.FederationSender.Database.ConnectionString = config.DataSource(database) cfg.KeyServer.Database.ConnectionString = config.DataSource(database) cfg.MediaAPI.Database.ConnectionString = config.DataSource(database) @@ -98,7 +97,6 @@ func MakeConfig(configDir, kafkaURI, database, host string, startPort int) (*con cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(database) cfg.AppServiceAPI.InternalAPI.Listen = assignAddress() - cfg.CurrentStateServer.InternalAPI.Listen = assignAddress() cfg.EDUServer.InternalAPI.Listen = assignAddress() cfg.FederationAPI.InternalAPI.Listen = assignAddress() cfg.FederationSender.InternalAPI.Listen = assignAddress() @@ -110,7 +108,6 @@ func MakeConfig(configDir, kafkaURI, database, host string, startPort int) (*con cfg.UserAPI.InternalAPI.Listen = assignAddress() cfg.AppServiceAPI.InternalAPI.Connect = cfg.AppServiceAPI.InternalAPI.Listen - cfg.CurrentStateServer.InternalAPI.Connect = cfg.CurrentStateServer.InternalAPI.Listen cfg.EDUServer.InternalAPI.Connect = cfg.EDUServer.InternalAPI.Listen cfg.FederationAPI.InternalAPI.Connect = cfg.FederationAPI.InternalAPI.Listen cfg.FederationSender.InternalAPI.Connect = cfg.FederationSender.InternalAPI.Listen diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index 2e5613632..78420db1f 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -48,10 +48,11 @@ func NewInternalAPI( DB: db, } updater := internal.NewDeviceListUpdater(db, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable - err = updater.Start() - if err != nil { - logrus.WithError(err).Panicf("failed to start device list updater") - } + go func() { + if err := updater.Start(); err != nil { + logrus.WithError(err).Panicf("failed to start device list updater") + } + }() return &internal.KeyInternalAPI{ DB: db, ThisServer: cfg.Matrix.ServerName, diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go index 779d02c03..e5bec8f6f 100644 --- a/keyserver/storage/postgres/device_keys_table.go +++ b/keyserver/storage/postgres/device_keys_table.go @@ -53,7 +53,7 @@ const selectDeviceKeysSQL = "" + "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" const selectBatchDeviceKeysSQL = "" + - "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1" + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''" const selectMaxStreamForUserSQL = "" + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index d4915afc1..de757f294 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -41,7 +41,7 @@ func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID str func (d *Database) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (counts *api.OneTimeKeysCount, err error) { _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { counts, err = d.OneTimeKeysTable.InsertOneTimeKeys(ctx, txn, keys) - return nil + return err }) return } diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index 195429f08..e7ff9976d 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -50,7 +50,7 @@ const selectDeviceKeysSQL = "" + "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" const selectBatchDeviceKeysSQL = "" + - "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1" + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''" const selectMaxStreamForUserSQL = "" + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" diff --git a/roomserver/api/output.go b/roomserver/api/output.go index d74a37b36..013ebdc83 100644 --- a/roomserver/api/output.go +++ b/roomserver/api/output.go @@ -204,7 +204,7 @@ type OutputRedactedEvent struct { // An OutputNewPeek is written whenever a user starts peeking into a room // using a given device. type OutputNewPeek struct { - RoomID string - UserID string + RoomID string + UserID string DeviceID string } diff --git a/roomserver/api/query.go b/roomserver/api/query.go index d0d0474d8..67a217c82 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -303,6 +303,39 @@ type QueryServerBannedFromRoomResponse struct { Banned bool `json:"banned"` } +// MarshalJSON stringifies the room ID and StateKeyTuple keys so they can be sent over the wire in HTTP API mode. +func (r *QueryBulkStateContentResponse) MarshalJSON() ([]byte, error) { + se := make(map[string]string) + for roomID, tupleToEvent := range r.Rooms { + for tuple, event := range tupleToEvent { + // use 0x1F (unit separator) as the delimiter between room ID/type/state key, + se[fmt.Sprintf("%s\x1F%s\x1F%s", roomID, tuple.EventType, tuple.StateKey)] = event + } + } + return json.Marshal(se) +} + +func (r *QueryBulkStateContentResponse) UnmarshalJSON(data []byte) error { + wireFormat := make(map[string]string) + err := json.Unmarshal(data, &wireFormat) + if err != nil { + return err + } + r.Rooms = make(map[string]map[gomatrixserverlib.StateKeyTuple]string) + for roomTuple, value := range wireFormat { + fields := strings.Split(roomTuple, "\x1F") + roomID := fields[0] + if r.Rooms[roomID] == nil { + r.Rooms[roomID] = make(map[gomatrixserverlib.StateKeyTuple]string) + } + r.Rooms[roomID][gomatrixserverlib.StateKeyTuple{ + EventType: fields[1], + StateKey: fields[2], + }] = value + } + return nil +} + // MarshalJSON stringifies the StateKeyTuple keys so they can be sent over the wire in HTTP API mode. func (r *QueryCurrentStateResponse) MarshalJSON() ([]byte, error) { se := make(map[string]*gomatrixserverlib.HeaderedEvent, len(r.StateEvents)) diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index b5ef59352..8dc1a170b 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -41,6 +41,7 @@ func NewRoomserverAPI( outputRoomEventTopic string, caches caching.RoomServerCaches, keyRing gomatrixserverlib.JSONVerifier, ) *RoomserverInternalAPI { + serverACLs := acls.NewServerACLs(roomserverDB) a := &RoomserverInternalAPI{ DB: roomserverDB, Cfg: cfg, @@ -50,13 +51,14 @@ func NewRoomserverAPI( Queryer: &query.Queryer{ DB: roomserverDB, Cache: caches, - ServerACLs: acls.NewServerACLs(roomserverDB), + ServerACLs: serverACLs, }, Inputer: &input.Inputer{ DB: roomserverDB, OutputRoomEventTopic: outputRoomEventTopic, Producer: producer, ServerName: cfg.Matrix.ServerName, + ACLs: serverACLs, }, // perform-er structs get initialised when we have a federation sender to use } diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index 7a44ff42c..51d20ad39 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -22,6 +22,7 @@ import ( "time" "github.com/Shopify/sarama" + "github.com/matrix-org/dendrite/roomserver/acls" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/gomatrixserverlib" @@ -33,6 +34,7 @@ type Inputer struct { DB storage.Database Producer sarama.SyncProducer ServerName gomatrixserverlib.ServerName + ACLs *acls.ServerACLs OutputRoomEventTopic string workers sync.Map // room ID -> *inputWorker @@ -88,6 +90,10 @@ func (r *Inputer) WriteOutputEvents(roomID string, updates []api.OutputEvent) er "send_as_server": updates[i].NewRoomEvent.SendAsServer, "sender": updates[i].NewRoomEvent.Event.Sender(), }) + if updates[i].NewRoomEvent.Event.Type() == "m.room.server_acl" && updates[i].NewRoomEvent.Event.StateKeyEquals("") { + ev := updates[i].NewRoomEvent.Event.Unwrap() + defer r.ACLs.OnServerACLUpdate(&ev) + } } logger.Infof("Producing to topic '%s'", r.OutputRoomEventTopic) messages[i] = &sarama.ProducerMessage{ diff --git a/roomserver/internal/perform/perform_peek.go b/roomserver/internal/perform/perform_peek.go index 8f61ba1eb..ab6d17b03 100644 --- a/roomserver/internal/perform/perform_peek.go +++ b/roomserver/internal/perform/perform_peek.go @@ -20,8 +20,8 @@ import ( "fmt" "strings" - "github.com/matrix-org/dendrite/internal/config" fsAPI "github.com/matrix-org/dendrite/federationsender/api" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/input" "github.com/matrix-org/dendrite/roomserver/storage" @@ -39,7 +39,6 @@ type Peeker struct { Inputer *input.Inputer } - // PerformPeek handles peeking into matrix rooms, including over federation by talking to the federationsender. func (r *Peeker) PerformPeek( ctx context.Context, @@ -152,27 +151,11 @@ func (r *Peeker) performPeekRoomByID( } } - // handle federated peeks + // If the server name in the room ID isn't ours then it's a + // possible candidate for finding the room via federation. Add + // it to the list of servers to try. if domain != r.Cfg.Matrix.ServerName { - // If the server name in the room ID isn't ours then it's a - // possible candidate for finding the room via federation. Add - // it to the list of servers to try. req.ServerNames = append(req.ServerNames, domain) - - // Try peeking by all of the supplied server names. - fedReq := fsAPI.PerformPeekRequest{ - RoomID: req.RoomIDOrAlias, // the room ID to try and peek - ServerNames: req.ServerNames, // the servers to try peeking via - } - fedRes := fsAPI.PerformPeekResponse{} - r.FSAPI.PerformPeek(ctx, &fedReq, &fedRes) - if fedRes.LastError != nil { - return &api.PerformError{ - Code: api.PerformErrRemote, - Msg: fedRes.LastError.Message, - RemoteCode: fedRes.LastError.Code, - } - } } // If this room isn't world_readable, we reject. @@ -180,7 +163,7 @@ func (r *Peeker) performPeekRoomByID( // XXX: we should probably factor out history_visibility checks into a common utility method somewhere // which handles the default value etc. var worldReadable = false - ev, err := r.DB.GetStateEvent(ctx, roomID, "m.room.history_visibility", "") + ev, _ := r.DB.GetStateEvent(ctx, roomID, "m.room.history_visibility", "") if ev != nil { content := map[string]string{} if err = json.Unmarshal(ev.Content(), &content); err != nil { @@ -195,7 +178,7 @@ func (r *Peeker) performPeekRoomByID( if !worldReadable { return "", &api.PerformError{ Code: api.PerformErrorNotAllowed, - Msg: "Room is not world-readable", + Msg: "Room is not world-readable", } } @@ -205,8 +188,8 @@ func (r *Peeker) performPeekRoomByID( { Type: api.OutputTypeNewPeek, NewPeek: &api.OutputNewPeek{ - RoomID: roomID, - UserID: req.UserID, + RoomID: roomID, + UserID: req.UserID, DeviceID: req.DeviceID, }, }, @@ -219,5 +202,5 @@ func (r *Peeker) performPeekRoomByID( // it will have been overwritten with a room ID by performPeekRoomByAlias. // We should now include this in the response so that the CS API can // return the right room ID. - return roomID, nil; + return roomID, nil } diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index f76c93166..b34ae7701 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -140,6 +140,9 @@ func (r *Queryer) QueryMembershipForUser( if err != nil { return err } + if info == nil { + return fmt.Errorf("QueryMembershipForUser: unknown room %s", request.RoomID) + } membershipEventNID, stillInRoom, err := r.DB.GetMembership(ctx, info.RoomNID, request.UserID) if err != nil { diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go index ebfb296d8..5816d4d82 100644 --- a/roomserver/inthttp/server.go +++ b/roomserver/inthttp/server.go @@ -63,6 +63,17 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(RoomserverPerformPeekPath, + httputil.MakeInternalAPI("performPeek", func(req *http.Request) util.JSONResponse { + var request api.PerformPeekRequest + var response api.PerformPeekResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + r.PerformPeek(req.Context(), &request, &response) + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) internalAPIMux.Handle(RoomserverPerformPublishPath, httputil.MakeInternalAPI("performPublish", func(req *http.Request) util.JSONResponse { var request api.PerformPublishRequest @@ -364,7 +375,7 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - internalAPIMux.Handle(RoomserverQuerySharedUsersPath, + internalAPIMux.Handle(RoomserverQueryKnownUsersPath, httputil.MakeInternalAPI("queryKnownUsers", func(req *http.Request) util.JSONResponse { request := api.QueryKnownUsersRequest{} response := api.QueryKnownUsersResponse{} diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index c4119f7ed..be724da62 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -17,9 +17,9 @@ package storage import ( "context" - "github.com/matrix-org/dendrite/currentstateserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) diff --git a/roomserver/storage/postgres/event_types_table.go b/roomserver/storage/postgres/event_types_table.go index 037d98fe7..02d6ad079 100644 --- a/roomserver/storage/postgres/event_types_table.go +++ b/roomserver/storage/postgres/event_types_table.go @@ -21,6 +21,7 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" @@ -117,7 +118,8 @@ func (s *eventTypeStatements) InsertEventTypeNID( ctx context.Context, txn *sql.Tx, eventType string, ) (types.EventTypeNID, error) { var eventTypeNID int64 - err := txn.Stmt(s.insertEventTypeNIDStmt).QueryRowContext(ctx, eventType).Scan(&eventTypeNID) + stmt := sqlutil.TxStmt(txn, s.insertEventTypeNIDStmt) + err := stmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID) return types.EventTypeNID(eventTypeNID), err } @@ -125,7 +127,8 @@ func (s *eventTypeStatements) SelectEventTypeNID( ctx context.Context, txn *sql.Tx, eventType string, ) (types.EventTypeNID, error) { var eventTypeNID int64 - err := txn.Stmt(s.selectEventTypeNIDStmt).QueryRowContext(ctx, eventType).Scan(&eventTypeNID) + stmt := sqlutil.TxStmt(txn, s.selectEventTypeNIDStmt) + err := stmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID) return types.EventTypeNID(eventTypeNID), err } diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go index ef1b7891a..ce635210e 100644 --- a/roomserver/storage/postgres/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -79,10 +79,10 @@ const selectRoomIDsSQL = "" + "SELECT room_id FROM roomserver_rooms" const bulkSelectRoomIDsSQL = "" + - "SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)" + "SELECT room_id FROM roomserver_rooms WHERE room_nid = ANY($1)" const bulkSelectRoomNIDsSQL = "" + - "SELECT room_nid FROM roomserver_rooms WHERE room_id IN ($1)" + "SELECT room_nid FROM roomserver_rooms WHERE room_id = ANY($1)" type roomStatements struct { insertRoomNIDStmt *sql.Stmt diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go index 329813bfc..834af6069 100644 --- a/roomserver/storage/shared/membership_updater.go +++ b/roomserver/storage/shared/membership_updater.go @@ -79,89 +79,103 @@ func (u *MembershipUpdater) IsLeave() bool { // SetToInvite implements types.MembershipUpdater func (u *MembershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) { - senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender()) - if err != nil { - return false, fmt.Errorf("u.d.AssignStateKeyNID: %w", err) - } - inserted, err := u.d.InvitesTable.InsertInviteEvent( - u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(), - ) - if err != nil { - return false, fmt.Errorf("u.d.InvitesTable.InsertInviteEvent: %w", err) - } - if u.membership != tables.MembershipStateInvite { - if err = u.d.MembershipTable.UpdateMembership( - u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0, - ); err != nil { - return false, fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) + var inserted bool + err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender()) + if err != nil { + return fmt.Errorf("u.d.AssignStateKeyNID: %w", err) } - } - return inserted, nil + inserted, err = u.d.InvitesTable.InsertInviteEvent( + u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(), + ) + if err != nil { + return fmt.Errorf("u.d.InvitesTable.InsertInviteEvent: %w", err) + } + if u.membership != tables.MembershipStateInvite { + if err = u.d.MembershipTable.UpdateMembership( + u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0, + ); err != nil { + return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) + } + } + return nil + }) + return inserted, err } // SetToJoin implements types.MembershipUpdater func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) { var inviteEventIDs []string - senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) - if err != nil { - return nil, fmt.Errorf("u.d.AssignStateKeyNID: %w", err) - } - - // If this is a join event update, there is no invite to update - if !isUpdate { - inviteEventIDs, err = u.d.InvitesTable.UpdateInviteRetired( - u.ctx, u.txn, u.roomNID, u.targetUserNID, - ) + err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) if err != nil { - return nil, fmt.Errorf("u.d.InvitesTables.UpdateInviteRetired: %w", err) + return fmt.Errorf("u.d.AssignStateKeyNID: %w", err) } - } - // Look up the NID of the new join event - nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) - if err != nil { - return nil, fmt.Errorf("u.d.EventNIDs: %w", err) - } - - if u.membership != tables.MembershipStateJoin || isUpdate { - if err = u.d.MembershipTable.UpdateMembership( - u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, - tables.MembershipStateJoin, nIDs[eventID], - ); err != nil { - return nil, fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) + // If this is a join event update, there is no invite to update + if !isUpdate { + inviteEventIDs, err = u.d.InvitesTable.UpdateInviteRetired( + u.ctx, u.txn, u.roomNID, u.targetUserNID, + ) + if err != nil { + return fmt.Errorf("u.d.InvitesTables.UpdateInviteRetired: %w", err) + } } - } - return inviteEventIDs, nil + // Look up the NID of the new join event + nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) + if err != nil { + return fmt.Errorf("u.d.EventNIDs: %w", err) + } + + if u.membership != tables.MembershipStateJoin || isUpdate { + if err = u.d.MembershipTable.UpdateMembership( + u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, + tables.MembershipStateJoin, nIDs[eventID], + ); err != nil { + return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) + } + } + + return nil + }) + + return inviteEventIDs, err } // SetToLeave implements types.MembershipUpdater func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) { - senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) - if err != nil { - return nil, fmt.Errorf("u.d.AssignStateKeyNID: %w", err) - } - inviteEventIDs, err := u.d.InvitesTable.UpdateInviteRetired( - u.ctx, u.txn, u.roomNID, u.targetUserNID, - ) - if err != nil { - return nil, fmt.Errorf("u.d.InvitesTable.updateInviteRetired: %w", err) - } + var inviteEventIDs []string - // Look up the NID of the new leave event - nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) - if err != nil { - return nil, fmt.Errorf("u.d.EventNIDs: %w", err) - } - - if u.membership != tables.MembershipStateLeaveOrBan { - if err = u.d.MembershipTable.UpdateMembership( - u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, - tables.MembershipStateLeaveOrBan, nIDs[eventID], - ); err != nil { - return nil, fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) + err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) + if err != nil { + return fmt.Errorf("u.d.AssignStateKeyNID: %w", err) } - } - return inviteEventIDs, nil + inviteEventIDs, err = u.d.InvitesTable.UpdateInviteRetired( + u.ctx, u.txn, u.roomNID, u.targetUserNID, + ) + if err != nil { + return fmt.Errorf("u.d.InvitesTable.updateInviteRetired: %w", err) + } + + // Look up the NID of the new leave event + nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) + if err != nil { + return fmt.Errorf("u.d.EventNIDs: %w", err) + } + + if u.membership != tables.MembershipStateLeaveOrBan { + if err = u.d.MembershipTable.UpdateMembership( + u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, + tables.MembershipStateLeaveOrBan, nIDs[eventID], + ); err != nil { + return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) + } + } + + return nil + }) + return inviteEventIDs, err } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index a3b33a4fe..262b0f2f8 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -7,7 +7,6 @@ import ( "fmt" "sort" - csstables "github.com/matrix-org/dendrite/currentstateserver/storage/tables" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" @@ -360,7 +359,7 @@ func (d *Database) MembershipUpdater( var updater *MembershipUpdater _ = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { updater, err = NewMembershipUpdater(ctx, d, txn, roomID, targetUserID, targetLocal, roomVersion) - return nil + return err }) return updater, err } @@ -375,7 +374,7 @@ func (d *Database) GetLatestEventsForUpdate( var updater *LatestEventsUpdater _ = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { updater, err = NewLatestEventsUpdater(ctx, d, txn, roomInfo) - return nil + return err }) return updater, err } @@ -724,6 +723,10 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s return nil, err } eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType) + if err == sql.ErrNoRows { + // No rooms have an event of this type, otherwise we'd have an event type NID + return nil, nil + } if err != nil { return nil, err } @@ -754,7 +757,7 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s } } - return nil, fmt.Errorf("GetStateEvent: no event type '%s' with key '%s' exists in room %s", evType, stateKey, roomID) + return nil, nil } // GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key). @@ -774,15 +777,18 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership } stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, userID) if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err) } roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, stateKeyNID, membershipState) if err != nil { - return nil, err + return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectRoomsWithMembership: %w", err) } roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, roomNIDs) if err != nil { - return nil, err + return nil, fmt.Errorf("GetRoomsByMembership: failed to lookup room nids: %w", err) } if len(roomIDs) != len(roomNIDs) { return nil, fmt.Errorf("GetRoomsByMembership: missing room IDs, got %d want %d", len(roomIDs), len(roomNIDs)) @@ -792,8 +798,90 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership // GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match. // If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned. -func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]csstables.StrippedEvent, error) { - return nil, fmt.Errorf("not implemented yet") +// nolint:gocyclo +func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) { + eventTypes := make([]string, 0, len(tuples)) + for _, tuple := range tuples { + eventTypes = append(eventTypes, tuple.EventType) + } + // we don't bother failing the request if we get asked for event types we don't know about, as all that would result in is no matches which + // isn't a failure. + eventTypeNIDMap, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, eventTypes) + if err != nil { + return nil, fmt.Errorf("GetBulkStateContent: failed to map event type nids: %w", err) + } + typeNIDSet := make(map[types.EventTypeNID]bool) + for _, nid := range eventTypeNIDMap { + typeNIDSet[nid] = true + } + + allowWildcard := make(map[types.EventTypeNID]bool) + eventStateKeys := make([]string, 0, len(tuples)) + for _, tuple := range tuples { + if allowWildcards && tuple.StateKey == "*" { + allowWildcard[eventTypeNIDMap[tuple.EventType]] = true + continue + } + eventStateKeys = append(eventStateKeys, tuple.StateKey) + + } + + eventStateKeyNIDMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, eventStateKeys) + if err != nil { + return nil, fmt.Errorf("GetBulkStateContent: failed to map state key nids: %w", err) + } + stateKeyNIDSet := make(map[types.EventStateKeyNID]bool) + for _, nid := range eventStateKeyNIDMap { + stateKeyNIDSet[nid] = true + } + + var eventNIDs []types.EventNID + eventNIDToVer := make(map[types.EventNID]gomatrixserverlib.RoomVersion) + // TODO: This feels like this is going to be really slow... + for _, roomID := range roomIDs { + roomInfo, err2 := d.RoomInfo(ctx, roomID) + if err2 != nil { + return nil, fmt.Errorf("GetBulkStateContent: failed to load room info for room %s : %w", roomID, err2) + } + // for unknown rooms or rooms which we don't have the current state, skip them. + if roomInfo == nil || roomInfo.IsStub { + continue + } + entries, err2 := d.loadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID) + if err2 != nil { + return nil, fmt.Errorf("GetBulkStateContent: failed to load state for room %s : %w", roomID, err2) + } + for _, entry := range entries { + if typeNIDSet[entry.EventTypeNID] { + if allowWildcard[entry.EventTypeNID] || stateKeyNIDSet[entry.EventStateKeyNID] { + eventNIDs = append(eventNIDs, entry.EventNID) + eventNIDToVer[entry.EventNID] = roomInfo.RoomVersion + } + } + } + } + + events, err := d.EventJSONTable.BulkSelectEventJSON(ctx, eventNIDs) + if err != nil { + return nil, fmt.Errorf("GetBulkStateContent: failed to load event JSON for event nids: %w", err) + } + result := make([]tables.StrippedEvent, len(events)) + for i := range events { + roomVer := eventNIDToVer[events[i].EventNID] + ev, err := gomatrixserverlib.NewEventFromTrustedJSON(events[i].EventJSON, false, roomVer) + if err != nil { + return nil, fmt.Errorf("GetBulkStateContent: failed to load event JSON for event NID %v : %w", events[i].EventNID, err) + } + hev := ev.Headered(roomVer) + result[i] = tables.StrippedEvent{ + EventType: ev.Type(), + RoomID: ev.RoomID(), + StateKey: *ev.StateKey(), + ContentValue: tables.ExtractContentValue(&hev), + } + } + + return result, nil } // JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 0d5ce516d..bb1ab39aa 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -41,7 +41,7 @@ const membershipSchema = ` ` var selectJoinedUsersSetForRoomsSQL = "" + - "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid = ANY($1) AND" + + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid IN ($1) AND" + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " GROUP BY target_nid" // Insert a row in to membership table so that it can be locked by the diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index a142f2b1a..adb06212a 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -6,6 +6,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/tidwall/gjson" ) type EventJSONPair struct { @@ -155,3 +156,45 @@ type Redactions interface { // successfully redacted the event JSON. MarkRedactionValidated(ctx context.Context, txn *sql.Tx, redactionEventID string, validated bool) error } + +// StrippedEvent represents a stripped event for returning extracted content values. +type StrippedEvent struct { + RoomID string + EventType string + StateKey string + ContentValue string +} + +// ExtractContentValue from the given state event. For example, given an m.room.name event with: +// content: { name: "Foo" } +// this returns "Foo". +func ExtractContentValue(ev *gomatrixserverlib.HeaderedEvent) string { + content := ev.Content() + key := "" + switch ev.Type() { + case gomatrixserverlib.MRoomCreate: + key = "creator" + case gomatrixserverlib.MRoomCanonicalAlias: + key = "alias" + case gomatrixserverlib.MRoomHistoryVisibility: + key = "history_visibility" + case gomatrixserverlib.MRoomJoinRules: + key = "join_rule" + case gomatrixserverlib.MRoomMember: + key = "membership" + case gomatrixserverlib.MRoomName: + key = "name" + case "m.room.avatar": + key = "url" + case "m.room.topic": + key = "topic" + case "m.room.guest_access": + key = "guest_access" + } + result := gjson.GetBytes(content, key) + if !result.Exists() { + return "" + } + // this returns the empty string if this is not a string type + return result.Str +} diff --git a/syncapi/consumers/keychange.go b/syncapi/consumers/keychange.go index 93fa822df..200ac85cc 100644 --- a/syncapi/consumers/keychange.go +++ b/syncapi/consumers/keychange.go @@ -20,9 +20,9 @@ import ( "sync" "github.com/Shopify/sarama" - currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/keyserver/api" + roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" syncinternal "github.com/matrix-org/dendrite/syncapi/internal" "github.com/matrix-org/dendrite/syncapi/storage" syncapi "github.com/matrix-org/dendrite/syncapi/sync" @@ -36,7 +36,7 @@ type OutputKeyChangeEventConsumer struct { keyChangeConsumer *internal.ContinualConsumer db storage.Database serverName gomatrixserverlib.ServerName // our server name - currentStateAPI currentstateAPI.CurrentStateInternalAPI + rsAPI roomserverAPI.RoomserverInternalAPI keyAPI api.KeyInternalAPI partitionToOffset map[int32]int64 partitionToOffsetMu sync.Mutex @@ -51,7 +51,7 @@ func NewOutputKeyChangeEventConsumer( kafkaConsumer sarama.Consumer, n *syncapi.Notifier, keyAPI api.KeyInternalAPI, - currentStateAPI currentstateAPI.CurrentStateInternalAPI, + rsAPI roomserverAPI.RoomserverInternalAPI, store storage.Database, ) *OutputKeyChangeEventConsumer { @@ -67,7 +67,7 @@ func NewOutputKeyChangeEventConsumer( db: store, serverName: serverName, keyAPI: keyAPI, - currentStateAPI: currentStateAPI, + rsAPI: rsAPI, partitionToOffset: make(map[int32]int64), partitionToOffsetMu: sync.Mutex{}, notifier: n, @@ -105,8 +105,8 @@ func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) er return err } // work out who we need to notify about the new key - var queryRes currentstateAPI.QuerySharedUsersResponse - err := s.currentStateAPI.QuerySharedUsers(context.Background(), ¤tstateAPI.QuerySharedUsersRequest{ + var queryRes roomserverAPI.QuerySharedUsersResponse + err := s.rsAPI.QuerySharedUsers(context.Background(), &roomserverAPI.QuerySharedUsersRequest{ UserID: output.UserID, }, &queryRes) if err != nil { @@ -115,7 +115,7 @@ func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) er } // TODO: f.e queryRes.UserIDsToCount : notify users by waking up streams posUpdate := types.NewStreamToken(0, 0, map[string]*types.LogPosition{ - syncinternal.DeviceListLogName: &types.LogPosition{ + syncinternal.DeviceListLogName: { Offset: msg.Offset, Partition: msg.Partition, }, @@ -129,7 +129,7 @@ func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) er func (s *OutputKeyChangeEventConsumer) OnJoinEvent(ev *gomatrixserverlib.HeaderedEvent) { // work out who we are now sharing rooms with which we previously were not and notify them about the joining // users keys: - changed, _, err := syncinternal.TrackChangedUsers(context.Background(), s.currentStateAPI, *ev.StateKey(), []string{ev.RoomID()}, nil) + changed, _, err := syncinternal.TrackChangedUsers(context.Background(), s.rsAPI, *ev.StateKey(), []string{ev.RoomID()}, nil) if err != nil { log.WithError(err).Error("OnJoinEvent: failed to work out changed users") return @@ -142,7 +142,7 @@ func (s *OutputKeyChangeEventConsumer) OnJoinEvent(ev *gomatrixserverlib.Headere func (s *OutputKeyChangeEventConsumer) OnLeaveEvent(ev *gomatrixserverlib.HeaderedEvent) { // work out who we are no longer sharing any rooms with and notify them about the leaving user - _, left, err := syncinternal.TrackChangedUsers(context.Background(), s.currentStateAPI, *ev.StateKey(), nil, []string{ev.RoomID()}) + _, left, err := syncinternal.TrackChangedUsers(context.Background(), s.rsAPI, *ev.StateKey(), nil, []string{ev.RoomID()}) if err != nil { log.WithError(err).Error("OnLeaveEvent: failed to work out left users") return diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index b205e1e90..b6ab9bd50 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -17,6 +17,7 @@ package consumers import ( "context" "encoding/json" + "fmt" "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/internal" @@ -26,11 +27,13 @@ import ( "github.com/matrix-org/dendrite/syncapi/sync" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" ) // OutputRoomEventConsumer consumes events that originated in the room server. type OutputRoomEventConsumer struct { + cfg *config.SyncAPI rsAPI api.RoomserverInternalAPI rsConsumer *internal.ContinualConsumer db storage.Database @@ -55,6 +58,7 @@ func NewOutputRoomEventConsumer( PartitionStore: store, } s := &OutputRoomEventConsumer{ + cfg: cfg, rsConsumer: &consumer, db: store, notifier: n, @@ -164,6 +168,12 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( }).Panicf("roomserver output log: write event failure") return nil } + + if pduPos, err = s.notifyJoinedPeeks(ctx, &ev, pduPos); err != nil { + logrus.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos) + return err + } + s.notifier.OnNewEvent(&ev, "", nil, types.NewStreamToken(pduPos, 0, nil)) s.notifyKeyChanges(&ev) @@ -186,6 +196,37 @@ func (s *OutputRoomEventConsumer) notifyKeyChanges(ev *gomatrixserverlib.Headere } } +func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, sp types.StreamPosition) (types.StreamPosition, error) { + if ev.Type() != gomatrixserverlib.MRoomMember { + return sp, nil + } + membership, err := ev.Membership() + if err != nil { + return sp, fmt.Errorf("ev.Membership: %w", err) + } + // TODO: check that it's a join and not a profile change (means unmarshalling prev_content) + if membership == gomatrixserverlib.Join { + // check it's a local join + _, domain, err := gomatrixserverlib.SplitID('@', *ev.StateKey()) + if err != nil { + return sp, fmt.Errorf("gomatrixserverlib.SplitID: %w", err) + } + if domain != s.cfg.Matrix.ServerName { + return sp, nil + } + + // cancel any peeks for it + peekSP, peekErr := s.db.DeletePeeks(ctx, ev.RoomID(), *ev.StateKey()) + if peekErr != nil { + return sp, fmt.Errorf("s.db.DeletePeeks: %w", peekErr) + } + if peekSP > 0 { + sp = peekSP + } + } + return sp, nil +} + func (s *OutputRoomEventConsumer) onNewInviteEvent( ctx context.Context, msg api.OutputNewInviteEvent, ) error { diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index 7d127aa83..090e0c658 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -19,9 +19,9 @@ import ( "strings" "github.com/Shopify/sarama" - currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api" "github.com/matrix-org/dendrite/keyserver/api" keyapi "github.com/matrix-org/dendrite/keyserver/api" + roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -48,7 +48,7 @@ func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.KeyInternalAPI, userID, // be already filled in with join/leave information. // nolint:gocyclo func DeviceListCatchup( - ctx context.Context, keyAPI keyapi.KeyInternalAPI, stateAPI currentstateAPI.CurrentStateInternalAPI, + ctx context.Context, keyAPI keyapi.KeyInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI, userID string, res *types.Response, from, to types.StreamingToken, ) (hasNew bool, err error) { @@ -56,7 +56,7 @@ func DeviceListCatchup( newlyJoinedRooms := joinedRooms(res, userID) newlyLeftRooms := leftRooms(res) if len(newlyJoinedRooms) > 0 || len(newlyLeftRooms) > 0 { - changed, left, err := TrackChangedUsers(ctx, stateAPI, userID, newlyJoinedRooms, newlyLeftRooms) + changed, left, err := TrackChangedUsers(ctx, rsAPI, userID, newlyJoinedRooms, newlyLeftRooms) if err != nil { return false, err } @@ -97,7 +97,7 @@ func DeviceListCatchup( } // QueryKeyChanges gets ALL users who have changed keys, we want the ones who share rooms with the user. var sharedUsersMap map[string]int - sharedUsersMap, queryRes.UserIDs = filterSharedUsers(ctx, stateAPI, userID, queryRes.UserIDs) + sharedUsersMap, queryRes.UserIDs = filterSharedUsers(ctx, rsAPI, userID, queryRes.UserIDs) util.GetLogger(ctx).Debugf( "QueryKeyChanges request p=%d,off=%d,to=%d response p=%d off=%d uids=%v", partition, offset, toOffset, queryRes.Partition, queryRes.Offset, queryRes.UserIDs, @@ -142,7 +142,7 @@ func DeviceListCatchup( // TrackChangedUsers calculates the values of device_lists.changed|left in the /sync response. // nolint:gocyclo func TrackChangedUsers( - ctx context.Context, stateAPI currentstateAPI.CurrentStateInternalAPI, userID string, newlyJoinedRooms, newlyLeftRooms []string, + ctx context.Context, rsAPI roomserverAPI.RoomserverInternalAPI, userID string, newlyJoinedRooms, newlyLeftRooms []string, ) (changed, left []string, err error) { // process leaves first, then joins afterwards so if we join/leave/join/leave we err on the side of including users. @@ -151,16 +151,16 @@ func TrackChangedUsers( // - Get users in newly left room. - QueryCurrentState // - Loop set of users and decrement by 1 for each user in newly left room. // - If count=0 then they share no more rooms so inform BOTH parties of this via 'left'=[...] in /sync. - var queryRes currentstateAPI.QuerySharedUsersResponse - err = stateAPI.QuerySharedUsers(ctx, ¤tstateAPI.QuerySharedUsersRequest{ + var queryRes roomserverAPI.QuerySharedUsersResponse + err = rsAPI.QuerySharedUsers(ctx, &roomserverAPI.QuerySharedUsersRequest{ UserID: userID, IncludeRoomIDs: newlyLeftRooms, }, &queryRes) if err != nil { return nil, nil, err } - var stateRes currentstateAPI.QueryBulkStateContentResponse - err = stateAPI.QueryBulkStateContent(ctx, ¤tstateAPI.QueryBulkStateContentRequest{ + var stateRes roomserverAPI.QueryBulkStateContentResponse + err = rsAPI.QueryBulkStateContent(ctx, &roomserverAPI.QueryBulkStateContentRequest{ RoomIDs: newlyLeftRooms, StateTuples: []gomatrixserverlib.StateKeyTuple{ { @@ -193,14 +193,14 @@ func TrackChangedUsers( // - Loop set of users in newly joined room, do they appear in the set of users prior to joining? // - If yes: then they already shared a room in common, do nothing. // - If no: then they are a brand new user so inform BOTH parties of this via 'changed=[...]' - err = stateAPI.QuerySharedUsers(ctx, ¤tstateAPI.QuerySharedUsersRequest{ + err = rsAPI.QuerySharedUsers(ctx, &roomserverAPI.QuerySharedUsersRequest{ UserID: userID, ExcludeRoomIDs: newlyJoinedRooms, }, &queryRes) if err != nil { return nil, left, err } - err = stateAPI.QueryBulkStateContent(ctx, ¤tstateAPI.QueryBulkStateContentRequest{ + err = rsAPI.QueryBulkStateContent(ctx, &roomserverAPI.QueryBulkStateContentRequest{ RoomIDs: newlyJoinedRooms, StateTuples: []gomatrixserverlib.StateKeyTuple{ { @@ -228,11 +228,11 @@ func TrackChangedUsers( } func filterSharedUsers( - ctx context.Context, stateAPI currentstateAPI.CurrentStateInternalAPI, userID string, usersWithChangedKeys []string, + ctx context.Context, rsAPI roomserverAPI.RoomserverInternalAPI, userID string, usersWithChangedKeys []string, ) (map[string]int, []string) { var result []string - var sharedUsersRes currentstateAPI.QuerySharedUsersResponse - err := stateAPI.QuerySharedUsers(ctx, ¤tstateAPI.QuerySharedUsersRequest{ + var sharedUsersRes roomserverAPI.QuerySharedUsersResponse + err := rsAPI.QuerySharedUsers(ctx, &roomserverAPI.QuerySharedUsersRequest{ UserID: userID, }, &sharedUsersRes) if err != nil { diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go index baf60ef05..c25011814 100644 --- a/syncapi/internal/keychange_test.go +++ b/syncapi/internal/keychange_test.go @@ -7,8 +7,8 @@ import ( "testing" "github.com/Shopify/sarama" - "github.com/matrix-org/dendrite/currentstateserver/api" keyapi "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" @@ -49,25 +49,18 @@ func (k *mockKeyAPI) InputDeviceListUpdate(ctx context.Context, req *keyapi.Inpu } -type mockCurrentStateAPI struct { +type mockRoomserverAPI struct { + api.RoomserverInternalAPITrace roomIDToJoinedMembers map[string][]string } -func (s *mockCurrentStateAPI) QueryCurrentState(ctx context.Context, req *api.QueryCurrentStateRequest, res *api.QueryCurrentStateResponse) error { - return nil -} - -func (s *mockCurrentStateAPI) QueryKnownUsers(ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse) error { - return nil -} - // QueryRoomsForUser retrieves a list of room IDs matching the given query. -func (s *mockCurrentStateAPI) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error { +func (s *mockRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error { return nil } // QueryBulkStateContent does a bulk query for state event content in the given rooms. -func (s *mockCurrentStateAPI) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkStateContentRequest, res *api.QueryBulkStateContentResponse) error { +func (s *mockRoomserverAPI) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkStateContentRequest, res *api.QueryBulkStateContentResponse) error { res.Rooms = make(map[string]map[gomatrixserverlib.StateKeyTuple]string) if req.AllowWildcards && len(req.StateTuples) == 1 && req.StateTuples[0].EventType == gomatrixserverlib.MRoomMember && req.StateTuples[0].StateKey == "*" { for _, roomID := range req.RoomIDs { @@ -84,7 +77,7 @@ func (s *mockCurrentStateAPI) QueryBulkStateContent(ctx context.Context, req *ap } // QuerySharedUsers returns a list of users who share at least 1 room in common with the given user. -func (s *mockCurrentStateAPI) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error { +func (s *mockRoomserverAPI) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error { roomsToQuery := req.IncludeRoomIDs for roomID, members := range s.roomIDToJoinedMembers { exclude := false @@ -114,10 +107,6 @@ func (s *mockCurrentStateAPI) QuerySharedUsers(ctx context.Context, req *api.Que return nil } -func (t *mockCurrentStateAPI) QueryServerBannedFromRoom(ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse) error { - return nil -} - type wantCatchup struct { hasNew bool changed []string @@ -185,12 +174,13 @@ func TestKeyChangeCatchupOnJoinShareNewUser(t *testing.T) { syncResponse := types.NewResponse() syncResponse = joinResponseWithRooms(syncResponse, syncingUser, []string{newlyJoinedRoom}) - hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, &mockCurrentStateAPI{ + rsAPI := &mockRoomserverAPI{ roomIDToJoinedMembers: map[string][]string{ newlyJoinedRoom: {syncingUser, newShareUser}, "!another:room": {syncingUser}, }, - }, syncingUser, syncResponse, emptyToken, newestToken) + } + hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken) if err != nil { t.Fatalf("DeviceListCatchup returned an error: %s", err) } @@ -207,12 +197,13 @@ func TestKeyChangeCatchupOnLeaveShareLeftUser(t *testing.T) { syncResponse := types.NewResponse() syncResponse = leaveResponseWithRooms(syncResponse, syncingUser, []string{newlyLeftRoom}) - hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, &mockCurrentStateAPI{ + rsAPI := &mockRoomserverAPI{ roomIDToJoinedMembers: map[string][]string{ newlyLeftRoom: {removeUser}, "!another:room": {syncingUser}, }, - }, syncingUser, syncResponse, emptyToken, newestToken) + } + hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken) if err != nil { t.Fatalf("DeviceListCatchup returned an error: %s", err) } @@ -229,12 +220,13 @@ func TestKeyChangeCatchupOnJoinShareNoNewUsers(t *testing.T) { syncResponse := types.NewResponse() syncResponse = joinResponseWithRooms(syncResponse, syncingUser, []string{newlyJoinedRoom}) - hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, &mockCurrentStateAPI{ + rsAPI := &mockRoomserverAPI{ roomIDToJoinedMembers: map[string][]string{ newlyJoinedRoom: {syncingUser, existingUser}, "!another:room": {syncingUser, existingUser}, }, - }, syncingUser, syncResponse, emptyToken, newestToken) + } + hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken) if err != nil { t.Fatalf("Catchup returned an error: %s", err) } @@ -250,12 +242,13 @@ func TestKeyChangeCatchupOnLeaveShareNoUsers(t *testing.T) { syncResponse := types.NewResponse() syncResponse = leaveResponseWithRooms(syncResponse, syncingUser, []string{newlyLeftRoom}) - hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, &mockCurrentStateAPI{ + rsAPI := &mockRoomserverAPI{ roomIDToJoinedMembers: map[string][]string{ newlyLeftRoom: {existingUser}, "!another:room": {syncingUser, existingUser}, }, - }, syncingUser, syncResponse, emptyToken, newestToken) + } + hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken) if err != nil { t.Fatalf("DeviceListCatchup returned an error: %s", err) } @@ -309,11 +302,12 @@ func TestKeyChangeCatchupNoNewJoinsButMessages(t *testing.T) { jr.Timeline.Events = roomTimelineEvents syncResponse.Rooms.Join[roomID] = jr - hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, &mockCurrentStateAPI{ + rsAPI := &mockRoomserverAPI{ roomIDToJoinedMembers: map[string][]string{ roomID: {syncingUser, existingUser}, }, - }, syncingUser, syncResponse, emptyToken, newestToken) + } + hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken) if err != nil { t.Fatalf("DeviceListCatchup returned an error: %s", err) } @@ -334,13 +328,14 @@ func TestKeyChangeCatchupChangeAndLeft(t *testing.T) { syncResponse = joinResponseWithRooms(syncResponse, syncingUser, []string{newlyJoinedRoom}) syncResponse = leaveResponseWithRooms(syncResponse, syncingUser, []string{newlyLeftRoom}) - hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, &mockCurrentStateAPI{ + rsAPI := &mockRoomserverAPI{ roomIDToJoinedMembers: map[string][]string{ newlyJoinedRoom: {syncingUser, newShareUser, newShareUser2}, newlyLeftRoom: {newlyLeftUser, newlyLeftUser2}, "!another:room": {syncingUser}, }, - }, syncingUser, syncResponse, emptyToken, newestToken) + } + hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken) if err != nil { t.Fatalf("Catchup returned an error: %s", err) } @@ -419,12 +414,15 @@ func TestKeyChangeCatchupChangeAndLeftSameRoom(t *testing.T) { lr.Timeline.Events = roomEvents syncResponse.Rooms.Leave[roomID] = lr - hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, &mockCurrentStateAPI{ + rsAPI := &mockRoomserverAPI{ roomIDToJoinedMembers: map[string][]string{ roomID: {newShareUser, newShareUser2}, "!another:room": {syncingUser}, }, - }, syncingUser, syncResponse, emptyToken, newestToken) + } + hasNew, err := DeviceListCatchup( + context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken, + ) if err != nil { t.Fatalf("DeviceListCatchup returned an error: %s", err) } diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 3559a9e4b..807c7f5e4 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -86,6 +86,9 @@ type Database interface { // AddPeek adds a new peek to our DB for a given room by a given user's device. // Returns an error if there was a problem communicating with the database. AddPeek(ctx context.Context, RoomID, UserID, DeviceID string) (types.StreamPosition, error) + // DeletePeek deletes all peeks for a given room by a given user + // Returns an error if there was a problem communicating with the database. + DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error) // SetTypingTimeoutCallback sets a callback function that is called right after // a user is removed from the typing user list due to timeout. SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) diff --git a/syncapi/storage/postgres/peeks_table.go b/syncapi/storage/postgres/peeks_table.go new file mode 100644 index 000000000..75eeac986 --- /dev/null +++ b/syncapi/storage/postgres/peeks_table.go @@ -0,0 +1,186 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" +) + +const peeksSchema = ` +CREATE TABLE IF NOT EXISTS syncapi_peeks ( + id BIGINT DEFAULT nextval('syncapi_stream_id'), + room_id TEXT NOT NULL, + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + deleted BOOL NOT NULL DEFAULT false, + -- When the peek was created in UNIX epoch ms. + creation_ts BIGINT NOT NULL, + UNIQUE(room_id, user_id, device_id) +); + +CREATE INDEX IF NOT EXISTS syncapi_peeks_room_id_idx ON syncapi_peeks(room_id); +CREATE INDEX IF NOT EXISTS syncapi_peeks_user_id_device_id_idx ON syncapi_peeks(user_id, device_id); +` + +const insertPeekSQL = "" + + "INSERT INTO syncapi_peeks" + + " (room_id, user_id, device_id, creation_ts)" + + " VALUES ($1, $2, $3, $4)" + + " ON CONFLICT (room_id, user_id, device_id) DO UPDATE SET deleted=false, creation_ts=$4" + + " RETURNING id" + +const deletePeekSQL = "" + + "UPDATE syncapi_peeks SET deleted=true, id=nextval('syncapi_stream_id') WHERE room_id = $1 AND user_id = $2 AND device_id = $3 RETURNING id" + +const deletePeeksSQL = "" + + "UPDATE syncapi_peeks SET deleted=true, id=nextval('syncapi_stream_id') WHERE room_id = $1 AND user_id = $2 RETURNING id" + +// we care about all the peeks which were created in this range, deleted in this range, +// or were created before this range but haven't been deleted yet. +const selectPeeksInRangeSQL = "" + + "SELECT room_id, deleted, (id > $3 AND id <= $4) AS changed FROM syncapi_peeks WHERE user_id = $1 AND device_id = $2 AND ((id <= $3 AND NOT deleted) OR (id > $3 AND id <= $4))" + +const selectPeekingDevicesSQL = "" + + "SELECT room_id, user_id, device_id FROM syncapi_peeks WHERE deleted=false" + +const selectMaxPeekIDSQL = "" + + "SELECT MAX(id) FROM syncapi_peeks" + +type peekStatements struct { + db *sql.DB + insertPeekStmt *sql.Stmt + deletePeekStmt *sql.Stmt + deletePeeksStmt *sql.Stmt + selectPeeksInRangeStmt *sql.Stmt + selectPeekingDevicesStmt *sql.Stmt + selectMaxPeekIDStmt *sql.Stmt +} + +func NewPostgresPeeksTable(db *sql.DB) (tables.Peeks, error) { + _, err := db.Exec(peeksSchema) + if err != nil { + return nil, err + } + s := &peekStatements{ + db: db, + } + if s.insertPeekStmt, err = db.Prepare(insertPeekSQL); err != nil { + return nil, err + } + if s.deletePeekStmt, err = db.Prepare(deletePeekSQL); err != nil { + return nil, err + } + if s.deletePeeksStmt, err = db.Prepare(deletePeeksSQL); err != nil { + return nil, err + } + if s.selectPeeksInRangeStmt, err = db.Prepare(selectPeeksInRangeSQL); err != nil { + return nil, err + } + if s.selectPeekingDevicesStmt, err = db.Prepare(selectPeekingDevicesSQL); err != nil { + return nil, err + } + if s.selectMaxPeekIDStmt, err = db.Prepare(selectMaxPeekIDSQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *peekStatements) InsertPeek( + ctx context.Context, txn *sql.Tx, roomID, userID, deviceID string, +) (streamPos types.StreamPosition, err error) { + nowMilli := time.Now().UnixNano() / int64(time.Millisecond) + stmt := sqlutil.TxStmt(txn, s.insertPeekStmt) + err = stmt.QueryRowContext(ctx, roomID, userID, deviceID, nowMilli).Scan(&streamPos) + return +} + +func (s *peekStatements) DeletePeek( + ctx context.Context, txn *sql.Tx, roomID, userID, deviceID string, +) (streamPos types.StreamPosition, err error) { + stmt := sqlutil.TxStmt(txn, s.deletePeekStmt) + err = stmt.QueryRowContext(ctx, roomID, userID, deviceID).Scan(&streamPos) + return +} + +func (s *peekStatements) DeletePeeks( + ctx context.Context, txn *sql.Tx, roomID, userID string, +) (streamPos types.StreamPosition, err error) { + stmt := sqlutil.TxStmt(txn, s.deletePeeksStmt) + err = stmt.QueryRowContext(ctx, roomID, userID).Scan(&streamPos) + return +} + +func (s *peekStatements) SelectPeeksInRange( + ctx context.Context, txn *sql.Tx, userID, deviceID string, r types.Range, +) (peeks []types.Peek, err error) { + rows, err := sqlutil.TxStmt(txn, s.selectPeeksInRangeStmt).QueryContext(ctx, userID, deviceID, r.Low(), r.High()) + if err != nil { + return + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectPeeksInRange: rows.close() failed") + + for rows.Next() { + peek := types.Peek{} + var changed bool + if err = rows.Scan(&peek.RoomID, &peek.Deleted, &changed); err != nil { + return + } + peek.New = changed && !peek.Deleted + peeks = append(peeks, peek) + } + + return peeks, rows.Err() +} + +func (s *peekStatements) SelectPeekingDevices( + ctx context.Context, +) (peekingDevices map[string][]types.PeekingDevice, err error) { + rows, err := s.selectPeekingDevicesStmt.QueryContext(ctx) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectPeekingDevices: rows.close() failed") + + result := make(map[string][]types.PeekingDevice) + for rows.Next() { + var roomID, userID, deviceID string + if err := rows.Scan(&roomID, &userID, &deviceID); err != nil { + return nil, err + } + devices := result[roomID] + devices = append(devices, types.PeekingDevice{UserID: userID, DeviceID: deviceID}) + result[roomID] = devices + } + return result, nil +} + +func (s *peekStatements) SelectMaxPeekID( + ctx context.Context, txn *sql.Tx, +) (id int64, err error) { + var nullableID sql.NullInt64 + stmt := sqlutil.TxStmt(txn, s.selectMaxPeekIDStmt) + err = stmt.QueryRowContext(ctx).Scan(&nullableID) + if nullableID.Valid { + id = nullableID.Int64 + } + return +} diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go index e7f2c9440..7f19722ae 100644 --- a/syncapi/storage/postgres/syncserver.go +++ b/syncapi/storage/postgres/syncserver.go @@ -62,6 +62,10 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e if err != nil { return nil, err } + peeks, err := NewPostgresPeeksTable(d.db) + if err != nil { + return nil, err + } topology, err := NewPostgresTopologyTable(d.db) if err != nil { return nil, err @@ -82,6 +86,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e DB: d.db, Writer: d.writer, Invites: invites, + Peeks: peeks, AccountData: accountData, OutputEvents: events, Topology: topology, diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index f11b00b33..695556bf6 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -30,7 +30,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" - "github.com/sirupsen/logrus" + log "github.com/sirupsen/logrus" ) // Database is a temporary struct until we have made syncserver.go the same for both pq/sqlite @@ -146,7 +146,7 @@ func (d *Database) AddInviteEvent( ) (sp types.StreamPosition, err error) { _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { sp, err = d.Invites.InsertInviteEvent(ctx, txn, inviteEvent) - return nil + return err }) return } @@ -158,7 +158,7 @@ func (d *Database) RetireInviteEvent( ) (sp types.StreamPosition, err error) { _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { sp, err = d.Invites.DeleteInviteEvent(ctx, txn, inviteEventID) - return nil + return err }) return } @@ -169,13 +169,30 @@ func (d *Database) RetireInviteEvent( func (d *Database) AddPeek( ctx context.Context, roomID, userID, deviceID string, ) (sp types.StreamPosition, err error) { - _ = d.Writer.Do(nil, nil, func(_ *sql.Tx) error { - sp, err = d.Peeks.InsertPeek(ctx, nil, roomID, userID, deviceID) - return nil + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + sp, err = d.Peeks.InsertPeek(ctx, txn, roomID, userID, deviceID) + return err }) return } +// DeletePeeks tracks the fact that a user has stopped peeking from all devices +// If the peeks was successfully deleted this returns the stream ID it was stored at. +// Returns an error if there was a problem communicating with the database. +func (d *Database) DeletePeeks( + ctx context.Context, roomID, userID string, +) (sp types.StreamPosition, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + sp, err = d.Peeks.DeletePeeks(ctx, txn, roomID, userID) + return err + }) + if err == sql.ErrNoRows { + sp = 0 + err = nil + } + return +} + // GetAccountDataInRange returns all account data for a given user inserted or // updated between two given positions // Returns a map following the format data[roomID] = []dataTypes @@ -214,7 +231,7 @@ func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.Strea "transaction_id", in[i].TransactionID.TransactionID, ) if err != nil { - logrus.WithFields(logrus.Fields{ + log.WithFields(log.Fields{ "event_id": out[i].EventID(), }).WithError(err).Warnf("Failed to add transaction ID to event") } @@ -407,7 +424,6 @@ func (d *Database) EventPositionInTopology( func (d *Database) syncPositionTx( ctx context.Context, txn *sql.Tx, ) (sp types.StreamingToken, err error) { - maxEventID, err := d.OutputEvents.SelectMaxEventID(ctx, txn) if err != nil { return sp, err @@ -426,6 +442,13 @@ func (d *Database) syncPositionTx( if maxInviteID > maxEventID { maxEventID = maxInviteID } + maxPeekID, err := d.Peeks.SelectMaxPeekID(ctx, txn) + if err != nil { + return sp, err + } + if maxPeekID > maxEventID { + maxEventID = maxPeekID + } sp = types.NewStreamToken(types.StreamPosition(maxEventID), types.StreamPosition(d.EDUCache.GetLatestSyncPosition()), nil) return } @@ -602,7 +625,7 @@ func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, reda return err } if len(redactedEvents) == 0 { - logrus.WithField("event_id", redactedEventID).WithField("redaction_event", redactedBecause.EventID()).Warnf("missing redacted event for redaction") + log.WithField("event_id", redactedEventID).WithField("redaction_event", redactedBecause.EventID()).Warnf("missing redacted event for redaction") return nil } eventToRedact := redactedEvents[0].Unwrap() @@ -675,19 +698,21 @@ func (d *Database) getResponseWithPDUsForCompleteSync( } // Add peeked rooms. - peeks, err := d.Peeks.SelectPeeks(ctx, txn, userID, deviceID) + peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, deviceID, r) if err != nil { return } for _, peek := range peeks { - var jr *types.JoinResponse - jr, err = d.getJoinResponseForCompleteSync( - ctx, txn, peek.RoomID, r, &stateFilter, numRecentEventsPerRoom, - ) - if err != nil { - return + if !peek.Deleted { + var jr *types.JoinResponse + jr, err = d.getJoinResponseForCompleteSync( + ctx, txn, peek.RoomID, r, &stateFilter, numRecentEventsPerRoom, + ) + if err != nil { + return + } + res.Rooms.Peek[peek.RoomID] = *jr } - res.Rooms.Peek[peek.RoomID] = *jr } if err = d.addInvitesToResponse(ctx, txn, userID, r, res); err != nil { @@ -698,9 +723,9 @@ func (d *Database) getResponseWithPDUsForCompleteSync( return //res, toPos, joinedRoomIDs, err } -func (d* Database) getJoinResponseForCompleteSync( +func (d *Database) getJoinResponseForCompleteSync( ctx context.Context, txn *sql.Tx, - roomID string, + roomID string, r types.Range, stateFilter *gomatrixserverlib.StateFilter, numRecentEventsPerRoom int, @@ -798,8 +823,10 @@ func (d *Database) addInvitesToResponse( res.Rooms.Invite[roomID] = *ir } for roomID := range retiredInvites { - lr := types.NewLeaveResponse() - res.Rooms.Leave[roomID] = *lr + if _, ok := res.Rooms.Join[roomID]; !ok { + lr := types.NewLeaveResponse() + res.Rooms.Leave[roomID] = *lr + } } return nil } @@ -985,6 +1012,7 @@ func (d *Database) fetchMissingStateEvents( // exclusive of oldPos, inclusive of newPos, for the rooms in which // the user has new membership events. // A list of joined room IDs is also returned in case the caller needs it. +// nolint:gocyclo func (d *Database) getStateDeltas( ctx context.Context, device *userapi.Device, txn *sql.Tx, r types.Range, userID string, @@ -1012,16 +1040,13 @@ func (d *Database) getStateDeltas( // find out which rooms this user is peeking, if any. // We do this before joins so any peeks get overwritten - peeks, err := d.Peeks.SelectPeeks(ctx, txn, userID, device.ID) + peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r) if err != nil { return nil, nil, err } // add peek blocks - peeking := make(map[string]bool) - newPeeks := false for _, peek := range peeks { - peeking[peek.RoomID] = true if peek.New { // send full room state down instead of a delta var s []types.StreamEvent @@ -1030,22 +1055,13 @@ func (d *Database) getStateDeltas( return nil, nil, err } state[peek.RoomID] = s - newPeeks = true } - - deltas = append(deltas, stateDelta{ - membership: gomatrixserverlib.Peek, - stateEvents: d.StreamEventsToEvents(device, state[peek.RoomID]), - roomID: peek.RoomID, - }) - } - - if newPeeks { - err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { - return d.Peeks.MarkPeeksAsOld(ctx, txn, userID, device.ID) - }) - if err != nil { - return nil, nil, err + if !peek.Deleted { + deltas = append(deltas, stateDelta{ + membership: gomatrixserverlib.Peek, + stateEvents: d.StreamEventsToEvents(device, state[peek.RoomID]), + roomID: peek.RoomID, + }) } } @@ -1112,35 +1128,33 @@ func (d *Database) getStateDeltas( // requests with full_state=true. // Fetches full state for all joined rooms and uses selectStateInRange to get // updates for other rooms. +// nolint:gocyclo func (d *Database) getStateDeltasForFullStateSync( ctx context.Context, device *userapi.Device, txn *sql.Tx, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter, ) ([]stateDelta, []string, error) { - joinedRoomIDs, err := d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) - if err != nil { - return nil, nil, err - } - - peeks, err := d.Peeks.SelectPeeks(ctx, txn, userID, device.ID) - if err != nil { - return nil, nil, err - } - // Use a reasonable initial capacity - deltas := make([]stateDelta, 0, len(joinedRoomIDs) + len(peeks)) + deltas := make(map[string]stateDelta) - // Add full states for all joined rooms - for _, joinedRoomID := range joinedRoomIDs { - s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, joinedRoomID, stateFilter) - if stateErr != nil { - return nil, nil, stateErr + peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r) + if err != nil { + return nil, nil, err + } + + // Add full states for all peeking rooms + for _, peek := range peeks { + if !peek.Deleted { + s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, peek.RoomID, stateFilter) + if stateErr != nil { + return nil, nil, stateErr + } + deltas[peek.RoomID] = stateDelta{ + membership: gomatrixserverlib.Peek, + stateEvents: d.StreamEventsToEvents(device, s), + roomID: peek.RoomID, + } } - deltas = append(deltas, stateDelta{ - membership: gomatrixserverlib.Join, - stateEvents: d.StreamEventsToEvents(device, s), - roomID: joinedRoomID, - }) } // Add full states for all peeking rooms @@ -1183,12 +1197,12 @@ func (d *Database) getStateDeltasForFullStateSync( for _, ev := range stateStreamEvents { if membership := getMembershipFromEvent(&ev.Event, userID); membership != "" { if membership != gomatrixserverlib.Join { // We've already added full state for all joined rooms above. - deltas = append(deltas, stateDelta{ + deltas[roomID] = stateDelta{ membership: membership, membershipPos: ev.StreamPosition, stateEvents: d.StreamEventsToEvents(device, stateStreamEvents), roomID: roomID, - }) + } } break @@ -1196,7 +1210,33 @@ func (d *Database) getStateDeltasForFullStateSync( } } - return deltas, joinedRoomIDs, nil + joinedRoomIDs, err := d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) + if err != nil { + return nil, nil, err + } + + // Add full states for all joined rooms + for _, joinedRoomID := range joinedRoomIDs { + s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, joinedRoomID, stateFilter) + if stateErr != nil { + return nil, nil, stateErr + } + deltas[joinedRoomID] = stateDelta{ + membership: gomatrixserverlib.Join, + stateEvents: d.StreamEventsToEvents(device, s), + roomID: joinedRoomID, + } + } + + // Create a response array. + result := make([]stateDelta, len(deltas)) + i := 0 + for _, delta := range deltas { + result[i] = delta + i++ + } + + return result, joinedRoomIDs, nil } func (d *Database) currentStateStreamEventsForRoom( diff --git a/syncapi/storage/sqlite3/peeks_table.go b/syncapi/storage/sqlite3/peeks_table.go index c013a6350..d755e28c2 100644 --- a/syncapi/storage/sqlite3/peeks_table.go +++ b/syncapi/storage/sqlite3/peeks_table.go @@ -27,13 +27,14 @@ import ( const peeksSchema = ` CREATE TABLE IF NOT EXISTS syncapi_peeks ( - id INTEGER PRIMARY KEY, + id INTEGER, room_id TEXT NOT NULL, user_id TEXT NOT NULL, device_id TEXT NOT NULL, - new BOOL NOT NULL DEFAULT true, + deleted BOOL NOT NULL DEFAULT false, -- When the peek was created in UNIX epoch ms. - creation_ts INTEGER NOT NULL + creation_ts INTEGER NOT NULL, + UNIQUE(room_id, user_id, device_id) ); CREATE INDEX IF NOT EXISTS syncapi_peeks_room_id_idx ON syncapi_peeks(room_id); @@ -41,34 +42,37 @@ CREATE INDEX IF NOT EXISTS syncapi_peeks_user_id_device_id_idx ON syncapi_peeks( ` const insertPeekSQL = "" + - "INSERT INTO syncapi_peeks" + - " (room_id, user_id, device_id, creation_ts)" + - " VALUES ($1, $2, $3, $4)" + "INSERT OR REPLACE INTO syncapi_peeks" + + " (id, room_id, user_id, device_id, creation_ts, deleted)" + + " VALUES ($1, $2, $3, $4, $5, false)" const deletePeekSQL = "" + - "DELETE FROM syncapi_peeks WHERE room_id = $1 AND user_id = $2 and device_id = $3" + "UPDATE syncapi_peeks SET deleted=true, id=$1 WHERE room_id = $2 AND user_id = $3 AND device_id = $4" const deletePeeksSQL = "" + - "DELETE FROM syncapi_peeks WHERE room_id = $1 AND user_id = $2" + "UPDATE syncapi_peeks SET deleted=true, id=$1 WHERE room_id = $2 AND user_id = $3" -const selectPeeksSQL = "" + - "SELECT room_id, new FROM syncapi_peeks WHERE user_id = $1 and device_id = $2" +// we care about all the peeks which were created in this range, deleted in this range, +// or were created before this range but haven't been deleted yet. +// BEWARE: sqlite chokes on out of order substitution strings. +const selectPeeksInRangeSQL = "" + + "SELECT id, room_id, deleted FROM syncapi_peeks WHERE user_id = $1 AND device_id = $2 AND ((id <= $3 AND NOT deleted=true) OR (id > $3 AND id <= $4))" const selectPeekingDevicesSQL = "" + - "SELECT room_id, user_id, device_id FROM syncapi_peeks" + "SELECT room_id, user_id, device_id FROM syncapi_peeks WHERE deleted=false" -const markPeeksAsOldSQL = "" + - "UPDATE syncapi_peeks SET new=false WHERE user_id = $1 and device_id = $2" +const selectMaxPeekIDSQL = "" + + "SELECT MAX(id) FROM syncapi_peeks" type peekStatements struct { - db *sql.DB - streamIDStatements *streamIDStatements - insertPeekStmt *sql.Stmt - deletePeekStmt *sql.Stmt - deletePeeksStmt *sql.Stmt - selectPeeksStmt *sql.Stmt - selectPeekingDevicesStmt *sql.Stmt - markPeeksAsOldStmt *sql.Stmt + db *sql.DB + streamIDStatements *streamIDStatements + insertPeekStmt *sql.Stmt + deletePeekStmt *sql.Stmt + deletePeeksStmt *sql.Stmt + selectPeeksInRangeStmt *sql.Stmt + selectPeekingDevicesStmt *sql.Stmt + selectMaxPeekIDStmt *sql.Stmt } func NewSqlitePeeksTable(db *sql.DB, streamID *streamIDStatements) (tables.Peeks, error) { @@ -77,7 +81,7 @@ func NewSqlitePeeksTable(db *sql.DB, streamID *streamIDStatements) (tables.Peeks return nil, err } s := &peekStatements{ - db: db, + db: db, streamIDStatements: streamID, } if s.insertPeekStmt, err = db.Prepare(insertPeekSQL); err != nil { @@ -89,13 +93,13 @@ func NewSqlitePeeksTable(db *sql.DB, streamID *streamIDStatements) (tables.Peeks if s.deletePeeksStmt, err = db.Prepare(deletePeeksSQL); err != nil { return nil, err } - if s.selectPeeksStmt, err = db.Prepare(selectPeeksSQL); err != nil { + if s.selectPeeksInRangeStmt, err = db.Prepare(selectPeeksInRangeSQL); err != nil { return nil, err } if s.selectPeekingDevicesStmt, err = db.Prepare(selectPeekingDevicesSQL); err != nil { return nil, err } - if s.markPeeksAsOldStmt, err = db.Prepare(markPeeksAsOldSQL); err != nil { + if s.selectMaxPeekIDStmt, err = db.Prepare(selectMaxPeekIDSQL); err != nil { return nil, err } return s, nil @@ -109,7 +113,7 @@ func (s *peekStatements) InsertPeek( return } nowMilli := time.Now().UnixNano() / int64(time.Millisecond) - _, err = sqlutil.TxStmt(txn, s.insertPeekStmt).ExecContext(ctx, roomID, userID, deviceID, nowMilli) + _, err = sqlutil.TxStmt(txn, s.insertPeekStmt).ExecContext(ctx, streamPos, roomID, userID, deviceID, nowMilli) return } @@ -120,48 +124,53 @@ func (s *peekStatements) DeletePeek( if err != nil { return } - _, err = sqlutil.TxStmt(txn, s.deletePeekStmt).ExecContext(ctx, roomID, userID, deviceID) + _, err = sqlutil.TxStmt(txn, s.deletePeekStmt).ExecContext(ctx, streamPos, roomID, userID, deviceID) return } func (s *peekStatements) DeletePeeks( ctx context.Context, txn *sql.Tx, roomID, userID string, -) (streamPos types.StreamPosition, err error) { - streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn) +) (types.StreamPosition, error) { + streamPos, err := s.streamIDStatements.nextStreamID(ctx, txn) if err != nil { - return + return 0, err } - _, err = sqlutil.TxStmt(txn, s.deletePeeksStmt).ExecContext(ctx, roomID, userID) - return + result, err := sqlutil.TxStmt(txn, s.deletePeeksStmt).ExecContext(ctx, streamPos, roomID, userID) + if err != nil { + return 0, err + } + numAffected, err := result.RowsAffected() + if err != nil { + return 0, err + } + if numAffected == 0 { + return 0, sql.ErrNoRows + } + return streamPos, nil } -func (s *peekStatements) SelectPeeks( - ctx context.Context, txn *sql.Tx, userID, deviceID string, +func (s *peekStatements) SelectPeeksInRange( + ctx context.Context, txn *sql.Tx, userID, deviceID string, r types.Range, ) (peeks []types.Peek, err error) { - rows, err := sqlutil.TxStmt(txn, s.selectPeeksStmt).QueryContext(ctx, userID, deviceID) + rows, err := sqlutil.TxStmt(txn, s.selectPeeksInRangeStmt).QueryContext(ctx, userID, deviceID, r.Low(), r.High()) if err != nil { return } - defer internal.CloseAndLogIfError(ctx, rows, "SelectPeeks: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "SelectPeeksInRange: rows.close() failed") for rows.Next() { peek := types.Peek{} - if err = rows.Scan(&peek.RoomID, &peek.New); err != nil { + var id types.StreamPosition + if err = rows.Scan(&id, &peek.RoomID, &peek.Deleted); err != nil { return } + peek.New = (id > r.Low() && id <= r.High()) && !peek.Deleted peeks = append(peeks, peek) } return peeks, rows.Err() } -func (s *peekStatements) MarkPeeksAsOld ( - ctx context.Context, txn *sql.Tx, userID, deviceID string, -) (err error) { - _, err = sqlutil.TxStmt(txn, s.markPeeksAsOldStmt).ExecContext(ctx, userID, deviceID) - return -} - func (s *peekStatements) SelectPeekingDevices( ctx context.Context, ) (peekingDevices map[string][]types.PeekingDevice, err error) { @@ -178,8 +187,20 @@ func (s *peekStatements) SelectPeekingDevices( return nil, err } devices := result[roomID] - devices = append(devices, types.PeekingDevice{userID, deviceID}) + devices = append(devices, types.PeekingDevice{UserID: userID, DeviceID: deviceID}) result[roomID] = devices } return result, nil } + +func (s *peekStatements) SelectMaxPeekID( + ctx context.Context, txn *sql.Tx, +) (id int64, err error) { + var nullableID sql.NullInt64 + stmt := sqlutil.TxStmt(txn, s.selectMaxPeekIDStmt) + err = stmt.QueryRowContext(ctx).Scan(&nullableID) + if nullableID.Valid { + id = nullableID.Int64 + } + return +} diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index 2cbab26e4..86d83ec98 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -99,7 +99,7 @@ func (d *SyncServerDatasource) prepare() (err error) { DB: d.db, Writer: d.writer, Invites: invites, - Peeks: peeks, + Peeks: peeks, AccountData: accountData, OutputEvents: events, BackwardExtremities: bwExtrem, diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 1f8663b9d..631746c6e 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -43,9 +43,9 @@ type Peeks interface { InsertPeek(ctx context.Context, txn *sql.Tx, roomID, userID, deviceID string) (streamPos types.StreamPosition, err error) DeletePeek(ctx context.Context, txn *sql.Tx, roomID, userID, deviceID string) (streamPos types.StreamPosition, err error) DeletePeeks(ctx context.Context, txn *sql.Tx, roomID, userID string) (streamPos types.StreamPosition, err error) - SelectPeeks(ctxt context.Context, txn *sql.Tx, userID, deviceID string) (peeks []types.Peek, err error) + SelectPeeksInRange(ctxt context.Context, txn *sql.Tx, userID, deviceID string, r types.Range) (peeks []types.Peek, err error) SelectPeekingDevices(ctxt context.Context) (peekingDevices map[string][]types.PeekingDevice, err error) - MarkPeeksAsOld(ctxt context.Context, txn *sql.Tx, userID, deviceID string) (err error) + SelectMaxPeekID(ctx context.Context, txn *sql.Tx) (id int64, err error) } type Events interface { diff --git a/syncapi/sync/notifier.go b/syncapi/sync/notifier.go index fc8482037..fcac3f16c 100644 --- a/syncapi/sync/notifier.go +++ b/syncapi/sync/notifier.go @@ -246,12 +246,10 @@ func (n *Notifier) wakeupUsers(userIDs []string, peekingDevices []types.PeekingD } } - if peekingDevices != nil { - for _, peekingDevice := range peekingDevices { - // TODO: don't bother waking up for devices whose users we already woke up - if stream := n.fetchUserDeviceStream(peekingDevice.UserID, peekingDevice.DeviceID, false); stream != nil { - stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream - } + for _, peekingDevice := range peekingDevices { + // TODO: don't bother waking up for devices whose users we already woke up + if stream := n.fetchUserDeviceStream(peekingDevice.UserID, peekingDevice.DeviceID, false); stream != nil { + stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream } } } @@ -332,22 +330,22 @@ func (n *Notifier) joinedUsers(roomID string) (userIDs []string) { return n.roomIDToJoinedUsers[roomID].values() } - // Not thread-safe: must be called on the OnNewEvent goroutine only func (n *Notifier) addPeekingDevice(roomID, userID, deviceID string) { if _, ok := n.roomIDToPeekingDevices[roomID]; !ok { n.roomIDToPeekingDevices[roomID] = make(peekingDeviceSet) } - n.roomIDToPeekingDevices[roomID].add(types.PeekingDevice{userID, deviceID}) + n.roomIDToPeekingDevices[roomID].add(types.PeekingDevice{UserID: userID, DeviceID: deviceID}) } // Not thread-safe: must be called on the OnNewEvent goroutine only +// nolint:unused func (n *Notifier) removePeekingDevice(roomID, userID, deviceID string) { if _, ok := n.roomIDToPeekingDevices[roomID]; !ok { n.roomIDToPeekingDevices[roomID] = make(peekingDeviceSet) } // XXX: is this going to work as a key? - n.roomIDToPeekingDevices[roomID].remove(types.PeekingDevice{userID, deviceID}) + n.roomIDToPeekingDevices[roomID].remove(types.PeekingDevice{UserID: userID, DeviceID: deviceID}) } // Not thread-safe: must be called on the OnNewEvent goroutine only @@ -358,8 +356,6 @@ func (n *Notifier) PeekingDevices(roomID string) (peekingDevices []types.Peeking return n.roomIDToPeekingDevices[roomID].values() } - - // removeEmptyUserStreams iterates through the user stream map and removes any // that have been empty for a certain amount of time. This is a crude way of // ensuring that the userStreams map doesn't grow forver. @@ -414,6 +410,7 @@ func (s peekingDeviceSet) add(d types.PeekingDevice) { s[d] = true } +// nolint:unused func (s peekingDeviceSet) remove(d types.PeekingDevice) { delete(s, d) } diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 357df240e..aaaf94917 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -23,8 +23,8 @@ import ( "time" "github.com/matrix-org/dendrite/clientapi/jsonerror" - currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api" keyapi "github.com/matrix-org/dendrite/keyserver/api" + roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/internal" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" @@ -40,15 +40,15 @@ type RequestPool struct { userAPI userapi.UserInternalAPI notifier *Notifier keyAPI keyapi.KeyInternalAPI - stateAPI currentstateAPI.CurrentStateInternalAPI + rsAPI roomserverAPI.RoomserverInternalAPI } // NewRequestPool makes a new RequestPool func NewRequestPool( db storage.Database, n *Notifier, userAPI userapi.UserInternalAPI, keyAPI keyapi.KeyInternalAPI, - stateAPI currentstateAPI.CurrentStateInternalAPI, + rsAPI roomserverAPI.RoomserverInternalAPI, ) *RequestPool { - return &RequestPool{db, userAPI, n, keyAPI, stateAPI} + return &RequestPool{db, userAPI, n, keyAPI, rsAPI} } // OnIncomingSyncRequest is called when a client makes a /sync request. This function MUST be @@ -265,7 +265,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea func (rp *RequestPool) appendDeviceLists( data *types.Response, userID string, since, to types.StreamingToken, ) (*types.Response, error) { - _, err := internal.DeviceListCatchup(context.Background(), rp.keyAPI, rp.stateAPI, userID, data, since, to) + _, err := internal.DeviceListCatchup(context.Background(), rp.keyAPI, rp.rsAPI, userID, data, since, to) if err != nil { return nil, fmt.Errorf("internal.DeviceListCatchup: %w", err) } diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index 0f4ea828b..c77c55412 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -21,7 +21,6 @@ import ( "github.com/gorilla/mux" "github.com/sirupsen/logrus" - currentstateapi "github.com/matrix-org/dendrite/currentstateserver/api" "github.com/matrix-org/dendrite/internal/config" keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" @@ -42,7 +41,6 @@ func AddPublicRoutes( userAPI userapi.UserInternalAPI, rsAPI api.RoomserverInternalAPI, keyAPI keyapi.KeyInternalAPI, - currentStateAPI currentstateapi.CurrentStateInternalAPI, federation *gomatrixserverlib.FederationClient, cfg *config.SyncAPI, ) { @@ -62,11 +60,11 @@ func AddPublicRoutes( logrus.WithError(err).Panicf("failed to start notifier") } - requestPool := sync.NewRequestPool(syncDB, notifier, userAPI, keyAPI, currentStateAPI) + requestPool := sync.NewRequestPool(syncDB, notifier, userAPI, keyAPI, rsAPI) keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer( cfg.Matrix.ServerName, string(cfg.Matrix.Kafka.TopicFor(config.TopicOutputKeyChangeEvent)), - consumer, notifier, keyAPI, currentStateAPI, syncDB, + consumer, notifier, keyAPI, rsAPI, syncDB, ) if err = keyChangeConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start key change consumer") diff --git a/syncapi/types/types.go b/syncapi/types/types.go index 23d88626c..2499976e5 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -511,11 +511,12 @@ type SendToDeviceEvent struct { } type PeekingDevice struct { - UserID string - DeviceID string + UserID string + DeviceID string } type Peek struct { - RoomID string - New bool -} \ No newline at end of file + RoomID string + New bool + Deleted bool +} diff --git a/sytest-whitelist b/sytest-whitelist index 7ce59fef6..0adeaee6f 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -460,3 +460,14 @@ If user leaves room, remote user changes device and rejoins we see update in /sy Can search public room list Can get remote public room list Asking for a remote rooms list, but supplying the local server's name, returns the local rooms list +After changing password, can't log in with old password +After changing password, can log in with new password +After changing password, existing session still works +After changing password, different sessions can optionally be kept +After changing password, a different session no longer works by default +Local users can peek into world_readable rooms by room ID +We can't peek into rooms with shared history_visibility +We can't peek into rooms with invited history_visibility +We can't peek into rooms with joined history_visibility +Local users can peek by room alias +Peeked rooms only turn up in the sync for the device who peeked them \ No newline at end of file diff --git a/userapi/api/api.go b/userapi/api/api.go index e6d05c335..3baaa1002 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -26,6 +26,7 @@ import ( type UserInternalAPI interface { InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error + PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error @@ -63,6 +64,10 @@ type PerformDeviceDeletionRequest struct { UserID string // The devices to delete. An empty slice means delete all devices. DeviceIDs []string + // The requesting device ID to exclude from deletion. This is needed + // so that a password change doesn't cause that client to be logged + // out. Only specify when DeviceIDs is empty. + ExceptDeviceID string } type PerformDeviceDeletionResponse struct { @@ -165,6 +170,18 @@ type PerformAccountCreationResponse struct { Account *Account } +// PerformAccountCreationRequest is the request for PerformAccountCreation +type PerformPasswordUpdateRequest struct { + Localpart string // Required: The localpart for this account. + Password string // Required: The new password to set. +} + +// PerformAccountCreationResponse is the response for PerformAccountCreation +type PerformPasswordUpdateResponse struct { + PasswordUpdated bool + Account *Account +} + // PerformDeviceCreationRequest is the request for PerformDeviceCreation type PerformDeviceCreationRequest struct { Localpart string diff --git a/userapi/internal/api.go b/userapi/internal/api.go index b97f148e0..461c548cc 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -98,6 +98,15 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P res.Account = acc return nil } + +func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error { + if err := a.AccountDB.SetPassword(ctx, req.Localpart, req.Password); err != nil { + return err + } + res.PasswordUpdated = true + return nil +} + func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.PerformDeviceCreationRequest, res *api.PerformDeviceCreationResponse) error { util.GetLogger(ctx).WithFields(logrus.Fields{ "localpart": req.Localpart, @@ -126,7 +135,7 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe deletedDeviceIDs := req.DeviceIDs if len(req.DeviceIDs) == 0 { var devices []api.Device - devices, err = a.DeviceDB.RemoveAllDevices(ctx, local) + devices, err = a.DeviceDB.RemoveAllDevices(ctx, local, req.ExceptDeviceID) for _, d := range devices { deletedDeviceIDs = append(deletedDeviceIDs, d.ID) } diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 5f4df0eb1..6dcaf7568 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -30,6 +30,7 @@ const ( PerformDeviceCreationPath = "/userapi/performDeviceCreation" PerformAccountCreationPath = "/userapi/performAccountCreation" + PerformPasswordUpdatePath = "/userapi/performPasswordUpdate" PerformDeviceDeletionPath = "/userapi/performDeviceDeletion" PerformDeviceUpdatePath = "/userapi/performDeviceUpdate" @@ -81,6 +82,18 @@ func (h *httpUserInternalAPI) PerformAccountCreation( return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } +func (h *httpUserInternalAPI) PerformPasswordUpdate( + ctx context.Context, + request *api.PerformPasswordUpdateRequest, + response *api.PerformPasswordUpdateResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPasswordUpdate") + defer span.Finish() + + apiURL := h.apiURL + PerformPasswordUpdatePath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + func (h *httpUserInternalAPI) PerformDeviceCreation( ctx context.Context, request *api.PerformDeviceCreationRequest, diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index 47d68ff21..d26746788 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -39,6 +39,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(PerformAccountCreationPath, + httputil.MakeInternalAPI("performPasswordUpdate", func(req *http.Request) util.JSONResponse { + request := api.PerformPasswordUpdateRequest{} + response := api.PerformPasswordUpdateResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformPasswordUpdate(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) internalAPIMux.Handle(PerformDeviceCreationPath, httputil.MakeInternalAPI("performDeviceCreation", func(req *http.Request) util.JSONResponse { request := api.PerformDeviceCreationRequest{} diff --git a/userapi/storage/accounts/interface.go b/userapi/storage/accounts/interface.go index 86b91e603..49446f11f 100644 --- a/userapi/storage/accounts/interface.go +++ b/userapi/storage/accounts/interface.go @@ -28,6 +28,7 @@ type Database interface { internal.PartitionStorer GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) + SetPassword(ctx context.Context, localpart string, plaintextPassword string) error SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error SetDisplayName(ctx context.Context, localpart string, displayName string) error // CreateAccount makes a new account with the given login name and password, and creates an empty profile diff --git a/userapi/storage/accounts/postgres/accounts_table.go b/userapi/storage/accounts/postgres/accounts_table.go index 931ffb73d..8c8d32cf8 100644 --- a/userapi/storage/accounts/postgres/accounts_table.go +++ b/userapi/storage/accounts/postgres/accounts_table.go @@ -47,6 +47,9 @@ CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1; const insertAccountSQL = "" + "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)" +const updatePasswordSQL = "" + + "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" + const selectAccountByLocalpartSQL = "" + "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" @@ -56,10 +59,9 @@ const selectPasswordHashSQL = "" + const selectNewNumericLocalpartSQL = "" + "SELECT nextval('numeric_username_seq')" -// TODO: Update password - type accountsStatements struct { insertAccountStmt *sql.Stmt + updatePasswordStmt *sql.Stmt selectAccountByLocalpartStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt selectNewNumericLocalpartStmt *sql.Stmt @@ -74,6 +76,9 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil { return } + if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil { + return + } if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil { return } @@ -114,6 +119,13 @@ func (s *accountsStatements) insertAccount( }, nil } +func (s *accountsStatements) updatePassword( + ctx context.Context, localpart, passwordHash string, +) (err error) { + _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) + return +} + func (s *accountsStatements) selectPasswordHash( ctx context.Context, localpart string, ) (hash string, err error) { diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go index b36264dd9..8b9ebef80 100644 --- a/userapi/storage/accounts/postgres/storage.go +++ b/userapi/storage/accounts/postgres/storage.go @@ -112,6 +112,17 @@ func (d *Database) SetDisplayName( return d.profiles.setDisplayName(ctx, localpart, displayName) } +// SetPassword sets the account password to the given hash. +func (d *Database) SetPassword( + ctx context.Context, localpart, plaintextPassword string, +) error { + hash, err := hashPassword(plaintextPassword) + if err != nil { + return err + } + return d.accounts.updatePassword(ctx, localpart, hash) +} + // CreateGuestAccount makes a new guest account and creates an empty profile // for this account. func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) { diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/accounts/sqlite3/accounts_table.go index 798a6de96..fbbdc3370 100644 --- a/userapi/storage/accounts/sqlite3/accounts_table.go +++ b/userapi/storage/accounts/sqlite3/accounts_table.go @@ -45,6 +45,9 @@ CREATE TABLE IF NOT EXISTS account_accounts ( const insertAccountSQL = "" + "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)" +const updatePasswordSQL = "" + + "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" + const selectAccountByLocalpartSQL = "" + "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" @@ -54,11 +57,10 @@ const selectPasswordHashSQL = "" + const selectNewNumericLocalpartSQL = "" + "SELECT COUNT(localpart) FROM account_accounts" -// TODO: Update password - type accountsStatements struct { db *sql.DB insertAccountStmt *sql.Stmt + updatePasswordStmt *sql.Stmt selectAccountByLocalpartStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt selectNewNumericLocalpartStmt *sql.Stmt @@ -75,6 +77,9 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil { return } + if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil { + return + } if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil { return } @@ -115,6 +120,13 @@ func (s *accountsStatements) insertAccount( }, nil } +func (s *accountsStatements) updatePassword( + ctx context.Context, localpart, passwordHash string, +) (err error) { + _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) + return +} + func (s *accountsStatements) selectPasswordHash( ctx context.Context, localpart string, ) (hash string, err error) { diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index 46106297b..4b66304c2 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -126,6 +126,18 @@ func (d *Database) SetDisplayName( }) } +// SetPassword sets the account password to the given hash. +func (d *Database) SetPassword( + ctx context.Context, localpart, plaintextPassword string, +) error { + hash, err := hashPassword(plaintextPassword) + if err != nil { + return err + } + err = d.accounts.updatePassword(ctx, localpart, hash) + return err +} + // CreateGuestAccount makes a new guest account and creates an empty profile // for this account. func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) { diff --git a/userapi/storage/devices/interface.go b/userapi/storage/devices/interface.go index 9b4261c9d..168c84c5c 100644 --- a/userapi/storage/devices/interface.go +++ b/userapi/storage/devices/interface.go @@ -36,5 +36,5 @@ type Database interface { RemoveDevice(ctx context.Context, deviceID, localpart string) error RemoveDevices(ctx context.Context, localpart string, devices []string) error // RemoveAllDevices deleted all devices for this user. Returns the devices deleted. - RemoveAllDevices(ctx context.Context, localpart string) (devices []api.Device, err error) + RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error) } diff --git a/userapi/storage/devices/postgres/devices_table.go b/userapi/storage/devices/postgres/devices_table.go index 282466f8d..c06af7549 100644 --- a/userapi/storage/devices/postgres/devices_table.go +++ b/userapi/storage/devices/postgres/devices_table.go @@ -70,7 +70,7 @@ const selectDeviceByIDSQL = "" + "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" const selectDevicesByLocalpartSQL = "" + - "SELECT device_id, display_name FROM device_devices WHERE localpart = $1" + "SELECT device_id, display_name FROM device_devices WHERE localpart = $1 AND device_id != $2" const updateDeviceNameSQL = "" + "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" @@ -79,7 +79,7 @@ const deleteDeviceSQL = "" + "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2" const deleteDevicesByLocalpartSQL = "" + - "DELETE FROM device_devices WHERE localpart = $1" + "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2" const deleteDevicesSQL = "" + "DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)" @@ -179,10 +179,10 @@ func (s *devicesStatements) deleteDevices( // deleteDevicesByLocalpart removes all devices for the // given user localpart. func (s *devicesStatements) deleteDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart string, + ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) - _, err := stmt.ExecContext(ctx, localpart) + _, err := stmt.ExecContext(ctx, localpart, exceptDeviceID) return err } @@ -251,10 +251,10 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s } func (s *devicesStatements) selectDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart string, + ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ) ([]api.Device, error) { devices := []api.Device{} - rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart) + rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID) if err != nil { return devices, err diff --git a/userapi/storage/devices/postgres/storage.go b/userapi/storage/devices/postgres/storage.go index 04dae9864..c5bd5b6cf 100644 --- a/userapi/storage/devices/postgres/storage.go +++ b/userapi/storage/devices/postgres/storage.go @@ -68,7 +68,7 @@ func (d *Database) GetDeviceByID( func (d *Database) GetDevicesByLocalpart( ctx context.Context, localpart string, ) ([]api.Device, error) { - return d.devices.selectDevicesByLocalpart(ctx, nil, localpart) + return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "") } func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { @@ -175,14 +175,14 @@ func (d *Database) RemoveDevices( // database matching the given user ID localpart. // If something went wrong during the deletion, it will return the SQL error. func (d *Database) RemoveAllDevices( - ctx context.Context, localpart string, + ctx context.Context, localpart, exceptDeviceID string, ) (devices []api.Device, err error) { err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart) + devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) if err != nil { return err } - if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { + if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { return err } return nil diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go index ecf43524a..c75e19825 100644 --- a/userapi/storage/devices/sqlite3/devices_table.go +++ b/userapi/storage/devices/sqlite3/devices_table.go @@ -59,7 +59,7 @@ const selectDeviceByIDSQL = "" + "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" const selectDevicesByLocalpartSQL = "" + - "SELECT device_id, display_name FROM device_devices WHERE localpart = $1" + "SELECT device_id, display_name FROM device_devices WHERE localpart = $1 AND device_id != $2" const updateDeviceNameSQL = "" + "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" @@ -68,7 +68,7 @@ const deleteDeviceSQL = "" + "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2" const deleteDevicesByLocalpartSQL = "" + - "DELETE FROM device_devices WHERE localpart = $1" + "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2" const deleteDevicesSQL = "" + "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)" @@ -182,10 +182,10 @@ func (s *devicesStatements) deleteDevices( } func (s *devicesStatements) deleteDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart string, + ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) - _, err := stmt.ExecContext(ctx, localpart) + _, err := stmt.ExecContext(ctx, localpart, exceptDeviceID) return err } @@ -231,10 +231,10 @@ func (s *devicesStatements) selectDeviceByID( } func (s *devicesStatements) selectDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart string, + ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ) ([]api.Device, error) { devices := []api.Device{} - rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart) + rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID) if err != nil { return devices, err diff --git a/userapi/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go index f775fb664..7c6645dd6 100644 --- a/userapi/storage/devices/sqlite3/storage.go +++ b/userapi/storage/devices/sqlite3/storage.go @@ -72,7 +72,7 @@ func (d *Database) GetDeviceByID( func (d *Database) GetDevicesByLocalpart( ctx context.Context, localpart string, ) ([]api.Device, error) { - return d.devices.selectDevicesByLocalpart(ctx, nil, localpart) + return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "") } func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { @@ -179,14 +179,14 @@ func (d *Database) RemoveDevices( // database matching the given user ID localpart. // If something went wrong during the deletion, it will return the SQL error. func (d *Database) RemoveAllDevices( - ctx context.Context, localpart string, + ctx context.Context, localpart, exceptDeviceID string, ) (devices []api.Device, err error) { err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart) + devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) if err != nil { return err } - if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { + if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { return err } return nil