diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go index df63262bf..c2dff12ca 100644 --- a/federationapi/queue/queue_test.go +++ b/federationapi/queue/queue_test.go @@ -70,6 +70,7 @@ func createDatabase() storage.Database { pendingPDUServers: make(map[gomatrixserverlib.ServerName]struct{}), pendingEDUServers: make(map[gomatrixserverlib.ServerName]struct{}), blacklistedServers: make(map[gomatrixserverlib.ServerName]struct{}), + assumedOffline: make(map[gomatrixserverlib.ServerName]struct{}), pendingPDUs: make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent), pendingEDUs: make(map[*shared.Receipt]*gomatrixserverlib.EDU), associatedPDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}), @@ -83,6 +84,7 @@ type fakeDatabase struct { pendingPDUServers map[gomatrixserverlib.ServerName]struct{} pendingEDUServers map[gomatrixserverlib.ServerName]struct{} blacklistedServers map[gomatrixserverlib.ServerName]struct{} + assumedOffline map[gomatrixserverlib.ServerName]struct{} pendingPDUs map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent pendingEDUs map[*shared.Receipt]*gomatrixserverlib.EDU associatedPDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{} @@ -301,6 +303,42 @@ func (d *fakeDatabase) IsServerBlacklisted(serverName gomatrixserverlib.ServerNa return isBlacklisted, nil } +func (d *fakeDatabase) SetServerAssumedOffline(serverName gomatrixserverlib.ServerName) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + d.assumedOffline[serverName] = struct{}{} + return nil +} + +func (d *fakeDatabase) RemoveServerAssumedOffline(serverName gomatrixserverlib.ServerName) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + delete(d.assumedOffline, serverName) + return nil +} + +func (d *fakeDatabase) RemoveAllServersAssumedOffine() error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + d.assumedOffline = make(map[gomatrixserverlib.ServerName]struct{}) + return nil +} + +func (d *fakeDatabase) IsServerAssumedOffline(serverName gomatrixserverlib.ServerName) (bool, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + assumedOffline := false + if _, ok := d.assumedOffline[serverName]; ok { + assumedOffline = true + } + + return assumedOffline, nil +} + func (d *fakeDatabase) GetMailserversForServer(serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error) { return []gomatrixserverlib.ServerName{}, nil } @@ -345,7 +383,7 @@ func mustCreateEDU(t *testing.T) *gomatrixserverlib.EDU { return &gomatrixserverlib.EDU{Type: gomatrixserverlib.MTyping} } -func testSetup(failuresUntilBlacklist uint32, shouldTxSucceed bool, t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *stubFederationClient, *OutgoingQueues, *process.ProcessContext, func()) { +func testSetup(failuresUntilBlacklist uint32, failuresUntilAssumedOffline uint32, shouldTxSucceed bool, t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *stubFederationClient, *OutgoingQueues, *process.ProcessContext, func()) { db, processContext, close := mustCreateFederationDatabase(t, dbType, realDatabase) fc := &stubFederationClient{ @@ -354,7 +392,6 @@ func testSetup(failuresUntilBlacklist uint32, shouldTxSucceed bool, t *testing.T } rs := &stubFederationRoomServerAPI{} - failuresUntilAssumedOffline := failuresUntilBlacklist + 1 stats := statistics.NewStatistics( db, failuresUntilBlacklist, failuresUntilAssumedOffline) signingInfo := &SigningInfo{ @@ -371,7 +408,7 @@ func TestSendPDUOnSuccessRemovedFromDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -400,7 +437,7 @@ func TestSendEDUOnSuccessRemovedFromDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -429,7 +466,7 @@ func TestSendPDUOnFailStoredInDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -459,7 +496,7 @@ func TestSendEDUOnFailStoredInDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -489,7 +526,7 @@ func TestSendPDUAgainDoesntInterruptBackoff(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -540,7 +577,7 @@ func TestSendEDUAgainDoesntInterruptBackoff(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -591,7 +628,7 @@ func TestSendPDUMultipleFailuresBlacklisted(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(2) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -623,7 +660,7 @@ func TestSendEDUMultipleFailuresBlacklisted(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(2) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -655,7 +692,7 @@ func TestSendPDUBlacklistedWithPriorExternalFailure(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(2) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -689,7 +726,7 @@ func TestSendEDUBlacklistedWithPriorExternalFailure(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(2) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -723,7 +760,7 @@ func TestRetryServerSendsPDUSuccessfully(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(1) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -774,7 +811,7 @@ func TestRetryServerSendsEDUSuccessfully(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(1) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -828,7 +865,7 @@ func TestSendPDUBatches(t *testing.T) { // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -872,7 +909,7 @@ func TestSendEDUBatches(t *testing.T) { // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -916,7 +953,7 @@ func TestSendPDUAndEDUBatches(t *testing.T) { // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -967,7 +1004,7 @@ func TestExternalFailureBackoffDoesntStartQueue(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -1005,7 +1042,7 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) { destination := gomatrixserverlib.ServerName("remotehost") destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, dbType, true) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, t, dbType, true) // NOTE : These defers aren't called if go test is killed so the dbs may not get cleaned up. defer close() defer func() { @@ -1065,3 +1102,36 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) { poll.WaitOn(t, checkRetry, poll.WithTimeout(10*time.Second), poll.WithDelay(100*time.Millisecond)) }) } + +func TestSendPDUMultipleFailuresAssumedOffline(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(7) + failuresUntilAssumedOffline := uint32(2) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilAssumedOffline { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + if val, _ := db.IsServerAssumedOffline(destination); val { + return poll.Success() + } + return poll.Continue("waiting for server to be assumed offline") + } + return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} diff --git a/federationapi/statistics/statistics.go b/federationapi/statistics/statistics.go index ba484c8fa..f8bfa2457 100644 --- a/federationapi/statistics/statistics.go +++ b/federationapi/statistics/statistics.go @@ -168,6 +168,11 @@ func (s *ServerStatistics) Failure() (time.Time, bool) { if backoffCount >= s.statistics.FailuresUntilAssumedOffline { s.assumedOffline.CompareAndSwap(false, true) + if s.statistics.DB != nil { + if err := s.statistics.DB.SetServerAssumedOffline(s.serverName); err != nil { + logrus.WithError(err).Errorf("Failed to set %q as assumed offline", s.serverName) + } + } } if backoffCount >= s.statistics.FailuresUntilBlacklist { diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index c7dee9fd4..6108862f7 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -62,6 +62,12 @@ type Database interface { RemoveAllServersFromBlacklist() error IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) + // these don't have contexts passed in as we want things to happen regardless of the request context + SetServerAssumedOffline(serverName gomatrixserverlib.ServerName) error + RemoveServerAssumedOffline(serverName gomatrixserverlib.ServerName) error + RemoveAllServersAssumedOffline() error + IsServerAssumedOffline(serverName gomatrixserverlib.ServerName) (bool, error) + AddMailserversForServer(serverName gomatrixserverlib.ServerName, mailservers []gomatrixserverlib.ServerName) error GetMailserversForServer(serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error) RemoveMailserversForServer(serverName gomatrixserverlib.ServerName, mailservers []gomatrixserverlib.ServerName) error diff --git a/federationapi/storage/postgres/assumed_offline_table.go b/federationapi/storage/postgres/assumed_offline_table.go new file mode 100644 index 000000000..f102c092e --- /dev/null +++ b/federationapi/storage/postgres/assumed_offline_table.go @@ -0,0 +1,115 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const assumedOfflineSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_assumed_offline( + -- The assumed offline server name + server_name TEXT NOT NULL, + UNIQUE (server_name) +); +` + +const insertAssumedOfflineSQL = "" + + "INSERT INTO federationsender_assumed_offline (server_name) VALUES ($1)" + + " ON CONFLICT DO NOTHING" + +const selectAssumedOfflineSQL = "" + + "SELECT server_name FROM federationsender_assumed_offline WHERE server_name = $1" + +const deleteAssumedOfflineSQL = "" + + "DELETE FROM federationsender_assumed_offline WHERE server_name = $1" + +const deleteAllAssumedOfflineSQL = "" + + "TRUNCATE federationsender_assumed_offline" + +type assumedOfflineStatements struct { + db *sql.DB + insertAssumedOfflineStmt *sql.Stmt + selectAssumedOfflineStmt *sql.Stmt + deleteAssumedOfflineStmt *sql.Stmt + deleteAllAssumedOfflineStmt *sql.Stmt +} + +func NewPostgresAssumedOfflineTable(db *sql.DB) (s *assumedOfflineStatements, err error) { + s = &assumedOfflineStatements{ + db: db, + } + _, err = db.Exec(assumedOfflineSchema) + if err != nil { + return + } + + if s.insertAssumedOfflineStmt, err = db.Prepare(insertAssumedOfflineSQL); err != nil { + return + } + if s.selectAssumedOfflineStmt, err = db.Prepare(selectAssumedOfflineSQL); err != nil { + return + } + if s.deleteAssumedOfflineStmt, err = db.Prepare(deleteAssumedOfflineSQL); err != nil { + return + } + if s.deleteAllAssumedOfflineStmt, err = db.Prepare(deleteAllAssumedOfflineSQL); err != nil { + return + } + return +} + +func (s *assumedOfflineStatements) InsertAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.insertAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err +} + +func (s *assumedOfflineStatements) SelectAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) (bool, error) { + stmt := sqlutil.TxStmt(txn, s.selectAssumedOfflineStmt) + res, err := stmt.QueryContext(ctx, serverName) + if err != nil { + return false, err + } + defer res.Close() // nolint:errcheck + // The query will return the server name if the server is assume offline, and + // will return no rows if not. By calling Next, we find out if a row was + // returned or not - we don't care about the value itself. + return res.Next(), nil +} + +func (s *assumedOfflineStatements) DeleteAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err +} + +func (s *assumedOfflineStatements) DeleteAllAssumedOffline( + ctx context.Context, txn *sql.Tx, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAllAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx) + return err +} diff --git a/federationapi/storage/postgres/storage.go b/federationapi/storage/postgres/storage.go index d32528f56..60936d352 100644 --- a/federationapi/storage/postgres/storage.go +++ b/federationapi/storage/postgres/storage.go @@ -70,6 +70,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, err } + assumedOffline, err := NewPostgresAssumedOfflineTable(d.db) + if err != nil { + return nil, err + } mailservers, err := NewPostgresMailserversTable(d.db) if err != nil { return nil, err @@ -118,6 +122,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, FederationQueueTransactions: queueTransactions, FederationTransactionJSON: transactionJSON, FederationBlacklist: blacklist, + FederationAssumedOffline: assumedOffline, FederationMailservers: mailservers, FederationInboundPeeks: inboundPeeks, FederationOutboundPeeks: outboundPeeks, diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index a85dcff0b..1ec9f5d87 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -38,6 +38,7 @@ type Database struct { FederationQueueJSON tables.FederationQueueJSON FederationJoinedHosts tables.FederationJoinedHosts FederationBlacklist tables.FederationBlacklist + FederationAssumedOffline tables.FederationAssumedOffline FederationMailservers tables.FederationMailservers FederationOutboundPeeks tables.FederationOutboundPeeks FederationInboundPeeks tables.FederationInboundPeeks @@ -178,6 +179,28 @@ func (d *Database) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) return d.FederationBlacklist.SelectBlacklist(context.TODO(), nil, serverName) } +func (d *Database) SetServerAssumedOffline(serverName gomatrixserverlib.ServerName) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationAssumedOffline.InsertAssumedOffline(context.TODO(), txn, serverName) + }) +} + +func (d *Database) RemoveServerAssumedOffline(serverName gomatrixserverlib.ServerName) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationAssumedOffline.DeleteAssumedOffline(context.TODO(), txn, serverName) + }) +} + +func (d *Database) RemoveAllServersAssumedOffline() error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationAssumedOffline.DeleteAllAssumedOffline(context.TODO(), txn) + }) +} + +func (d *Database) IsServerAssumedOffline(serverName gomatrixserverlib.ServerName) (bool, error) { + return d.FederationAssumedOffline.SelectAssumedOffline(context.TODO(), nil, serverName) +} + func (d *Database) AddMailserversForServer(serverName gomatrixserverlib.ServerName, mailservers []gomatrixserverlib.ServerName) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationMailservers.InsertMailservers(context.TODO(), txn, serverName, mailservers) diff --git a/federationapi/storage/sqlite3/assumed_offline_table.go b/federationapi/storage/sqlite3/assumed_offline_table.go new file mode 100644 index 000000000..b0967cf6f --- /dev/null +++ b/federationapi/storage/sqlite3/assumed_offline_table.go @@ -0,0 +1,115 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const assumedOfflineSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_assumed_offline( + -- The assumed offline server name + server_name TEXT NOT NULL, + UNIQUE (server_name) +); +` + +const insertAssumedOfflineSQL = "" + + "INSERT INTO federationsender_assumed_offline (server_name) VALUES ($1)" + + " ON CONFLICT DO NOTHING" + +const selectAssumedOfflineSQL = "" + + "SELECT server_name FROM federationsender_assumed_offline WHERE server_name = $1" + +const deleteAssumedOfflineSQL = "" + + "DELETE FROM federationsender_assumed_offline WHERE server_name = $1" + +const deleteAllAssumedOfflineSQL = "" + + "DELETE FROM federationsender_assumed_offline" + +type assumedOfflineStatements struct { + db *sql.DB + insertAssumedOfflineStmt *sql.Stmt + selectAssumedOfflineStmt *sql.Stmt + deleteAssumedOfflineStmt *sql.Stmt + deleteAllAssumedOfflineStmt *sql.Stmt +} + +func NewSQLiteAssumedOfflineTable(db *sql.DB) (s *assumedOfflineStatements, err error) { + s = &assumedOfflineStatements{ + db: db, + } + _, err = db.Exec(assumedOfflineSchema) + if err != nil { + return + } + + if s.insertAssumedOfflineStmt, err = db.Prepare(insertAssumedOfflineSQL); err != nil { + return + } + if s.selectAssumedOfflineStmt, err = db.Prepare(selectAssumedOfflineSQL); err != nil { + return + } + if s.deleteAssumedOfflineStmt, err = db.Prepare(deleteAssumedOfflineSQL); err != nil { + return + } + if s.deleteAllAssumedOfflineStmt, err = db.Prepare(deleteAllAssumedOfflineSQL); err != nil { + return + } + return +} + +func (s *assumedOfflineStatements) InsertAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.insertAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err +} + +func (s *assumedOfflineStatements) SelectAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) (bool, error) { + stmt := sqlutil.TxStmt(txn, s.selectAssumedOfflineStmt) + res, err := stmt.QueryContext(ctx, serverName) + if err != nil { + return false, err + } + defer res.Close() // nolint:errcheck + // The query will return the server name if the server is assume offline, and + // will return no rows if not. By calling Next, we find out if a row was + // returned or not - we don't care about the value itself. + return res.Next(), nil +} + +func (s *assumedOfflineStatements) DeleteAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err +} + +func (s *assumedOfflineStatements) DeleteAllAssumedOffline( + ctx context.Context, txn *sql.Tx, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAllAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx) + return err +} diff --git a/federationapi/storage/sqlite3/storage.go b/federationapi/storage/sqlite3/storage.go index e1120c3a5..f5ac76c6a 100644 --- a/federationapi/storage/sqlite3/storage.go +++ b/federationapi/storage/sqlite3/storage.go @@ -63,6 +63,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, err } + assumedOffline, err := NewSQLiteAssumedOfflineTable(d.db) + if err != nil { + return nil, err + } mailservers, err := NewSQLiteMailserversTable(d.db) if err != nil { return nil, err @@ -111,6 +115,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, FederationQueueTransactions: queueTransactions, FederationTransactionJSON: transactionJSON, FederationBlacklist: blacklist, + FederationAssumedOffline: assumedOffline, FederationMailservers: mailservers, FederationOutboundPeeks: outboundPeeks, FederationInboundPeeks: inboundPeeks, diff --git a/federationapi/storage/tables/assumed_offline_table_test.go b/federationapi/storage/tables/assumed_offline_table_test.go new file mode 100644 index 000000000..9c855ec79 --- /dev/null +++ b/federationapi/storage/tables/assumed_offline_table_test.go @@ -0,0 +1,152 @@ +package tables_test + +import ( + "context" + "database/sql" + "testing" + + "github.com/matrix-org/dendrite/federationapi/storage/postgres" + "github.com/matrix-org/dendrite/federationapi/storage/sqlite3" + "github.com/matrix-org/dendrite/federationapi/storage/tables" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/stretchr/testify/assert" +) + +type AssumedOfflineDatabase struct { + DB *sql.DB + Writer sqlutil.Writer + Table tables.FederationAssumedOffline +} + +func mustCreateAssumedOfflineTable(t *testing.T, dbType test.DBType) (database AssumedOfflineDatabase, close func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + var tab tables.FederationAssumedOffline + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresAssumedOfflineTable(db) + assert.NoError(t, err) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSQLiteAssumedOfflineTable(db) + assert.NoError(t, err) + } + assert.NoError(t, err) + + database = AssumedOfflineDatabase{ + DB: db, + Writer: sqlutil.NewDummyWriter(), + Table: tab, + } + return database, close +} + +func TestShouldInsertAssumedOfflineServer(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateAssumedOfflineTable(t, dbType) + defer close() + + err := db.Table.InsertAssumedOffline(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed inserting server: %s", err.Error()) + } + + isOffline, err := db.Table.SelectAssumedOffline(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed retrieving server: %s", err.Error()) + } + + assert.Equal(t, isOffline, true) + }) +} + +func TestShouldDeleteCorrectAssumedOfflineServer(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateAssumedOfflineTable(t, dbType) + defer close() + + err := db.Table.InsertAssumedOffline(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed inserting server: %s", err.Error()) + } + err = db.Table.InsertAssumedOffline(ctx, nil, server2) + if err != nil { + t.Fatalf("Failed inserting server: %s", err.Error()) + } + + isOffline, err := db.Table.SelectAssumedOffline(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed retrieving server status: %s", err.Error()) + } + assert.Equal(t, isOffline, true) + + err = db.Table.DeleteAssumedOffline(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed deleting server: %s", err.Error()) + } + + isOffline, err = db.Table.SelectAssumedOffline(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed retrieving server status: %s", err.Error()) + } + assert.Equal(t, isOffline, false) + + isOffline2, err := db.Table.SelectAssumedOffline(ctx, nil, server2) + if err != nil { + t.Fatalf("Failed retrieving server status: %s", err.Error()) + } + assert.Equal(t, isOffline2, true) + }) +} + +func TestShouldDeleteAllAssumedOfflineServers(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateAssumedOfflineTable(t, dbType) + defer close() + + err := db.Table.InsertAssumedOffline(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed inserting server: %s", err.Error()) + } + err = db.Table.InsertAssumedOffline(ctx, nil, server2) + if err != nil { + t.Fatalf("Failed inserting server: %s", err.Error()) + } + + isOffline, err := db.Table.SelectAssumedOffline(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed retrieving server status: %s", err.Error()) + } + assert.Equal(t, isOffline, true) + isOffline2, err := db.Table.SelectAssumedOffline(ctx, nil, server2) + if err != nil { + t.Fatalf("Failed retrieving server status: %s", err.Error()) + } + + assert.Equal(t, isOffline2, true) + + err = db.Table.DeleteAllAssumedOffline(ctx, nil) + if err != nil { + t.Fatalf("Failed deleting server: %s", err.Error()) + } + + isOffline, err = db.Table.SelectAssumedOffline(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed retrieving server status: %s", err.Error()) + } + assert.Equal(t, isOffline, false) + isOffline2, err = db.Table.SelectAssumedOffline(ctx, nil, server2) + if err != nil { + t.Fatalf("Failed retrieving server status: %s", err.Error()) + } + assert.Equal(t, isOffline2, false) + }) +} diff --git a/federationapi/storage/tables/interface.go b/federationapi/storage/tables/interface.go index e77163221..08344540c 100644 --- a/federationapi/storage/tables/interface.go +++ b/federationapi/storage/tables/interface.go @@ -81,6 +81,13 @@ type FederationBlacklist interface { DeleteAllBlacklist(ctx context.Context, txn *sql.Tx) error } +type FederationAssumedOffline interface { + InsertAssumedOffline(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error + SelectAssumedOffline(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (bool, error) + DeleteAssumedOffline(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error + DeleteAllAssumedOffline(ctx context.Context, txn *sql.Tx) error +} + type FederationMailservers interface { InsertMailservers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, mailservers []gomatrixserverlib.ServerName) error SelectMailservers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error)