mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-16 18:43:10 -06:00
Remove BaseDendrite from relay API
This commit is contained in:
parent
d4db4ed40d
commit
9b06041052
|
|
@ -57,6 +57,8 @@ func (c *Connections) Connection(dbProperties *config.DatabaseOptions) (*sql.DB,
|
||||||
if c.processContext == nil {
|
if c.processContext == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// If we have a ProcessContext, start a component and wait for
|
||||||
|
// Dendrite to shut down to cleanly close the database connection.
|
||||||
c.processContext.ComponentStarted()
|
c.processContext.ComponentStarted()
|
||||||
<-c.processContext.WaitForShutdown()
|
<-c.processContext.WaitForShutdown()
|
||||||
_ = c.db.Close()
|
_ = c.db.Close()
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,8 @@ import (
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing"
|
"github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing"
|
||||||
"github.com/matrix-org/dendrite/internal/caching"
|
"github.com/matrix-org/dendrite/internal/caching"
|
||||||
|
"github.com/matrix-org/dendrite/internal/httputil"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/relayapi"
|
"github.com/matrix-org/dendrite/relayapi"
|
||||||
"github.com/matrix-org/dendrite/test"
|
"github.com/matrix-org/dendrite/test"
|
||||||
"github.com/matrix-org/dendrite/test/testrig"
|
"github.com/matrix-org/dendrite/test/testrig"
|
||||||
|
|
@ -35,38 +37,38 @@ import (
|
||||||
|
|
||||||
func TestCreateNewRelayInternalAPI(t *testing.T) {
|
func TestCreateNewRelayInternalAPI(t *testing.T) {
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
base, close := testrig.CreateBaseDendrite(t, dbType)
|
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
|
||||||
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
||||||
defer close()
|
defer close()
|
||||||
|
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
||||||
relayAPI := relayapi.NewRelayInternalAPI(base.Cfg, base.ConnectionManager, nil, nil, nil, nil, true, caches)
|
relayAPI := relayapi.NewRelayInternalAPI(cfg, cm, nil, nil, nil, nil, true, caches)
|
||||||
assert.NotNil(t, relayAPI)
|
assert.NotNil(t, relayAPI)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateRelayInternalInvalidDatabasePanics(t *testing.T) {
|
func TestCreateRelayInternalInvalidDatabasePanics(t *testing.T) {
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
base, close := testrig.CreateBaseDendrite(t, dbType)
|
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
|
||||||
if dbType == test.DBTypeSQLite {
|
if dbType == test.DBTypeSQLite {
|
||||||
base.Cfg.RelayAPI.Database.ConnectionString = "file:"
|
cfg.RelayAPI.Database.ConnectionString = "file:"
|
||||||
} else {
|
} else {
|
||||||
base.Cfg.RelayAPI.Database.ConnectionString = "test"
|
cfg.RelayAPI.Database.ConnectionString = "test"
|
||||||
}
|
}
|
||||||
defer close()
|
defer close()
|
||||||
|
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
||||||
assert.Panics(t, func() {
|
assert.Panics(t, func() {
|
||||||
relayapi.NewRelayInternalAPI(base.Cfg, base.ConnectionManager, nil, nil, nil, nil, true, nil)
|
relayapi.NewRelayInternalAPI(cfg, cm, nil, nil, nil, nil, true, nil)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateInvalidRelayPublicRoutesPanics(t *testing.T) {
|
func TestCreateInvalidRelayPublicRoutesPanics(t *testing.T) {
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
base, close := testrig.CreateBaseDendrite(t, dbType)
|
cfg, _, close := testrig.CreateConfig(t, dbType)
|
||||||
defer close()
|
defer close()
|
||||||
|
routers := httputil.NewRouters()
|
||||||
assert.Panics(t, func() {
|
assert.Panics(t, func() {
|
||||||
relayapi.AddPublicRoutes(base.Routers, base.Cfg, nil, nil)
|
relayapi.AddPublicRoutes(routers, cfg, nil, nil)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
@ -108,16 +110,19 @@ func createSendRelayTxnHTTPRequest(serverName gomatrixserverlib.ServerName, txnI
|
||||||
|
|
||||||
func TestCreateRelayPublicRoutes(t *testing.T) {
|
func TestCreateRelayPublicRoutes(t *testing.T) {
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
base, close := testrig.CreateBaseDendrite(t, dbType)
|
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
|
||||||
defer close()
|
defer close()
|
||||||
|
routers := httputil.NewRouters()
|
||||||
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
||||||
|
|
||||||
relayAPI := relayapi.NewRelayInternalAPI(base.Cfg, base.ConnectionManager, nil, nil, nil, nil, true, caches)
|
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
||||||
|
|
||||||
|
relayAPI := relayapi.NewRelayInternalAPI(cfg, cm, nil, nil, nil, nil, true, caches)
|
||||||
assert.NotNil(t, relayAPI)
|
assert.NotNil(t, relayAPI)
|
||||||
|
|
||||||
serverKeyAPI := &signing.YggdrasilKeys{}
|
serverKeyAPI := &signing.YggdrasilKeys{}
|
||||||
keyRing := serverKeyAPI.KeyRing()
|
keyRing := serverKeyAPI.KeyRing()
|
||||||
relayapi.AddPublicRoutes(base.Routers, base.Cfg, keyRing, relayAPI)
|
relayapi.AddPublicRoutes(routers, cfg, keyRing, relayAPI)
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
|
|
@ -126,29 +131,29 @@ func TestCreateRelayPublicRoutes(t *testing.T) {
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "relay_txn invalid user id",
|
name: "relay_txn invalid user id",
|
||||||
req: createGetRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "user:local"),
|
req: createGetRelayTxnHTTPRequest(cfg.Global.ServerName, "user:local"),
|
||||||
wantCode: 400,
|
wantCode: 400,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "relay_txn valid user id",
|
name: "relay_txn valid user id",
|
||||||
req: createGetRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "@user:local"),
|
req: createGetRelayTxnHTTPRequest(cfg.Global.ServerName, "@user:local"),
|
||||||
wantCode: 200,
|
wantCode: 200,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "send_relay invalid user id",
|
name: "send_relay invalid user id",
|
||||||
req: createSendRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "123", "user:local"),
|
req: createSendRelayTxnHTTPRequest(cfg.Global.ServerName, "123", "user:local"),
|
||||||
wantCode: 400,
|
wantCode: 400,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "send_relay valid user id",
|
name: "send_relay valid user id",
|
||||||
req: createSendRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "123", "@user:local"),
|
req: createSendRelayTxnHTTPRequest(cfg.Global.ServerName, "123", "@user:local"),
|
||||||
wantCode: 200,
|
wantCode: 200,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
base.Routers.Federation.ServeHTTP(w, tc.req)
|
routers.Federation.ServeHTTP(w, tc.req)
|
||||||
if w.Code != tc.wantCode {
|
if w.Code != tc.wantCode {
|
||||||
t.Fatalf("%s: got HTTP %d want %d", tc.name, w.Code, tc.wantCode)
|
t.Fatalf("%s: got HTTP %d want %d", tc.name, w.Code, tc.wantCode)
|
||||||
}
|
}
|
||||||
|
|
@ -158,16 +163,19 @@ func TestCreateRelayPublicRoutes(t *testing.T) {
|
||||||
|
|
||||||
func TestDisableRelayPublicRoutes(t *testing.T) {
|
func TestDisableRelayPublicRoutes(t *testing.T) {
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
base, close := testrig.CreateBaseDendrite(t, dbType)
|
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
|
||||||
defer close()
|
defer close()
|
||||||
|
routers := httputil.NewRouters()
|
||||||
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
||||||
|
|
||||||
relayAPI := relayapi.NewRelayInternalAPI(base.Cfg, base.ConnectionManager, nil, nil, nil, nil, false, caches)
|
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
||||||
|
|
||||||
|
relayAPI := relayapi.NewRelayInternalAPI(cfg, cm, nil, nil, nil, nil, false, caches)
|
||||||
assert.NotNil(t, relayAPI)
|
assert.NotNil(t, relayAPI)
|
||||||
|
|
||||||
serverKeyAPI := &signing.YggdrasilKeys{}
|
serverKeyAPI := &signing.YggdrasilKeys{}
|
||||||
keyRing := serverKeyAPI.KeyRing()
|
keyRing := serverKeyAPI.KeyRing()
|
||||||
relayapi.AddPublicRoutes(base.Routers, base.Cfg, keyRing, relayAPI)
|
relayapi.AddPublicRoutes(routers, cfg, keyRing, relayAPI)
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
|
|
@ -176,19 +184,19 @@ func TestDisableRelayPublicRoutes(t *testing.T) {
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "relay_txn valid user id",
|
name: "relay_txn valid user id",
|
||||||
req: createGetRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "@user:local"),
|
req: createGetRelayTxnHTTPRequest(cfg.Global.ServerName, "@user:local"),
|
||||||
wantCode: 404,
|
wantCode: 404,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "send_relay valid user id",
|
name: "send_relay valid user id",
|
||||||
req: createSendRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "123", "@user:local"),
|
req: createSendRelayTxnHTTPRequest(cfg.Global.ServerName, "123", "@user:local"),
|
||||||
wantCode: 404,
|
wantCode: 404,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
base.Routers.Federation.ServeHTTP(w, tc.req)
|
routers.Federation.ServeHTTP(w, tc.req)
|
||||||
if w.Code != tc.wantCode {
|
if w.Code != tc.wantCode {
|
||||||
t.Fatalf("%s: got HTTP %d want %d", tc.name, w.Code, tc.wantCode)
|
t.Fatalf("%s: got HTTP %d want %d", tc.name, w.Code, tc.wantCode)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue