diff --git a/cmd/resolve-state/main.go b/cmd/resolve-state/main.go index a9cc80cb7..099daaa9a 100644 --- a/cmd/resolve-state/main.go +++ b/cmd/resolve-state/main.go @@ -55,7 +55,7 @@ func main() { fmt.Println("Fetching", len(snapshotNIDs), "snapshot NIDs") roomserverDB, err := storage.Open( - base, &cfg.RoomServer.Database, + base.ProcessContext.Context(), base.ConnectionManager, &cfg.RoomServer.Database, caching.NewRistrettoCache(128*1024*1024, time.Hour, true), ) if err != nil { diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index b47527b11..8a4237bae 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -94,7 +94,7 @@ func NewInternalAPI( ) api.FederationInternalAPI { cfg := &base.Cfg.FederationAPI - federationDB, err := storage.NewDatabase(base, &cfg.Database, caches, base.Cfg.Global.IsLocalServerName) + federationDB, err := storage.NewDatabase(base.ProcessContext.Context(), base.ConnectionManager, &cfg.Database, caches, base.Cfg.Global.IsLocalServerName) if err != nil { logrus.WithError(err).Panic("failed to connect to federation sender db") } diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go index 8cc3a9a18..c3fb10561 100644 --- a/federationapi/queue/queue_test.go +++ b/federationapi/queue/queue_test.go @@ -22,6 +22,7 @@ import ( "time" "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/test/testrig" "go.uber.org/atomic" "gotest.tools/v3/poll" @@ -35,7 +36,6 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/dendrite/test/testrig" ) func mustCreateFederationDatabase(t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *process.ProcessContext, func()) { @@ -44,7 +44,7 @@ func mustCreateFederationDatabase(t *testing.T, dbType test.DBType, realDatabase b, baseClose := testrig.CreateBaseDendrite(t, dbType) caches := caching.NewRistrettoCache(b.Cfg.Global.Cache.EstimatedMaxSize, b.Cfg.Global.Cache.MaxAge, caching.DisableMetrics) connStr, dbClose := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewDatabase(b, &config.DatabaseOptions{ + db, err := storage.NewDatabase(b.ProcessContext.Context(), b.ConnectionManager, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, caches, b.Cfg.Global.IsLocalServerName) if err != nil { @@ -57,10 +57,7 @@ func mustCreateFederationDatabase(t *testing.T, dbType test.DBType, realDatabase } else { // Fake Database db := test.NewInMemoryFederationDatabase() - b := struct { - ProcessContext *process.ProcessContext - }{ProcessContext: process.NewProcessContext()} - return db, b.ProcessContext, func() {} + return db, process.NewProcessContext(), func() {} } } diff --git a/federationapi/storage/postgres/storage.go b/federationapi/storage/postgres/storage.go index b81f128e7..09fd8bcfe 100644 --- a/federationapi/storage/postgres/storage.go +++ b/federationapi/storage/postgres/storage.go @@ -16,6 +16,7 @@ package postgres import ( + "context" "database/sql" "fmt" @@ -23,7 +24,6 @@ import ( "github.com/matrix-org/dendrite/federationapi/storage/shared" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" ) @@ -36,10 +36,10 @@ type Database struct { } // NewDatabase opens a new database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (*Database, error) { +func NewDatabase(ctx context.Context, conMan sqlutil.ConnectionManager, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (*Database, error) { var d Database var err error - if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil { + if d.db, d.writer, err = conMan.Connection(dbProperties); err != nil { return nil, err } blacklist, err := NewPostgresBlacklistTable(d.db) @@ -95,7 +95,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, Version: "federationsender: drop federationsender_rooms", Up: deltas.UpRemoveRoomsTable, }) - err = m.Up(base.Context()) + err = m.Up(ctx) if err != nil { return nil, err } diff --git a/federationapi/storage/sqlite3/storage.go b/federationapi/storage/sqlite3/storage.go index 1e7e41a2c..06292fa43 100644 --- a/federationapi/storage/sqlite3/storage.go +++ b/federationapi/storage/sqlite3/storage.go @@ -15,13 +15,13 @@ package sqlite3 import ( + "context" "database/sql" "github.com/matrix-org/dendrite/federationapi/storage/shared" "github.com/matrix-org/dendrite/federationapi/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" ) @@ -34,10 +34,10 @@ type Database struct { } // NewDatabase opens a new database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (*Database, error) { +func NewDatabase(ctx context.Context, conMan sqlutil.ConnectionManager, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (*Database, error) { var d Database var err error - if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil { + if d.db, d.writer, err = conMan.Connection(dbProperties); err != nil { return nil, err } blacklist, err := NewSQLiteBlacklistTable(d.db) @@ -93,7 +93,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, Version: "federationsender: drop federationsender_rooms", Up: deltas.UpRemoveRoomsTable, }) - err = m.Up(base.Context()) + err = m.Up(ctx) if err != nil { return nil, err } diff --git a/federationapi/storage/storage.go b/federationapi/storage/storage.go index 142e281ea..bc63ab286 100644 --- a/federationapi/storage/storage.go +++ b/federationapi/storage/storage.go @@ -18,23 +18,24 @@ package storage import ( + "context" "fmt" "github.com/matrix-org/dendrite/federationapi/storage/postgres" "github.com/matrix-org/dendrite/federationapi/storage/sqlite3" "github.com/matrix-org/dendrite/internal/caching" - "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" ) // NewDatabase opens a new database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (Database, error) { +func NewDatabase(ctx context.Context, conMan sqlutil.ConnectionManager, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties, cache, isLocalServerName) + return sqlite3.NewDatabase(ctx, conMan, dbProperties, cache, isLocalServerName) case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(base, dbProperties, cache, isLocalServerName) + return postgres.NewDatabase(ctx, conMan, dbProperties, cache, isLocalServerName) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/federationapi/storage/storage_test.go b/federationapi/storage/storage_test.go index 9e6f27420..5f69a0b5d 100644 --- a/federationapi/storage/storage_test.go +++ b/federationapi/storage/storage_test.go @@ -8,19 +8,20 @@ import ( "github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/stretchr/testify/assert" ) func mustCreateFederationDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { - b, baseClose := testrig.CreateBaseDendrite(t, dbType) - caches := caching.NewRistrettoCache(b.Cfg.Global.Cache.EstimatedMaxSize, b.Cfg.Global.Cache.MaxAge, caching.DisableMetrics) + caches := caching.NewRistrettoCache(8*1024*1024, time.Hour, false) connStr, dbClose := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewDatabase(b, &config.DatabaseOptions{ + ctx := context.Background() + cm := sqlutil.NewConnectionManager() + db, err := storage.NewDatabase(ctx, &cm, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, caches, func(server gomatrixserverlib.ServerName) bool { return server == "localhost" }) if err != nil { @@ -28,7 +29,6 @@ func mustCreateFederationDatabase(t *testing.T, dbType test.DBType) (storage.Dat } return db, func() { dbClose() - baseClose() } } diff --git a/internal/sqlutil/connection_manager.go b/internal/sqlutil/connection_manager.go new file mode 100644 index 000000000..d902cb2c4 --- /dev/null +++ b/internal/sqlutil/connection_manager.go @@ -0,0 +1,58 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlutil + +import ( + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/setup/config" +) + +type ConnectionManager interface { + Connection(dbProperties *config.DatabaseOptions) (*sql.DB, Writer, error) +} + +type Connections struct { + db *sql.DB + writer Writer +} + +func NewConnectionManager() Connections { + return Connections{} +} + +func (c *Connections) Connection(dbProperties *config.DatabaseOptions) (*sql.DB, Writer, error) { + writer := NewDummyWriter() + if dbProperties.ConnectionString.IsSQLite() { + writer = NewExclusiveWriter() + } + if dbProperties.ConnectionString != "" || c.db == nil { + var err error + // Open a new database connection using the supplied config. + c.db, err = Open(dbProperties, writer) + if err != nil { + return nil, nil, err + } + c.writer = writer + return c.db, c.writer, nil + } + if c.db != nil && c.writer != nil { + // Ignore the supplied config and return the global pool and + // writer. + return c.db, c.writer, nil + } + return nil, nil, fmt.Errorf("no database connections configured") +} diff --git a/mediaapi/mediaapi.go b/mediaapi/mediaapi.go index bf76e9ddc..42e0ea88f 100644 --- a/mediaapi/mediaapi.go +++ b/mediaapi/mediaapi.go @@ -32,7 +32,7 @@ func AddPublicRoutes( cfg := &base.Cfg.MediaAPI rateCfg := &base.Cfg.ClientAPI.RateLimiting - mediaDB, err := storage.NewMediaAPIDatasource(base, &cfg.Database) + mediaDB, err := storage.NewMediaAPIDatasource(base.ConnectionManager, &cfg.Database) if err != nil { logrus.WithError(err).Panicf("failed to connect to media db") } diff --git a/mediaapi/routing/upload_test.go b/mediaapi/routing/upload_test.go index 420d0eba9..ccfbea624 100644 --- a/mediaapi/routing/upload_test.go +++ b/mediaapi/routing/upload_test.go @@ -9,6 +9,7 @@ import ( "strings" "testing" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/mediaapi/fileutils" "github.com/matrix-org/dendrite/mediaapi/storage" "github.com/matrix-org/dendrite/mediaapi/types" @@ -49,8 +50,8 @@ func Test_uploadRequest_doUpload(t *testing.T) { // create testdata folder and remove when done _ = os.Mkdir(testdataPath, os.ModePerm) defer fileutils.RemoveDir(types.Path(testdataPath), nil) - - db, err := storage.NewMediaAPIDatasource(nil, &config.DatabaseOptions{ + cm := sqlutil.NewConnectionManager() + db, err := storage.NewMediaAPIDatasource(&cm, &config.DatabaseOptions{ ConnectionString: "file::memory:?cache=shared", MaxOpenConnections: 100, MaxIdleConnections: 2, diff --git a/mediaapi/storage/postgres/mediaapi.go b/mediaapi/storage/postgres/mediaapi.go index 30ec64f84..d3e04b508 100644 --- a/mediaapi/storage/postgres/mediaapi.go +++ b/mediaapi/storage/postgres/mediaapi.go @@ -20,13 +20,12 @@ import ( _ "github.com/lib/pq" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/mediaapi/storage/shared" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" ) // NewDatabase opens a postgres database. -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.Database, error) { - db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()) +func NewDatabase(conMan sqlutil.ConnectionManager, dbProperties *config.DatabaseOptions) (*shared.Database, error) { + db, writer, err := conMan.Connection(dbProperties) if err != nil { return nil, err } diff --git a/mediaapi/storage/sqlite3/mediaapi.go b/mediaapi/storage/sqlite3/mediaapi.go index c0ab10e9f..b4f40c9da 100644 --- a/mediaapi/storage/sqlite3/mediaapi.go +++ b/mediaapi/storage/sqlite3/mediaapi.go @@ -19,13 +19,12 @@ import ( // Import the postgres database driver. "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/mediaapi/storage/shared" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" ) // NewDatabase opens a SQLIte database. -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.Database, error) { - db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()) +func NewDatabase(conMan sqlutil.ConnectionManager, dbProperties *config.DatabaseOptions) (*shared.Database, error) { + db, writer, err := conMan.Connection(dbProperties) if err != nil { return nil, err } diff --git a/mediaapi/storage/storage.go b/mediaapi/storage/storage.go index f673ae7e6..86b27df9f 100644 --- a/mediaapi/storage/storage.go +++ b/mediaapi/storage/storage.go @@ -20,19 +20,19 @@ package storage import ( "fmt" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/mediaapi/storage/postgres" "github.com/matrix-org/dendrite/mediaapi/storage/sqlite3" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" ) // NewMediaAPIDatasource opens a database connection. -func NewMediaAPIDatasource(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) { +func NewMediaAPIDatasource(conMan sqlutil.ConnectionManager, dbProperties *config.DatabaseOptions) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties) + return sqlite3.NewDatabase(conMan, dbProperties) case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(base, dbProperties) + return postgres.NewDatabase(conMan, dbProperties) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/mediaapi/storage/storage_test.go b/mediaapi/storage/storage_test.go index 81f0a5d24..603baddb1 100644 --- a/mediaapi/storage/storage_test.go +++ b/mediaapi/storage/storage_test.go @@ -5,6 +5,7 @@ import ( "reflect" "testing" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/mediaapi/storage" "github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/dendrite/setup/config" @@ -13,7 +14,8 @@ import ( func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewMediaAPIDatasource(nil, &config.DatabaseOptions{ + cm := sqlutil.NewConnectionManager() + db, err := storage.NewMediaAPIDatasource(&cm, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }) if err != nil { diff --git a/relayapi/relayapi.go b/relayapi/relayapi.go index 902c1a013..925fc031d 100644 --- a/relayapi/relayapi.go +++ b/relayapi/relayapi.go @@ -59,7 +59,7 @@ func NewRelayInternalAPI( caches caching.FederationCache, ) api.RelayInternalAPI { cfg := &base.Cfg.RelayAPI - relayDB, err := storage.NewDatabase(base, &cfg.Database, caches, base.Cfg.Global.IsLocalServerName) + relayDB, err := storage.NewDatabase(base.ConnectionManager, &cfg.Database, caches, base.Cfg.Global.IsLocalServerName) if err != nil { logrus.WithError(err).Panic("failed to connect to relay db") } diff --git a/relayapi/storage/postgres/storage.go b/relayapi/storage/postgres/storage.go index 1042beba7..03fae63bc 100644 --- a/relayapi/storage/postgres/storage.go +++ b/relayapi/storage/postgres/storage.go @@ -20,7 +20,6 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/relayapi/storage/shared" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" ) @@ -34,14 +33,14 @@ type Database struct { // NewDatabase opens a new database func NewDatabase( - base *base.BaseDendrite, + conMan sqlutil.ConnectionManager, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool, ) (*Database, error) { var d Database var err error - if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil { + if d.db, d.writer, err = conMan.Connection(dbProperties); err != nil { return nil, err } queue, err := NewPostgresRelayQueueTable(d.db) diff --git a/relayapi/storage/sqlite3/storage.go b/relayapi/storage/sqlite3/storage.go index 3ed4ab046..ca6bcff9d 100644 --- a/relayapi/storage/sqlite3/storage.go +++ b/relayapi/storage/sqlite3/storage.go @@ -20,7 +20,6 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/relayapi/storage/shared" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" ) @@ -34,14 +33,14 @@ type Database struct { // NewDatabase opens a new database func NewDatabase( - base *base.BaseDendrite, + conMan sqlutil.ConnectionManager, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool, ) (*Database, error) { var d Database var err error - if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil { + if d.db, d.writer, err = conMan.Connection(dbProperties); err != nil { return nil, err } queue, err := NewSQLiteRelayQueueTable(d.db) diff --git a/relayapi/storage/storage.go b/relayapi/storage/storage.go index 16ecbcfb7..c1a876d13 100644 --- a/relayapi/storage/storage.go +++ b/relayapi/storage/storage.go @@ -21,25 +21,25 @@ import ( "fmt" "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/relayapi/storage/postgres" "github.com/matrix-org/dendrite/relayapi/storage/sqlite3" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" ) // NewDatabase opens a new database func NewDatabase( - base *base.BaseDendrite, + conMan sqlutil.ConnectionManager, dbProperties *config.DatabaseOptions, cache caching.FederationCache, isLocalServerName func(gomatrixserverlib.ServerName) bool, ) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties, cache, isLocalServerName) + return sqlite3.NewDatabase(conMan, dbProperties, cache, isLocalServerName) case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(base, dbProperties, cache, isLocalServerName) + return postgres.NewDatabase(conMan, dbProperties, cache, isLocalServerName) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/roomserver/internal/helpers/helpers_test.go b/roomserver/internal/helpers/helpers_test.go index e8df857f0..3fbf72b21 100644 --- a/roomserver/internal/helpers/helpers_test.go +++ b/roomserver/internal/helpers/helpers_test.go @@ -3,28 +3,28 @@ package helpers import ( "context" "testing" + "time" "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" "github.com/stretchr/testify/assert" "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/dendrite/setup/base" - - "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/dendrite/test/testrig" - "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/test" ) -func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, storage.Database, func()) { - base, close := testrig.CreateBaseDendrite(t, dbType) - caches := caching.NewRistrettoCache(base.Cfg.Global.Cache.EstimatedMaxSize, base.Cfg.Global.Cache.MaxAge, caching.DisableMetrics) - db, err := storage.Open(base, &base.Cfg.RoomServer.Database, caches) +func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { + conStr, close := test.PrepareDBConnectionString(t, dbType) + caches := caching.NewRistrettoCache(8*1024*1024, time.Hour, caching.DisableMetrics) + cm := sqlutil.NewConnectionManager() + db, err := storage.Open(context.Background(), &cm, &config.DatabaseOptions{ConnectionString: config.DataSource(conStr)}, caches) if err != nil { t.Fatalf("failed to create Database: %v", err) } - return base, db, close + return db, close } func TestIsInvitePendingWithoutNID(t *testing.T) { @@ -34,7 +34,7 @@ func TestIsInvitePendingWithoutNID(t *testing.T) { room := test.NewRoom(t, alice, test.RoomPreset(test.PresetPublicChat)) _ = bob test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - _, db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateDatabase(t, dbType) defer close() // store all events diff --git a/roomserver/internal/input/input_test.go b/roomserver/internal/input/input_test.go index 4708560ac..0aba046c2 100644 --- a/roomserver/internal/input/input_test.go +++ b/roomserver/internal/input/input_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/input" "github.com/matrix-org/dendrite/roomserver/storage" @@ -48,14 +49,15 @@ func TestSingleTransactionOnInput(t *testing.T) { Kind: api.KindOutlier, // don't panic if we generate an output event Event: event.Headered(gomatrixserverlib.RoomVersionV6), } + cm := sqlutil.NewConnectionManager() db, err := storage.Open( - nil, + context.Background(), &cm, &config.DatabaseOptions{ ConnectionString: "", MaxOpenConnections: 1, MaxIdleConnections: 1, }, - caching.NewRistrettoCache(8*1024*1024, time.Hour, false), + caching.NewRistrettoCache(8*1024*1024, time.Hour, caching.DisableMetrics), ) if err != nil { t.Logf("PostgreSQL not available (%s), skipping", err) diff --git a/roomserver/roomserver.go b/roomserver/roomserver.go index 83928a006..1c55423ef 100644 --- a/roomserver/roomserver.go +++ b/roomserver/roomserver.go @@ -31,7 +31,7 @@ func NewInternalAPI( ) api.RoomserverInternalAPI { cfg := &base.Cfg.RoomServer - roomserverDB, err := storage.Open(base, &cfg.Database, caches) + roomserverDB, err := storage.Open(base.ProcessContext.Context(), base.ConnectionManager, &cfg.Database, caches) if err != nil { logrus.WithError(err).Panicf("failed to connect to room server db") } diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 39415f63d..1b0b3155d 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -34,7 +34,7 @@ func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, s t.Helper() base, close := testrig.CreateBaseDendrite(t, dbType) caches := caching.NewRistrettoCache(base.Cfg.Global.Cache.EstimatedMaxSize, base.Cfg.Global.Cache.MaxAge, caching.DisableMetrics) - db, err := storage.Open(base, &base.Cfg.RoomServer.Database, caches) + db, err := storage.Open(base.ProcessContext.Context(), base.ConnectionManager, &base.Cfg.RoomServer.Database, caches) if err != nil { t.Fatalf("failed to create Database: %v", err) } diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index d98a5cf97..340e1764a 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -28,7 +28,6 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/postgres/deltas" "github.com/matrix-org/dendrite/roomserver/storage/shared" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" ) @@ -38,10 +37,10 @@ type Database struct { } // Open a postgres database. -func Open(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) { +func Open(ctx context.Context, conMan sqlutil.ConnectionManager, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) { var d Database var err error - db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()) + db, writer, err := conMan.Connection(dbProperties) if err != nil { return nil, fmt.Errorf("sqlutil.Open: %w", err) } @@ -53,7 +52,7 @@ func Open(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache c // Special case, since this migration uses several tables, so it needs to // be sure that all tables are created first. - if err = executeMigration(base.Context(), db); err != nil { + if err = executeMigration(ctx, db); err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 2adedd2d8..05c36b494 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -28,7 +28,6 @@ import ( "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" ) @@ -38,10 +37,10 @@ type Database struct { } // Open a sqlite database. -func Open(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) { +func Open(ctx context.Context, conMan sqlutil.ConnectionManager, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) { var d Database var err error - db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()) + db, writer, err := conMan.Connection(dbProperties) if err != nil { return nil, fmt.Errorf("sqlutil.Open: %w", err) } @@ -62,7 +61,7 @@ func Open(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache c // Special case, since this migration uses several tables, so it needs to // be sure that all tables are created first. - if err = executeMigration(base.Context(), db); err != nil { + if err = executeMigration(ctx, db); err != nil { return nil, err } diff --git a/roomserver/storage/storage.go b/roomserver/storage/storage.go index 8a87b7d7c..c93268de5 100644 --- a/roomserver/storage/storage.go +++ b/roomserver/storage/storage.go @@ -18,22 +18,23 @@ package storage import ( + "context" "fmt" "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/postgres" "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" ) // Open opens a database connection. -func Open(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (Database, error) { +func Open(ctx context.Context, conMan sqlutil.ConnectionManager, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.Open(base, dbProperties, cache) + return sqlite3.Open(ctx, conMan, dbProperties, cache) case dbProperties.ConnectionString.IsPostgres(): - return postgres.Open(base, dbProperties, cache) + return postgres.Open(ctx, conMan, dbProperties, cache) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/setup/base/base.go b/setup/base/base.go index ed3445019..b7dddc87c 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -64,15 +64,14 @@ var staticContent embed.FS // Must be closed when shutting down. type BaseDendrite struct { *process.ProcessContext - tracerCloser io.Closer - Routers httputil.Routers - NATS *jetstream.NATSInstance - Cfg *config.Dendrite - DNSCache *gomatrixserverlib.DNSCache - Database *sql.DB - DatabaseWriter sqlutil.Writer - EnableMetrics bool - startupLock sync.Mutex + tracerCloser io.Closer + Routers httputil.Routers + NATS *jetstream.NATSInstance + Cfg *config.Dendrite + DNSCache *gomatrixserverlib.DNSCache + ConnectionManager sqlutil.ConnectionManager + EnableMetrics bool + startupLock sync.Mutex } const HTTPServerTimeout = time.Minute * 5 @@ -149,14 +148,13 @@ func NewBaseDendrite(cfg *config.Dendrite, options ...BaseDendriteOptions) *Base // If we're in monolith mode, we'll set up a global pool of database // connections. A component is welcome to use this pool if they don't // have a separate database config of their own. - var db *sql.DB - var writer sqlutil.Writer + cm := sqlutil.NewConnectionManager() if cfg.Global.DatabaseOptions.ConnectionString != "" { if cfg.Global.DatabaseOptions.ConnectionString.IsSQLite() { logrus.Panic("Using a global database connection pool is not supported with SQLite databases") } - writer = sqlutil.NewDummyWriter() - if db, err = sqlutil.Open(&cfg.Global.DatabaseOptions, writer); err != nil { + _, _, err := cm.Connection(&cfg.Global.DatabaseOptions) + if err != nil { logrus.WithError(err).Panic("Failed to set up global database connections") } logrus.Debug("Using global database connection pool") @@ -175,15 +173,14 @@ func NewBaseDendrite(cfg *config.Dendrite, options ...BaseDendriteOptions) *Base // directory traversal attack e.g /../../../etc/passwd return &BaseDendrite{ - ProcessContext: process.NewProcessContext(), - tracerCloser: closer, - Cfg: cfg, - DNSCache: dnsCache, - Routers: httputil.NewRouters(), - NATS: &jetstream.NATSInstance{}, - Database: db, // set if monolith with global connection pool only - DatabaseWriter: writer, // set if monolith with global connection pool only - EnableMetrics: enableMetrics, + ProcessContext: process.NewProcessContext(), + tracerCloser: closer, + Cfg: cfg, + DNSCache: dnsCache, + Routers: httputil.NewRouters(), + NATS: &jetstream.NATSInstance{}, + ConnectionManager: &cm, + EnableMetrics: enableMetrics, } } @@ -204,17 +201,12 @@ func (b *BaseDendrite) Close() error { // have a BaseDendrite and just want a connection with the supplied config // without any pooling stuff. func (b *BaseDendrite) DatabaseConnection(dbProperties *config.DatabaseOptions, writer sqlutil.Writer) (*sql.DB, sqlutil.Writer, error) { + logrus.Infof("%#v", dbProperties) if dbProperties.ConnectionString != "" || b == nil { - // Open a new database connection using the supplied config. - db, err := sqlutil.Open(dbProperties, writer) - return db, writer, err + cm := sqlutil.NewConnectionManager() + return cm.Connection(dbProperties) } - if b.Database != nil && b.DatabaseWriter != nil { - // Ignore the supplied config and return the global pool and - // writer. - return b.Database, b.DatabaseWriter, nil - } - return nil, nil, fmt.Errorf("no database connections configured") + return b.ConnectionManager.Connection(dbProperties) } // CreateClient creates a new client (normally used for media fetch requests). diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index cab4c4934..7c1e0fc61 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -102,7 +102,7 @@ func Enable( base *base.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, userAPI userapi.UserInternalAPI, keyRing gomatrixserverlib.JSONVerifier, ) error { - db, err := NewDatabase(base, &base.Cfg.MSCs.Database) + db, err := NewDatabase(base.ConnectionManager, &base.Cfg.MSCs.Database) if err != nil { return fmt.Errorf("cannot enable MSC2836: %w", err) } diff --git a/setup/mscs/msc2836/storage.go b/setup/mscs/msc2836/storage.go index 827e82f70..33e6ce734 100644 --- a/setup/mscs/msc2836/storage.go +++ b/setup/mscs/msc2836/storage.go @@ -8,7 +8,6 @@ import ( "encoding/json" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -59,17 +58,17 @@ type DB struct { } // NewDatabase loads the database for msc2836 -func NewDatabase(base *base.BaseDendrite, dbOpts *config.DatabaseOptions) (Database, error) { +func NewDatabase(conMan sqlutil.ConnectionManager, dbOpts *config.DatabaseOptions) (Database, error) { if dbOpts.ConnectionString.IsPostgres() { - return newPostgresDatabase(base, dbOpts) + return newPostgresDatabase(conMan, dbOpts) } - return newSQLiteDatabase(base, dbOpts) + return newSQLiteDatabase(conMan, dbOpts) } -func newPostgresDatabase(base *base.BaseDendrite, dbOpts *config.DatabaseOptions) (Database, error) { +func newPostgresDatabase(conMan sqlutil.ConnectionManager, dbOpts *config.DatabaseOptions) (Database, error) { d := DB{} var err error - if d.db, d.writer, err = base.DatabaseConnection(dbOpts, sqlutil.NewDummyWriter()); err != nil { + if d.db, d.writer, err = conMan.Connection(dbOpts); err != nil { return nil, err } _, err = d.db.Exec(` @@ -144,10 +143,10 @@ func newPostgresDatabase(base *base.BaseDendrite, dbOpts *config.DatabaseOptions return &d, err } -func newSQLiteDatabase(base *base.BaseDendrite, dbOpts *config.DatabaseOptions) (Database, error) { +func newSQLiteDatabase(conMan sqlutil.ConnectionManager, dbOpts *config.DatabaseOptions) (Database, error) { d := DB{} var err error - if d.db, d.writer, err = base.DatabaseConnection(dbOpts, sqlutil.NewExclusiveWriter()); err != nil { + if d.db, d.writer, err = conMan.Connection(dbOpts); err != nil { return nil, err } _, err = d.db.Exec(` diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go index 850d24a07..4b21ae43d 100644 --- a/syncapi/storage/postgres/syncserver.go +++ b/syncapi/storage/postgres/syncserver.go @@ -16,12 +16,12 @@ package postgres import ( + "context" "database/sql" // Import the postgres database driver. _ "github.com/lib/pq" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas" "github.com/matrix-org/dendrite/syncapi/storage/shared" @@ -36,10 +36,10 @@ type SyncServerDatasource struct { } // NewDatabase creates a new sync server database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*SyncServerDatasource, error) { +func NewDatabase(ctx context.Context, cm sqlutil.ConnectionManager, dbProperties *config.DatabaseOptions) (*SyncServerDatasource, error) { var d SyncServerDatasource var err error - if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil { + if d.db, d.writer, err = cm.Connection(dbProperties); err != nil { return nil, err } accountData, err := NewPostgresAccountDataTable(d.db) @@ -111,7 +111,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) Up: deltas.UpSetHistoryVisibility, // Requires current_room_state and output_room_events to be created. }, ) - err = m.Up(base.Context()) + err = m.Up(ctx) if err != nil { return nil, err } diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index 510546909..70cd17f0e 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -20,7 +20,6 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/storage/shared" "github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas" @@ -37,13 +36,14 @@ type SyncServerDatasource struct { // NewDatabase creates a new sync server database // nolint: gocyclo -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*SyncServerDatasource, error) { +func NewDatabase(ctx context.Context, conMan sqlutil.ConnectionManager, dbProperties *config.DatabaseOptions) (*SyncServerDatasource, error) { var d SyncServerDatasource var err error - if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil { + + if d.db, d.writer, err = conMan.Connection(dbProperties); err != nil { return nil, err } - if err = d.prepare(base.Context()); err != nil { + if err = d.prepare(ctx); err != nil { return nil, err } return &d, nil diff --git a/syncapi/storage/storage.go b/syncapi/storage/storage.go index 5b20c6cc2..80b780585 100644 --- a/syncapi/storage/storage.go +++ b/syncapi/storage/storage.go @@ -18,21 +18,22 @@ package storage import ( + "context" "fmt" - "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/storage/postgres" "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" ) // NewSyncServerDatasource opens a database connection. -func NewSyncServerDatasource(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) { +func NewSyncServerDatasource(ctx context.Context, conMan sqlutil.ConnectionManager, dbProperties *config.DatabaseOptions) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties) + return sqlite3.NewDatabase(ctx, conMan, dbProperties) case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(base, dbProperties) + return postgres.NewDatabase(ctx, conMan, dbProperties) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index 05d498bc2..6dff0a689 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -9,27 +9,27 @@ import ( "reflect" "testing" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/gomatrixserverlib" "github.com/stretchr/testify/assert" ) var ctx = context.Background() -func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func(), func()) { +func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { connStr, close := test.PrepareDBConnectionString(t, dbType) - base, closeBase := testrig.CreateBaseDendrite(t, dbType) - db, err := storage.NewSyncServerDatasource(base, &config.DatabaseOptions{ + cm := sqlutil.NewConnectionManager() + db, err := storage.NewSyncServerDatasource(context.Background(), &cm, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }) if err != nil { t.Fatalf("NewSyncServerDatasource returned %s", err) } - return db, close, closeBase + return db, close } func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserverlib.HeaderedEvent) (positions []types.StreamPosition) { @@ -55,9 +55,8 @@ func TestWriteEvents(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { alice := test.NewUser(t) r := test.NewRoom(t, alice) - db, close, closeBase := MustCreateDatabase(t, dbType) + db, close := MustCreateDatabase(t, dbType) defer close() - defer closeBase() MustWriteEvents(t, db, r.Events()) }) } @@ -76,9 +75,8 @@ func WithSnapshot(t *testing.T, db storage.Database, f func(snapshot storage.Dat // These tests assert basic functionality of RecentEvents for PDUs func TestRecentEventsPDU(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close, closeBase := MustCreateDatabase(t, dbType) + db, close := MustCreateDatabase(t, dbType) defer close() - defer closeBase() alice := test.NewUser(t) // dummy room to make sure SQL queries are filtering on room ID MustWriteEvents(t, db, test.NewRoom(t, alice).Events()) @@ -191,9 +189,8 @@ func TestRecentEventsPDU(t *testing.T) { // The purpose of this test is to ensure that backfill does indeed go backwards, using a topology token func TestGetEventsInRangeWithTopologyToken(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close, closeBase := MustCreateDatabase(t, dbType) + db, close := MustCreateDatabase(t, dbType) defer close() - defer closeBase() alice := test.NewUser(t) r := test.NewRoom(t, alice) for i := 0; i < 10; i++ { @@ -276,9 +273,8 @@ func TestStreamToTopologicalPosition(t *testing.T) { } test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close, closeBase := MustCreateDatabase(t, dbType) + db, close := MustCreateDatabase(t, dbType) defer close() - defer closeBase() txn, err := db.NewDatabaseTransaction(ctx) if err != nil { @@ -514,9 +510,8 @@ func TestSendToDeviceBehaviour(t *testing.T) { bob := test.NewUser(t) deviceID := "one" test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close, closeBase := MustCreateDatabase(t, dbType) + db, close := MustCreateDatabase(t, dbType) defer close() - defer closeBase() // At this point there should be no messages. We haven't sent anything // yet. @@ -899,9 +894,8 @@ func TestRoomSummary(t *testing.T) { } test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close, closeBase := MustCreateDatabase(t, dbType) + db, close := MustCreateDatabase(t, dbType) defer close() - defer closeBase() for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { @@ -939,11 +933,8 @@ func TestRecentEvents(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { filter := gomatrixserverlib.DefaultRoomEventFilter() - db, close, closeBase := MustCreateDatabase(t, dbType) - t.Cleanup(func() { - close() - closeBase() - }) + db, close := MustCreateDatabase(t, dbType) + t.Cleanup(close) MustWriteEvents(t, db, room1.Events()) MustWriteEvents(t, db, room2.Events()) diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index 51fbabe0a..e0cc8462e 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -48,7 +48,7 @@ func AddPublicRoutes( js, natsClient := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) - syncDB, err := storage.NewSyncServerDatasource(base, &cfg.Database) + syncDB, err := storage.NewSyncServerDatasource(base.Context(), base.ConnectionManager, &cfg.Database) if err != nil { logrus.WithError(err).Panicf("failed to connect to sync db") } diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index 021b6e845..13a078659 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -1145,7 +1145,7 @@ func TestUpdateRelations(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { base, shutdownBase := testrig.CreateBaseDendrite(t, dbType) t.Cleanup(shutdownBase) - db, err := storage.NewSyncServerDatasource(base, &base.Cfg.SyncAPI.Database) + db, err := storage.NewSyncServerDatasource(base.Context(), base.ConnectionManager, &base.Cfg.SyncAPI.Database) if err != nil { t.Fatal(err) } diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go index bc5ae652d..57eabfe50 100644 --- a/userapi/consumers/roomserver_test.go +++ b/userapi/consumers/roomserver_test.go @@ -7,22 +7,22 @@ import ( "testing" "time" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" "github.com/stretchr/testify/assert" "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/userapi/storage" userAPITypes "github.com/matrix-org/dendrite/userapi/types" ) func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) { - base, baseclose := testrig.CreateBaseDendrite(t, dbType) t.Helper() connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{ + cm := sqlutil.NewConnectionManager() + db, err := storage.NewUserDatabase(context.Background(), &cm, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "", 4, 0, 0, "") if err != nil { @@ -30,7 +30,6 @@ func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, } return db, func() { close() - baseclose() } } diff --git a/userapi/internal/device_list_update_test.go b/userapi/internal/device_list_update_test.go index 868fc9be8..e42133d46 100644 --- a/userapi/internal/device_list_update_test.go +++ b/userapi/internal/device_list_update_test.go @@ -27,13 +27,13 @@ import ( "testing" "time" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" roomserver "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage" ) @@ -363,9 +363,9 @@ func TestDebounce(t *testing.T) { func mustCreateKeyserverDB(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) { t.Helper() - base, _, _ := testrig.Base(nil) connStr, clearDB := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)}) + cm := sqlutil.NewConnectionManager() + db, err := storage.NewKeyDatabase(&cm, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)}) if err != nil { t.Fatal(err) } diff --git a/userapi/internal/key_api_test.go b/userapi/internal/key_api_test.go index fc7e7e0df..8af835d21 100644 --- a/userapi/internal/key_api_test.go +++ b/userapi/internal/key_api_test.go @@ -5,9 +5,9 @@ import ( "reflect" "testing" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/internal" "github.com/matrix-org/dendrite/userapi/storage" @@ -16,15 +16,14 @@ import ( func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) { t.Helper() connStr, close := test.PrepareDBConnectionString(t, dbType) - base, _, _ := testrig.Base(nil) - db, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{ + cm := sqlutil.NewConnectionManager() + db, err := storage.NewKeyDatabase(&cm, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }) if err != nil { t.Fatalf("failed to create new user db: %v", err) } return db, func() { - base.Close() close() } } diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index 673d123ba..39cabf0b0 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -23,7 +23,6 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas" "github.com/matrix-org/dendrite/userapi/storage/shared" @@ -33,8 +32,8 @@ import ( ) // NewDatabase creates a new accounts and profiles database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) { - db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()) +func NewDatabase(ctx context.Context, conMan sqlutil.ConnectionManager, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) { + db, writer, err := conMan.Connection(dbProperties) if err != nil { return nil, err } @@ -51,7 +50,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, return deltas.UpServerNames(ctx, txn, serverName) }, }) - if err = m.Up(base.Context()); err != nil { + if err = m.Up(ctx); err != nil { return nil, err } @@ -111,7 +110,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, return deltas.UpServerNamesPopulate(ctx, txn, serverName) }, }) - if err = m.Up(base.Context()); err != nil { + if err = m.Up(ctx); err != nil { return nil, err } @@ -137,8 +136,8 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, }, nil } -func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) { - db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()) +func NewKeyDatabase(conMan sqlutil.ConnectionManager, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) { + db, writer, err := conMan.Connection(dbProperties) if err != nil { return nil, err } diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index 0f3eeed1b..f377b6490 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -23,7 +23,6 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/storage/shared" @@ -31,8 +30,8 @@ import ( ) // NewUserDatabase creates a new accounts and profiles database -func NewUserDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) { - db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()) +func NewUserDatabase(ctx context.Context, conMan sqlutil.ConnectionManager, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) { + db, writer, err := conMan.Connection(dbProperties) if err != nil { return nil, err } @@ -49,7 +48,7 @@ func NewUserDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptio return deltas.UpServerNames(ctx, txn, serverName) }, }) - if err = m.Up(base.Context()); err != nil { + if err = m.Up(ctx); err != nil { return nil, err } @@ -109,7 +108,7 @@ func NewUserDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptio return deltas.UpServerNamesPopulate(ctx, txn, serverName) }, }) - if err = m.Up(base.Context()); err != nil { + if err = m.Up(ctx); err != nil { return nil, err } @@ -135,8 +134,8 @@ func NewUserDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptio }, nil } -func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) { - db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()) +func NewKeyDatabase(conMan sqlutil.ConnectionManager, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) { + db, writer, err := conMan.Connection(dbProperties) if err != nil { return nil, err } diff --git a/userapi/storage/storage.go b/userapi/storage/storage.go index 0329fb46a..cdf79f327 100644 --- a/userapi/storage/storage.go +++ b/userapi/storage/storage.go @@ -18,12 +18,13 @@ package storage import ( + "context" "fmt" "time" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/storage/postgres" "github.com/matrix-org/dendrite/userapi/storage/sqlite3" @@ -32,7 +33,8 @@ import ( // NewUserDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) // and sets postgres connection parameters func NewUserDatabase( - base *base.BaseDendrite, + ctx context.Context, + conMan sqlutil.ConnectionManager, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, @@ -42,9 +44,9 @@ func NewUserDatabase( ) (UserDatabase, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewUserDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) + return sqlite3.NewUserDatabase(ctx, conMan, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) + return postgres.NewDatabase(ctx, conMan, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) default: return nil, fmt.Errorf("unexpected database type") } @@ -52,12 +54,12 @@ func NewUserDatabase( // NewKeyDatabase opens a new Postgres or Sqlite database (base on dataSourceName) scheme) // and sets postgres connection parameters. -func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (KeyDatabase, error) { +func NewKeyDatabase(conMan sqlutil.ConnectionManager, dbProperties *config.DatabaseOptions) (KeyDatabase, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewKeyDatabase(base, dbProperties) + return sqlite3.NewKeyDatabase(conMan, dbProperties) case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewKeyDatabase(base, dbProperties) + return postgres.NewKeyDatabase(conMan, dbProperties) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index f52e7e17d..1e492c703 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -33,9 +34,9 @@ var ( ) func mustCreateUserDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) { - base, baseclose := testrig.CreateBaseDendrite(t, dbType) connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{ + cm := sqlutil.NewConnectionManager() + db, err := storage.NewUserDatabase(context.Background(), &cm, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server") if err != nil { @@ -43,7 +44,6 @@ func mustCreateUserDatabase(t *testing.T, dbType test.DBType) (storage.UserDatab } return db, func() { close() - baseclose() } } @@ -577,7 +577,7 @@ func Test_Notification(t *testing.T) { func mustCreateKeyDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) { base, close := testrig.CreateBaseDendrite(t, dbType) - db, err := storage.NewKeyDatabase(base, &base.Cfg.KeyServer.Database) + db, err := storage.NewKeyDatabase(base.ConnectionManager, &base.Cfg.KeyServer.Database) if err != nil { t.Fatalf("failed to create new database: %v", err) } diff --git a/userapi/userapi.go b/userapi/userapi.go index bb8c34a1a..3ada8020c 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -46,7 +46,8 @@ func NewInternalAPI( pgClient := pushgateway.NewHTTPClient(cfg.PushGatewayDisableTLSValidation) db, err := storage.NewUserDatabase( - base, + base.ProcessContext.Context(), + base.ConnectionManager, &cfg.AccountDatabase, cfg.Matrix.ServerName, cfg.BCryptCost, @@ -58,7 +59,7 @@ func NewInternalAPI( logrus.WithError(err).Panicf("failed to connect to accounts db") } - keyDB, err := storage.NewKeyDatabase(base, &base.Cfg.KeyServer.Database) + keyDB, err := storage.NewKeyDatabase(base.ConnectionManager, &base.Cfg.KeyServer.Database) if err != nil { logrus.WithError(err).Panicf("failed to connect to key db") } diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 01e491cb6..74363637f 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -22,6 +22,7 @@ import ( "testing" "time" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/producers" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -73,14 +74,16 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType, pub if opts.serverName != "" { sName = gomatrixserverlib.ServerName(opts.serverName) } - accountDB, err := storage.NewUserDatabase(base, &config.DatabaseOptions{ + cm := sqlutil.NewConnectionManager() + ctx := context.Background() + accountDB, err := storage.NewUserDatabase(ctx, &cm, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, sName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "") if err != nil { t.Fatalf("failed to create account DB: %s", err) } - keyDB, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{ + keyDB, err := storage.NewKeyDatabase(base.ConnectionManager, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }) if err != nil { diff --git a/userapi/util/notify_test.go b/userapi/util/notify_test.go index 421852d3f..74c8befbc 100644 --- a/userapi/util/notify_test.go +++ b/userapi/util/notify_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "golang.org/x/crypto/bcrypt" @@ -15,7 +16,6 @@ import ( "github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage" userUtil "github.com/matrix-org/dendrite/userapi/util" @@ -77,9 +77,8 @@ func TestNotifyUserCountsAsync(t *testing.T) { // Create DB and Dendrite base connStr, close := test.PrepareDBConnectionString(t, dbType) defer close() - base, _, _ := testrig.Base(nil) - defer base.Close() - db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{ + cm := sqlutil.NewConnectionManager() + db, err := storage.NewUserDatabase(ctx, &cm, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "test", bcrypt.MinCost, 0, 0, "") if err != nil { diff --git a/userapi/util/phonehomestats_test.go b/userapi/util/phonehomestats_test.go index 5f626b5bc..4e931a1b7 100644 --- a/userapi/util/phonehomestats_test.go +++ b/userapi/util/phonehomestats_test.go @@ -21,7 +21,7 @@ func TestCollect(t *testing.T) { b, _, _ := testrig.Base(nil) connStr, closeDB := test.PrepareDBConnectionString(t, dbType) defer closeDB() - db, err := storage.NewUserDatabase(b, &config.DatabaseOptions{ + db, err := storage.NewUserDatabase(b.Context(), b.ConnectionManager, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "localhost", bcrypt.MinCost, 1000, 1000, "") if err != nil {