diff --git a/appservice/appservice_test.go b/appservice/appservice_test.go index 679b132df..6c8a07b5c 100644 --- a/appservice/appservice_test.go +++ b/appservice/appservice_test.go @@ -134,7 +134,7 @@ func TestAppserviceInternalAPI(t *testing.T) { caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) // Create required internal APIs natsInstance := jetstream.NATSInstance{} - cm := sqlutil.NewConnectionManager(cfg.Global.DatabaseOptions) + cm := sqlutil.NewConnectionManager(ctx, cfg.Global.DatabaseOptions) rsAPI := roomserver.NewInternalAPI(ctx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) usrAPI := userapi.NewInternalAPI(ctx, cfg, cm, &natsInstance, rsAPI, nil) asAPI := appservice.NewInternalAPI(ctx, cfg, &natsInstance, usrAPI, rsAPI) diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go index 84cdba7ce..4d2bf67b2 100644 --- a/clientapi/admin_test.go +++ b/clientapi/admin_test.go @@ -43,7 +43,7 @@ func TestAdminResetPassword(t *testing.T) { }) routers := httputil.NewRouters() - cm := sqlutil.NewConnectionManager(cfg.Global.DatabaseOptions) + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) // Needed for changing the password/login @@ -161,7 +161,7 @@ func TestPurgeRoom(t *testing.T) { defer close() routers := httputil.NewRouters() - cm := sqlutil.NewConnectionManager(cfg.Global.DatabaseOptions) + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) diff --git a/clientapi/routing/joinroom_test.go b/clientapi/routing/joinroom_test.go index b98b8558f..fd58ff5d5 100644 --- a/clientapi/routing/joinroom_test.go +++ b/clientapi/routing/joinroom_test.go @@ -30,7 +30,7 @@ func TestJoinRoomByIDOrAlias(t *testing.T) { cfg, processCtx, close := testrig.CreateConfig(t, dbType) defer close() - cm := sqlutil.NewConnectionManager(cfg.Global.DatabaseOptions) + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) natsInstance := jetstream.NATSInstance{} rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) diff --git a/clientapi/routing/login_test.go b/clientapi/routing/login_test.go index 3c0f2afe4..b27730767 100644 --- a/clientapi/routing/login_test.go +++ b/clientapi/routing/login_test.go @@ -42,7 +42,7 @@ func TestLogin(t *testing.T) { SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh1"}, }) - cm := sqlutil.NewConnectionManager(cfg.Global.DatabaseOptions) + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) routers := httputil.NewRouters() caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) diff --git a/clientapi/routing/register_test.go b/clientapi/routing/register_test.go index c866f0dfb..46cd8b2b9 100644 --- a/clientapi/routing/register_test.go +++ b/clientapi/routing/register_test.go @@ -413,7 +413,7 @@ func Test_register(t *testing.T) { caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) natsInstance := jetstream.NATSInstance{} - cm := sqlutil.NewConnectionManager(cfg.Global.DatabaseOptions) + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) @@ -587,7 +587,7 @@ func TestRegisterUserWithDisplayName(t *testing.T) { caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) natsInstance := jetstream.NATSInstance{} - cm := sqlutil.NewConnectionManager(cfg.Global.DatabaseOptions) + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) deviceName, deviceID := "deviceName", "deviceID" @@ -628,7 +628,7 @@ func TestRegisterAdminUsingSharedSecret(t *testing.T) { sharedSecret := "dendritetest" cfg.ClientAPI.RegistrationSharedSecret = sharedSecret - cm := sqlutil.NewConnectionManager(cfg.Global.DatabaseOptions) + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil) diff --git a/federationapi/storage/storage_test.go b/federationapi/storage/storage_test.go index 321f9b758..74863c07c 100644 --- a/federationapi/storage/storage_test.go +++ b/federationapi/storage/storage_test.go @@ -20,7 +20,7 @@ func mustCreateFederationDatabase(t *testing.T, dbType test.DBType) (storage.Dat caches := caching.NewRistrettoCache(8*1024*1024, time.Hour, false) connStr, dbClose := test.PrepareDBConnectionString(t, dbType) ctx := context.Background() - cm := sqlutil.NewConnectionManager(config.DatabaseOptions{}) + cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{}) db, err := storage.NewDatabase(ctx, cm, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, caches, func(server gomatrixserverlib.ServerName) bool { return server == "localhost" }) diff --git a/internal/sqlutil/connection_manager.go b/internal/sqlutil/connection_manager.go index 8176a7592..86b62a9b2 100644 --- a/internal/sqlutil/connection_manager.go +++ b/internal/sqlutil/connection_manager.go @@ -19,17 +19,20 @@ import ( "fmt" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/process" ) type Connections struct { - db *sql.DB - writer Writer - globalConfig config.DatabaseOptions + db *sql.DB + writer Writer + globalConfig config.DatabaseOptions + processContext *process.ProcessContext } -func NewConnectionManager(globalConfig config.DatabaseOptions) Connections { +func NewConnectionManager(processCtx *process.ProcessContext, globalConfig config.DatabaseOptions) Connections { return Connections{ - globalConfig: globalConfig, + globalConfig: globalConfig, + processContext: processCtx, } } @@ -50,6 +53,15 @@ func (c *Connections) Connection(dbProperties *config.DatabaseOptions) (*sql.DB, return nil, nil, err } c.writer = writer + go func() { + if c.processContext == nil { + return + } + c.processContext.ComponentStarted() + <-c.processContext.WaitForShutdown() + _ = c.db.Close() + c.processContext.ComponentFinished() + }() return c.db, c.writer, nil } if c.db != nil && c.writer != nil { diff --git a/internal/sqlutil/connection_manager_test.go b/internal/sqlutil/connection_manager_test.go index 1c26041dc..723f2967e 100644 --- a/internal/sqlutil/connection_manager_test.go +++ b/internal/sqlutil/connection_manager_test.go @@ -13,7 +13,7 @@ func TestConnectionManager(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { conStr, close := test.PrepareDBConnectionString(t, dbType) t.Cleanup(close) - cm := sqlutil.NewConnectionManager(config.DatabaseOptions{}) + cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{}) dbProps := &config.DatabaseOptions{ConnectionString: config.DataSource(string(conStr))} db, writer, err := cm.Connection(dbProps) @@ -47,7 +47,7 @@ func TestConnectionManager(t *testing.T) { } // test invalid connection string configured - cm = sqlutil.NewConnectionManager(config.DatabaseOptions{}) + cm = sqlutil.NewConnectionManager(nil, config.DatabaseOptions{}) _, _, err = cm.Connection(&config.DatabaseOptions{ConnectionString: "http://"}) if err == nil { t.Fatal("expected an error but got none") diff --git a/mediaapi/routing/upload_test.go b/mediaapi/routing/upload_test.go index bc58e5d1f..d088950ca 100644 --- a/mediaapi/routing/upload_test.go +++ b/mediaapi/routing/upload_test.go @@ -50,7 +50,7 @@ func Test_uploadRequest_doUpload(t *testing.T) { // create testdata folder and remove when done _ = os.Mkdir(testdataPath, os.ModePerm) defer fileutils.RemoveDir(types.Path(testdataPath), nil) - cm := sqlutil.NewConnectionManager(config.DatabaseOptions{}) + cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{}) db, err := storage.NewMediaAPIDatasource(cm, &config.DatabaseOptions{ ConnectionString: "file::memory:?cache=shared", MaxOpenConnections: 100, diff --git a/mediaapi/storage/storage_test.go b/mediaapi/storage/storage_test.go index 4731d7701..8cd29a54d 100644 --- a/mediaapi/storage/storage_test.go +++ b/mediaapi/storage/storage_test.go @@ -14,7 +14,7 @@ import ( func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { connStr, close := test.PrepareDBConnectionString(t, dbType) - cm := sqlutil.NewConnectionManager(config.DatabaseOptions{}) + cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{}) db, err := storage.NewMediaAPIDatasource(cm, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }) diff --git a/roomserver/internal/helpers/helpers_test.go b/roomserver/internal/helpers/helpers_test.go index 0d1086359..dd74b844a 100644 --- a/roomserver/internal/helpers/helpers_test.go +++ b/roomserver/internal/helpers/helpers_test.go @@ -19,7 +19,7 @@ import ( func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { conStr, close := test.PrepareDBConnectionString(t, dbType) caches := caching.NewRistrettoCache(8*1024*1024, time.Hour, caching.DisableMetrics) - cm := sqlutil.NewConnectionManager(config.DatabaseOptions{}) + cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{}) db, err := storage.Open(context.Background(), cm, &config.DatabaseOptions{ConnectionString: config.DataSource(conStr)}, caches) if err != nil { t.Fatalf("failed to create Database: %v", err) diff --git a/roomserver/internal/input/input_test.go b/roomserver/internal/input/input_test.go index 1bfd8f50a..186151a4f 100644 --- a/roomserver/internal/input/input_test.go +++ b/roomserver/internal/input/input_test.go @@ -49,7 +49,7 @@ func TestSingleTransactionOnInput(t *testing.T) { Kind: api.KindOutlier, // don't panic if we generate an output event Event: event.Headered(gomatrixserverlib.RoomVersionV6), } - cm := sqlutil.NewConnectionManager(config.DatabaseOptions{}) + cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{}) db, err := storage.Open( context.Background(), cm, &config.DatabaseOptions{ diff --git a/setup/base/base.go b/setup/base/base.go index dbcdac6e4..ed5d9a900 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -145,7 +145,8 @@ func NewBaseDendrite(cfg *config.Dendrite, options ...BaseDendriteOptions) *Base // If we're in monolith mode, we'll set up a global pool of database // connections. A component is welcome to use this pool if they don't // have a separate database config of their own. - cm := sqlutil.NewConnectionManager(cfg.Global.DatabaseOptions) + pCtx := process.NewProcessContext() + cm := sqlutil.NewConnectionManager(pCtx, cfg.Global.DatabaseOptions) // Ideally we would only use SkipClean on routes which we know can allow '/' but due to // https://github.com/gorilla/mux/issues/460 we have to attach this at the top router. @@ -160,7 +161,7 @@ func NewBaseDendrite(cfg *config.Dendrite, options ...BaseDendriteOptions) *Base // directory traversal attack e.g /../../../etc/passwd return &BaseDendrite{ - ProcessContext: process.NewProcessContext(), + ProcessContext: pCtx, tracerCloser: closer, Cfg: cfg, DNSCache: dnsCache, diff --git a/setup/mscs/msc2836/msc2836_test.go b/setup/mscs/msc2836/msc2836_test.go index fc2ff60d8..ef2dde73c 100644 --- a/setup/mscs/msc2836/msc2836_test.go +++ b/setup/mscs/msc2836/msc2836_test.go @@ -555,7 +555,7 @@ func injectEvents(t *testing.T, userAPI userapi.UserInternalAPI, rsAPI roomserve cfg.Global.ServerName = "localhost" cfg.MSCs.Database.ConnectionString = "file:msc2836_test.db" cfg.MSCs.MSCs = []string{"msc2836"} - cm := sqlutil.NewConnectionManager(config.DatabaseOptions{}) + cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{}) base := &base.BaseDendrite{ Cfg: cfg, Routers: httputil.NewRouters(), diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index f5aa68096..e81a341f1 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -22,7 +22,7 @@ var ctx = context.Background() func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { connStr, close := test.PrepareDBConnectionString(t, dbType) - cm := sqlutil.NewConnectionManager(config.DatabaseOptions{}) + cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{}) db, err := storage.NewSyncServerDatasource(context.Background(), cm, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }) diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go index c48821fca..4827ad47c 100644 --- a/userapi/consumers/roomserver_test.go +++ b/userapi/consumers/roomserver_test.go @@ -21,7 +21,7 @@ import ( func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) { t.Helper() connStr, close := test.PrepareDBConnectionString(t, dbType) - cm := sqlutil.NewConnectionManager(config.DatabaseOptions{}) + cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{}) db, err := storage.NewUserDatabase(context.Background(), cm, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "", 4, 0, 0, "") diff --git a/userapi/internal/device_list_update_test.go b/userapi/internal/device_list_update_test.go index 9b826908b..c0965a2c2 100644 --- a/userapi/internal/device_list_update_test.go +++ b/userapi/internal/device_list_update_test.go @@ -364,7 +364,7 @@ func mustCreateKeyserverDB(t *testing.T, dbType test.DBType) (storage.KeyDatabas t.Helper() connStr, clearDB := test.PrepareDBConnectionString(t, dbType) - cm := sqlutil.NewConnectionManager(config.DatabaseOptions{}) + cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{}) db, err := storage.NewKeyDatabase(cm, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)}) if err != nil { t.Fatal(err) diff --git a/userapi/internal/key_api_test.go b/userapi/internal/key_api_test.go index e2b01bd56..de2a6d2c8 100644 --- a/userapi/internal/key_api_test.go +++ b/userapi/internal/key_api_test.go @@ -16,7 +16,7 @@ import ( func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) { t.Helper() connStr, close := test.PrepareDBConnectionString(t, dbType) - cm := sqlutil.NewConnectionManager(config.DatabaseOptions{}) + cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{}) db, err := storage.NewKeyDatabase(cm, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }) diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index 251bd02eb..cf7c5144e 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -35,7 +35,7 @@ var ( func mustCreateUserDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) { connStr, close := test.PrepareDBConnectionString(t, dbType) - cm := sqlutil.NewConnectionManager(config.DatabaseOptions{}) + cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{}) db, err := storage.NewUserDatabase(context.Background(), cm, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server") @@ -576,8 +576,8 @@ func Test_Notification(t *testing.T) { } func mustCreateKeyDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) { - cfg, _, close := testrig.CreateConfig(t, dbType) - cm := sqlutil.NewConnectionManager(cfg.Global.DatabaseOptions) + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) db, err := storage.NewKeyDatabase(cm, &cfg.KeyServer.Database) if err != nil { t.Fatalf("failed to create new database: %v", err) diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 1e69e83a8..03e656354 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -73,7 +73,7 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType, pub if opts.serverName != "" { sName = gomatrixserverlib.ServerName(opts.serverName) } - cm := sqlutil.NewConnectionManager(cfg.Global.DatabaseOptions) + cm := sqlutil.NewConnectionManager(ctx, cfg.Global.DatabaseOptions) accountDB, err := storage.NewUserDatabase(ctx.Context(), cm, &cfg.UserAPI.AccountDatabase, sName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "") if err != nil { diff --git a/userapi/util/notify_test.go b/userapi/util/notify_test.go index ab39187b1..69461ddd1 100644 --- a/userapi/util/notify_test.go +++ b/userapi/util/notify_test.go @@ -77,7 +77,7 @@ func TestNotifyUserCountsAsync(t *testing.T) { // Create DB and Dendrite base connStr, close := test.PrepareDBConnectionString(t, dbType) defer close() - cm := sqlutil.NewConnectionManager(config.DatabaseOptions{}) + cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{}) db, err := storage.NewUserDatabase(ctx, cm, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "test", bcrypt.MinCost, 0, 0, "")