Remove BaseDendrite from relay API

This commit is contained in:
Till Faelligen 2023-03-21 08:53:08 +01:00
parent d4db4ed40d
commit 9b06041052
No known key found for this signature in database
GPG key ID: ACCDC9606D472758
2 changed files with 35 additions and 25 deletions

View file

@ -57,6 +57,8 @@ func (c *Connections) Connection(dbProperties *config.DatabaseOptions) (*sql.DB,
if c.processContext == nil { if c.processContext == nil {
return return
} }
// If we have a ProcessContext, start a component and wait for
// Dendrite to shut down to cleanly close the database connection.
c.processContext.ComponentStarted() c.processContext.ComponentStarted()
<-c.processContext.WaitForShutdown() <-c.processContext.WaitForShutdown()
_ = c.db.Close() _ = c.db.Close()

View file

@ -26,6 +26,8 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/relayapi" "github.com/matrix-org/dendrite/relayapi"
"github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/test/testrig"
@ -35,38 +37,38 @@ import (
func TestCreateNewRelayInternalAPI(t *testing.T) { func TestCreateNewRelayInternalAPI(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
base, close := testrig.CreateBaseDendrite(t, dbType) cfg, processCtx, close := testrig.CreateConfig(t, dbType)
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
defer close() defer close()
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
relayAPI := relayapi.NewRelayInternalAPI(base.Cfg, base.ConnectionManager, nil, nil, nil, nil, true, caches) relayAPI := relayapi.NewRelayInternalAPI(cfg, cm, nil, nil, nil, nil, true, caches)
assert.NotNil(t, relayAPI) assert.NotNil(t, relayAPI)
}) })
} }
func TestCreateRelayInternalInvalidDatabasePanics(t *testing.T) { func TestCreateRelayInternalInvalidDatabasePanics(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
base, close := testrig.CreateBaseDendrite(t, dbType) cfg, processCtx, close := testrig.CreateConfig(t, dbType)
if dbType == test.DBTypeSQLite { if dbType == test.DBTypeSQLite {
base.Cfg.RelayAPI.Database.ConnectionString = "file:" cfg.RelayAPI.Database.ConnectionString = "file:"
} else { } else {
base.Cfg.RelayAPI.Database.ConnectionString = "test" cfg.RelayAPI.Database.ConnectionString = "test"
} }
defer close() defer close()
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
assert.Panics(t, func() { assert.Panics(t, func() {
relayapi.NewRelayInternalAPI(base.Cfg, base.ConnectionManager, nil, nil, nil, nil, true, nil) relayapi.NewRelayInternalAPI(cfg, cm, nil, nil, nil, nil, true, nil)
}) })
}) })
} }
func TestCreateInvalidRelayPublicRoutesPanics(t *testing.T) { func TestCreateInvalidRelayPublicRoutesPanics(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
base, close := testrig.CreateBaseDendrite(t, dbType) cfg, _, close := testrig.CreateConfig(t, dbType)
defer close() defer close()
routers := httputil.NewRouters()
assert.Panics(t, func() { assert.Panics(t, func() {
relayapi.AddPublicRoutes(base.Routers, base.Cfg, nil, nil) relayapi.AddPublicRoutes(routers, cfg, nil, nil)
}) })
}) })
} }
@ -108,16 +110,19 @@ func createSendRelayTxnHTTPRequest(serverName gomatrixserverlib.ServerName, txnI
func TestCreateRelayPublicRoutes(t *testing.T) { func TestCreateRelayPublicRoutes(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
base, close := testrig.CreateBaseDendrite(t, dbType) cfg, processCtx, close := testrig.CreateConfig(t, dbType)
defer close() defer close()
routers := httputil.NewRouters()
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
relayAPI := relayapi.NewRelayInternalAPI(base.Cfg, base.ConnectionManager, nil, nil, nil, nil, true, caches) cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
relayAPI := relayapi.NewRelayInternalAPI(cfg, cm, nil, nil, nil, nil, true, caches)
assert.NotNil(t, relayAPI) assert.NotNil(t, relayAPI)
serverKeyAPI := &signing.YggdrasilKeys{} serverKeyAPI := &signing.YggdrasilKeys{}
keyRing := serverKeyAPI.KeyRing() keyRing := serverKeyAPI.KeyRing()
relayapi.AddPublicRoutes(base.Routers, base.Cfg, keyRing, relayAPI) relayapi.AddPublicRoutes(routers, cfg, keyRing, relayAPI)
testCases := []struct { testCases := []struct {
name string name string
@ -126,29 +131,29 @@ func TestCreateRelayPublicRoutes(t *testing.T) {
}{ }{
{ {
name: "relay_txn invalid user id", name: "relay_txn invalid user id",
req: createGetRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "user:local"), req: createGetRelayTxnHTTPRequest(cfg.Global.ServerName, "user:local"),
wantCode: 400, wantCode: 400,
}, },
{ {
name: "relay_txn valid user id", name: "relay_txn valid user id",
req: createGetRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "@user:local"), req: createGetRelayTxnHTTPRequest(cfg.Global.ServerName, "@user:local"),
wantCode: 200, wantCode: 200,
}, },
{ {
name: "send_relay invalid user id", name: "send_relay invalid user id",
req: createSendRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "123", "user:local"), req: createSendRelayTxnHTTPRequest(cfg.Global.ServerName, "123", "user:local"),
wantCode: 400, wantCode: 400,
}, },
{ {
name: "send_relay valid user id", name: "send_relay valid user id",
req: createSendRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "123", "@user:local"), req: createSendRelayTxnHTTPRequest(cfg.Global.ServerName, "123", "@user:local"),
wantCode: 200, wantCode: 200,
}, },
} }
for _, tc := range testCases { for _, tc := range testCases {
w := httptest.NewRecorder() w := httptest.NewRecorder()
base.Routers.Federation.ServeHTTP(w, tc.req) routers.Federation.ServeHTTP(w, tc.req)
if w.Code != tc.wantCode { if w.Code != tc.wantCode {
t.Fatalf("%s: got HTTP %d want %d", tc.name, w.Code, tc.wantCode) t.Fatalf("%s: got HTTP %d want %d", tc.name, w.Code, tc.wantCode)
} }
@ -158,16 +163,19 @@ func TestCreateRelayPublicRoutes(t *testing.T) {
func TestDisableRelayPublicRoutes(t *testing.T) { func TestDisableRelayPublicRoutes(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
base, close := testrig.CreateBaseDendrite(t, dbType) cfg, processCtx, close := testrig.CreateConfig(t, dbType)
defer close() defer close()
routers := httputil.NewRouters()
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
relayAPI := relayapi.NewRelayInternalAPI(base.Cfg, base.ConnectionManager, nil, nil, nil, nil, false, caches) cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
relayAPI := relayapi.NewRelayInternalAPI(cfg, cm, nil, nil, nil, nil, false, caches)
assert.NotNil(t, relayAPI) assert.NotNil(t, relayAPI)
serverKeyAPI := &signing.YggdrasilKeys{} serverKeyAPI := &signing.YggdrasilKeys{}
keyRing := serverKeyAPI.KeyRing() keyRing := serverKeyAPI.KeyRing()
relayapi.AddPublicRoutes(base.Routers, base.Cfg, keyRing, relayAPI) relayapi.AddPublicRoutes(routers, cfg, keyRing, relayAPI)
testCases := []struct { testCases := []struct {
name string name string
@ -176,19 +184,19 @@ func TestDisableRelayPublicRoutes(t *testing.T) {
}{ }{
{ {
name: "relay_txn valid user id", name: "relay_txn valid user id",
req: createGetRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "@user:local"), req: createGetRelayTxnHTTPRequest(cfg.Global.ServerName, "@user:local"),
wantCode: 404, wantCode: 404,
}, },
{ {
name: "send_relay valid user id", name: "send_relay valid user id",
req: createSendRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "123", "@user:local"), req: createSendRelayTxnHTTPRequest(cfg.Global.ServerName, "123", "@user:local"),
wantCode: 404, wantCode: 404,
}, },
} }
for _, tc := range testCases { for _, tc := range testCases {
w := httptest.NewRecorder() w := httptest.NewRecorder()
base.Routers.Federation.ServeHTTP(w, tc.req) routers.Federation.ServeHTTP(w, tc.req)
if w.Code != tc.wantCode { if w.Code != tc.wantCode {
t.Fatalf("%s: got HTTP %d want %d", tc.name, w.Code, tc.wantCode) t.Fatalf("%s: got HTTP %d want %d", tc.name, w.Code, tc.wantCode)
} }