Rename mailserver to relay server

This commit is contained in:
Devon Hudson 2022-12-09 13:06:16 -07:00
parent 0ffa0a5317
commit ee8a1c5680
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
24 changed files with 678 additions and 681 deletions

View file

@ -67,27 +67,27 @@ import (
) )
const ( const (
PeerTypeRemote = pineconeRouter.PeerTypeRemote PeerTypeRemote = pineconeRouter.PeerTypeRemote
PeerTypeMulticast = pineconeRouter.PeerTypeMulticast PeerTypeMulticast = pineconeRouter.PeerTypeMulticast
PeerTypeBluetooth = pineconeRouter.PeerTypeBluetooth PeerTypeBluetooth = pineconeRouter.PeerTypeBluetooth
PeerTypeBonjour = pineconeRouter.PeerTypeBonjour PeerTypeBonjour = pineconeRouter.PeerTypeBonjour
mailserverRetryInterval = time.Second * 30 relayServerRetryInterval = time.Second * 30
) )
type DendriteMonolith struct { type DendriteMonolith struct {
logger logrus.Logger logger logrus.Logger
PineconeRouter *pineconeRouter.Router PineconeRouter *pineconeRouter.Router
PineconeMulticast *pineconeMulticast.Multicast PineconeMulticast *pineconeMulticast.Multicast
PineconeQUIC *pineconeSessions.Sessions PineconeQUIC *pineconeSessions.Sessions
PineconeManager *pineconeConnections.ConnectionManager PineconeManager *pineconeConnections.ConnectionManager
StorageDirectory string StorageDirectory string
CacheDirectory string CacheDirectory string
listener net.Listener listener net.Listener
httpServer *http.Server httpServer *http.Server
processContext *process.ProcessContext processContext *process.ProcessContext
userAPI userapiAPI.UserInternalAPI userAPI userapiAPI.UserInternalAPI
federationAPI api.FederationInternalAPI federationAPI api.FederationInternalAPI
mailserversQueried map[gomatrixserverlib.ServerName]bool relayServersQueried map[gomatrixserverlib.ServerName]bool
} }
func (m *DendriteMonolith) PublicKey() string { func (m *DendriteMonolith) PublicKey() string {
@ -439,30 +439,30 @@ func (m *DendriteMonolith) Start() {
go func(ch <-chan pineconeEvents.Event) { go func(ch <-chan pineconeEvents.Event) {
eLog := logrus.WithField("pinecone", "events") eLog := logrus.WithField("pinecone", "events")
mailserverSyncRunning := atomic.NewBool(false) relayServerSyncRunning := atomic.NewBool(false)
stopMailserverSync := make(chan bool) stopRelayServerSync := make(chan bool)
// Setup mailserver info // Setup relay server info
request := api.QueryMailserversRequest{Server: gomatrixserverlib.ServerName(m.PublicKey())} request := api.QueryRelayServersRequest{Server: gomatrixserverlib.ServerName(m.PublicKey())}
response := api.QueryMailserversResponse{} response := api.QueryRelayServersResponse{}
err := m.federationAPI.QueryMailservers(m.processContext.Context(), &request, &response) err := m.federationAPI.QueryRelayServers(m.processContext.Context(), &request, &response)
if err != nil { if err != nil {
// TODO // TODO
} }
m.mailserversQueried = make(map[gomatrixserverlib.ServerName]bool) m.relayServersQueried = make(map[gomatrixserverlib.ServerName]bool)
for _, server := range response.Mailservers { for _, server := range response.RelayServers {
m.mailserversQueried[server] = false m.relayServersQueried[server] = false
} }
for event := range ch { for event := range ch {
switch e := event.(type) { switch e := event.(type) {
case pineconeEvents.PeerAdded: case pineconeEvents.PeerAdded:
if !mailserverSyncRunning.Load() { if !relayServerSyncRunning.Load() {
go m.syncMailservers(stopMailserverSync, *mailserverSyncRunning) go m.syncRelayServers(stopRelayServerSync, *relayServerSyncRunning)
} }
case pineconeEvents.PeerRemoved: case pineconeEvents.PeerRemoved:
if mailserverSyncRunning.Load() && m.PineconeRouter.PeerCount(-1) == 0 { if relayServerSyncRunning.Load() && m.PineconeRouter.PeerCount(-1) == 0 {
stopMailserverSync <- true stopRelayServerSync <- true
} }
case pineconeEvents.TreeParentUpdate: case pineconeEvents.TreeParentUpdate:
case pineconeEvents.SnakeDescUpdate: case pineconeEvents.SnakeDescUpdate:
@ -486,23 +486,23 @@ func (m *DendriteMonolith) Start() {
}(pineconeEventChannel) }(pineconeEventChannel)
} }
func (m *DendriteMonolith) syncMailservers(stop <-chan bool, running atomic.Bool) { func (m *DendriteMonolith) syncRelayServers(stop <-chan bool, running atomic.Bool) {
defer running.Store(false) defer running.Store(false)
t := time.NewTimer(mailserverRetryInterval) t := time.NewTimer(relayServerRetryInterval)
for { for {
mailserversToQuery := []gomatrixserverlib.ServerName{} relayServersToQuery := []gomatrixserverlib.ServerName{}
for server, complete := range m.mailserversQueried { for server, complete := range m.relayServersQueried {
if !complete { if !complete {
mailserversToQuery = append(mailserversToQuery, server) relayServersToQuery = append(relayServersToQuery, server)
} }
} }
if len(mailserversToQuery) == 0 { if len(relayServersToQuery) == 0 {
// All mailservers have been synced. // All relay servers have been synced.
return return
} }
m.queryMailservers(mailserversToQuery) m.queryRelayServers(relayServersToQuery)
t.Reset(mailserverRetryInterval) t.Reset(relayServerRetryInterval)
select { select {
case <-stop: case <-stop:
@ -515,13 +515,13 @@ func (m *DendriteMonolith) syncMailservers(stop <-chan bool, running atomic.Bool
} }
} }
func (m *DendriteMonolith) queryMailservers(mailservers []gomatrixserverlib.ServerName) { func (m *DendriteMonolith) queryRelayServers(relayServers []gomatrixserverlib.ServerName) {
for _, server := range mailservers { for _, server := range relayServers {
request := api.PerformMailserverSyncRequest{Mailserver: server} request := api.PerformRelayServerSyncRequest{RelayServer: server}
response := api.PerformMailserverSyncResponse{} response := api.PerformRelayServerSyncResponse{}
err := m.federationAPI.PerformMailserverSync(m.processContext.Context(), &request, &response) err := m.federationAPI.PerformRelayServerSync(m.processContext.Context(), &request, &response)
if err == nil { if err == nil {
m.mailserversQueried[server] = true m.relayServersQueried[server] = true
} }
} }
} }

View file

@ -1,7 +1,7 @@
## Relay Server Architecture ## Relay Server Architecture
Relay Servers function similar to the way physical mail drop boxes do. Relay Servers function similar to the way physical mail drop boxes do.
A node can have many associated relay servers. Matrix events can be sent to them instead of to the destination node, and the destination node will eventually retrieve them from the mailserver. A node can have many associated relay servers. Matrix events can be sent to them instead of to the destination node, and the destination node will eventually retrieve them from the relay server.
Nodes that want to send events to an offline node need to know what relay servers are associated with their intended destination. Nodes that want to send events to an offline node need to know what relay servers are associated with their intended destination.
Currently this is manually configured in the dendrite database. In the future this information could be configurable in the app and shared automatically via other means. Currently this is manually configured in the dendrite database. In the future this information could be configurable in the app and shared automatically via other means.

View file

@ -67,7 +67,7 @@ var (
instanceDir = flag.String("dir", ".", "the directory to store the databases in (if --config not specified)") instanceDir = flag.String("dir", ".", "the directory to store the databases in (if --config not specified)")
) )
const mailserverRetryInterval = time.Second * 30 const relayServerRetryInterval = time.Second * 30
// nolint:gocyclo // nolint:gocyclo
func main() { func main() {
@ -308,26 +308,26 @@ func main() {
go func(ch <-chan pineconeEvents.Event) { go func(ch <-chan pineconeEvents.Event) {
eLog := logrus.WithField("pinecone", "events") eLog := logrus.WithField("pinecone", "events")
mailserverSyncRunning := atomic.NewBool(false) relayServerSyncRunning := atomic.NewBool(false)
stopMailserverSync := make(chan bool) stopRelayServerSync := make(chan bool)
m := MailserverRetriever{ m := RelayServerRetriever{
Context: context.Background(), Context: context.Background(),
ServerName: gomatrixserverlib.ServerName(pRouter.PublicKey().String()), ServerName: gomatrixserverlib.ServerName(pRouter.PublicKey().String()),
FederationAPI: fsAPI, FederationAPI: fsAPI,
MailserversQueried: make(map[gomatrixserverlib.ServerName]bool), RelayServersQueried: make(map[gomatrixserverlib.ServerName]bool),
} }
m.InitializeMailservers(eLog) m.InitializeRelayServers(eLog)
for event := range ch { for event := range ch {
switch e := event.(type) { switch e := event.(type) {
case pineconeEvents.PeerAdded: case pineconeEvents.PeerAdded:
if !mailserverSyncRunning.Load() { if !relayServerSyncRunning.Load() {
go m.syncMailservers(stopMailserverSync, *mailserverSyncRunning) go m.syncRelayServers(stopRelayServerSync, *relayServerSyncRunning)
} }
case pineconeEvents.PeerRemoved: case pineconeEvents.PeerRemoved:
if mailserverSyncRunning.Load() && pRouter.PeerCount(-1) == 0 { if relayServerSyncRunning.Load() && pRouter.PeerCount(-1) == 0 {
stopMailserverSync <- true stopRelayServerSync <- true
} }
case pineconeEvents.TreeParentUpdate: case pineconeEvents.TreeParentUpdate:
case pineconeEvents.SnakeDescUpdate: case pineconeEvents.SnakeDescUpdate:
@ -353,44 +353,44 @@ func main() {
base.WaitForShutdown() base.WaitForShutdown()
} }
type MailserverRetriever struct { type RelayServerRetriever struct {
Context context.Context Context context.Context
ServerName gomatrixserverlib.ServerName ServerName gomatrixserverlib.ServerName
FederationAPI api.FederationInternalAPI FederationAPI api.FederationInternalAPI
MailserversQueried map[gomatrixserverlib.ServerName]bool RelayServersQueried map[gomatrixserverlib.ServerName]bool
} }
func (m *MailserverRetriever) InitializeMailservers(eLog *logrus.Entry) { func (m *RelayServerRetriever) InitializeRelayServers(eLog *logrus.Entry) {
request := api.QueryMailserversRequest{Server: gomatrixserverlib.ServerName(m.ServerName)} request := api.QueryRelayServersRequest{Server: gomatrixserverlib.ServerName(m.ServerName)}
response := api.QueryMailserversResponse{} response := api.QueryRelayServersResponse{}
err := m.FederationAPI.QueryMailservers(m.Context, &request, &response) err := m.FederationAPI.QueryRelayServers(m.Context, &request, &response)
if err != nil { if err != nil {
// TODO // TODO
} }
for _, server := range response.Mailservers { for _, server := range response.RelayServers {
m.MailserversQueried[server] = false m.RelayServersQueried[server] = false
} }
eLog.Infof("Registered mailservers: %v", response.Mailservers) eLog.Infof("Registered relay servers: %v", response.RelayServers)
} }
func (m *MailserverRetriever) syncMailservers(stop <-chan bool, running atomic.Bool) { func (m *RelayServerRetriever) syncRelayServers(stop <-chan bool, running atomic.Bool) {
defer running.Store(false) defer running.Store(false)
t := time.NewTimer(mailserverRetryInterval) t := time.NewTimer(relayServerRetryInterval)
for { for {
mailserversToQuery := []gomatrixserverlib.ServerName{} relayServersToQuery := []gomatrixserverlib.ServerName{}
for server, complete := range m.MailserversQueried { for server, complete := range m.RelayServersQueried {
if !complete { if !complete {
mailserversToQuery = append(mailserversToQuery, server) relayServersToQuery = append(relayServersToQuery, server)
} }
} }
if len(mailserversToQuery) == 0 { if len(relayServersToQuery) == 0 {
// All mailservers have been synced. // All relay servers have been synced.
return return
} }
m.queryMailservers(mailserversToQuery) m.queryRelayServers(relayServersToQuery)
t.Reset(mailserverRetryInterval) t.Reset(relayServerRetryInterval)
select { select {
case <-stop: case <-stop:
@ -403,16 +403,16 @@ func (m *MailserverRetriever) syncMailservers(stop <-chan bool, running atomic.B
} }
} }
func (m *MailserverRetriever) queryMailservers(mailservers []gomatrixserverlib.ServerName) { func (m *RelayServerRetriever) queryRelayServers(relayServers []gomatrixserverlib.ServerName) {
logrus.Info("querying mailservers for async_events") logrus.Info("querying relay servers for async_events")
for _, server := range mailservers { for _, server := range relayServers {
request := api.PerformMailserverSyncRequest{Mailserver: server} request := api.PerformRelayServerSyncRequest{RelayServer: server}
response := api.PerformMailserverSyncResponse{} response := api.PerformRelayServerSyncResponse{}
err := m.FederationAPI.PerformMailserverSync(m.Context, &request, &response) err := m.FederationAPI.PerformRelayServerSync(m.Context, &request, &response)
if err == nil { if err == nil {
m.MailserversQueried[server] = true m.RelayServersQueried[server] = true
} else { } else {
logrus.Errorf("Failed querying mailserver: %s", err.Error()) logrus.Errorf("Failed querying relay server: %s", err.Error())
} }
} }
} }

View file

@ -18,7 +18,7 @@ type FederationInternalAPI interface {
gomatrixserverlib.KeyDatabase gomatrixserverlib.KeyDatabase
ClientFederationAPI ClientFederationAPI
RoomserverFederationAPI RoomserverFederationAPI
MailserverAPI RelayServerAPI
QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error
LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error) LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
@ -37,20 +37,20 @@ type FederationInternalAPI interface {
response *PerformWakeupServersResponse, response *PerformWakeupServersResponse,
) error ) error
// Mailserver sync api used in the pinecone demos. // Relay Server sync api used in the pinecone demos.
QueryMailservers( QueryRelayServers(
ctx context.Context, ctx context.Context,
request *QueryMailserversRequest, request *QueryRelayServersRequest,
response *QueryMailserversResponse, response *QueryRelayServersResponse,
) error ) error
PerformMailserverSync( PerformRelayServerSync(
ctx context.Context, ctx context.Context,
request *PerformMailserverSyncRequest, request *PerformRelayServerSyncRequest,
response *PerformMailserverSyncResponse, response *PerformRelayServerSyncResponse,
) error ) error
} }
type MailserverAPI interface { type RelayServerAPI interface {
// Store async transactions for forwarding to the destination at a later time. // Store async transactions for forwarding to the destination at a later time.
PerformStoreAsync( PerformStoreAsync(
ctx context.Context, ctx context.Context,
@ -114,7 +114,7 @@ type FederationClient interface {
SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error)
SendAsyncTransaction(ctx context.Context, u gomatrixserverlib.UserID, t gomatrixserverlib.Transaction, forwardingServer gomatrixserverlib.ServerName) (res gomatrixserverlib.EmptyResp, err error) SendAsyncTransaction(ctx context.Context, u gomatrixserverlib.UserID, t gomatrixserverlib.Transaction, forwardingServer gomatrixserverlib.ServerName) (res gomatrixserverlib.EmptyResp, err error)
GetAsyncEvents(ctx context.Context, u gomatrixserverlib.UserID, mailserver gomatrixserverlib.ServerName) (res gomatrixserverlib.RespGetAsyncEvents, err error) GetAsyncEvents(ctx context.Context, u gomatrixserverlib.UserID, relayServer gomatrixserverlib.ServerName) (res gomatrixserverlib.RespGetAsyncEvents, err error)
// Perform operations // Perform operations
LookupRoomAlias(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error) LookupRoomAlias(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error)
@ -265,19 +265,19 @@ type InputPublicKeysRequest struct {
type InputPublicKeysResponse struct { type InputPublicKeysResponse struct {
} }
type QueryMailserversRequest struct { type QueryRelayServersRequest struct {
Server gomatrixserverlib.ServerName Server gomatrixserverlib.ServerName
} }
type QueryMailserversResponse struct { type QueryRelayServersResponse struct {
Mailservers []gomatrixserverlib.ServerName RelayServers []gomatrixserverlib.ServerName
} }
type PerformMailserverSyncRequest struct { type PerformRelayServerSyncRequest struct {
Mailserver gomatrixserverlib.ServerName RelayServer gomatrixserverlib.ServerName
} }
type PerformMailserverSyncResponse struct { type PerformRelayServerSyncResponse struct {
SyncComplete bool SyncComplete bool
} }

View file

@ -27,8 +27,8 @@ func (r *FederationInternalAPI) PerformDirectoryLookup(
response *api.PerformDirectoryLookupResponse, response *api.PerformDirectoryLookupResponse,
) (err error) { ) (err error) {
stats := r.statistics.ForServer(request.ServerName) stats := r.statistics.ForServer(request.ServerName)
if stats.AssumedOffline() && len(stats.KnownMailservers()) > 0 { if stats.AssumedOffline() && len(stats.KnownRelayServers()) > 0 {
return fmt.Errorf("not performing federation since server is assumed offline with known mailboxes") return fmt.Errorf("not performing federation since server is assumed offline with known relay servers")
} }
dir, err := r.federation.LookupRoomAlias( dir, err := r.federation.LookupRoomAlias(
@ -152,8 +152,8 @@ func (r *FederationInternalAPI) performJoinUsingServer(
unsigned map[string]interface{}, unsigned map[string]interface{},
) error { ) error {
stats := r.statistics.ForServer(serverName) stats := r.statistics.ForServer(serverName)
if stats.AssumedOffline() && len(stats.KnownMailservers()) > 0 { if stats.AssumedOffline() && len(stats.KnownRelayServers()) > 0 {
return fmt.Errorf("not performing federation since server is assumed offline with known mailboxes") return fmt.Errorf("not performing federation since server is assumed offline with known relay servers")
} }
_, origin, err := r.cfg.Matrix.SplitLocalID('@', userID) _, origin, err := r.cfg.Matrix.SplitLocalID('@', userID)
@ -420,8 +420,8 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer(
supportedVersions []gomatrixserverlib.RoomVersion, supportedVersions []gomatrixserverlib.RoomVersion,
) error { ) error {
stats := r.statistics.ForServer(serverName) stats := r.statistics.ForServer(serverName)
if stats.AssumedOffline() && len(stats.KnownMailservers()) > 0 { if stats.AssumedOffline() && len(stats.KnownRelayServers()) > 0 {
return fmt.Errorf("not performing federation since server is assumed offline with known mailboxes") return fmt.Errorf("not performing federation since server is assumed offline with known relay servers")
} }
// create a unique ID for this peek. // create a unique ID for this peek.
@ -534,7 +534,7 @@ func (r *FederationInternalAPI) PerformLeave(
// successfully completes the make-leave send-leave dance. // successfully completes the make-leave send-leave dance.
for _, serverName := range request.ServerNames { for _, serverName := range request.ServerNames {
stats := r.statistics.ForServer(serverName) stats := r.statistics.ForServer(serverName)
if stats.AssumedOffline() && len(stats.KnownMailservers()) > 0 { if stats.AssumedOffline() && len(stats.KnownRelayServers()) > 0 {
continue continue
} }
@ -639,8 +639,8 @@ func (r *FederationInternalAPI) PerformInvite(
} }
stats := r.statistics.ForServer(destination) stats := r.statistics.ForServer(destination)
if stats.AssumedOffline() && len(stats.KnownMailservers()) > 0 { if stats.AssumedOffline() && len(stats.KnownRelayServers()) > 0 {
return fmt.Errorf("not performing federation since server is assumed offline with known mailboxes") return fmt.Errorf("not performing federation since server is assumed offline with known relay servers")
} }
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
@ -833,34 +833,34 @@ func federatedAuthProvider(
} }
} }
// QueryMailservers implements api.FederationInternalAPI // QueryRelayServers implements api.FederationInternalAPI
func (r *FederationInternalAPI) QueryMailservers( func (r *FederationInternalAPI) QueryRelayServers(
ctx context.Context, ctx context.Context,
request *api.QueryMailserversRequest, request *api.QueryRelayServersRequest,
response *api.QueryMailserversResponse, response *api.QueryRelayServersResponse,
) error { ) error {
logrus.Infof("Getting mailservers for: %s", request.Server) logrus.Infof("Getting relay servers for: %s", request.Server)
mailservers, err := r.db.GetMailserversForServer(request.Server) relayServers, err := r.db.GetRelayServersForServer(request.Server)
if err != nil { if err != nil {
return err return err
} }
response.Mailservers = mailservers response.RelayServers = relayServers
return nil return nil
} }
// PerformMailserverSync implements api.FederationInternalAPI // PerformRelayServerSync implements api.FederationInternalAPI
func (r *FederationInternalAPI) PerformMailserverSync( func (r *FederationInternalAPI) PerformRelayServerSync(
ctx context.Context, ctx context.Context,
request *api.PerformMailserverSyncRequest, request *api.PerformRelayServerSyncRequest,
response *api.PerformMailserverSyncResponse, response *api.PerformRelayServerSyncResponse,
) error { ) error {
userID, err := gomatrixserverlib.NewUserID("@user:"+string(r.cfg.Matrix.ServerName), false) userID, err := gomatrixserverlib.NewUserID("@user:"+string(r.cfg.Matrix.ServerName), false)
if err != nil { if err != nil {
return err return err
} }
asyncResponse, err := r.federation.GetAsyncEvents(ctx, *userID, request.Mailserver) asyncResponse, err := r.federation.GetAsyncEvents(ctx, *userID, request.RelayServer)
if err != nil { if err != nil {
logrus.Errorf("GetAsyncEvents: %s", err.Error()) logrus.Errorf("GetAsyncEvents: %s", err.Error())
return err return err
@ -868,7 +868,7 @@ func (r *FederationInternalAPI) PerformMailserverSync(
r.processTransaction(&asyncResponse.Transaction) r.processTransaction(&asyncResponse.Transaction)
for asyncResponse.Remaining > 0 { for asyncResponse.Remaining > 0 {
asyncResponse, err := r.federation.GetAsyncEvents(ctx, *userID, request.Mailserver) asyncResponse, err := r.federation.GetAsyncEvents(ctx, *userID, request.RelayServer)
if err != nil { if err != nil {
logrus.Errorf("GetAsyncEvents: %s", err.Error()) logrus.Errorf("GetAsyncEvents: %s", err.Error())
return err return err
@ -880,7 +880,7 @@ func (r *FederationInternalAPI) PerformMailserverSync(
} }
func (r *FederationInternalAPI) processTransaction(txn *gomatrixserverlib.Transaction) { func (r *FederationInternalAPI) processTransaction(txn *gomatrixserverlib.Transaction) {
logrus.Warn("Processing transaction from mailserver") logrus.Warn("Processing transaction from relay server")
mu := internal.NewMutexByRoom() mu := internal.NewMutexByRoom()
// js, _ := base.NATS.Prepare(base.ProcessContext, &r.cfg.Matrix.JetStream) // js, _ := base.NATS.Prepare(base.ProcessContext, &r.cfg.Matrix.JetStream)
// producer := &producers.SyncAPIProducer{ // producer := &producers.SyncAPIProducer{
@ -948,6 +948,8 @@ func (r *FederationInternalAPI) QueryAsyncTransactions(
// TODO : Shouldn't be deleting unless the transaction was successfully returned... // TODO : Shouldn't be deleting unless the transaction was successfully returned...
// TODO : Should delete transaction json from table if no more associations // TODO : Should delete transaction json from table if no more associations
// Maybe track last received transaction, and send that as part of the request,
// then delete before getting the new events from the db.
if transaction != nil && receipt != nil { if transaction != nil && receipt != nil {
err = r.db.CleanAsyncTransactions(ctx, request.UserID, []*shared.Receipt{receipt}) err = r.db.CleanAsyncTransactions(ctx, request.UserID, []*shared.Receipt{receipt})
if err != nil { if err != nil {

View file

@ -24,8 +24,8 @@ const (
FederationAPIPerformOutboundPeekRequestPath = "/federationapi/performOutboundPeekRequest" FederationAPIPerformOutboundPeekRequestPath = "/federationapi/performOutboundPeekRequest"
FederationAPIPerformBroadcastEDUPath = "/federationapi/performBroadcastEDU" FederationAPIPerformBroadcastEDUPath = "/federationapi/performBroadcastEDU"
FederationAPIPerformWakeupServers = "/federationapi/performWakeupServers" FederationAPIPerformWakeupServers = "/federationapi/performWakeupServers"
FederationAPIQueryMailservers = "/federationapi/queryMailservers" FederationAPIQueryRelayServers = "/federationapi/queryRelayServers"
FederationAPIPerformMailserverSync = "/federationapi/performMailserverSync" FederationAPIPerformRelayServerSync = "/federationapi/performRelayServerSync"
FederationAPIPerformStoreAsyncPath = "/federationapi/performStoreAsync" FederationAPIPerformStoreAsyncPath = "/federationapi/performStoreAsync"
FederationAPIQueryAsyncTransactionsPath = "/federationapi/queryAsyncTransactions" FederationAPIQueryAsyncTransactionsPath = "/federationapi/queryAsyncTransactions"
@ -516,25 +516,25 @@ func (h *httpFederationInternalAPI) QueryPublicKeys(
) )
} }
func (h *httpFederationInternalAPI) QueryMailservers( func (h *httpFederationInternalAPI) QueryRelayServers(
ctx context.Context, ctx context.Context,
request *api.QueryMailserversRequest, request *api.QueryRelayServersRequest,
response *api.QueryMailserversResponse, response *api.QueryRelayServersResponse,
) error { ) error {
return httputil.CallInternalRPCAPI( return httputil.CallInternalRPCAPI(
"QueryMailservers", h.federationAPIURL+FederationAPIQueryMailservers, "QueryRelayServers", h.federationAPIURL+FederationAPIQueryRelayServers,
h.httpClient, ctx, request, response, h.httpClient, ctx, request, response,
) )
} }
// PerformMailserverSync implements api.FederationInternalAPI // PerformRelayServerSync implements api.FederationInternalAPI
func (h *httpFederationInternalAPI) PerformMailserverSync( func (h *httpFederationInternalAPI) PerformRelayServerSync(
ctx context.Context, ctx context.Context,
request *api.PerformMailserverSyncRequest, request *api.PerformRelayServerSyncRequest,
response *api.PerformMailserverSyncResponse, response *api.PerformRelayServerSyncResponse,
) error { ) error {
return httputil.CallInternalRPCAPI( return httputil.CallInternalRPCAPI(
"PerformMailserverSync", h.federationAPIURL+FederationAPIPerformMailserverSync, "PerformRelayServerSync", h.federationAPIURL+FederationAPIPerformRelayServerSync,
h.httpClient, ctx, request, response, h.httpClient, ctx, request, response,
) )
} }

View file

@ -396,7 +396,7 @@ func (oq *destinationQueue) backgroundSend() {
// nextTransaction creates a new transaction from the pending event // nextTransaction creates a new transaction from the pending event
// queue and sends it. // queue and sends it.
// Returns an error if the transaction wasn't sent. And whether the success // Returns an error if the transaction wasn't sent. And whether the success
// was to an async mailserver or not. // was to an async relay server or not.
func (oq *destinationQueue) nextTransaction( func (oq *destinationQueue) nextTransaction(
pdus []*queuedPDU, pdus []*queuedPDU,
edus []*queuedEDU, edus []*queuedEDU,
@ -409,16 +409,16 @@ func (oq *destinationQueue) nextTransaction(
ctx, cancel := context.WithTimeout(oq.process.Context(), time.Minute*5) ctx, cancel := context.WithTimeout(oq.process.Context(), time.Minute*5)
defer cancel() defer cancel()
mailservers := oq.statistics.KnownMailservers() relayServers := oq.statistics.KnownRelayServers()
if oq.statistics.AssumedOffline() && len(mailservers) > 0 { if oq.statistics.AssumedOffline() && len(relayServers) > 0 {
logrus.Infof("Sending to mailservers: %v", mailservers) logrus.Infof("Sending to relay servers: %v", relayServers)
// TODO : how to pass through actual userID here?!?!?!?! // TODO : how to pass through actual userID here?!?!?!?!
userID, userErr := gomatrixserverlib.NewUserID("@user:"+string(oq.destination), false) userID, userErr := gomatrixserverlib.NewUserID("@user:"+string(oq.destination), false)
if userErr != nil { if userErr != nil {
return userErr, false return userErr, false
} }
for _, mailserver := range mailservers { for _, relayServer := range relayServers {
_, asyncErr := oq.client.SendAsyncTransaction(ctx, *userID, t, mailserver) _, asyncErr := oq.client.SendAsyncTransaction(ctx, *userID, t, relayServer)
if asyncErr != nil { if asyncErr != nil {
err = asyncErr err = asyncErr
} else { } else {

View file

@ -75,7 +75,7 @@ func createDatabase() storage.Database {
pendingEDUs: make(map[*shared.Receipt]*gomatrixserverlib.EDU), pendingEDUs: make(map[*shared.Receipt]*gomatrixserverlib.EDU),
associatedPDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}), associatedPDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}),
associatedEDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}), associatedEDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}),
mailservers: make(map[gomatrixserverlib.ServerName][]gomatrixserverlib.ServerName), relayServers: make(map[gomatrixserverlib.ServerName][]gomatrixserverlib.ServerName),
} }
} }
@ -90,7 +90,7 @@ type fakeDatabase struct {
pendingEDUs map[*shared.Receipt]*gomatrixserverlib.EDU pendingEDUs map[*shared.Receipt]*gomatrixserverlib.EDU
associatedPDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{} associatedPDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}
associatedEDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{} associatedEDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}
mailservers map[gomatrixserverlib.ServerName][]gomatrixserverlib.ServerName relayServers map[gomatrixserverlib.ServerName][]gomatrixserverlib.ServerName
} }
var nidMutex sync.Mutex var nidMutex sync.Mutex
@ -341,32 +341,32 @@ func (d *fakeDatabase) IsServerAssumedOffline(serverName gomatrixserverlib.Serve
return assumedOffline, nil return assumedOffline, nil
} }
func (d *fakeDatabase) GetMailserversForServer(serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error) { func (d *fakeDatabase) GetRelayServersForServer(serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error) {
d.dbMutex.Lock() d.dbMutex.Lock()
defer d.dbMutex.Unlock() defer d.dbMutex.Unlock()
knownMailservers := []gomatrixserverlib.ServerName{} knownRelayServers := []gomatrixserverlib.ServerName{}
if mailservers, ok := d.mailservers[serverName]; ok { if relayServers, ok := d.relayServers[serverName]; ok {
knownMailservers = mailservers knownRelayServers = relayServers
} }
return knownMailservers, nil return knownRelayServers, nil
} }
func (d *fakeDatabase) AddMailserversForServer(serverName gomatrixserverlib.ServerName, mailservers []gomatrixserverlib.ServerName) error { func (d *fakeDatabase) AddRelayServersForServer(serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error {
d.dbMutex.Lock() d.dbMutex.Lock()
defer d.dbMutex.Unlock() defer d.dbMutex.Unlock()
if knownMailservers, ok := d.mailservers[serverName]; ok { if knownRelayServers, ok := d.relayServers[serverName]; ok {
for _, mailserver := range mailservers { for _, relayServer := range relayServers {
alreadyKnown := false alreadyKnown := false
for _, knownMailserver := range knownMailservers { for _, knownRelayServer := range knownRelayServers {
if mailserver == knownMailserver { if relayServer == knownRelayServer {
alreadyKnown = true alreadyKnown = true
} }
} }
if !alreadyKnown { if !alreadyKnown {
d.mailservers[serverName] = append(d.mailservers[serverName], mailserver) d.relayServers[serverName] = append(d.relayServers[serverName], relayServer)
} }
} }
} }
@ -1227,8 +1227,8 @@ func TestSendPDUOnAsyncSuccessRemovedFromDB(t *testing.T) {
<-pc.WaitForShutdown() <-pc.WaitForShutdown()
}() }()
mailservers := []gomatrixserverlib.ServerName{"mailserver"} relayServers := []gomatrixserverlib.ServerName{"relayserver"}
queues.statistics.ForServer(destination).AddMailservers(mailservers) queues.statistics.ForServer(destination).AddRelayServers(relayServers)
ev := mustCreatePDU(t) ev := mustCreatePDU(t)
err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination})
@ -1266,8 +1266,8 @@ func TestSendEDUOnAsyncSuccessRemovedFromDB(t *testing.T) {
<-pc.WaitForShutdown() <-pc.WaitForShutdown()
}() }()
mailservers := []gomatrixserverlib.ServerName{"mailserver"} relayServers := []gomatrixserverlib.ServerName{"relayserver"}
queues.statistics.ForServer(destination).AddMailservers(mailservers) queues.statistics.ForServer(destination).AddRelayServers(relayServers)
ev := mustCreateEDU(t) ev := mustCreateEDU(t)
err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination})

View file

@ -15,7 +15,7 @@ type AsyncEventsResponse struct {
} }
// GetAsyncEvents implements /_matrix/federation/v1/async_events/{userID} // GetAsyncEvents implements /_matrix/federation/v1/async_events/{userID}
// This endpoint can be extracted into a separate mailserver service. // This endpoint can be extracted into a separate relay server service.
func GetAsyncEvents( func GetAsyncEvents(
httpReq *http.Request, httpReq *http.Request,
fedReq *gomatrixserverlib.FederationRequest, fedReq *gomatrixserverlib.FederationRequest,

View file

@ -11,7 +11,7 @@ import (
) )
// ForwardAsync implements /_matrix/federation/v1/forward_async/{txnID}/{userID} // ForwardAsync implements /_matrix/federation/v1/forward_async/{txnID}/{userID}
// This endpoint can be extracted into a separate mailserver service. // This endpoint can be extracted into a separate relay server service.
func ForwardAsync( func ForwardAsync(
httpReq *http.Request, httpReq *http.Request,
fedReq *gomatrixserverlib.FederationRequest, fedReq *gomatrixserverlib.FederationRequest,
@ -62,10 +62,5 @@ func ForwardAsync(
} }
} }
// Naming:
// mailServer? assign mailserver for user?
// configure my mailserver
// Homeserver, idendity server, mailserver... why not?
return util.JSONResponse{Code: 200} return util.JSONResponse{Code: 200}
} }

View file

@ -31,7 +31,7 @@ type Statistics struct {
// How many times should we tolerate consecutive failures before we // How many times should we tolerate consecutive failures before we
// mark the destination as offline. At this point we should attempt // mark the destination as offline. At this point we should attempt
// to send messages to the user's async mailservers if we know them. // to send messages to the user's async relay servers if we know them.
FailuresUntilAssumedOffline uint32 FailuresUntilAssumedOffline uint32
} }
@ -65,9 +65,9 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS
if !found { if !found {
s.mutex.Lock() s.mutex.Lock()
server = &ServerStatistics{ server = &ServerStatistics{
statistics: s, statistics: s,
serverName: serverName, serverName: serverName,
knownMailservers: []gomatrixserverlib.ServerName{}, knownRelayServers: []gomatrixserverlib.ServerName{},
} }
s.servers[serverName] = server s.servers[serverName] = server
s.mutex.Unlock() s.mutex.Unlock()
@ -78,11 +78,11 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS
server.blacklisted.Store(blacklisted) server.blacklisted.Store(blacklisted)
} }
knownMailservers, err := s.DB.GetMailserversForServer(serverName) knownRelayServers, err := s.DB.GetRelayServersForServer(serverName)
if err != nil { if err != nil {
logrus.WithError(err).Errorf("Failed to get mailserver list for %q", serverName) logrus.WithError(err).Errorf("Failed to get relay server list for %q", serverName)
} else { } else {
server.knownMailservers = knownMailservers server.knownRelayServers = knownRelayServers
} }
} }
return server return server
@ -93,17 +93,17 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS
// many times we failed etc. It also manages the backoff time and black- // many times we failed etc. It also manages the backoff time and black-
// listing a remote host if it remains uncooperative. // listing a remote host if it remains uncooperative.
type ServerStatistics struct { type ServerStatistics struct {
statistics *Statistics // statistics *Statistics //
serverName gomatrixserverlib.ServerName // serverName gomatrixserverlib.ServerName //
blacklisted atomic.Bool // is the node blacklisted blacklisted atomic.Bool // is the node blacklisted
assumedOffline atomic.Bool // is the node assumed to be offline assumedOffline atomic.Bool // is the node assumed to be offline
backoffStarted atomic.Bool // is the backoff started backoffStarted atomic.Bool // is the backoff started
backoffUntil atomic.Value // time.Time until this backoff interval ends backoffUntil atomic.Value // time.Time until this backoff interval ends
backoffCount atomic.Uint32 // number of times BackoffDuration has been called backoffCount atomic.Uint32 // number of times BackoffDuration has been called
successCounter atomic.Uint32 // how many times have we succeeded? successCounter atomic.Uint32 // how many times have we succeeded?
backoffNotifier func() // notifies destination queue when backoff completes backoffNotifier func() // notifies destination queue when backoff completes
notifierMutex sync.Mutex notifierMutex sync.Mutex
knownMailservers []gomatrixserverlib.ServerName knownRelayServers []gomatrixserverlib.ServerName
} }
const maxJitterMultiplier = 1.4 const maxJitterMultiplier = 1.4
@ -139,11 +139,11 @@ func (s *ServerStatistics) AssignBackoffNotifier(notifier func()) {
// failure counters. If a host was blacklisted at this point then // failure counters. If a host was blacklisted at this point then
// we will unblacklist it. // we will unblacklist it.
// `async` specifies whether the success was to the actual destination // `async` specifies whether the success was to the actual destination
// or one of their mailservers. // or one of their relay servers.
func (s *ServerStatistics) Success(async bool) { func (s *ServerStatistics) Success(async bool) {
s.cancel() s.cancel()
s.backoffCount.Store(0) s.backoffCount.Store(0)
// NOTE : Sending to the final destination vs. a mailserver has // NOTE : Sending to the final destination vs. a relay server has
// slightly different semantics. // slightly different semantics.
if !async { if !async {
s.successCounter.Inc() s.successCounter.Inc()
@ -271,16 +271,16 @@ func (s *ServerStatistics) SuccessCount() uint32 {
return s.successCounter.Load() return s.successCounter.Load()
} }
// KnownMailservers returns the list of mailservers associated with this // KnownRelayServers returns the list of relay servers associated with this
// server. // server.
func (s *ServerStatistics) KnownMailservers() []gomatrixserverlib.ServerName { func (s *ServerStatistics) KnownRelayServers() []gomatrixserverlib.ServerName {
return s.knownMailservers return s.knownRelayServers
} }
func (s *ServerStatistics) AddMailservers(mailservers []gomatrixserverlib.ServerName) { func (s *ServerStatistics) AddRelayServers(relayServers []gomatrixserverlib.ServerName) {
seenSet := make(map[gomatrixserverlib.ServerName]bool) seenSet := make(map[gomatrixserverlib.ServerName]bool)
uniqueList := []gomatrixserverlib.ServerName{} uniqueList := []gomatrixserverlib.ServerName{}
for _, srv := range mailservers { for _, srv := range relayServers {
if seenSet[srv] { if seenSet[srv] {
continue continue
} }
@ -288,18 +288,18 @@ func (s *ServerStatistics) AddMailservers(mailservers []gomatrixserverlib.Server
uniqueList = append(uniqueList, srv) uniqueList = append(uniqueList, srv)
} }
err := s.statistics.DB.AddMailserversForServer(s.serverName, uniqueList) err := s.statistics.DB.AddRelayServersForServer(s.serverName, uniqueList)
if err == nil { if err == nil {
for _, newServer := range uniqueList { for _, newServer := range uniqueList {
alreadyKnown := false alreadyKnown := false
for _, srv := range s.knownMailservers { for _, srv := range s.knownRelayServers {
if srv == newServer { if srv == newServer {
alreadyKnown = true alreadyKnown = true
} }
} }
if !alreadyKnown { if !alreadyKnown {
s.knownMailservers = append(s.knownMailservers, newServer) s.knownRelayServers = append(s.knownRelayServers, newServer)
} }
} }
} }

View file

@ -69,10 +69,10 @@ type Database interface {
RemoveAllServersAssumedOffline() error RemoveAllServersAssumedOffline() error
IsServerAssumedOffline(serverName gomatrixserverlib.ServerName) (bool, error) IsServerAssumedOffline(serverName gomatrixserverlib.ServerName) (bool, error)
AddMailserversForServer(serverName gomatrixserverlib.ServerName, mailservers []gomatrixserverlib.ServerName) error AddRelayServersForServer(serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error
GetMailserversForServer(serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error) GetRelayServersForServer(serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error)
RemoveMailserversForServer(serverName gomatrixserverlib.ServerName, mailservers []gomatrixserverlib.ServerName) error RemoveRelayServersForServer(serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error
RemoveAllMailserversForServer(serverName gomatrixserverlib.ServerName) error RemoveAllRelayServersForServer(serverName gomatrixserverlib.ServerName) error
AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error
RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error

View file

@ -1,147 +0,0 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const mailserversSchema = `
CREATE TABLE IF NOT EXISTS federationsender_mailservers (
-- The destination server name
server_name TEXT NOT NULL,
-- The mailserver name for a given destination
mailserver_name TEXT NOT NULL,
UNIQUE (server_name, mailserver_name)
);
CREATE INDEX IF NOT EXISTS federationsender_mailservers_server_name_idx
ON federationsender_mailservers (server_name);
`
const insertMailserversSQL = "" +
"INSERT INTO federationsender_mailservers (server_name, mailserver_name) VALUES ($1, $2)" +
" ON CONFLICT DO NOTHING"
const selectMailserversSQL = "" +
"SELECT mailserver_name FROM federationsender_mailservers WHERE server_name = $1"
const deleteMailserversSQL = "" +
"DELETE FROM federationsender_mailservers WHERE server_name = $1 AND mailserver_name IN ($2)"
const deleteAllMailserversSQL = "" +
"DELETE FROM federationsender_mailservers WHERE server_name = $1"
type mailserversStatements struct {
db *sql.DB
insertMailserversStmt *sql.Stmt
selectMailserversStmt *sql.Stmt
deleteMailserversStmt *sql.Stmt
deleteAllMailserversStmt *sql.Stmt
}
func NewPostgresMailserversTable(db *sql.DB) (s *mailserversStatements, err error) {
s = &mailserversStatements{
db: db,
}
_, err = db.Exec(mailserversSchema)
if err != nil {
return
}
if s.insertMailserversStmt, err = db.Prepare(insertMailserversSQL); err != nil {
return
}
if s.selectMailserversStmt, err = db.Prepare(selectMailserversSQL); err != nil {
return
}
if s.deleteMailserversStmt, err = db.Prepare(deleteMailserversSQL); err != nil {
return
}
if s.deleteAllMailserversStmt, err = db.Prepare(deleteAllMailserversSQL); err != nil {
return
}
return
}
func (s *mailserversStatements) InsertMailservers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
mailservers []gomatrixserverlib.ServerName,
) error {
for _, mailserver := range mailservers {
stmt := sqlutil.TxStmt(txn, s.insertMailserversStmt)
if _, err := stmt.ExecContext(ctx, serverName, mailserver); err != nil {
return err
}
}
return nil
}
func (s *mailserversStatements) SelectMailservers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
) ([]gomatrixserverlib.ServerName, error) {
stmt := sqlutil.TxStmt(txn, s.selectMailserversStmt)
rows, err := stmt.QueryContext(ctx, serverName)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectMailservers: rows.close() failed")
var result []gomatrixserverlib.ServerName
for rows.Next() {
var mailserver string
if err = rows.Scan(&mailserver); err != nil {
return nil, err
}
result = append(result, gomatrixserverlib.ServerName(mailserver))
}
return result, nil
}
func (s *mailserversStatements) DeleteMailservers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
mailservers []gomatrixserverlib.ServerName,
) error {
for _, mailserver := range mailservers {
stmt := sqlutil.TxStmt(txn, s.deleteMailserversStmt)
if _, err := stmt.ExecContext(ctx, serverName, mailserver); err != nil {
return err
}
}
return nil
}
func (s *mailserversStatements) DeleteAllMailservers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteAllMailserversStmt)
if _, err := stmt.ExecContext(ctx, serverName); err != nil {
return err
}
return nil
}

View file

@ -0,0 +1,147 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const relayServersSchema = `
CREATE TABLE IF NOT EXISTS federationsender_relay_servers (
-- The destination server name
server_name TEXT NOT NULL,
-- The relay server name for a given destination
relay_server_name TEXT NOT NULL,
UNIQUE (server_name, relay_server_name)
);
CREATE INDEX IF NOT EXISTS federationsender_relay_servers_server_name_idx
ON federationsender_relay_servers (server_name);
`
const insertRelayServersSQL = "" +
"INSERT INTO federationsender_relay_servers (server_name, relay_server_name) VALUES ($1, $2)" +
" ON CONFLICT DO NOTHING"
const selectRelayServersSQL = "" +
"SELECT relay_server_name FROM federationsender_relay_servers WHERE server_name = $1"
const deleteRelayServersSQL = "" +
"DELETE FROM federationsender_relay_servers WHERE server_name = $1 AND relay_server_name IN ($2)"
const deleteAllRelayServersSQL = "" +
"DELETE FROM federationsender_relay_servers WHERE server_name = $1"
type relayServersStatements struct {
db *sql.DB
insertRelayServersStmt *sql.Stmt
selectRelayServersStmt *sql.Stmt
deleteRelayServersStmt *sql.Stmt
deleteAllRelayServersStmt *sql.Stmt
}
func NewPostgresRelayServersTable(db *sql.DB) (s *relayServersStatements, err error) {
s = &relayServersStatements{
db: db,
}
_, err = db.Exec(relayServersSchema)
if err != nil {
return
}
if s.insertRelayServersStmt, err = db.Prepare(insertRelayServersSQL); err != nil {
return
}
if s.selectRelayServersStmt, err = db.Prepare(selectRelayServersSQL); err != nil {
return
}
if s.deleteRelayServersStmt, err = db.Prepare(deleteRelayServersSQL); err != nil {
return
}
if s.deleteAllRelayServersStmt, err = db.Prepare(deleteAllRelayServersSQL); err != nil {
return
}
return
}
func (s *relayServersStatements) InsertRelayServers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
relayServers []gomatrixserverlib.ServerName,
) error {
for _, relayServer := range relayServers {
stmt := sqlutil.TxStmt(txn, s.insertRelayServersStmt)
if _, err := stmt.ExecContext(ctx, serverName, relayServer); err != nil {
return err
}
}
return nil
}
func (s *relayServersStatements) SelectRelayServers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
) ([]gomatrixserverlib.ServerName, error) {
stmt := sqlutil.TxStmt(txn, s.selectRelayServersStmt)
rows, err := stmt.QueryContext(ctx, serverName)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectRelayServers: rows.close() failed")
var result []gomatrixserverlib.ServerName
for rows.Next() {
var relayServer string
if err = rows.Scan(&relayServer); err != nil {
return nil, err
}
result = append(result, gomatrixserverlib.ServerName(relayServer))
}
return result, nil
}
func (s *relayServersStatements) DeleteRelayServers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
relayServers []gomatrixserverlib.ServerName,
) error {
for _, relayServer := range relayServers {
stmt := sqlutil.TxStmt(txn, s.deleteRelayServersStmt)
if _, err := stmt.ExecContext(ctx, serverName, relayServer); err != nil {
return err
}
}
return nil
}
func (s *relayServersStatements) DeleteAllRelayServers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteAllRelayServersStmt)
if _, err := stmt.ExecContext(ctx, serverName); err != nil {
return err
}
return nil
}

View file

@ -74,7 +74,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil { if err != nil {
return nil, err return nil, err
} }
mailservers, err := NewPostgresMailserversTable(d.db) relayServers, err := NewPostgresRelayServersTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -123,7 +123,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
FederationTransactionJSON: transactionJSON, FederationTransactionJSON: transactionJSON,
FederationBlacklist: blacklist, FederationBlacklist: blacklist,
FederationAssumedOffline: assumedOffline, FederationAssumedOffline: assumedOffline,
FederationMailservers: mailservers, FederationRelayServers: relayServers,
FederationInboundPeeks: inboundPeeks, FederationInboundPeeks: inboundPeeks,
FederationOutboundPeeks: outboundPeeks, FederationOutboundPeeks: outboundPeeks,
NotaryServerKeysJSON: notaryJSON, NotaryServerKeysJSON: notaryJSON,

View file

@ -39,7 +39,7 @@ type Database struct {
FederationJoinedHosts tables.FederationJoinedHosts FederationJoinedHosts tables.FederationJoinedHosts
FederationBlacklist tables.FederationBlacklist FederationBlacklist tables.FederationBlacklist
FederationAssumedOffline tables.FederationAssumedOffline FederationAssumedOffline tables.FederationAssumedOffline
FederationMailservers tables.FederationMailservers FederationRelayServers tables.FederationRelayServers
FederationOutboundPeeks tables.FederationOutboundPeeks FederationOutboundPeeks tables.FederationOutboundPeeks
FederationInboundPeeks tables.FederationInboundPeeks FederationInboundPeeks tables.FederationInboundPeeks
NotaryServerKeysJSON tables.FederationNotaryServerKeysJSON NotaryServerKeysJSON tables.FederationNotaryServerKeysJSON
@ -203,25 +203,25 @@ func (d *Database) IsServerAssumedOffline(serverName gomatrixserverlib.ServerNam
return d.FederationAssumedOffline.SelectAssumedOffline(context.TODO(), nil, serverName) return d.FederationAssumedOffline.SelectAssumedOffline(context.TODO(), nil, serverName)
} }
func (d *Database) AddMailserversForServer(serverName gomatrixserverlib.ServerName, mailservers []gomatrixserverlib.ServerName) error { func (d *Database) AddRelayServersForServer(serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationMailservers.InsertMailservers(context.TODO(), txn, serverName, mailservers) return d.FederationRelayServers.InsertRelayServers(context.TODO(), txn, serverName, relayServers)
}) })
} }
func (d *Database) GetMailserversForServer(serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error) { func (d *Database) GetRelayServersForServer(serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error) {
return d.FederationMailservers.SelectMailservers(context.TODO(), nil, serverName) return d.FederationRelayServers.SelectRelayServers(context.TODO(), nil, serverName)
} }
func (d *Database) RemoveMailserversForServer(serverName gomatrixserverlib.ServerName, mailservers []gomatrixserverlib.ServerName) error { func (d *Database) RemoveRelayServersForServer(serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationMailservers.DeleteMailservers(context.TODO(), txn, serverName, mailservers) return d.FederationRelayServers.DeleteRelayServers(context.TODO(), txn, serverName, relayServers)
}) })
} }
func (d *Database) RemoveAllMailserversForServer(serverName gomatrixserverlib.ServerName) error { func (d *Database) RemoveAllRelayServersForServer(serverName gomatrixserverlib.ServerName) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationMailservers.DeleteAllMailservers(context.TODO(), txn, serverName) return d.FederationRelayServers.DeleteAllRelayServers(context.TODO(), txn, serverName)
}) })
} }

View file

@ -1,147 +0,0 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const mailserversSchema = `
CREATE TABLE IF NOT EXISTS federationsender_mailservers (
-- The destination server name
server_name TEXT NOT NULL,
-- The mailserver name for a given destination
mailserver_name TEXT NOT NULL,
UNIQUE (server_name, mailserver_name)
);
CREATE INDEX IF NOT EXISTS federationsender_mailservers_server_name_idx
ON federationsender_mailservers (server_name);
`
const insertMailserversSQL = "" +
"INSERT INTO federationsender_mailservers (server_name, mailserver_name) VALUES ($1, $2)" +
" ON CONFLICT DO NOTHING"
const selectMailserversSQL = "" +
"SELECT mailserver_name FROM federationsender_mailservers WHERE server_name = $1"
const deleteMailserversSQL = "" +
"DELETE FROM federationsender_mailservers WHERE server_name = $1 AND mailserver_name IN ($2)"
const deleteAllMailserversSQL = "" +
"DELETE FROM federationsender_mailservers WHERE server_name = $1"
type mailserversStatements struct {
db *sql.DB
insertMailserversStmt *sql.Stmt
selectMailserversStmt *sql.Stmt
deleteMailserversStmt *sql.Stmt
deleteAllMailserversStmt *sql.Stmt
}
func NewSQLiteMailserversTable(db *sql.DB) (s *mailserversStatements, err error) {
s = &mailserversStatements{
db: db,
}
_, err = db.Exec(mailserversSchema)
if err != nil {
return
}
if s.insertMailserversStmt, err = db.Prepare(insertMailserversSQL); err != nil {
return
}
if s.selectMailserversStmt, err = db.Prepare(selectMailserversSQL); err != nil {
return
}
if s.deleteMailserversStmt, err = db.Prepare(deleteMailserversSQL); err != nil {
return
}
if s.deleteAllMailserversStmt, err = db.Prepare(deleteAllMailserversSQL); err != nil {
return
}
return
}
func (s *mailserversStatements) InsertMailservers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
mailservers []gomatrixserverlib.ServerName,
) error {
for _, mailserver := range mailservers {
stmt := sqlutil.TxStmt(txn, s.insertMailserversStmt)
if _, err := stmt.ExecContext(ctx, serverName, mailserver); err != nil {
return err
}
}
return nil
}
func (s *mailserversStatements) SelectMailservers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
) ([]gomatrixserverlib.ServerName, error) {
stmt := sqlutil.TxStmt(txn, s.selectMailserversStmt)
rows, err := stmt.QueryContext(ctx, serverName)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectMailservers: rows.close() failed")
var result []gomatrixserverlib.ServerName
for rows.Next() {
var mailserver string
if err = rows.Scan(&mailserver); err != nil {
return nil, err
}
result = append(result, gomatrixserverlib.ServerName(mailserver))
}
return result, nil
}
func (s *mailserversStatements) DeleteMailservers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
mailservers []gomatrixserverlib.ServerName,
) error {
for _, mailserver := range mailservers {
stmt := sqlutil.TxStmt(txn, s.deleteMailserversStmt)
if _, err := stmt.ExecContext(ctx, serverName, mailserver); err != nil {
return err
}
}
return nil
}
func (s *mailserversStatements) DeleteAllMailservers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteAllMailserversStmt)
if _, err := stmt.ExecContext(ctx, serverName); err != nil {
return err
}
return nil
}

View file

@ -0,0 +1,147 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const relayServersSchema = `
CREATE TABLE IF NOT EXISTS federationsender_relay_servers (
-- The destination server name
server_name TEXT NOT NULL,
-- The relay server name for a given destination
relay_server_name TEXT NOT NULL,
UNIQUE (server_name, relay_server_name)
);
CREATE INDEX IF NOT EXISTS federationsender_relay_servers_server_name_idx
ON federationsender_relay_servers (server_name);
`
const insertRelayServersSQL = "" +
"INSERT INTO federationsender_relay_servers (server_name, relay_server_name) VALUES ($1, $2)" +
" ON CONFLICT DO NOTHING"
const selectRelayServersSQL = "" +
"SELECT relay_server_name FROM federationsender_relay_servers WHERE server_name = $1"
const deleteRelayServersSQL = "" +
"DELETE FROM federationsender_relay_servers WHERE server_name = $1 AND relay_server_name IN ($2)"
const deleteAllRelayServersSQL = "" +
"DELETE FROM federationsender_relay_servers WHERE server_name = $1"
type relayServersStatements struct {
db *sql.DB
insertRelayServersStmt *sql.Stmt
selectRelayServersStmt *sql.Stmt
deleteRelayServersStmt *sql.Stmt
deleteAllRelayServersStmt *sql.Stmt
}
func NewSQLiteRelayServersTable(db *sql.DB) (s *relayServersStatements, err error) {
s = &relayServersStatements{
db: db,
}
_, err = db.Exec(relayServersSchema)
if err != nil {
return
}
if s.insertRelayServersStmt, err = db.Prepare(insertRelayServersSQL); err != nil {
return
}
if s.selectRelayServersStmt, err = db.Prepare(selectRelayServersSQL); err != nil {
return
}
if s.deleteRelayServersStmt, err = db.Prepare(deleteRelayServersSQL); err != nil {
return
}
if s.deleteAllRelayServersStmt, err = db.Prepare(deleteAllRelayServersSQL); err != nil {
return
}
return
}
func (s *relayServersStatements) InsertRelayServers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
relayServers []gomatrixserverlib.ServerName,
) error {
for _, relayServer := range relayServers {
stmt := sqlutil.TxStmt(txn, s.insertRelayServersStmt)
if _, err := stmt.ExecContext(ctx, serverName, relayServer); err != nil {
return err
}
}
return nil
}
func (s *relayServersStatements) SelectRelayServers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
) ([]gomatrixserverlib.ServerName, error) {
stmt := sqlutil.TxStmt(txn, s.selectRelayServersStmt)
rows, err := stmt.QueryContext(ctx, serverName)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectRelayServers: rows.close() failed")
var result []gomatrixserverlib.ServerName
for rows.Next() {
var relayServer string
if err = rows.Scan(&relayServer); err != nil {
return nil, err
}
result = append(result, gomatrixserverlib.ServerName(relayServer))
}
return result, nil
}
func (s *relayServersStatements) DeleteRelayServers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
relayServers []gomatrixserverlib.ServerName,
) error {
for _, relayServer := range relayServers {
stmt := sqlutil.TxStmt(txn, s.deleteRelayServersStmt)
if _, err := stmt.ExecContext(ctx, serverName, relayServer); err != nil {
return err
}
}
return nil
}
func (s *relayServersStatements) DeleteAllRelayServers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteAllRelayServersStmt)
if _, err := stmt.ExecContext(ctx, serverName); err != nil {
return err
}
return nil
}

View file

@ -67,7 +67,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil { if err != nil {
return nil, err return nil, err
} }
mailservers, err := NewSQLiteMailserversTable(d.db) relayServers, err := NewSQLiteRelayServersTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -116,7 +116,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
FederationTransactionJSON: transactionJSON, FederationTransactionJSON: transactionJSON,
FederationBlacklist: blacklist, FederationBlacklist: blacklist,
FederationAssumedOffline: assumedOffline, FederationAssumedOffline: assumedOffline,
FederationMailservers: mailservers, FederationRelayServers: relayServers,
FederationOutboundPeeks: outboundPeeks, FederationOutboundPeeks: outboundPeeks,
FederationInboundPeeks: inboundPeeks, FederationInboundPeeks: inboundPeeks,
NotaryServerKeysJSON: notaryKeys, NotaryServerKeysJSON: notaryKeys,

View file

@ -88,11 +88,11 @@ type FederationAssumedOffline interface {
DeleteAllAssumedOffline(ctx context.Context, txn *sql.Tx) error DeleteAllAssumedOffline(ctx context.Context, txn *sql.Tx) error
} }
type FederationMailservers interface { type FederationRelayServers interface {
InsertMailservers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, mailservers []gomatrixserverlib.ServerName) error InsertRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error
SelectMailservers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error) SelectRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error)
DeleteMailservers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, mailservers []gomatrixserverlib.ServerName) error DeleteRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error
DeleteAllMailservers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error DeleteAllRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error
} }
type FederationOutboundPeeks interface { type FederationOutboundPeeks interface {

View file

@ -1,167 +0,0 @@
package tables_test
import (
"context"
"database/sql"
"testing"
"github.com/matrix-org/dendrite/federationapi/storage/postgres"
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
"github.com/matrix-org/dendrite/federationapi/storage/tables"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
"github.com/stretchr/testify/assert"
)
const (
server1 = "server1"
server2 = "server2"
server3 = "server3"
)
type MailserversDatabase struct {
DB *sql.DB
Writer sqlutil.Writer
Table tables.FederationMailservers
}
func mustCreateMailserversTable(t *testing.T, dbType test.DBType) (database MailserversDatabase, close func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
assert.NoError(t, err)
var tab tables.FederationMailservers
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresMailserversTable(db)
assert.NoError(t, err)
case test.DBTypeSQLite:
tab, err = sqlite3.NewSQLiteMailserversTable(db)
assert.NoError(t, err)
}
assert.NoError(t, err)
database = MailserversDatabase{
DB: db,
Writer: sqlutil.NewDummyWriter(),
Table: tab,
}
return database, close
}
func Equal(a, b []gomatrixserverlib.ServerName) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if v != b[i] {
return false
}
}
return true
}
func TestShouldInsertMailservers(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateMailserversTable(t, dbType)
defer close()
expectedMailservers := []gomatrixserverlib.ServerName{server2, server3}
err := db.Table.InsertMailservers(ctx, nil, server1, expectedMailservers)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
mailservers, err := db.Table.SelectMailservers(ctx, nil, server1)
if err != nil {
t.Fatalf("Failed retrieving mailservers for %s: %s", mailservers, err.Error())
}
if !Equal(mailservers, expectedMailservers) {
t.Fatalf("Expected: %v \nActual: %v", expectedMailservers, mailservers)
}
})
}
func TestShouldDeleteCorrectMailservers(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateMailserversTable(t, dbType)
defer close()
expectedMailservers := []gomatrixserverlib.ServerName{server2, server3}
err := db.Table.InsertMailservers(ctx, nil, server1, expectedMailservers)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
err = db.Table.InsertMailservers(ctx, nil, server2, expectedMailservers)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
err = db.Table.DeleteMailservers(ctx, nil, server1, []gomatrixserverlib.ServerName{server2})
if err != nil {
t.Fatalf("Failed deleting mailservers for %s: %s", server1, err.Error())
}
expectedMailservers1 := []gomatrixserverlib.ServerName{server3}
mailservers, err := db.Table.SelectMailservers(ctx, nil, server1)
if err != nil {
t.Fatalf("Failed retrieving mailservers for %s: %s", mailservers, err.Error())
}
if !Equal(mailservers, expectedMailservers1) {
t.Fatalf("Expected: %v \nActual: %v", expectedMailservers1, mailservers)
}
mailservers, err = db.Table.SelectMailservers(ctx, nil, server2)
if err != nil {
t.Fatalf("Failed retrieving mailservers for %s: %s", mailservers, err.Error())
}
if !Equal(mailservers, expectedMailservers) {
t.Fatalf("Expected: %v \nActual: %v", expectedMailservers, mailservers)
}
})
}
func TestShouldDeleteAllMailservers(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateMailserversTable(t, dbType)
defer close()
expectedMailservers := []gomatrixserverlib.ServerName{server2, server3}
err := db.Table.InsertMailservers(ctx, nil, server1, expectedMailservers)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
err = db.Table.InsertMailservers(ctx, nil, server2, expectedMailservers)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
err = db.Table.DeleteAllMailservers(ctx, nil, server1)
if err != nil {
t.Fatalf("Failed deleting mailservers for %s: %s", server1, err.Error())
}
expectedMailservers1 := []gomatrixserverlib.ServerName{}
mailservers, err := db.Table.SelectMailservers(ctx, nil, server1)
if err != nil {
t.Fatalf("Failed retrieving mailservers for %s: %s", mailservers, err.Error())
}
if !Equal(mailservers, expectedMailservers1) {
t.Fatalf("Expected: %v \nActual: %v", expectedMailservers1, mailservers)
}
mailservers, err = db.Table.SelectMailservers(ctx, nil, server2)
if err != nil {
t.Fatalf("Failed retrieving mailservers for %s: %s", mailservers, err.Error())
}
if !Equal(mailservers, expectedMailservers) {
t.Fatalf("Expected: %v \nActual: %v", expectedMailservers, mailservers)
}
})
}

View file

@ -0,0 +1,167 @@
package tables_test
import (
"context"
"database/sql"
"testing"
"github.com/matrix-org/dendrite/federationapi/storage/postgres"
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
"github.com/matrix-org/dendrite/federationapi/storage/tables"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
"github.com/stretchr/testify/assert"
)
const (
server1 = "server1"
server2 = "server2"
server3 = "server3"
)
type RelayServersDatabase struct {
DB *sql.DB
Writer sqlutil.Writer
Table tables.FederationRelayServers
}
func mustCreateRelayServersTable(t *testing.T, dbType test.DBType) (database RelayServersDatabase, close func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
assert.NoError(t, err)
var tab tables.FederationRelayServers
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresRelayServersTable(db)
assert.NoError(t, err)
case test.DBTypeSQLite:
tab, err = sqlite3.NewSQLiteRelayServersTable(db)
assert.NoError(t, err)
}
assert.NoError(t, err)
database = RelayServersDatabase{
DB: db,
Writer: sqlutil.NewDummyWriter(),
Table: tab,
}
return database, close
}
func Equal(a, b []gomatrixserverlib.ServerName) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if v != b[i] {
return false
}
}
return true
}
func TestShouldInsertRelayServers(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateRelayServersTable(t, dbType)
defer close()
expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3}
err := db.Table.InsertRelayServers(ctx, nil, server1, expectedRelayServers)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1)
if err != nil {
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
}
if !Equal(relayServers, expectedRelayServers) {
t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers)
}
})
}
func TestShouldDeleteCorrectRelayServers(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateRelayServersTable(t, dbType)
defer close()
expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3}
err := db.Table.InsertRelayServers(ctx, nil, server1, expectedRelayServers)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
err = db.Table.InsertRelayServers(ctx, nil, server2, expectedRelayServers)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
err = db.Table.DeleteRelayServers(ctx, nil, server1, []gomatrixserverlib.ServerName{server2})
if err != nil {
t.Fatalf("Failed deleting relay servers for %s: %s", server1, err.Error())
}
expectedRelayServers1 := []gomatrixserverlib.ServerName{server3}
relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1)
if err != nil {
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
}
if !Equal(relayServers, expectedRelayServers1) {
t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers1, relayServers)
}
relayServers, err = db.Table.SelectRelayServers(ctx, nil, server2)
if err != nil {
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
}
if !Equal(relayServers, expectedRelayServers) {
t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers)
}
})
}
func TestShouldDeleteAllRelayServers(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateRelayServersTable(t, dbType)
defer close()
expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3}
err := db.Table.InsertRelayServers(ctx, nil, server1, expectedRelayServers)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
err = db.Table.InsertRelayServers(ctx, nil, server2, expectedRelayServers)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
err = db.Table.DeleteAllRelayServers(ctx, nil, server1)
if err != nil {
t.Fatalf("Failed deleting relay servers for %s: %s", server1, err.Error())
}
expectedRelayServers1 := []gomatrixserverlib.ServerName{}
relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1)
if err != nil {
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
}
if !Equal(relayServers, expectedRelayServers1) {
t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers1, relayServers)
}
relayServers, err = db.Table.SelectRelayServers(ctx, nil, server2)
if err != nil {
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
}
if !Equal(relayServers, expectedRelayServers) {
t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers)
}
})
}

View file

@ -21,7 +21,7 @@ type FederationAPI struct {
// How many consecutive failures that we should tolerate when sending federation // How many consecutive failures that we should tolerate when sending federation
// requests to a specific server until we should assume they are offline. If we // requests to a specific server until we should assume they are offline. If we
// assume they are offline then we will attempt to send messages to their async // assume they are offline then we will attempt to send messages to their async
// mailserver if we know of one that is appropriate. // relay server if we know of one that is appropriate.
FederationRetriesUntilAssumedOffline uint32 `yaml:"retries_until_assumed_offline"` FederationRetriesUntilAssumedOffline uint32 `yaml:"retries_until_assumed_offline"`
// FederationDisableTLSValidation disables the validation of X.509 TLS certs // FederationDisableTLSValidation disables the validation of X.509 TLS certs