diff --git a/build/gobind/monolith.go b/build/gobind/monolith.go index 2ea09f636..59535c7b9 100644 --- a/build/gobind/monolith.go +++ b/build/gobind/monolith.go @@ -111,13 +111,12 @@ func (m *DendriteMonolith) Start() { defer base.Close() // nolint: errcheck accountDB := base.CreateAccountsDB() - deviceDB := base.CreateDeviceDB() federation := ygg.CreateFederationClient(base) serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() keyAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation, base.KafkaProducer) - userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, cfg.Derived.ApplicationServices, keyAPI) + userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI) keyAPI.SetUserAPI(userAPI) rsAPI := roomserver.NewInternalAPI( @@ -153,7 +152,6 @@ func (m *DendriteMonolith) Start() { monolith := setup.Monolith{ Config: base.Cfg, AccountDB: accountDB, - DeviceDB: deviceDB, Client: ygg.CreateClient(base), FedClient: federation, KeyRing: keyRing, diff --git a/build/scripts/Complement.Dockerfile b/build/scripts/Complement.Dockerfile index 6489c22e4..32c5234b1 100644 --- a/build/scripts/Complement.Dockerfile +++ b/build/scripts/Complement.Dockerfile @@ -1,5 +1,5 @@ FROM golang:1.13-stretch as build -RUN apt-get update && apt-get install sqlite3 +RUN apt-get update && apt-get install -y sqlite3 WORKDIR /build # Utilise Docker caching when downloading dependencies, this stops us needlessly diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index 1a4307c18..fe6789fcb 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -30,7 +30,6 @@ import ( roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" ) @@ -39,7 +38,6 @@ func AddPublicRoutes( router *mux.Router, cfg *config.ClientAPI, producer sarama.SyncProducer, - deviceDB devices.Database, accountsDB accounts.Database, federation *gomatrixserverlib.FederationClient, rsAPI roomserverAPI.RoomserverInternalAPI, @@ -59,7 +57,7 @@ func AddPublicRoutes( routing.Setup( router, cfg, eduInputAPI, rsAPI, asAPI, - accountsDB, deviceDB, userAPI, federation, + accountsDB, userAPI, federation, syncProducer, transactionsCache, fsAPI, stateAPI, keyAPI, extRoomsProvider, ) } diff --git a/clientapi/routing/device.go b/clientapi/routing/device.go index d0b3bdbe5..56886d57f 100644 --- a/clientapi/routing/device.go +++ b/clientapi/routing/device.go @@ -15,7 +15,6 @@ package routing import ( - "database/sql" "encoding/json" "io/ioutil" "net/http" @@ -23,7 +22,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/devices" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -50,57 +49,56 @@ type devicesDeleteJSON struct { // GetDeviceByID handles /devices/{deviceID} func GetDeviceByID( - req *http.Request, deviceDB devices.Database, device *api.Device, + req *http.Request, userAPI userapi.UserInternalAPI, device *api.Device, deviceID string, ) util.JSONResponse { - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + var queryRes userapi.QueryDevicesResponse + err := userAPI.QueryDevices(req.Context(), &userapi.QueryDevicesRequest{ + UserID: device.UserID, + }, &queryRes) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") + util.GetLogger(req.Context()).WithError(err).Error("QueryDevices failed") return jsonerror.InternalServerError() } - - ctx := req.Context() - dev, err := deviceDB.GetDeviceByID(ctx, localpart, deviceID) - if err == sql.ErrNoRows { + var targetDevice *userapi.Device + for _, device := range queryRes.Devices { + if device.ID == deviceID { + targetDevice = &device + break + } + } + if targetDevice == nil { return util.JSONResponse{ Code: http.StatusNotFound, JSON: jsonerror.NotFound("Unknown device"), } - } else if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("deviceDB.GetDeviceByID failed") - return jsonerror.InternalServerError() } return util.JSONResponse{ Code: http.StatusOK, JSON: deviceJSON{ - DeviceID: dev.ID, - DisplayName: dev.DisplayName, + DeviceID: targetDevice.ID, + DisplayName: targetDevice.DisplayName, }, } } // GetDevicesByLocalpart handles /devices func GetDevicesByLocalpart( - req *http.Request, deviceDB devices.Database, device *api.Device, + req *http.Request, userAPI userapi.UserInternalAPI, device *api.Device, ) util.JSONResponse { - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + var queryRes userapi.QueryDevicesResponse + err := userAPI.QueryDevices(req.Context(), &userapi.QueryDevicesRequest{ + UserID: device.UserID, + }, &queryRes) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() - } - - ctx := req.Context() - deviceList, err := deviceDB.GetDevicesByLocalpart(ctx, localpart) - - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("deviceDB.GetDevicesByLocalpart failed") + util.GetLogger(req.Context()).WithError(err).Error("QueryDevices failed") return jsonerror.InternalServerError() } res := devicesJSON{} - for _, dev := range deviceList { + for _, dev := range queryRes.Devices { res.Devices = append(res.Devices, deviceJSON{ DeviceID: dev.ID, DisplayName: dev.DisplayName, diff --git a/clientapi/routing/joinroom.go b/clientapi/routing/joinroom.go index 3c7421bb2..c10113574 100644 --- a/clientapi/routing/joinroom.go +++ b/clientapi/routing/joinroom.go @@ -41,6 +41,17 @@ func JoinRoomByIDOrAlias( } joinRes := roomserverAPI.PerformJoinResponse{} + // Check to see if any ?server_name= query parameters were + // given in the request. + if serverNames, ok := req.URL.Query()["server_name"]; ok { + for _, serverName := range serverNames { + joinReq.ServerNames = append( + joinReq.ServerNames, + gomatrixserverlib.ServerName(serverName), + ) + } + } + // If content was provided in the request then include that // in the request. It'll get used as a part of the membership // event content. diff --git a/clientapi/routing/logout.go b/clientapi/routing/logout.go index 3ce47169e..cb300e9ff 100644 --- a/clientapi/routing/logout.go +++ b/clientapi/routing/logout.go @@ -19,23 +19,21 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/devices" - "github.com/matrix-org/gomatrixserverlib" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/util" ) // Logout handles POST /logout func Logout( - req *http.Request, deviceDB devices.Database, device *api.Device, + req *http.Request, userAPI userapi.UserInternalAPI, device *api.Device, ) util.JSONResponse { - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + var performRes userapi.PerformDeviceDeletionResponse + err := userAPI.PerformDeviceDeletion(req.Context(), &userapi.PerformDeviceDeletionRequest{ + UserID: device.UserID, + DeviceIDs: []string{device.ID}, + }, &performRes) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() - } - - if err := deviceDB.RemoveDevice(req.Context(), device.ID, localpart); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("deviceDB.RemoveDevice failed") + util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed") return jsonerror.InternalServerError() } @@ -47,16 +45,15 @@ func Logout( // LogoutAll handles POST /logout/all func LogoutAll( - req *http.Request, deviceDB devices.Database, device *api.Device, + req *http.Request, userAPI userapi.UserInternalAPI, device *api.Device, ) util.JSONResponse { - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + var performRes userapi.PerformDeviceDeletionResponse + err := userAPI.PerformDeviceDeletion(req.Context(), &userapi.PerformDeviceDeletionRequest{ + UserID: device.UserID, + DeviceIDs: nil, + }, &performRes) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() - } - - if err := deviceDB.RemoveAllDevices(req.Context(), localpart); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("deviceDB.RemoveAllDevices failed") + util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 69c76cf93..3d4b99935 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -35,7 +35,6 @@ import ( roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -52,7 +51,6 @@ func Setup( rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, accountDB accounts.Database, - deviceDB devices.Database, userAPI userapi.UserInternalAPI, federation *gomatrixserverlib.FederationClient, syncProducer *producers.SyncAPIProducer, @@ -333,13 +331,13 @@ func Setup( r0mux.Handle("/logout", httputil.MakeAuthAPI("logout", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return Logout(req, deviceDB, device) + return Logout(req, userAPI, device) }), ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/logout/all", httputil.MakeAuthAPI("logout", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return LogoutAll(req, deviceDB, device) + return LogoutAll(req, userAPI, device) }), ).Methods(http.MethodPost, http.MethodOptions) @@ -643,7 +641,7 @@ func Setup( r0mux.Handle("/devices", httputil.MakeAuthAPI("get_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return GetDevicesByLocalpart(req, deviceDB, device) + return GetDevicesByLocalpart(req, userAPI, device) }), ).Methods(http.MethodGet, http.MethodOptions) @@ -653,7 +651,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return GetDeviceByID(req, deviceDB, device, vars["deviceID"]) + return GetDeviceByID(req, userAPI, device, vars["deviceID"]) }), ).Methods(http.MethodGet, http.MethodOptions) diff --git a/cmd/dendrite-client-api-server/main.go b/cmd/dendrite-client-api-server/main.go index 4961b34e0..35dbb7745 100644 --- a/cmd/dendrite-client-api-server/main.go +++ b/cmd/dendrite-client-api-server/main.go @@ -27,7 +27,6 @@ func main() { defer base.Close() // nolint: errcheck accountDB := base.CreateAccountsDB() - deviceDB := base.CreateDeviceDB() federation := base.CreateFederationClient() asQuery := base.AppserviceHTTPClient() @@ -39,7 +38,7 @@ func main() { keyAPI := base.KeyServerHTTPClient() clientapi.AddPublicRoutes( - base.PublicClientAPIMux, &base.Cfg.ClientAPI, base.KafkaProducer, deviceDB, accountDB, federation, + base.PublicClientAPIMux, &base.Cfg.ClientAPI, base.KafkaProducer, accountDB, federation, rsAPI, eduInputAPI, asQuery, stateAPI, transactions.New(), fsAPI, userAPI, keyAPI, nil, ) diff --git a/cmd/dendrite-demo-libp2p/main.go b/cmd/dendrite-demo-libp2p/main.go index 4934fe5f0..e2d23e895 100644 --- a/cmd/dendrite-demo-libp2p/main.go +++ b/cmd/dendrite-demo-libp2p/main.go @@ -140,10 +140,9 @@ func main() { defer base.Base.Close() // nolint: errcheck accountDB := base.Base.CreateAccountsDB() - deviceDB := base.Base.CreateDeviceDB() federation := createFederationClient(base) keyAPI := keyserver.NewInternalAPI(&base.Base.Cfg.KeyServer, federation, base.Base.KafkaProducer) - userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, nil, keyAPI) + userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI) keyAPI.SetUserAPI(userAPI) serverKeyAPI := serverkeyapi.NewInternalAPI( @@ -175,7 +174,6 @@ func main() { monolith := setup.Monolith{ Config: base.Base.Cfg, AccountDB: accountDB, - DeviceDB: deviceDB, Client: createClient(base), FedClient: federation, KeyRing: keyRing, diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index e8745b3ec..26999ebed 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -94,14 +94,13 @@ func main() { defer base.Close() // nolint: errcheck accountDB := base.CreateAccountsDB() - deviceDB := base.CreateDeviceDB() federation := ygg.CreateFederationClient(base) serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() keyAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation, base.KafkaProducer) - userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, nil, keyAPI) + userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI) keyAPI.SetUserAPI(userAPI) rsComponent := roomserver.NewInternalAPI( @@ -136,7 +135,6 @@ func main() { monolith := setup.Monolith{ Config: base.Cfg, AccountDB: accountDB, - DeviceDB: deviceDB, Client: ygg.CreateClient(base), FedClient: federation, KeyRing: keyRing, diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go index e2d2de48c..815117463 100644 --- a/cmd/dendrite-monolith-server/main.go +++ b/cmd/dendrite-monolith-server/main.go @@ -69,7 +69,6 @@ func main() { defer base.Close() // nolint: errcheck accountDB := base.CreateAccountsDB() - deviceDB := base.CreateDeviceDB() federation := base.CreateFederationClient() serverKeyAPI := serverkeyapi.NewInternalAPI( @@ -110,7 +109,7 @@ func main() { rsImpl.SetFederationSenderAPI(fsAPI) keyAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, fsAPI, base.KafkaProducer) - userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, cfg.Derived.ApplicationServices, keyAPI) + userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI) keyAPI.SetUserAPI(userAPI) eduInputAPI := eduserver.NewInternalAPI( @@ -130,7 +129,6 @@ func main() { monolith := setup.Monolith{ Config: base.Cfg, AccountDB: accountDB, - DeviceDB: deviceDB, Client: gomatrixserverlib.NewClient(cfg.FederationSender.DisableTLSValidation), FedClient: federation, KeyRing: keyRing, diff --git a/cmd/dendrite-user-api-server/main.go b/cmd/dendrite-user-api-server/main.go index c21525e60..c8e2e2a37 100644 --- a/cmd/dendrite-user-api-server/main.go +++ b/cmd/dendrite-user-api-server/main.go @@ -25,9 +25,8 @@ func main() { defer base.Close() // nolint: errcheck accountDB := base.CreateAccountsDB() - deviceDB := base.CreateDeviceDB() - userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, cfg.Derived.ApplicationServices, base.KeyServerHTTPClient()) + userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, base.KeyServerHTTPClient()) userapi.AddInternalRoutes(base.InternalAPIMux, userAPI) diff --git a/cmd/dendritejs/main.go b/cmd/dendritejs/main.go index ceb252d82..c95eb3fce 100644 --- a/cmd/dendritejs/main.go +++ b/cmd/dendritejs/main.go @@ -191,10 +191,9 @@ func main() { defer base.Close() // nolint: errcheck accountDB := base.CreateAccountsDB() - deviceDB := base.CreateDeviceDB() federation := createFederationClient(cfg, node) keyAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation, base.KafkaProducer) - userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, nil, keyAPI) + userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI) keyAPI.SetUserAPI(userAPI) fetcher := &libp2pKeyFetcher{} @@ -218,7 +217,6 @@ func main() { monolith := setup.Monolith{ Config: base.Cfg, AccountDB: accountDB, - DeviceDB: deviceDB, Client: createClient(node), FedClient: federation, KeyRing: &keyRing, diff --git a/dendrite-config.yaml b/dendrite-config.yaml index e48035b58..23f142a83 100644 --- a/dendrite-config.yaml +++ b/dendrite-config.yaml @@ -302,6 +302,8 @@ user_api: conn_max_lifetime: -1 # Configuration for Opentracing. +# See https://github.com/matrix-org/dendrite/tree/master/docs/tracing for information on +# how this works and how to set it up. tracing: enabled: false jaeger: diff --git a/docs/tracing/jaeger.png b/docs/tracing/jaeger.png new file mode 100644 index 000000000..8b1e61feb Binary files /dev/null and b/docs/tracing/jaeger.png differ diff --git a/docs/opentracing.md b/docs/tracing/opentracing.md similarity index 100% rename from docs/opentracing.md rename to docs/tracing/opentracing.md diff --git a/docs/tracing/setup.md b/docs/tracing/setup.md new file mode 100644 index 000000000..2cab4d1ef --- /dev/null +++ b/docs/tracing/setup.md @@ -0,0 +1,49 @@ +## OpenTracing Setup + +![Trace when sending an event into a room](/docs/tracing/jaeger.png) + +Dendrite uses [Jaeger](https://www.jaegertracing.io/) for tracing between microservices. +Tracing shows the nesting of logical spans which provides visibility on how the microservices interact. +This document explains how to set up Jaeger locally on a single machine. + +### Set up the Jaeger backend + +The [easiest way](https://www.jaegertracing.io/docs/1.18/getting-started/) is to use the all-in-one Docker image: +``` +$ docker run -d --name jaeger \ + -e COLLECTOR_ZIPKIN_HTTP_PORT=9411 \ + -p 5775:5775/udp \ + -p 6831:6831/udp \ + -p 6832:6832/udp \ + -p 5778:5778 \ + -p 16686:16686 \ + -p 14268:14268 \ + -p 14250:14250 \ + -p 9411:9411 \ + jaegertracing/all-in-one:1.18 +``` + +### Configuring Dendrite to talk to Jaeger + +Modify your config to look like: (this will send every single span to Jaeger which will be slow on large instances, but for local testing it's fine) +``` +tracing: + enabled: true + jaeger: + serviceName: "dendrite" + disabled: false + rpc_metrics: true + tags: [] + sampler: + type: const + param: 1 +``` + +then run the monolith server with `--api true` to use polylith components which do tracing spans: +``` +$ ./dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml --api true +``` + +### Checking traces + +Visit http://localhost:16686 to see traces under `DendriteMonolith`. diff --git a/federationsender/storage/shared/storage_edus.go b/federationsender/storage/shared/storage_edus.go index 75a6dd51f..529b46aa9 100644 --- a/federationsender/storage/shared/storage_edus.go +++ b/federationsender/storage/shared/storage_edus.go @@ -21,7 +21,6 @@ import ( "errors" "fmt" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" ) @@ -33,7 +32,7 @@ func (d *Database) AssociateEDUWithDestination( serverName gomatrixserverlib.ServerName, receipt *Receipt, ) error { - return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { for _, nid := range receipt.nids { if err := d.FederationSenderQueueEDUs.InsertQueueEDU( ctx, // context @@ -60,7 +59,7 @@ func (d *Database) GetNextTransactionEDUs( receipt *Receipt, err error, ) { - err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { nids, err := d.FederationSenderQueueEDUs.SelectQueueEDUs(ctx, txn, serverName, limit) if err != nil { return fmt.Errorf("SelectQueueEDUs: %w", err) @@ -99,7 +98,7 @@ func (d *Database) CleanEDUs( return errors.New("expected receipt") } - return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { if err := d.FederationSenderQueueEDUs.DeleteQueueEDUs(ctx, txn, serverName, receipt.nids); err != nil { return err } diff --git a/federationsender/storage/shared/storage_pdus.go b/federationsender/storage/shared/storage_pdus.go index 005889561..9ab0b094c 100644 --- a/federationsender/storage/shared/storage_pdus.go +++ b/federationsender/storage/shared/storage_pdus.go @@ -21,7 +21,6 @@ import ( "errors" "fmt" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" ) @@ -34,7 +33,7 @@ func (d *Database) AssociatePDUWithDestination( serverName gomatrixserverlib.ServerName, receipt *Receipt, ) error { - return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { for _, nid := range receipt.nids { if err := d.FederationSenderQueuePDUs.InsertQueuePDU( ctx, // context @@ -62,7 +61,12 @@ func (d *Database) GetNextTransactionPDUs( receipt *Receipt, err error, ) { - err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + // Strictly speaking this doesn't need to be using the writer + // since we are only performing selects, but since we don't have + // a guarantee of transactional isolation, it's actually useful + // to know in SQLite mode that nothing else is trying to modify + // the database. + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { transactionID, err = d.FederationSenderQueuePDUs.SelectQueuePDUNextTransactionID(ctx, txn, serverName) if err != nil { return fmt.Errorf("SelectQueuePDUNextTransactionID: %w", err) @@ -111,7 +115,7 @@ func (d *Database) CleanPDUs( return errors.New("expected receipt") } - return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { if err := d.FederationSenderQueuePDUs.DeleteQueuePDUs(ctx, txn, serverName, receipt.nids); err != nil { return err } diff --git a/internal/setup/base.go b/internal/setup/base.go index fc4083115..7bf06e748 100644 --- a/internal/setup/base.go +++ b/internal/setup/base.go @@ -32,7 +32,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/Shopify/sarama" "github.com/gorilla/mux" @@ -237,17 +236,6 @@ func (b *BaseDendrite) KeyServerHTTPClient() keyserverAPI.KeyInternalAPI { return f } -// CreateDeviceDB creates a new instance of the device database. Should only be -// called once per component. -func (b *BaseDendrite) CreateDeviceDB() devices.Database { - db, err := devices.NewDatabase(&b.Cfg.UserAPI.DeviceDatabase, b.Cfg.Global.ServerName) - if err != nil { - logrus.WithError(err).Panicf("failed to connect to devices db") - } - - return db -} - // CreateAccountsDB creates a new instance of the accounts database. Should only // be called once per component. func (b *BaseDendrite) CreateAccountsDB() accounts.Database { diff --git a/internal/setup/monolith.go b/internal/setup/monolith.go index 5e6c8fcfc..f79ebae45 100644 --- a/internal/setup/monolith.go +++ b/internal/setup/monolith.go @@ -33,7 +33,6 @@ import ( "github.com/matrix-org/dendrite/syncapi" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" ) @@ -41,7 +40,6 @@ import ( // all components of Dendrite, for use in monolith mode. type Monolith struct { Config *config.Dendrite - DeviceDB devices.Database AccountDB accounts.Database KeyRing *gomatrixserverlib.KeyRing Client *gomatrixserverlib.Client @@ -65,7 +63,7 @@ type Monolith struct { // AddAllPublicRoutes attaches all public paths to the given router func (m *Monolith) AddAllPublicRoutes(csMux, ssMux, keyMux, mediaMux *mux.Router) { clientapi.AddPublicRoutes( - csMux, &m.Config.ClientAPI, m.KafkaProducer, m.DeviceDB, m.AccountDB, + csMux, &m.Config.ClientAPI, m.KafkaProducer, m.AccountDB, m.FedClient, m.RoomserverAPI, m.EDUInternalAPI, m.AppserviceAPI, m.StateAPI, transactions.New(), m.FederationSenderAPI, m.UserAPI, m.KeyAPI, m.ExtPublicRoomsProvider, diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index 3fbf31f1e..4d1b1107c 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -341,8 +341,12 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam if err != nil { logger.WithError(err).WithField("user_id", userID).Error("failed to query device keys for user") fcerr, ok := err.(*fedsenderapi.FederationClientError) - if ok && fcerr.RetryAfter > 0 { - waitTime = fcerr.RetryAfter + if ok { + if fcerr.RetryAfter > 0 { + waitTime = fcerr.RetryAfter + } else if fcerr.Blacklisted { + waitTime = time.Hour * 8 + } } hasFailures = true continue diff --git a/mediaapi/fileutils/fileutils.go b/mediaapi/fileutils/fileutils.go index 39687b9d4..92ce64001 100644 --- a/mediaapi/fileutils/fileutils.go +++ b/mediaapi/fileutils/fileutils.go @@ -16,6 +16,7 @@ package fileutils import ( "bufio" + "context" "crypto/sha256" "encoding/base64" "fmt" @@ -27,6 +28,7 @@ import ( "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/mediaapi/types" + "github.com/matrix-org/util" log "github.com/sirupsen/logrus" ) @@ -104,18 +106,31 @@ func RemoveDir(dir types.Path, logger *log.Entry) { } } -// WriteTempFile writes to a new temporary file -func WriteTempFile(reqReader io.Reader, maxFileSizeBytes config.FileSizeBytes, absBasePath config.Path) (hash types.Base64Hash, size types.FileSizeBytes, path types.Path, err error) { +// WriteTempFile writes to a new temporary file. +// The file is deleted if there was an error while writing. +func WriteTempFile( + ctx context.Context, reqReader io.Reader, maxFileSizeBytes config.FileSizeBytes, absBasePath config.Path, +) (hash types.Base64Hash, size types.FileSizeBytes, path types.Path, err error) { size = -1 - + logger := util.GetLogger(ctx) tmpFileWriter, tmpFile, tmpDir, err := createTempFileWriter(absBasePath) if err != nil { return } - defer (func() { err = tmpFile.Close() })() + defer func() { + err2 := tmpFile.Close() + if err == nil { + err = err2 + } + }() - // The amount of data read is limited to maxFileSizeBytes. At this point, if there is more data it will be truncated. - limitedReader := io.LimitReader(reqReader, int64(maxFileSizeBytes)) + // If the max_file_size_bytes configuration option is set to a positive + // number then limit the upload to that size. Otherwise, just read the + // whole file. + limitedReader := reqReader + if maxFileSizeBytes > 0 { + limitedReader = io.LimitReader(reqReader, int64(maxFileSizeBytes)) + } // Hash the file data. The hash will be returned. The hash is useful as a // method of deduplicating files to save storage, as well as a way to conduct // integrity checks on the file data in the repository. @@ -123,11 +138,13 @@ func WriteTempFile(reqReader io.Reader, maxFileSizeBytes config.FileSizeBytes, a teeReader := io.TeeReader(limitedReader, hasher) bytesWritten, err := io.Copy(tmpFileWriter, teeReader) if err != nil && err != io.EOF { + RemoveDir(tmpDir, logger) return } err = tmpFileWriter.Flush() if err != nil { + RemoveDir(tmpDir, logger) return } diff --git a/mediaapi/routing/download.go b/mediaapi/routing/download.go index e61fa82b0..be0419048 100644 --- a/mediaapi/routing/download.go +++ b/mediaapi/routing/download.go @@ -728,12 +728,11 @@ func (r *downloadRequest) fetchRemoteFile( // method of deduplicating files to save storage, as well as a way to conduct // integrity checks on the file data in the repository. // Data is truncated to maxFileSizeBytes. Content-Length was reported as 0 < Content-Length <= maxFileSizeBytes so this is OK. - hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(resp.Body, maxFileSizeBytes, absBasePath) + hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(ctx, resp.Body, maxFileSizeBytes, absBasePath) if err != nil { r.Logger.WithError(err).WithFields(log.Fields{ "MaxFileSizeBytes": maxFileSizeBytes, }).Warn("Error while downloading file from remote server") - fileutils.RemoveDir(tmpDir, r.Logger) return "", false, errors.New("file could not be downloaded from remote server") } diff --git a/mediaapi/routing/routing.go b/mediaapi/routing/routing.go index 75f195cd3..4b6d2fd75 100644 --- a/mediaapi/routing/routing.go +++ b/mediaapi/routing/routing.go @@ -53,8 +53,8 @@ func Setup( uploadHandler := httputil.MakeAuthAPI( "upload", userAPI, - func(req *http.Request, _ *userapi.Device) util.JSONResponse { - return Upload(req, cfg, db, activeThumbnailGeneration) + func(req *http.Request, dev *userapi.Device) util.JSONResponse { + return Upload(req, cfg, dev, db, activeThumbnailGeneration) }, ) diff --git a/mediaapi/routing/upload.go b/mediaapi/routing/upload.go index b74d17323..1724ad255 100644 --- a/mediaapi/routing/upload.go +++ b/mediaapi/routing/upload.go @@ -31,6 +31,7 @@ import ( "github.com/matrix-org/dendrite/mediaapi/storage" "github.com/matrix-org/dendrite/mediaapi/thumbnailer" "github.com/matrix-org/dendrite/mediaapi/types" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" log "github.com/sirupsen/logrus" @@ -55,8 +56,8 @@ type uploadResponse struct { // This implementation supports a configurable maximum file size limit in bytes. If a user tries to upload more than this, they will receive an error that their upload is too large. // Uploaded files are processed piece-wise to avoid DoS attacks which would starve the server of memory. // TODO: We should time out requests if they have not received any data within a configured timeout period. -func Upload(req *http.Request, cfg *config.MediaAPI, db storage.Database, activeThumbnailGeneration *types.ActiveThumbnailGeneration) util.JSONResponse { - r, resErr := parseAndValidateRequest(req, cfg) +func Upload(req *http.Request, cfg *config.MediaAPI, dev *userapi.Device, db storage.Database, activeThumbnailGeneration *types.ActiveThumbnailGeneration) util.JSONResponse { + r, resErr := parseAndValidateRequest(req, cfg, dev) if resErr != nil { return *resErr } @@ -76,13 +77,14 @@ func Upload(req *http.Request, cfg *config.MediaAPI, db storage.Database, active // parseAndValidateRequest parses the incoming upload request to validate and extract // all the metadata about the media being uploaded. // Returns either an uploadRequest or an error formatted as a util.JSONResponse -func parseAndValidateRequest(req *http.Request, cfg *config.MediaAPI) (*uploadRequest, *util.JSONResponse) { +func parseAndValidateRequest(req *http.Request, cfg *config.MediaAPI, dev *userapi.Device) (*uploadRequest, *util.JSONResponse) { r := &uploadRequest{ MediaMetadata: &types.MediaMetadata{ Origin: cfg.Matrix.ServerName, FileSizeBytes: types.FileSizeBytes(req.ContentLength), ContentType: types.ContentType(req.Header.Get("Content-Type")), UploadName: types.Filename(url.PathEscape(req.FormValue("filename"))), + UserID: types.MatrixUserID(dev.UserID), }, Logger: util.GetLogger(req.Context()).WithField("Origin", cfg.Matrix.ServerName), } @@ -138,12 +140,18 @@ func (r *uploadRequest) doUpload( // method of deduplicating files to save storage, as well as a way to conduct // integrity checks on the file data in the repository. // Data is truncated to maxFileSizeBytes. Content-Length was reported as 0 < Content-Length <= maxFileSizeBytes so this is OK. - hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(reqReader, *cfg.MaxFileSizeBytes, cfg.AbsBasePath) + // + // TODO: This has a bad API shape where you either need to call: + // fileutils.RemoveDir(tmpDir, r.Logger) + // or call: + // r.storeFileAndMetadata(ctx, tmpDir, ...) + // before you return from doUpload else we will leak a temp file. We could make this nicer with a `WithTransaction` style of + // nested function to guarantee either storage or cleanup. + hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(ctx, reqReader, *cfg.MaxFileSizeBytes, cfg.AbsBasePath) if err != nil { r.Logger.WithError(err).WithFields(log.Fields{ "MaxFileSizeBytes": *cfg.MaxFileSizeBytes, }).Warn("Error while transferring file") - fileutils.RemoveDir(tmpDir, r.Logger) return &util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.Unknown("Failed to upload"), @@ -157,11 +165,14 @@ func (r *uploadRequest) doUpload( ctx, hash, r.MediaMetadata.Origin, ) if err != nil { + fileutils.RemoveDir(tmpDir, r.Logger) r.Logger.WithError(err).Error("Error querying the database by hash.") resErr := jsonerror.InternalServerError() return &resErr } if existingMetadata != nil { + // The file already exists, delete the uploaded temporary file. + defer fileutils.RemoveDir(tmpDir, r.Logger) // The file already exists. Make a new media ID up for it. mediaID, merr := r.generateMediaID(ctx, db) if merr != nil { @@ -181,15 +192,13 @@ func (r *uploadRequest) doUpload( Base64Hash: hash, UserID: r.MediaMetadata.UserID, } - - // Clean up the uploaded temporary file. - fileutils.RemoveDir(tmpDir, r.Logger) } else { // The file doesn't exist. Update the request metadata. r.MediaMetadata.FileSizeBytes = bytesWritten r.MediaMetadata.Base64Hash = hash r.MediaMetadata.MediaID, err = r.generateMediaID(ctx, db) if err != nil { + fileutils.RemoveDir(tmpDir, r.Logger) r.Logger.WithError(err).Error("Failed to generate media ID for new upload") resErr := jsonerror.InternalServerError() return &resErr diff --git a/roomserver/internal/perform_join.go b/roomserver/internal/perform_join.go index b92a6663b..3b9b1b3ca 100644 --- a/roomserver/internal/perform_join.go +++ b/roomserver/internal/perform_join.go @@ -161,8 +161,9 @@ func (r *RoomserverInternalAPI) performJoinRoomByID( // where we might think we know about a room in the following // section but don't know the latest state as all of our users // have left. + serverInRoom, _ := r.isServerCurrentlyInRoom(ctx, r.ServerName, req.RoomIDOrAlias) isInvitePending, inviteSender, _, err := r.isInvitePending(ctx, req.RoomIDOrAlias, req.UserID) - if err == nil && isInvitePending { + if err == nil && isInvitePending && !serverInRoom { // Check if there's an invite pending. _, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender) if ierr != nil { diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 6f7efefb7..257abf080 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -468,7 +468,7 @@ func (d *Database) addPDUDeltaToResponse( wantFullState bool, res *types.Response, ) (joinedRoomIDs []string, err error) { - txn, err := d.DB.BeginTx(ctx, &txReadOnlySnapshot) + txn, err := d.DB.BeginTx(context.TODO(), &txReadOnlySnapshot) // TODO: check mattn/go-sqlite3#764 if err != nil { return nil, err } @@ -594,20 +594,23 @@ func (d *Database) IncrementalSync( joinedRoomIDs, err = d.addPDUDeltaToResponse( ctx, device, r, numRecentEventsPerRoom, wantFullState, res, ) + if err != nil { + return nil, fmt.Errorf("d.addPDUDeltaToResponse: %w", err) + } } else { joinedRoomIDs, err = d.CurrentRoomState.SelectRoomIDsWithMembership( ctx, nil, device.UserID, gomatrixserverlib.Join, ) - } - if err != nil { - return nil, err + if err != nil { + return nil, fmt.Errorf("d.CurrentRoomState.SelectRoomIDsWithMembership: %w", err) + } } err = d.addEDUDeltaToResponse( fromPos, toPos, joinedRoomIDs, res, ) if err != nil { - return nil, err + return nil, fmt.Errorf("d.addEDUDeltaToResponse: %w", err) } return res, nil @@ -649,7 +652,7 @@ func (d *Database) getResponseWithPDUsForCompleteSync( // a consistent view of the database throughout. This includes extracting the sync position. // This does have the unfortunate side-effect that all the matrixy logic resides in this function, // but it's better to not hide the fact that this is being done in a transaction. - txn, err := d.DB.BeginTx(ctx, &txReadOnlySnapshot) + txn, err := d.DB.BeginTx(context.TODO(), &txReadOnlySnapshot) // TODO: check mattn/go-sqlite3#764 if err != nil { return } @@ -736,7 +739,7 @@ func (d *Database) CompleteSync( ctx, res, device.UserID, numRecentEventsPerRoom, ) if err != nil { - return nil, err + return nil, fmt.Errorf("d.getResponseWithPDUsForCompleteSync: %w", err) } // Use a zero value SyncPosition for fromPos so all EDU states are added. @@ -744,7 +747,7 @@ func (d *Database) CompleteSync( types.NewStreamToken(0, 0, nil), toPos, joinedRoomIDs, res, ) if err != nil { - return nil, err + return nil, fmt.Errorf("d.addEDUDeltaToResponse: %w", err) } return res, nil @@ -770,7 +773,7 @@ func (d *Database) addInvitesToResponse( ctx, txn, userID, r, ) if err != nil { - return err + return fmt.Errorf("d.Invites.SelectInviteEventsInRange: %w", err) } for roomID, inviteEvent := range invites { ir := types.NewInviteResponse(inviteEvent) diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 12c597bbe..357df240e 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -18,6 +18,7 @@ package sync import ( "context" + "fmt" "net/http" "time" @@ -204,31 +205,34 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea // See if we have any new tasks to do for the send-to-device messaging. events, updates, deletions, err := rp.db.SendToDeviceUpdatesForSync(req.ctx, req.device.UserID, req.device.ID, since) if err != nil { - return nil, err + return nil, fmt.Errorf("rp.db.SendToDeviceUpdatesForSync: %w", err) } // TODO: handle ignored users if req.since == nil { res, err = rp.db.CompleteSync(req.ctx, res, req.device, req.limit) + if err != nil { + return res, fmt.Errorf("rp.db.CompleteSync: %w", err) + } } else { res, err = rp.db.IncrementalSync(req.ctx, res, req.device, *req.since, latestPos, req.limit, req.wantFullState) - } - if err != nil { - return res, err + if err != nil { + return res, fmt.Errorf("rp.db.IncrementalSync: %w", err) + } } accountDataFilter := gomatrixserverlib.DefaultEventFilter() // TODO: use filter provided in req instead res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition(), &accountDataFilter) if err != nil { - return res, err + return res, fmt.Errorf("rp.appendAccountData: %w", err) } res, err = rp.appendDeviceLists(res, req.device.UserID, since, latestPos) if err != nil { - return res, err + return res, fmt.Errorf("rp.appendDeviceLists: %w", err) } err = internal.DeviceOTKCounts(req.ctx, rp.keyAPI, req.device.UserID, req.device.ID, res) if err != nil { - return res, err + return res, fmt.Errorf("internal.DeviceOTKCounts: %w", err) } // Before we return the sync response, make sure that we take action on @@ -238,7 +242,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea // Handle the updates and deletions in the database. err = rp.db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, since) if err != nil { - return res, err + return res, fmt.Errorf("rp.db.CleanSendToDeviceUpdates: %w", err) } } if len(events) > 0 { @@ -263,7 +267,7 @@ func (rp *RequestPool) appendDeviceLists( ) (*types.Response, error) { _, err := internal.DeviceListCatchup(context.Background(), rp.keyAPI, rp.stateAPI, userID, data, since, to) if err != nil { - return nil, err + return nil, fmt.Errorf("internal.DeviceListCatchup: %w", err) } return data, nil @@ -329,7 +333,7 @@ func (rp *RequestPool) appendAccountData( req.ctx, userID, r, accountDataFilter, ) if err != nil { - return nil, err + return nil, fmt.Errorf("rp.db.GetAccountDataInRange: %w", err) } if len(dataTypes) == 0 { diff --git a/userapi/api/api.go b/userapi/api/api.go index 84338dbf2..e6d05c335 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -61,7 +61,7 @@ type PerformDeviceUpdateResponse struct { type PerformDeviceDeletionRequest struct { UserID string - // The devices to delete + // The devices to delete. An empty slice means delete all devices. DeviceIDs []string } @@ -192,8 +192,7 @@ type Device struct { // The unique ID of the session identified by the access token. // Can be used as a secure substitution in places where data needs to be // associated with access tokens. - SessionID int64 - // TODO: display name, last used timestamp, keys, etc + SessionID int64 DisplayName string } diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 05cecc1bc..b97f148e0 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -123,12 +123,21 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe if domain != a.ServerName { return fmt.Errorf("cannot PerformDeviceDeletion of remote users: got %s want %s", domain, a.ServerName) } - err = a.DeviceDB.RemoveDevices(ctx, local, req.DeviceIDs) + deletedDeviceIDs := req.DeviceIDs + if len(req.DeviceIDs) == 0 { + var devices []api.Device + devices, err = a.DeviceDB.RemoveAllDevices(ctx, local) + for _, d := range devices { + deletedDeviceIDs = append(deletedDeviceIDs, d.ID) + } + } else { + err = a.DeviceDB.RemoveDevices(ctx, local, req.DeviceIDs) + } if err != nil { return err } // create empty device keys and upload them to delete what was once there and trigger device list changes - return a.deviceListUpdate(req.UserID, req.DeviceIDs) + return a.deviceListUpdate(req.UserID, deletedDeviceIDs) } func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) error { diff --git a/userapi/storage/devices/interface.go b/userapi/storage/devices/interface.go index 3c9ec934a..9b4261c9d 100644 --- a/userapi/storage/devices/interface.go +++ b/userapi/storage/devices/interface.go @@ -35,5 +35,6 @@ type Database interface { UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error RemoveDevice(ctx context.Context, deviceID, localpart string) error RemoveDevices(ctx context.Context, localpart string, devices []string) error - RemoveAllDevices(ctx context.Context, localpart string) error + // RemoveAllDevices deleted all devices for this user. Returns the devices deleted. + RemoveAllDevices(ctx context.Context, localpart string) (devices []api.Device, err error) } diff --git a/userapi/storage/devices/postgres/devices_table.go b/userapi/storage/devices/postgres/devices_table.go index 03bf7c722..282466f8d 100644 --- a/userapi/storage/devices/postgres/devices_table.go +++ b/userapi/storage/devices/postgres/devices_table.go @@ -251,11 +251,10 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s } func (s *devicesStatements) selectDevicesByLocalpart( - ctx context.Context, localpart string, + ctx context.Context, txn *sql.Tx, localpart string, ) ([]api.Device, error) { devices := []api.Device{} - - rows, err := s.selectDevicesByLocalpartStmt.QueryContext(ctx, localpart) + rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart) if err != nil { return devices, err diff --git a/userapi/storage/devices/postgres/storage.go b/userapi/storage/devices/postgres/storage.go index 4a7c7f975..04dae9864 100644 --- a/userapi/storage/devices/postgres/storage.go +++ b/userapi/storage/devices/postgres/storage.go @@ -68,7 +68,7 @@ func (d *Database) GetDeviceByID( func (d *Database) GetDevicesByLocalpart( ctx context.Context, localpart string, ) ([]api.Device, error) { - return d.devices.selectDevicesByLocalpart(ctx, localpart) + return d.devices.selectDevicesByLocalpart(ctx, nil, localpart) } func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { @@ -176,11 +176,16 @@ func (d *Database) RemoveDevices( // If something went wrong during the deletion, it will return the SQL error. func (d *Database) RemoveAllDevices( ctx context.Context, localpart string, -) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { +) (devices []api.Device, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart) + if err != nil { + return err + } if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { return err } return nil }) + return } diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go index c93e8b772..ecf43524a 100644 --- a/userapi/storage/devices/sqlite3/devices_table.go +++ b/userapi/storage/devices/sqlite3/devices_table.go @@ -231,11 +231,10 @@ func (s *devicesStatements) selectDeviceByID( } func (s *devicesStatements) selectDevicesByLocalpart( - ctx context.Context, localpart string, + ctx context.Context, txn *sql.Tx, localpart string, ) ([]api.Device, error) { devices := []api.Device{} - - rows, err := s.selectDevicesByLocalpartStmt.QueryContext(ctx, localpart) + rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart) if err != nil { return devices, err diff --git a/userapi/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go index 4f426c6ed..f775fb664 100644 --- a/userapi/storage/devices/sqlite3/storage.go +++ b/userapi/storage/devices/sqlite3/storage.go @@ -72,7 +72,7 @@ func (d *Database) GetDeviceByID( func (d *Database) GetDevicesByLocalpart( ctx context.Context, localpart string, ) ([]api.Device, error) { - return d.devices.selectDevicesByLocalpart(ctx, localpart) + return d.devices.selectDevicesByLocalpart(ctx, nil, localpart) } func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { @@ -180,11 +180,16 @@ func (d *Database) RemoveDevices( // If something went wrong during the deletion, it will return the SQL error. func (d *Database) RemoveAllDevices( ctx context.Context, localpart string, -) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { +) (devices []api.Device, err error) { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart) + if err != nil { + return err + } if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { return err } return nil }) + return } diff --git a/userapi/userapi.go b/userapi/userapi.go index c4ab90bac..132491429 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -23,7 +23,7 @@ import ( "github.com/matrix-org/dendrite/userapi/inthttp" "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/dendrite/userapi/storage/devices" - "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" ) // AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions @@ -34,13 +34,19 @@ func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) { // NewInternalAPI returns a concerete implementation of the internal API. Callers // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. -func NewInternalAPI(accountDB accounts.Database, deviceDB devices.Database, - serverName gomatrixserverlib.ServerName, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI) api.UserInternalAPI { +func NewInternalAPI( + accountDB accounts.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI, +) api.UserInternalAPI { + + deviceDB, err := devices.NewDatabase(&cfg.DeviceDatabase, cfg.Matrix.ServerName) + if err != nil { + logrus.WithError(err).Panicf("failed to connect to device db") + } return &internal.UserInternalAPI{ AccountDB: accountDB, DeviceDB: deviceDB, - ServerName: serverName, + ServerName: cfg.Matrix.ServerName, AppServices: appServices, KeyAPI: keyAPI, } diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 548148f27..3fc97d06a 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -15,7 +15,6 @@ import ( "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/inthttp" "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" ) @@ -23,27 +22,31 @@ const ( serverName = gomatrixserverlib.ServerName("example.com") ) -func MustMakeInternalAPI(t *testing.T) (api.UserInternalAPI, accounts.Database, devices.Database) { +func MustMakeInternalAPI(t *testing.T) (api.UserInternalAPI, accounts.Database) { accountDB, err := accounts.NewDatabase(&config.DatabaseOptions{ ConnectionString: "file::memory:", }, serverName) if err != nil { t.Fatalf("failed to create account DB: %s", err) } - deviceDB, err := devices.NewDatabase(&config.DatabaseOptions{ - ConnectionString: "file::memory:", - }, serverName) - if err != nil { - t.Fatalf("failed to create device DB: %s", err) + cfg := &config.UserAPI{ + DeviceDatabase: config.DatabaseOptions{ + ConnectionString: "file::memory:", + MaxOpenConnections: 1, + MaxIdleConnections: 1, + }, + Matrix: &config.Global{ + ServerName: serverName, + }, } - return userapi.NewInternalAPI(accountDB, deviceDB, serverName, nil, nil), accountDB, deviceDB + return userapi.NewInternalAPI(accountDB, cfg, nil, nil), accountDB } func TestQueryProfile(t *testing.T) { aliceAvatarURL := "mxc://example.com/alice" aliceDisplayName := "Alice" - userAPI, accountDB, _ := MustMakeInternalAPI(t) + userAPI, accountDB := MustMakeInternalAPI(t) _, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "") if err != nil { t.Fatalf("failed to make account: %s", err)