Add database table for mailservers

This commit is contained in:
Devon Hudson 2022-11-29 17:43:09 -07:00
parent 2df4b0750e
commit 0520a9b0ed
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
8 changed files with 506 additions and 0 deletions

View file

@ -62,6 +62,11 @@ type Database interface {
RemoveAllServersFromBlacklist() error
IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error)
AddMailserversForServer(serverName gomatrixserverlib.ServerName, mailservers []gomatrixserverlib.ServerName) error
GetMailserversForServer(serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error)
RemoveMailserversForServer(serverName gomatrixserverlib.ServerName, mailservers []gomatrixserverlib.ServerName) error
RemoveAllMailserversForServer(serverName gomatrixserverlib.ServerName) error
AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error
RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error
GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error)

View file

@ -0,0 +1,147 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const mailserversSchema = `
CREATE TABLE IF NOT EXISTS federationsender_mailservers (
-- The destination server name
server_name TEXT NOT NULL,
-- The mailserver name for a given destination
mailserver_name TEXT NOT NULL,
UNIQUE (server_name, mailserver_name)
);
CREATE INDEX IF NOT EXISTS federationsender_mailservers_server_name_idx
ON federationsender_mailservers (server_name);
`
const insertMailserversSQL = "" +
"INSERT INTO federationsender_mailservers (server_name, mailserver_name) VALUES ($1, $2)" +
" ON CONFLICT DO NOTHING"
const selectMailserversSQL = "" +
"SELECT mailserver_name FROM federationsender_mailservers WHERE server_name = $1"
const deleteMailserversSQL = "" +
"DELETE FROM federationsender_mailservers WHERE server_name = $1 AND mailserver_name IN ($2)"
const deleteAllMailserversSQL = "" +
"DELETE FROM federationsender_mailservers WHERE server_name = $1"
type mailserversStatements struct {
db *sql.DB
insertMailserversStmt *sql.Stmt
selectMailserversStmt *sql.Stmt
deleteMailserversStmt *sql.Stmt
deleteAllMailserversStmt *sql.Stmt
}
func NewPostgresMailserversTable(db *sql.DB) (s *mailserversStatements, err error) {
s = &mailserversStatements{
db: db,
}
_, err = db.Exec(mailserversSchema)
if err != nil {
return
}
if s.insertMailserversStmt, err = db.Prepare(insertMailserversSQL); err != nil {
return
}
if s.selectMailserversStmt, err = db.Prepare(selectMailserversSQL); err != nil {
return
}
if s.deleteMailserversStmt, err = db.Prepare(deleteMailserversSQL); err != nil {
return
}
if s.deleteAllMailserversStmt, err = db.Prepare(deleteAllMailserversSQL); err != nil {
return
}
return
}
func (s *mailserversStatements) InsertMailservers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
mailservers []gomatrixserverlib.ServerName,
) error {
for _, mailserver := range mailservers {
stmt := sqlutil.TxStmt(txn, s.insertMailserversStmt)
if _, err := stmt.ExecContext(ctx, serverName, mailserver); err != nil {
return err
}
}
return nil
}
func (s *mailserversStatements) SelectMailservers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
) ([]gomatrixserverlib.ServerName, error) {
stmt := sqlutil.TxStmt(txn, s.selectMailserversStmt)
rows, err := stmt.QueryContext(ctx, serverName)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectMailservers: rows.close() failed")
var result []gomatrixserverlib.ServerName
for rows.Next() {
var mailserver string
if err = rows.Scan(&mailserver); err != nil {
return nil, err
}
result = append(result, gomatrixserverlib.ServerName(mailserver))
}
return result, nil
}
func (s *mailserversStatements) DeleteMailservers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
mailservers []gomatrixserverlib.ServerName,
) error {
for _, mailserver := range mailservers {
stmt := sqlutil.TxStmt(txn, s.deleteMailserversStmt)
if _, err := stmt.ExecContext(ctx, serverName, mailserver); err != nil {
return err
}
}
return nil
}
func (s *mailserversStatements) DeleteAllMailservers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteAllMailserversStmt)
if _, err := stmt.ExecContext(ctx, serverName); err != nil {
return err
}
return nil
}

View file

@ -70,6 +70,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil {
return nil, err
}
mailservers, err := NewPostgresMailserversTable(d.db)
if err != nil {
return nil, err
}
inboundPeeks, err := NewPostgresInboundPeeksTable(d.db)
if err != nil {
return nil, err
@ -114,6 +118,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
FederationQueueTransactions: queueTransactions,
FederationTransactionJSON: transactionJSON,
FederationBlacklist: blacklist,
FederationMailservers: mailservers,
FederationInboundPeeks: inboundPeeks,
FederationOutboundPeeks: outboundPeeks,
NotaryServerKeysJSON: notaryJSON,

View file

@ -38,6 +38,7 @@ type Database struct {
FederationQueueJSON tables.FederationQueueJSON
FederationJoinedHosts tables.FederationJoinedHosts
FederationBlacklist tables.FederationBlacklist
FederationMailservers tables.FederationMailservers
FederationOutboundPeeks tables.FederationOutboundPeeks
FederationInboundPeeks tables.FederationInboundPeeks
NotaryServerKeysJSON tables.FederationNotaryServerKeysJSON
@ -177,6 +178,28 @@ func (d *Database) IsServerBlacklisted(serverName gomatrixserverlib.ServerName)
return d.FederationBlacklist.SelectBlacklist(context.TODO(), nil, serverName)
}
func (d *Database) AddMailserversForServer(serverName gomatrixserverlib.ServerName, mailservers []gomatrixserverlib.ServerName) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationMailservers.InsertMailservers(context.TODO(), txn, serverName, mailservers)
})
}
func (d *Database) GetMailserversForServer(serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error) {
return d.FederationMailservers.SelectMailservers(context.TODO(), nil, serverName)
}
func (d *Database) RemoveMailserversForServer(serverName gomatrixserverlib.ServerName, mailservers []gomatrixserverlib.ServerName) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationMailservers.DeleteMailservers(context.TODO(), txn, serverName, mailservers)
})
}
func (d *Database) RemoveAllMailserversForServer(serverName gomatrixserverlib.ServerName) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationMailservers.DeleteAllMailservers(context.TODO(), txn, serverName)
})
}
func (d *Database) AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationOutboundPeeks.InsertOutboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval)

View file

@ -0,0 +1,147 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const mailserversSchema = `
CREATE TABLE IF NOT EXISTS federationsender_mailservers (
-- The destination server name
server_name TEXT NOT NULL,
-- The mailserver name for a given destination
mailserver_name TEXT NOT NULL,
UNIQUE (server_name, mailserver_name)
);
CREATE INDEX IF NOT EXISTS federationsender_mailservers_server_name_idx
ON federationsender_mailservers (server_name);
`
const insertMailserversSQL = "" +
"INSERT INTO federationsender_mailservers (server_name, mailserver_name) VALUES ($1, $2)" +
" ON CONFLICT DO NOTHING"
const selectMailserversSQL = "" +
"SELECT mailserver_name FROM federationsender_mailservers WHERE server_name = $1"
const deleteMailserversSQL = "" +
"DELETE FROM federationsender_mailservers WHERE server_name = $1 AND mailserver_name IN ($2)"
const deleteAllMailserversSQL = "" +
"DELETE FROM federationsender_mailservers WHERE server_name = $1"
type mailserversStatements struct {
db *sql.DB
insertMailserversStmt *sql.Stmt
selectMailserversStmt *sql.Stmt
deleteMailserversStmt *sql.Stmt
deleteAllMailserversStmt *sql.Stmt
}
func NewSQLiteMailserversTable(db *sql.DB) (s *mailserversStatements, err error) {
s = &mailserversStatements{
db: db,
}
_, err = db.Exec(mailserversSchema)
if err != nil {
return
}
if s.insertMailserversStmt, err = db.Prepare(insertMailserversSQL); err != nil {
return
}
if s.selectMailserversStmt, err = db.Prepare(selectMailserversSQL); err != nil {
return
}
if s.deleteMailserversStmt, err = db.Prepare(deleteMailserversSQL); err != nil {
return
}
if s.deleteAllMailserversStmt, err = db.Prepare(deleteAllMailserversSQL); err != nil {
return
}
return
}
func (s *mailserversStatements) InsertMailservers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
mailservers []gomatrixserverlib.ServerName,
) error {
for _, mailserver := range mailservers {
stmt := sqlutil.TxStmt(txn, s.insertMailserversStmt)
if _, err := stmt.ExecContext(ctx, serverName, mailserver); err != nil {
return err
}
}
return nil
}
func (s *mailserversStatements) SelectMailservers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
) ([]gomatrixserverlib.ServerName, error) {
stmt := sqlutil.TxStmt(txn, s.selectMailserversStmt)
rows, err := stmt.QueryContext(ctx, serverName)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectMailservers: rows.close() failed")
var result []gomatrixserverlib.ServerName
for rows.Next() {
var mailserver string
if err = rows.Scan(&mailserver); err != nil {
return nil, err
}
result = append(result, gomatrixserverlib.ServerName(mailserver))
}
return result, nil
}
func (s *mailserversStatements) DeleteMailservers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
mailservers []gomatrixserverlib.ServerName,
) error {
for _, mailserver := range mailservers {
stmt := sqlutil.TxStmt(txn, s.deleteMailserversStmt)
if _, err := stmt.ExecContext(ctx, serverName, mailserver); err != nil {
return err
}
}
return nil
}
func (s *mailserversStatements) DeleteAllMailservers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteAllMailserversStmt)
if _, err := stmt.ExecContext(ctx, serverName); err != nil {
return err
}
return nil
}

View file

@ -63,6 +63,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil {
return nil, err
}
mailservers, err := NewSQLiteMailserversTable(d.db)
if err != nil {
return nil, err
}
outboundPeeks, err := NewSQLiteOutboundPeeksTable(d.db)
if err != nil {
return nil, err
@ -107,6 +111,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
FederationQueueTransactions: queueTransactions,
FederationTransactionJSON: transactionJSON,
FederationBlacklist: blacklist,
FederationMailservers: mailservers,
FederationOutboundPeeks: outboundPeeks,
FederationInboundPeeks: inboundPeeks,
NotaryServerKeysJSON: notaryKeys,

View file

@ -81,6 +81,13 @@ type FederationBlacklist interface {
DeleteAllBlacklist(ctx context.Context, txn *sql.Tx) error
}
type FederationMailservers interface {
InsertMailservers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, mailservers []gomatrixserverlib.ServerName) error
SelectMailservers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error)
DeleteMailservers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, mailservers []gomatrixserverlib.ServerName) error
DeleteAllMailservers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error
}
type FederationOutboundPeeks interface {
InsertOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error)
RenewOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error)

View file

@ -0,0 +1,167 @@
package tables_test
import (
"context"
"database/sql"
"testing"
"github.com/matrix-org/dendrite/federationapi/storage/postgres"
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
"github.com/matrix-org/dendrite/federationapi/storage/tables"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
"github.com/stretchr/testify/assert"
)
const (
server1 = "server1"
server2 = "server2"
server3 = "server3"
)
type MailserversDatabase struct {
DB *sql.DB
Writer sqlutil.Writer
Table tables.FederationMailservers
}
func mustCreateMailserversTable(t *testing.T, dbType test.DBType) (database MailserversDatabase, close func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
assert.NoError(t, err)
var tab tables.FederationMailservers
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresMailserversTable(db)
assert.NoError(t, err)
case test.DBTypeSQLite:
tab, err = sqlite3.NewSQLiteMailserversTable(db)
assert.NoError(t, err)
}
assert.NoError(t, err)
database = MailserversDatabase{
DB: db,
Writer: sqlutil.NewDummyWriter(),
Table: tab,
}
return database, close
}
func Equal(a, b []gomatrixserverlib.ServerName) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if v != b[i] {
return false
}
}
return true
}
func TestShouldInsertMailservers(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateMailserversTable(t, dbType)
defer close()
expectedMailservers := []gomatrixserverlib.ServerName{server2, server3}
err := db.Table.InsertMailservers(ctx, nil, server1, expectedMailservers)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
mailservers, err := db.Table.SelectMailservers(ctx, nil, server1)
if err != nil {
t.Fatalf("Failed retrieving mailservers for %s: %s", mailservers, err.Error())
}
if !Equal(mailservers, expectedMailservers) {
t.Fatalf("Expected: %v \nActual: %v", expectedMailservers, mailservers)
}
})
}
func TestShouldDeleteCorrectMailservers(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateMailserversTable(t, dbType)
defer close()
expectedMailservers := []gomatrixserverlib.ServerName{server2, server3}
err := db.Table.InsertMailservers(ctx, nil, server1, expectedMailservers)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
err = db.Table.InsertMailservers(ctx, nil, server2, expectedMailservers)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
err = db.Table.DeleteMailservers(ctx, nil, server1, []gomatrixserverlib.ServerName{server2})
if err != nil {
t.Fatalf("Failed deleting mailservers for %s: %s", server1, err.Error())
}
expectedMailservers1 := []gomatrixserverlib.ServerName{server3}
mailservers, err := db.Table.SelectMailservers(ctx, nil, server1)
if err != nil {
t.Fatalf("Failed retrieving mailservers for %s: %s", mailservers, err.Error())
}
if !Equal(mailservers, expectedMailservers1) {
t.Fatalf("Expected: %v \nActual: %v", expectedMailservers1, mailservers)
}
mailservers, err = db.Table.SelectMailservers(ctx, nil, server2)
if err != nil {
t.Fatalf("Failed retrieving mailservers for %s: %s", mailservers, err.Error())
}
if !Equal(mailservers, expectedMailservers) {
t.Fatalf("Expected: %v \nActual: %v", expectedMailservers, mailservers)
}
})
}
func TestShouldDeleteAllMailservers(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateMailserversTable(t, dbType)
defer close()
expectedMailservers := []gomatrixserverlib.ServerName{server2, server3}
err := db.Table.InsertMailservers(ctx, nil, server1, expectedMailservers)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
err = db.Table.InsertMailservers(ctx, nil, server2, expectedMailservers)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
err = db.Table.DeleteAllMailservers(ctx, nil, server1)
if err != nil {
t.Fatalf("Failed deleting mailservers for %s: %s", server1, err.Error())
}
expectedMailservers1 := []gomatrixserverlib.ServerName{}
mailservers, err := db.Table.SelectMailservers(ctx, nil, server1)
if err != nil {
t.Fatalf("Failed retrieving mailservers for %s: %s", mailservers, err.Error())
}
if !Equal(mailservers, expectedMailservers1) {
t.Fatalf("Expected: %v \nActual: %v", expectedMailservers1, mailservers)
}
mailservers, err = db.Table.SelectMailservers(ctx, nil, server2)
if err != nil {
t.Fatalf("Failed retrieving mailservers for %s: %s", mailservers, err.Error())
}
if !Equal(mailservers, expectedMailservers) {
t.Fatalf("Expected: %v \nActual: %v", expectedMailservers, mailservers)
}
})
}