dendrite/federationapi/storage/tables/relay_servers_table_test.go
2023-01-12 16:11:15 -07:00

171 lines
5.2 KiB
Go

package tables_test
import (
"context"
"database/sql"
"testing"
"github.com/matrix-org/dendrite/federationapi/storage/postgres"
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
"github.com/matrix-org/dendrite/federationapi/storage/tables"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
"github.com/stretchr/testify/assert"
)
const (
server1 = "server1"
server2 = "server2"
server3 = "server3"
)
type RelayServersDatabase struct {
DB *sql.DB
Writer sqlutil.Writer
Table tables.FederationRelayServers
}
func mustCreateRelayServersTable(
t *testing.T,
dbType test.DBType,
) (database RelayServersDatabase, close func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
assert.NoError(t, err)
var tab tables.FederationRelayServers
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresRelayServersTable(db)
assert.NoError(t, err)
case test.DBTypeSQLite:
tab, err = sqlite3.NewSQLiteRelayServersTable(db)
assert.NoError(t, err)
}
assert.NoError(t, err)
database = RelayServersDatabase{
DB: db,
Writer: sqlutil.NewDummyWriter(),
Table: tab,
}
return database, close
}
func Equal(a, b []gomatrixserverlib.ServerName) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if v != b[i] {
return false
}
}
return true
}
func TestShouldInsertRelayServers(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateRelayServersTable(t, dbType)
defer close()
expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3}
err := db.Table.InsertRelayServers(ctx, nil, server1, expectedRelayServers)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1)
if err != nil {
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
}
if !Equal(relayServers, expectedRelayServers) {
t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers)
}
})
}
func TestShouldDeleteCorrectRelayServers(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateRelayServersTable(t, dbType)
defer close()
expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3}
err := db.Table.InsertRelayServers(ctx, nil, server1, expectedRelayServers)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
err = db.Table.InsertRelayServers(ctx, nil, server2, expectedRelayServers)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
err = db.Table.DeleteRelayServers(ctx, nil, server1, []gomatrixserverlib.ServerName{server2})
if err != nil {
t.Fatalf("Failed deleting relay servers for %s: %s", server1, err.Error())
}
expectedRelayServers1 := []gomatrixserverlib.ServerName{server3}
relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1)
if err != nil {
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
}
if !Equal(relayServers, expectedRelayServers1) {
t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers1, relayServers)
}
relayServers, err = db.Table.SelectRelayServers(ctx, nil, server2)
if err != nil {
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
}
if !Equal(relayServers, expectedRelayServers) {
t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers)
}
})
}
func TestShouldDeleteAllRelayServers(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateRelayServersTable(t, dbType)
defer close()
expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3}
err := db.Table.InsertRelayServers(ctx, nil, server1, expectedRelayServers)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
err = db.Table.InsertRelayServers(ctx, nil, server2, expectedRelayServers)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
err = db.Table.DeleteAllRelayServers(ctx, nil, server1)
if err != nil {
t.Fatalf("Failed deleting relay servers for %s: %s", server1, err.Error())
}
expectedRelayServers1 := []gomatrixserverlib.ServerName{}
relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1)
if err != nil {
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
}
if !Equal(relayServers, expectedRelayServers1) {
t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers1, relayServers)
}
relayServers, err = db.Table.SelectRelayServers(ctx, nil, server2)
if err != nil {
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
}
if !Equal(relayServers, expectedRelayServers) {
t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers)
}
})
}