mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-16 18:43:10 -06:00
Add database table for mailservers
This commit is contained in:
parent
2df4b0750e
commit
0520a9b0ed
|
|
@ -62,6 +62,11 @@ type Database interface {
|
||||||
RemoveAllServersFromBlacklist() error
|
RemoveAllServersFromBlacklist() error
|
||||||
IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error)
|
IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error)
|
||||||
|
|
||||||
|
AddMailserversForServer(serverName gomatrixserverlib.ServerName, mailservers []gomatrixserverlib.ServerName) error
|
||||||
|
GetMailserversForServer(serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error)
|
||||||
|
RemoveMailserversForServer(serverName gomatrixserverlib.ServerName, mailservers []gomatrixserverlib.ServerName) error
|
||||||
|
RemoveAllMailserversForServer(serverName gomatrixserverlib.ServerName) error
|
||||||
|
|
||||||
AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error
|
AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error
|
||||||
RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error
|
RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error
|
||||||
GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error)
|
GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error)
|
||||||
|
|
|
||||||
147
federationapi/storage/postgres/mailservers_table.go
Normal file
147
federationapi/storage/postgres/mailservers_table.go
Normal file
|
|
@ -0,0 +1,147 @@
|
||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
const mailserversSchema = `
|
||||||
|
CREATE TABLE IF NOT EXISTS federationsender_mailservers (
|
||||||
|
-- The destination server name
|
||||||
|
server_name TEXT NOT NULL,
|
||||||
|
-- The mailserver name for a given destination
|
||||||
|
mailserver_name TEXT NOT NULL,
|
||||||
|
UNIQUE (server_name, mailserver_name)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS federationsender_mailservers_server_name_idx
|
||||||
|
ON federationsender_mailservers (server_name);
|
||||||
|
`
|
||||||
|
|
||||||
|
const insertMailserversSQL = "" +
|
||||||
|
"INSERT INTO federationsender_mailservers (server_name, mailserver_name) VALUES ($1, $2)" +
|
||||||
|
" ON CONFLICT DO NOTHING"
|
||||||
|
|
||||||
|
const selectMailserversSQL = "" +
|
||||||
|
"SELECT mailserver_name FROM federationsender_mailservers WHERE server_name = $1"
|
||||||
|
|
||||||
|
const deleteMailserversSQL = "" +
|
||||||
|
"DELETE FROM federationsender_mailservers WHERE server_name = $1 AND mailserver_name IN ($2)"
|
||||||
|
|
||||||
|
const deleteAllMailserversSQL = "" +
|
||||||
|
"DELETE FROM federationsender_mailservers WHERE server_name = $1"
|
||||||
|
|
||||||
|
type mailserversStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
insertMailserversStmt *sql.Stmt
|
||||||
|
selectMailserversStmt *sql.Stmt
|
||||||
|
deleteMailserversStmt *sql.Stmt
|
||||||
|
deleteAllMailserversStmt *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPostgresMailserversTable(db *sql.DB) (s *mailserversStatements, err error) {
|
||||||
|
s = &mailserversStatements{
|
||||||
|
db: db,
|
||||||
|
}
|
||||||
|
_, err = db.Exec(mailserversSchema)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.insertMailserversStmt, err = db.Prepare(insertMailserversSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.selectMailserversStmt, err = db.Prepare(selectMailserversSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.deleteMailserversStmt, err = db.Prepare(deleteMailserversSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.deleteAllMailserversStmt, err = db.Prepare(deleteAllMailserversSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *mailserversStatements) InsertMailservers(
|
||||||
|
ctx context.Context,
|
||||||
|
txn *sql.Tx,
|
||||||
|
serverName gomatrixserverlib.ServerName,
|
||||||
|
mailservers []gomatrixserverlib.ServerName,
|
||||||
|
) error {
|
||||||
|
for _, mailserver := range mailservers {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.insertMailserversStmt)
|
||||||
|
if _, err := stmt.ExecContext(ctx, serverName, mailserver); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *mailserversStatements) SelectMailservers(
|
||||||
|
ctx context.Context,
|
||||||
|
txn *sql.Tx,
|
||||||
|
serverName gomatrixserverlib.ServerName,
|
||||||
|
) ([]gomatrixserverlib.ServerName, error) {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.selectMailserversStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, serverName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "SelectMailservers: rows.close() failed")
|
||||||
|
|
||||||
|
var result []gomatrixserverlib.ServerName
|
||||||
|
for rows.Next() {
|
||||||
|
var mailserver string
|
||||||
|
if err = rows.Scan(&mailserver); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result = append(result, gomatrixserverlib.ServerName(mailserver))
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *mailserversStatements) DeleteMailservers(
|
||||||
|
ctx context.Context,
|
||||||
|
txn *sql.Tx,
|
||||||
|
serverName gomatrixserverlib.ServerName,
|
||||||
|
mailservers []gomatrixserverlib.ServerName,
|
||||||
|
) error {
|
||||||
|
for _, mailserver := range mailservers {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.deleteMailserversStmt)
|
||||||
|
if _, err := stmt.ExecContext(ctx, serverName, mailserver); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *mailserversStatements) DeleteAllMailservers(
|
||||||
|
ctx context.Context,
|
||||||
|
txn *sql.Tx,
|
||||||
|
serverName gomatrixserverlib.ServerName,
|
||||||
|
) error {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.deleteAllMailserversStmt)
|
||||||
|
if _, err := stmt.ExecContext(ctx, serverName); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
@ -70,6 +70,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
mailservers, err := NewPostgresMailserversTable(d.db)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
inboundPeeks, err := NewPostgresInboundPeeksTable(d.db)
|
inboundPeeks, err := NewPostgresInboundPeeksTable(d.db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -114,6 +118,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
||||||
FederationQueueTransactions: queueTransactions,
|
FederationQueueTransactions: queueTransactions,
|
||||||
FederationTransactionJSON: transactionJSON,
|
FederationTransactionJSON: transactionJSON,
|
||||||
FederationBlacklist: blacklist,
|
FederationBlacklist: blacklist,
|
||||||
|
FederationMailservers: mailservers,
|
||||||
FederationInboundPeeks: inboundPeeks,
|
FederationInboundPeeks: inboundPeeks,
|
||||||
FederationOutboundPeeks: outboundPeeks,
|
FederationOutboundPeeks: outboundPeeks,
|
||||||
NotaryServerKeysJSON: notaryJSON,
|
NotaryServerKeysJSON: notaryJSON,
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,7 @@ type Database struct {
|
||||||
FederationQueueJSON tables.FederationQueueJSON
|
FederationQueueJSON tables.FederationQueueJSON
|
||||||
FederationJoinedHosts tables.FederationJoinedHosts
|
FederationJoinedHosts tables.FederationJoinedHosts
|
||||||
FederationBlacklist tables.FederationBlacklist
|
FederationBlacklist tables.FederationBlacklist
|
||||||
|
FederationMailservers tables.FederationMailservers
|
||||||
FederationOutboundPeeks tables.FederationOutboundPeeks
|
FederationOutboundPeeks tables.FederationOutboundPeeks
|
||||||
FederationInboundPeeks tables.FederationInboundPeeks
|
FederationInboundPeeks tables.FederationInboundPeeks
|
||||||
NotaryServerKeysJSON tables.FederationNotaryServerKeysJSON
|
NotaryServerKeysJSON tables.FederationNotaryServerKeysJSON
|
||||||
|
|
@ -177,6 +178,28 @@ func (d *Database) IsServerBlacklisted(serverName gomatrixserverlib.ServerName)
|
||||||
return d.FederationBlacklist.SelectBlacklist(context.TODO(), nil, serverName)
|
return d.FederationBlacklist.SelectBlacklist(context.TODO(), nil, serverName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Database) AddMailserversForServer(serverName gomatrixserverlib.ServerName, mailservers []gomatrixserverlib.ServerName) error {
|
||||||
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.FederationMailservers.InsertMailservers(context.TODO(), txn, serverName, mailservers)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) GetMailserversForServer(serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error) {
|
||||||
|
return d.FederationMailservers.SelectMailservers(context.TODO(), nil, serverName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) RemoveMailserversForServer(serverName gomatrixserverlib.ServerName, mailservers []gomatrixserverlib.ServerName) error {
|
||||||
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.FederationMailservers.DeleteMailservers(context.TODO(), txn, serverName, mailservers)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) RemoveAllMailserversForServer(serverName gomatrixserverlib.ServerName) error {
|
||||||
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.FederationMailservers.DeleteAllMailservers(context.TODO(), txn, serverName)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (d *Database) AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error {
|
func (d *Database) AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error {
|
||||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
return d.FederationOutboundPeeks.InsertOutboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval)
|
return d.FederationOutboundPeeks.InsertOutboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval)
|
||||||
|
|
|
||||||
147
federationapi/storage/sqlite3/mailservers_table.go
Normal file
147
federationapi/storage/sqlite3/mailservers_table.go
Normal file
|
|
@ -0,0 +1,147 @@
|
||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
const mailserversSchema = `
|
||||||
|
CREATE TABLE IF NOT EXISTS federationsender_mailservers (
|
||||||
|
-- The destination server name
|
||||||
|
server_name TEXT NOT NULL,
|
||||||
|
-- The mailserver name for a given destination
|
||||||
|
mailserver_name TEXT NOT NULL,
|
||||||
|
UNIQUE (server_name, mailserver_name)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS federationsender_mailservers_server_name_idx
|
||||||
|
ON federationsender_mailservers (server_name);
|
||||||
|
`
|
||||||
|
|
||||||
|
const insertMailserversSQL = "" +
|
||||||
|
"INSERT INTO federationsender_mailservers (server_name, mailserver_name) VALUES ($1, $2)" +
|
||||||
|
" ON CONFLICT DO NOTHING"
|
||||||
|
|
||||||
|
const selectMailserversSQL = "" +
|
||||||
|
"SELECT mailserver_name FROM federationsender_mailservers WHERE server_name = $1"
|
||||||
|
|
||||||
|
const deleteMailserversSQL = "" +
|
||||||
|
"DELETE FROM federationsender_mailservers WHERE server_name = $1 AND mailserver_name IN ($2)"
|
||||||
|
|
||||||
|
const deleteAllMailserversSQL = "" +
|
||||||
|
"DELETE FROM federationsender_mailservers WHERE server_name = $1"
|
||||||
|
|
||||||
|
type mailserversStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
insertMailserversStmt *sql.Stmt
|
||||||
|
selectMailserversStmt *sql.Stmt
|
||||||
|
deleteMailserversStmt *sql.Stmt
|
||||||
|
deleteAllMailserversStmt *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSQLiteMailserversTable(db *sql.DB) (s *mailserversStatements, err error) {
|
||||||
|
s = &mailserversStatements{
|
||||||
|
db: db,
|
||||||
|
}
|
||||||
|
_, err = db.Exec(mailserversSchema)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.insertMailserversStmt, err = db.Prepare(insertMailserversSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.selectMailserversStmt, err = db.Prepare(selectMailserversSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.deleteMailserversStmt, err = db.Prepare(deleteMailserversSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.deleteAllMailserversStmt, err = db.Prepare(deleteAllMailserversSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *mailserversStatements) InsertMailservers(
|
||||||
|
ctx context.Context,
|
||||||
|
txn *sql.Tx,
|
||||||
|
serverName gomatrixserverlib.ServerName,
|
||||||
|
mailservers []gomatrixserverlib.ServerName,
|
||||||
|
) error {
|
||||||
|
for _, mailserver := range mailservers {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.insertMailserversStmt)
|
||||||
|
if _, err := stmt.ExecContext(ctx, serverName, mailserver); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *mailserversStatements) SelectMailservers(
|
||||||
|
ctx context.Context,
|
||||||
|
txn *sql.Tx,
|
||||||
|
serverName gomatrixserverlib.ServerName,
|
||||||
|
) ([]gomatrixserverlib.ServerName, error) {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.selectMailserversStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, serverName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "SelectMailservers: rows.close() failed")
|
||||||
|
|
||||||
|
var result []gomatrixserverlib.ServerName
|
||||||
|
for rows.Next() {
|
||||||
|
var mailserver string
|
||||||
|
if err = rows.Scan(&mailserver); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result = append(result, gomatrixserverlib.ServerName(mailserver))
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *mailserversStatements) DeleteMailservers(
|
||||||
|
ctx context.Context,
|
||||||
|
txn *sql.Tx,
|
||||||
|
serverName gomatrixserverlib.ServerName,
|
||||||
|
mailservers []gomatrixserverlib.ServerName,
|
||||||
|
) error {
|
||||||
|
for _, mailserver := range mailservers {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.deleteMailserversStmt)
|
||||||
|
if _, err := stmt.ExecContext(ctx, serverName, mailserver); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *mailserversStatements) DeleteAllMailservers(
|
||||||
|
ctx context.Context,
|
||||||
|
txn *sql.Tx,
|
||||||
|
serverName gomatrixserverlib.ServerName,
|
||||||
|
) error {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.deleteAllMailserversStmt)
|
||||||
|
if _, err := stmt.ExecContext(ctx, serverName); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
@ -63,6 +63,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
mailservers, err := NewSQLiteMailserversTable(d.db)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
outboundPeeks, err := NewSQLiteOutboundPeeksTable(d.db)
|
outboundPeeks, err := NewSQLiteOutboundPeeksTable(d.db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -107,6 +111,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
||||||
FederationQueueTransactions: queueTransactions,
|
FederationQueueTransactions: queueTransactions,
|
||||||
FederationTransactionJSON: transactionJSON,
|
FederationTransactionJSON: transactionJSON,
|
||||||
FederationBlacklist: blacklist,
|
FederationBlacklist: blacklist,
|
||||||
|
FederationMailservers: mailservers,
|
||||||
FederationOutboundPeeks: outboundPeeks,
|
FederationOutboundPeeks: outboundPeeks,
|
||||||
FederationInboundPeeks: inboundPeeks,
|
FederationInboundPeeks: inboundPeeks,
|
||||||
NotaryServerKeysJSON: notaryKeys,
|
NotaryServerKeysJSON: notaryKeys,
|
||||||
|
|
|
||||||
|
|
@ -81,6 +81,13 @@ type FederationBlacklist interface {
|
||||||
DeleteAllBlacklist(ctx context.Context, txn *sql.Tx) error
|
DeleteAllBlacklist(ctx context.Context, txn *sql.Tx) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type FederationMailservers interface {
|
||||||
|
InsertMailservers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, mailservers []gomatrixserverlib.ServerName) error
|
||||||
|
SelectMailservers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error)
|
||||||
|
DeleteMailservers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, mailservers []gomatrixserverlib.ServerName) error
|
||||||
|
DeleteAllMailservers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error
|
||||||
|
}
|
||||||
|
|
||||||
type FederationOutboundPeeks interface {
|
type FederationOutboundPeeks interface {
|
||||||
InsertOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error)
|
InsertOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error)
|
||||||
RenewOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error)
|
RenewOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error)
|
||||||
|
|
|
||||||
167
federationapi/storage/tables/mailservers_table_test.go
Normal file
167
federationapi/storage/tables/mailservers_table_test.go
Normal file
|
|
@ -0,0 +1,167 @@
|
||||||
|
package tables_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/federationapi/storage/postgres"
|
||||||
|
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
|
||||||
|
"github.com/matrix-org/dendrite/federationapi/storage/tables"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/test"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
server1 = "server1"
|
||||||
|
server2 = "server2"
|
||||||
|
server3 = "server3"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MailserversDatabase struct {
|
||||||
|
DB *sql.DB
|
||||||
|
Writer sqlutil.Writer
|
||||||
|
Table tables.FederationMailservers
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustCreateMailserversTable(t *testing.T, dbType test.DBType) (database MailserversDatabase, close func()) {
|
||||||
|
t.Helper()
|
||||||
|
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||||
|
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||||
|
ConnectionString: config.DataSource(connStr),
|
||||||
|
}, sqlutil.NewExclusiveWriter())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
var tab tables.FederationMailservers
|
||||||
|
switch dbType {
|
||||||
|
case test.DBTypePostgres:
|
||||||
|
tab, err = postgres.NewPostgresMailserversTable(db)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
case test.DBTypeSQLite:
|
||||||
|
tab, err = sqlite3.NewSQLiteMailserversTable(db)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
database = MailserversDatabase{
|
||||||
|
DB: db,
|
||||||
|
Writer: sqlutil.NewDummyWriter(),
|
||||||
|
Table: tab,
|
||||||
|
}
|
||||||
|
return database, close
|
||||||
|
}
|
||||||
|
|
||||||
|
func Equal(a, b []gomatrixserverlib.ServerName) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i, v := range a {
|
||||||
|
if v != b[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShouldInsertMailservers(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
db, close := mustCreateMailserversTable(t, dbType)
|
||||||
|
defer close()
|
||||||
|
expectedMailservers := []gomatrixserverlib.ServerName{server2, server3}
|
||||||
|
|
||||||
|
err := db.Table.InsertMailservers(ctx, nil, server1, expectedMailservers)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed inserting transaction: %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
mailservers, err := db.Table.SelectMailservers(ctx, nil, server1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed retrieving mailservers for %s: %s", mailservers, err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
if !Equal(mailservers, expectedMailservers) {
|
||||||
|
t.Fatalf("Expected: %v \nActual: %v", expectedMailservers, mailservers)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShouldDeleteCorrectMailservers(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
db, close := mustCreateMailserversTable(t, dbType)
|
||||||
|
defer close()
|
||||||
|
expectedMailservers := []gomatrixserverlib.ServerName{server2, server3}
|
||||||
|
|
||||||
|
err := db.Table.InsertMailservers(ctx, nil, server1, expectedMailservers)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed inserting transaction: %s", err.Error())
|
||||||
|
}
|
||||||
|
err = db.Table.InsertMailservers(ctx, nil, server2, expectedMailservers)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed inserting transaction: %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
err = db.Table.DeleteMailservers(ctx, nil, server1, []gomatrixserverlib.ServerName{server2})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed deleting mailservers for %s: %s", server1, err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedMailservers1 := []gomatrixserverlib.ServerName{server3}
|
||||||
|
mailservers, err := db.Table.SelectMailservers(ctx, nil, server1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed retrieving mailservers for %s: %s", mailservers, err.Error())
|
||||||
|
}
|
||||||
|
if !Equal(mailservers, expectedMailservers1) {
|
||||||
|
t.Fatalf("Expected: %v \nActual: %v", expectedMailservers1, mailservers)
|
||||||
|
}
|
||||||
|
mailservers, err = db.Table.SelectMailservers(ctx, nil, server2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed retrieving mailservers for %s: %s", mailservers, err.Error())
|
||||||
|
}
|
||||||
|
if !Equal(mailservers, expectedMailservers) {
|
||||||
|
t.Fatalf("Expected: %v \nActual: %v", expectedMailservers, mailservers)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShouldDeleteAllMailservers(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
db, close := mustCreateMailserversTable(t, dbType)
|
||||||
|
defer close()
|
||||||
|
expectedMailservers := []gomatrixserverlib.ServerName{server2, server3}
|
||||||
|
|
||||||
|
err := db.Table.InsertMailservers(ctx, nil, server1, expectedMailservers)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed inserting transaction: %s", err.Error())
|
||||||
|
}
|
||||||
|
err = db.Table.InsertMailservers(ctx, nil, server2, expectedMailservers)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed inserting transaction: %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
err = db.Table.DeleteAllMailservers(ctx, nil, server1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed deleting mailservers for %s: %s", server1, err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedMailservers1 := []gomatrixserverlib.ServerName{}
|
||||||
|
mailservers, err := db.Table.SelectMailservers(ctx, nil, server1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed retrieving mailservers for %s: %s", mailservers, err.Error())
|
||||||
|
}
|
||||||
|
if !Equal(mailservers, expectedMailservers1) {
|
||||||
|
t.Fatalf("Expected: %v \nActual: %v", expectedMailservers1, mailservers)
|
||||||
|
}
|
||||||
|
mailservers, err = db.Table.SelectMailservers(ctx, nil, server2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed retrieving mailservers for %s: %s", mailservers, err.Error())
|
||||||
|
}
|
||||||
|
if !Equal(mailservers, expectedMailservers) {
|
||||||
|
t.Fatalf("Expected: %v \nActual: %v", expectedMailservers, mailservers)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
Loading…
Reference in a new issue