Add ProcessContext to ConnectionManager to cleanly shut down database

connections
This commit is contained in:
Till Faelligen 2023-03-20 18:14:49 +01:00
parent 4c77ff3b26
commit d4db4ed40d
No known key found for this signature in database
GPG key ID: ACCDC9606D472758
21 changed files with 45 additions and 32 deletions

View file

@ -134,7 +134,7 @@ func TestAppserviceInternalAPI(t *testing.T) {
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
// Create required internal APIs
natsInstance := jetstream.NATSInstance{}
cm := sqlutil.NewConnectionManager(cfg.Global.DatabaseOptions)
cm := sqlutil.NewConnectionManager(ctx, cfg.Global.DatabaseOptions)
rsAPI := roomserver.NewInternalAPI(ctx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
usrAPI := userapi.NewInternalAPI(ctx, cfg, cm, &natsInstance, rsAPI, nil)
asAPI := appservice.NewInternalAPI(ctx, cfg, &natsInstance, usrAPI, rsAPI)

View file

@ -43,7 +43,7 @@ func TestAdminResetPassword(t *testing.T) {
})
routers := httputil.NewRouters()
cm := sqlutil.NewConnectionManager(cfg.Global.DatabaseOptions)
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
// Needed for changing the password/login
@ -161,7 +161,7 @@ func TestPurgeRoom(t *testing.T) {
defer close()
routers := httputil.NewRouters()
cm := sqlutil.NewConnectionManager(cfg.Global.DatabaseOptions)
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)

View file

@ -30,7 +30,7 @@ func TestJoinRoomByIDOrAlias(t *testing.T) {
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
defer close()
cm := sqlutil.NewConnectionManager(cfg.Global.DatabaseOptions)
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
natsInstance := jetstream.NATSInstance{}
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)

View file

@ -42,7 +42,7 @@ func TestLogin(t *testing.T) {
SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh1"},
})
cm := sqlutil.NewConnectionManager(cfg.Global.DatabaseOptions)
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
routers := httputil.NewRouters()
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)

View file

@ -413,7 +413,7 @@ func Test_register(t *testing.T) {
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
natsInstance := jetstream.NATSInstance{}
cm := sqlutil.NewConnectionManager(cfg.Global.DatabaseOptions)
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
@ -587,7 +587,7 @@ func TestRegisterUserWithDisplayName(t *testing.T) {
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
natsInstance := jetstream.NATSInstance{}
cm := sqlutil.NewConnectionManager(cfg.Global.DatabaseOptions)
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
deviceName, deviceID := "deviceName", "deviceID"
@ -628,7 +628,7 @@ func TestRegisterAdminUsingSharedSecret(t *testing.T) {
sharedSecret := "dendritetest"
cfg.ClientAPI.RegistrationSharedSecret = sharedSecret
cm := sqlutil.NewConnectionManager(cfg.Global.DatabaseOptions)
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)

View file

@ -20,7 +20,7 @@ func mustCreateFederationDatabase(t *testing.T, dbType test.DBType) (storage.Dat
caches := caching.NewRistrettoCache(8*1024*1024, time.Hour, false)
connStr, dbClose := test.PrepareDBConnectionString(t, dbType)
ctx := context.Background()
cm := sqlutil.NewConnectionManager(config.DatabaseOptions{})
cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{})
db, err := storage.NewDatabase(ctx, cm, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, caches, func(server gomatrixserverlib.ServerName) bool { return server == "localhost" })

View file

@ -19,17 +19,20 @@ import (
"fmt"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/process"
)
type Connections struct {
db *sql.DB
writer Writer
globalConfig config.DatabaseOptions
processContext *process.ProcessContext
}
func NewConnectionManager(globalConfig config.DatabaseOptions) Connections {
func NewConnectionManager(processCtx *process.ProcessContext, globalConfig config.DatabaseOptions) Connections {
return Connections{
globalConfig: globalConfig,
processContext: processCtx,
}
}
@ -50,6 +53,15 @@ func (c *Connections) Connection(dbProperties *config.DatabaseOptions) (*sql.DB,
return nil, nil, err
}
c.writer = writer
go func() {
if c.processContext == nil {
return
}
c.processContext.ComponentStarted()
<-c.processContext.WaitForShutdown()
_ = c.db.Close()
c.processContext.ComponentFinished()
}()
return c.db, c.writer, nil
}
if c.db != nil && c.writer != nil {

View file

@ -13,7 +13,7 @@ func TestConnectionManager(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
conStr, close := test.PrepareDBConnectionString(t, dbType)
t.Cleanup(close)
cm := sqlutil.NewConnectionManager(config.DatabaseOptions{})
cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{})
dbProps := &config.DatabaseOptions{ConnectionString: config.DataSource(string(conStr))}
db, writer, err := cm.Connection(dbProps)
@ -47,7 +47,7 @@ func TestConnectionManager(t *testing.T) {
}
// test invalid connection string configured
cm = sqlutil.NewConnectionManager(config.DatabaseOptions{})
cm = sqlutil.NewConnectionManager(nil, config.DatabaseOptions{})
_, _, err = cm.Connection(&config.DatabaseOptions{ConnectionString: "http://"})
if err == nil {
t.Fatal("expected an error but got none")

View file

@ -50,7 +50,7 @@ func Test_uploadRequest_doUpload(t *testing.T) {
// create testdata folder and remove when done
_ = os.Mkdir(testdataPath, os.ModePerm)
defer fileutils.RemoveDir(types.Path(testdataPath), nil)
cm := sqlutil.NewConnectionManager(config.DatabaseOptions{})
cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{})
db, err := storage.NewMediaAPIDatasource(cm, &config.DatabaseOptions{
ConnectionString: "file::memory:?cache=shared",
MaxOpenConnections: 100,

View file

@ -14,7 +14,7 @@ import (
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
connStr, close := test.PrepareDBConnectionString(t, dbType)
cm := sqlutil.NewConnectionManager(config.DatabaseOptions{})
cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{})
db, err := storage.NewMediaAPIDatasource(cm, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
})

View file

@ -19,7 +19,7 @@ import (
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
conStr, close := test.PrepareDBConnectionString(t, dbType)
caches := caching.NewRistrettoCache(8*1024*1024, time.Hour, caching.DisableMetrics)
cm := sqlutil.NewConnectionManager(config.DatabaseOptions{})
cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{})
db, err := storage.Open(context.Background(), cm, &config.DatabaseOptions{ConnectionString: config.DataSource(conStr)}, caches)
if err != nil {
t.Fatalf("failed to create Database: %v", err)

View file

@ -49,7 +49,7 @@ func TestSingleTransactionOnInput(t *testing.T) {
Kind: api.KindOutlier, // don't panic if we generate an output event
Event: event.Headered(gomatrixserverlib.RoomVersionV6),
}
cm := sqlutil.NewConnectionManager(config.DatabaseOptions{})
cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{})
db, err := storage.Open(
context.Background(), cm,
&config.DatabaseOptions{

View file

@ -145,7 +145,8 @@ func NewBaseDendrite(cfg *config.Dendrite, options ...BaseDendriteOptions) *Base
// If we're in monolith mode, we'll set up a global pool of database
// connections. A component is welcome to use this pool if they don't
// have a separate database config of their own.
cm := sqlutil.NewConnectionManager(cfg.Global.DatabaseOptions)
pCtx := process.NewProcessContext()
cm := sqlutil.NewConnectionManager(pCtx, cfg.Global.DatabaseOptions)
// Ideally we would only use SkipClean on routes which we know can allow '/' but due to
// https://github.com/gorilla/mux/issues/460 we have to attach this at the top router.
@ -160,7 +161,7 @@ func NewBaseDendrite(cfg *config.Dendrite, options ...BaseDendriteOptions) *Base
// directory traversal attack e.g /../../../etc/passwd
return &BaseDendrite{
ProcessContext: process.NewProcessContext(),
ProcessContext: pCtx,
tracerCloser: closer,
Cfg: cfg,
DNSCache: dnsCache,

View file

@ -555,7 +555,7 @@ func injectEvents(t *testing.T, userAPI userapi.UserInternalAPI, rsAPI roomserve
cfg.Global.ServerName = "localhost"
cfg.MSCs.Database.ConnectionString = "file:msc2836_test.db"
cfg.MSCs.MSCs = []string{"msc2836"}
cm := sqlutil.NewConnectionManager(config.DatabaseOptions{})
cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{})
base := &base.BaseDendrite{
Cfg: cfg,
Routers: httputil.NewRouters(),

View file

@ -22,7 +22,7 @@ var ctx = context.Background()
func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
connStr, close := test.PrepareDBConnectionString(t, dbType)
cm := sqlutil.NewConnectionManager(config.DatabaseOptions{})
cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{})
db, err := storage.NewSyncServerDatasource(context.Background(), cm, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
})

View file

@ -21,7 +21,7 @@ import (
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
cm := sqlutil.NewConnectionManager(config.DatabaseOptions{})
cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{})
db, err := storage.NewUserDatabase(context.Background(), cm, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, "", 4, 0, 0, "")

View file

@ -364,7 +364,7 @@ func mustCreateKeyserverDB(t *testing.T, dbType test.DBType) (storage.KeyDatabas
t.Helper()
connStr, clearDB := test.PrepareDBConnectionString(t, dbType)
cm := sqlutil.NewConnectionManager(config.DatabaseOptions{})
cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{})
db, err := storage.NewKeyDatabase(cm, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)})
if err != nil {
t.Fatal(err)

View file

@ -16,7 +16,7 @@ import (
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
cm := sqlutil.NewConnectionManager(config.DatabaseOptions{})
cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{})
db, err := storage.NewKeyDatabase(cm, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
})

View file

@ -35,7 +35,7 @@ var (
func mustCreateUserDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) {
connStr, close := test.PrepareDBConnectionString(t, dbType)
cm := sqlutil.NewConnectionManager(config.DatabaseOptions{})
cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{})
db, err := storage.NewUserDatabase(context.Background(), cm, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server")
@ -576,8 +576,8 @@ func Test_Notification(t *testing.T) {
}
func mustCreateKeyDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) {
cfg, _, close := testrig.CreateConfig(t, dbType)
cm := sqlutil.NewConnectionManager(cfg.Global.DatabaseOptions)
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
db, err := storage.NewKeyDatabase(cm, &cfg.KeyServer.Database)
if err != nil {
t.Fatalf("failed to create new database: %v", err)

View file

@ -73,7 +73,7 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType, pub
if opts.serverName != "" {
sName = gomatrixserverlib.ServerName(opts.serverName)
}
cm := sqlutil.NewConnectionManager(cfg.Global.DatabaseOptions)
cm := sqlutil.NewConnectionManager(ctx, cfg.Global.DatabaseOptions)
accountDB, err := storage.NewUserDatabase(ctx.Context(), cm, &cfg.UserAPI.AccountDatabase, sName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "")
if err != nil {

View file

@ -77,7 +77,7 @@ func TestNotifyUserCountsAsync(t *testing.T) {
// Create DB and Dendrite base
connStr, close := test.PrepareDBConnectionString(t, dbType)
defer close()
cm := sqlutil.NewConnectionManager(config.DatabaseOptions{})
cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{})
db, err := storage.NewUserDatabase(ctx, cm, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, "test", bcrypt.MinCost, 0, 0, "")