diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index 2597f0830..c7dee9fd4 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -62,6 +62,11 @@ type Database interface { RemoveAllServersFromBlacklist() error IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) + AddMailserversForServer(serverName gomatrixserverlib.ServerName, mailservers []gomatrixserverlib.ServerName) error + GetMailserversForServer(serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error) + RemoveMailserversForServer(serverName gomatrixserverlib.ServerName, mailservers []gomatrixserverlib.ServerName) error + RemoveAllMailserversForServer(serverName gomatrixserverlib.ServerName) error + AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error) diff --git a/federationapi/storage/postgres/mailservers_table.go b/federationapi/storage/postgres/mailservers_table.go new file mode 100644 index 000000000..b8f52c5fe --- /dev/null +++ b/federationapi/storage/postgres/mailservers_table.go @@ -0,0 +1,147 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const mailserversSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_mailservers ( + -- The destination server name + server_name TEXT NOT NULL, + -- The mailserver name for a given destination + mailserver_name TEXT NOT NULL, + UNIQUE (server_name, mailserver_name) +); + +CREATE INDEX IF NOT EXISTS federationsender_mailservers_server_name_idx + ON federationsender_mailservers (server_name); +` + +const insertMailserversSQL = "" + + "INSERT INTO federationsender_mailservers (server_name, mailserver_name) VALUES ($1, $2)" + + " ON CONFLICT DO NOTHING" + +const selectMailserversSQL = "" + + "SELECT mailserver_name FROM federationsender_mailservers WHERE server_name = $1" + +const deleteMailserversSQL = "" + + "DELETE FROM federationsender_mailservers WHERE server_name = $1 AND mailserver_name IN ($2)" + +const deleteAllMailserversSQL = "" + + "DELETE FROM federationsender_mailservers WHERE server_name = $1" + +type mailserversStatements struct { + db *sql.DB + insertMailserversStmt *sql.Stmt + selectMailserversStmt *sql.Stmt + deleteMailserversStmt *sql.Stmt + deleteAllMailserversStmt *sql.Stmt +} + +func NewPostgresMailserversTable(db *sql.DB) (s *mailserversStatements, err error) { + s = &mailserversStatements{ + db: db, + } + _, err = db.Exec(mailserversSchema) + if err != nil { + return + } + + if s.insertMailserversStmt, err = db.Prepare(insertMailserversSQL); err != nil { + return + } + if s.selectMailserversStmt, err = db.Prepare(selectMailserversSQL); err != nil { + return + } + if s.deleteMailserversStmt, err = db.Prepare(deleteMailserversSQL); err != nil { + return + } + if s.deleteAllMailserversStmt, err = db.Prepare(deleteAllMailserversSQL); err != nil { + return + } + return +} + +func (s *mailserversStatements) InsertMailservers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + mailservers []gomatrixserverlib.ServerName, +) error { + for _, mailserver := range mailservers { + stmt := sqlutil.TxStmt(txn, s.insertMailserversStmt) + if _, err := stmt.ExecContext(ctx, serverName, mailserver); err != nil { + return err + } + } + return nil +} + +func (s *mailserversStatements) SelectMailservers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) ([]gomatrixserverlib.ServerName, error) { + stmt := sqlutil.TxStmt(txn, s.selectMailserversStmt) + rows, err := stmt.QueryContext(ctx, serverName) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectMailservers: rows.close() failed") + + var result []gomatrixserverlib.ServerName + for rows.Next() { + var mailserver string + if err = rows.Scan(&mailserver); err != nil { + return nil, err + } + result = append(result, gomatrixserverlib.ServerName(mailserver)) + } + return result, nil +} + +func (s *mailserversStatements) DeleteMailservers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + mailservers []gomatrixserverlib.ServerName, +) error { + for _, mailserver := range mailservers { + stmt := sqlutil.TxStmt(txn, s.deleteMailserversStmt) + if _, err := stmt.ExecContext(ctx, serverName, mailserver); err != nil { + return err + } + } + return nil +} + +func (s *mailserversStatements) DeleteAllMailservers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAllMailserversStmt) + if _, err := stmt.ExecContext(ctx, serverName); err != nil { + return err + } + return nil +} diff --git a/federationapi/storage/postgres/storage.go b/federationapi/storage/postgres/storage.go index 8e38603ed..d32528f56 100644 --- a/federationapi/storage/postgres/storage.go +++ b/federationapi/storage/postgres/storage.go @@ -70,6 +70,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, err } + mailservers, err := NewPostgresMailserversTable(d.db) + if err != nil { + return nil, err + } inboundPeeks, err := NewPostgresInboundPeeksTable(d.db) if err != nil { return nil, err @@ -114,6 +118,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, FederationQueueTransactions: queueTransactions, FederationTransactionJSON: transactionJSON, FederationBlacklist: blacklist, + FederationMailservers: mailservers, FederationInboundPeeks: inboundPeeks, FederationOutboundPeeks: outboundPeeks, NotaryServerKeysJSON: notaryJSON, diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index dae25b197..a85dcff0b 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -38,6 +38,7 @@ type Database struct { FederationQueueJSON tables.FederationQueueJSON FederationJoinedHosts tables.FederationJoinedHosts FederationBlacklist tables.FederationBlacklist + FederationMailservers tables.FederationMailservers FederationOutboundPeeks tables.FederationOutboundPeeks FederationInboundPeeks tables.FederationInboundPeeks NotaryServerKeysJSON tables.FederationNotaryServerKeysJSON @@ -177,6 +178,28 @@ func (d *Database) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) return d.FederationBlacklist.SelectBlacklist(context.TODO(), nil, serverName) } +func (d *Database) AddMailserversForServer(serverName gomatrixserverlib.ServerName, mailservers []gomatrixserverlib.ServerName) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationMailservers.InsertMailservers(context.TODO(), txn, serverName, mailservers) + }) +} + +func (d *Database) GetMailserversForServer(serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error) { + return d.FederationMailservers.SelectMailservers(context.TODO(), nil, serverName) +} + +func (d *Database) RemoveMailserversForServer(serverName gomatrixserverlib.ServerName, mailservers []gomatrixserverlib.ServerName) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationMailservers.DeleteMailservers(context.TODO(), txn, serverName, mailservers) + }) +} + +func (d *Database) RemoveAllMailserversForServer(serverName gomatrixserverlib.ServerName) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationMailservers.DeleteAllMailservers(context.TODO(), txn, serverName) + }) +} + func (d *Database) AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationOutboundPeeks.InsertOutboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval) diff --git a/federationapi/storage/sqlite3/mailservers_table.go b/federationapi/storage/sqlite3/mailservers_table.go new file mode 100644 index 000000000..e40738476 --- /dev/null +++ b/federationapi/storage/sqlite3/mailservers_table.go @@ -0,0 +1,147 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const mailserversSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_mailservers ( + -- The destination server name + server_name TEXT NOT NULL, + -- The mailserver name for a given destination + mailserver_name TEXT NOT NULL, + UNIQUE (server_name, mailserver_name) +); + +CREATE INDEX IF NOT EXISTS federationsender_mailservers_server_name_idx + ON federationsender_mailservers (server_name); +` + +const insertMailserversSQL = "" + + "INSERT INTO federationsender_mailservers (server_name, mailserver_name) VALUES ($1, $2)" + + " ON CONFLICT DO NOTHING" + +const selectMailserversSQL = "" + + "SELECT mailserver_name FROM federationsender_mailservers WHERE server_name = $1" + +const deleteMailserversSQL = "" + + "DELETE FROM federationsender_mailservers WHERE server_name = $1 AND mailserver_name IN ($2)" + +const deleteAllMailserversSQL = "" + + "DELETE FROM federationsender_mailservers WHERE server_name = $1" + +type mailserversStatements struct { + db *sql.DB + insertMailserversStmt *sql.Stmt + selectMailserversStmt *sql.Stmt + deleteMailserversStmt *sql.Stmt + deleteAllMailserversStmt *sql.Stmt +} + +func NewSQLiteMailserversTable(db *sql.DB) (s *mailserversStatements, err error) { + s = &mailserversStatements{ + db: db, + } + _, err = db.Exec(mailserversSchema) + if err != nil { + return + } + + if s.insertMailserversStmt, err = db.Prepare(insertMailserversSQL); err != nil { + return + } + if s.selectMailserversStmt, err = db.Prepare(selectMailserversSQL); err != nil { + return + } + if s.deleteMailserversStmt, err = db.Prepare(deleteMailserversSQL); err != nil { + return + } + if s.deleteAllMailserversStmt, err = db.Prepare(deleteAllMailserversSQL); err != nil { + return + } + return +} + +func (s *mailserversStatements) InsertMailservers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + mailservers []gomatrixserverlib.ServerName, +) error { + for _, mailserver := range mailservers { + stmt := sqlutil.TxStmt(txn, s.insertMailserversStmt) + if _, err := stmt.ExecContext(ctx, serverName, mailserver); err != nil { + return err + } + } + return nil +} + +func (s *mailserversStatements) SelectMailservers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) ([]gomatrixserverlib.ServerName, error) { + stmt := sqlutil.TxStmt(txn, s.selectMailserversStmt) + rows, err := stmt.QueryContext(ctx, serverName) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectMailservers: rows.close() failed") + + var result []gomatrixserverlib.ServerName + for rows.Next() { + var mailserver string + if err = rows.Scan(&mailserver); err != nil { + return nil, err + } + result = append(result, gomatrixserverlib.ServerName(mailserver)) + } + return result, nil +} + +func (s *mailserversStatements) DeleteMailservers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + mailservers []gomatrixserverlib.ServerName, +) error { + for _, mailserver := range mailservers { + stmt := sqlutil.TxStmt(txn, s.deleteMailserversStmt) + if _, err := stmt.ExecContext(ctx, serverName, mailserver); err != nil { + return err + } + } + return nil +} + +func (s *mailserversStatements) DeleteAllMailservers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAllMailserversStmt) + if _, err := stmt.ExecContext(ctx, serverName); err != nil { + return err + } + return nil +} diff --git a/federationapi/storage/sqlite3/storage.go b/federationapi/storage/sqlite3/storage.go index e8fa9a0b6..e1120c3a5 100644 --- a/federationapi/storage/sqlite3/storage.go +++ b/federationapi/storage/sqlite3/storage.go @@ -63,6 +63,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, err } + mailservers, err := NewSQLiteMailserversTable(d.db) + if err != nil { + return nil, err + } outboundPeeks, err := NewSQLiteOutboundPeeksTable(d.db) if err != nil { return nil, err @@ -107,6 +111,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, FederationQueueTransactions: queueTransactions, FederationTransactionJSON: transactionJSON, FederationBlacklist: blacklist, + FederationMailservers: mailservers, FederationOutboundPeeks: outboundPeeks, FederationInboundPeeks: inboundPeeks, NotaryServerKeysJSON: notaryKeys, diff --git a/federationapi/storage/tables/interface.go b/federationapi/storage/tables/interface.go index 37c7bb299..e77163221 100644 --- a/federationapi/storage/tables/interface.go +++ b/federationapi/storage/tables/interface.go @@ -81,6 +81,13 @@ type FederationBlacklist interface { DeleteAllBlacklist(ctx context.Context, txn *sql.Tx) error } +type FederationMailservers interface { + InsertMailservers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, mailservers []gomatrixserverlib.ServerName) error + SelectMailservers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error) + DeleteMailservers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, mailservers []gomatrixserverlib.ServerName) error + DeleteAllMailservers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error +} + type FederationOutboundPeeks interface { InsertOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error) RenewOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error) diff --git a/federationapi/storage/tables/mailservers_table_test.go b/federationapi/storage/tables/mailservers_table_test.go new file mode 100644 index 000000000..2def88d97 --- /dev/null +++ b/federationapi/storage/tables/mailservers_table_test.go @@ -0,0 +1,167 @@ +package tables_test + +import ( + "context" + "database/sql" + "testing" + + "github.com/matrix-org/dendrite/federationapi/storage/postgres" + "github.com/matrix-org/dendrite/federationapi/storage/sqlite3" + "github.com/matrix-org/dendrite/federationapi/storage/tables" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +const ( + server1 = "server1" + server2 = "server2" + server3 = "server3" +) + +type MailserversDatabase struct { + DB *sql.DB + Writer sqlutil.Writer + Table tables.FederationMailservers +} + +func mustCreateMailserversTable(t *testing.T, dbType test.DBType) (database MailserversDatabase, close func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + var tab tables.FederationMailservers + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresMailserversTable(db) + assert.NoError(t, err) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSQLiteMailserversTable(db) + assert.NoError(t, err) + } + assert.NoError(t, err) + + database = MailserversDatabase{ + DB: db, + Writer: sqlutil.NewDummyWriter(), + Table: tab, + } + return database, close +} + +func Equal(a, b []gomatrixserverlib.ServerName) bool { + if len(a) != len(b) { + return false + } + for i, v := range a { + if v != b[i] { + return false + } + } + return true +} + +func TestShouldInsertMailservers(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateMailserversTable(t, dbType) + defer close() + expectedMailservers := []gomatrixserverlib.ServerName{server2, server3} + + err := db.Table.InsertMailservers(ctx, nil, server1, expectedMailservers) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + mailservers, err := db.Table.SelectMailservers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed retrieving mailservers for %s: %s", mailservers, err.Error()) + } + + if !Equal(mailservers, expectedMailservers) { + t.Fatalf("Expected: %v \nActual: %v", expectedMailservers, mailservers) + } + }) +} + +func TestShouldDeleteCorrectMailservers(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateMailserversTable(t, dbType) + defer close() + expectedMailservers := []gomatrixserverlib.ServerName{server2, server3} + + err := db.Table.InsertMailservers(ctx, nil, server1, expectedMailservers) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + err = db.Table.InsertMailservers(ctx, nil, server2, expectedMailservers) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + err = db.Table.DeleteMailservers(ctx, nil, server1, []gomatrixserverlib.ServerName{server2}) + if err != nil { + t.Fatalf("Failed deleting mailservers for %s: %s", server1, err.Error()) + } + + expectedMailservers1 := []gomatrixserverlib.ServerName{server3} + mailservers, err := db.Table.SelectMailservers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed retrieving mailservers for %s: %s", mailservers, err.Error()) + } + if !Equal(mailservers, expectedMailservers1) { + t.Fatalf("Expected: %v \nActual: %v", expectedMailservers1, mailservers) + } + mailservers, err = db.Table.SelectMailservers(ctx, nil, server2) + if err != nil { + t.Fatalf("Failed retrieving mailservers for %s: %s", mailservers, err.Error()) + } + if !Equal(mailservers, expectedMailservers) { + t.Fatalf("Expected: %v \nActual: %v", expectedMailservers, mailservers) + } + }) +} + +func TestShouldDeleteAllMailservers(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateMailserversTable(t, dbType) + defer close() + expectedMailservers := []gomatrixserverlib.ServerName{server2, server3} + + err := db.Table.InsertMailservers(ctx, nil, server1, expectedMailservers) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + err = db.Table.InsertMailservers(ctx, nil, server2, expectedMailservers) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + err = db.Table.DeleteAllMailservers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed deleting mailservers for %s: %s", server1, err.Error()) + } + + expectedMailservers1 := []gomatrixserverlib.ServerName{} + mailservers, err := db.Table.SelectMailservers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed retrieving mailservers for %s: %s", mailservers, err.Error()) + } + if !Equal(mailservers, expectedMailservers1) { + t.Fatalf("Expected: %v \nActual: %v", expectedMailservers1, mailservers) + } + mailservers, err = db.Table.SelectMailservers(ctx, nil, server2) + if err != nil { + t.Fatalf("Failed retrieving mailservers for %s: %s", mailservers, err.Error()) + } + if !Equal(mailservers, expectedMailservers) { + t.Fatalf("Expected: %v \nActual: %v", expectedMailservers, mailservers) + } + }) +}