Federation backoff fixes and tests (#2792)

This fixes some edge cases where federation queue backoffs and
blacklisting weren't behaving as expected.
It also adds new tests for the federation queues to ensure their
behaviour continues to work correctly.
This commit is contained in:
devonh 2022-10-13 14:38:13 +00:00 committed by GitHub
parent 23a3e04579
commit dcedd1b6bf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 441 additions and 2 deletions

View file

@ -75,6 +75,7 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re
logrus.Errorf("attempt to send nil PDU with destination %q", oq.destination)
return
}
// Create a database entry that associates the given PDU NID with
// this destination queue. We'll then be able to retrieve the PDU
// later.
@ -108,6 +109,8 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re
case oq.notify <- struct{}{}:
default:
}
} else {
oq.overflowed.Store(true)
}
}
@ -153,6 +156,8 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share
case oq.notify <- struct{}{}:
default:
}
} else {
oq.overflowed.Store(true)
}
}
@ -335,6 +340,11 @@ func (oq *destinationQueue) backgroundSend() {
// We failed to send the transaction. Mark it as a failure.
oq.statistics.Failure()
// Queue up another attempt since the transaction failed.
select {
case oq.notify <- struct{}{}:
default:
}
} else if transaction {
// If we successfully sent the transaction then clear out
// the pending events and EDUs, and wipe our transaction ID.

View file

@ -332,6 +332,7 @@ func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) {
if oqs.disabled {
return
}
oqs.statistics.ForServer(srv).RemoveBlacklist()
if queue := oqs.getQueue(srv); queue != nil {
queue.wakeQueueIfNeeded()
}

View file

@ -0,0 +1,422 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package queue
import (
"context"
"encoding/json"
"fmt"
"sync"
"testing"
"time"
"go.uber.org/atomic"
"gotest.tools/v3/poll"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/statistics"
"github.com/matrix-org/dendrite/federationapi/storage"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
)
var dbMutex sync.Mutex
type fakeDatabase struct {
storage.Database
pendingPDUServers map[gomatrixserverlib.ServerName]struct{}
pendingEDUServers map[gomatrixserverlib.ServerName]struct{}
blacklistedServers map[gomatrixserverlib.ServerName]struct{}
pendingPDUs map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent
pendingEDUs map[*shared.Receipt]*gomatrixserverlib.EDU
associatedPDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}
associatedEDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}
}
func (d *fakeDatabase) StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) {
dbMutex.Lock()
defer dbMutex.Unlock()
var event gomatrixserverlib.HeaderedEvent
if err := json.Unmarshal([]byte(js), &event); err == nil {
receipt := &shared.Receipt{}
d.pendingPDUs[receipt] = &event
return receipt, nil
}
var edu gomatrixserverlib.EDU
if err := json.Unmarshal([]byte(js), &edu); err == nil {
receipt := &shared.Receipt{}
d.pendingEDUs[receipt] = &edu
return receipt, nil
}
return nil, errors.New("Failed to determine type of json to store")
}
func (d *fakeDatabase) GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error) {
dbMutex.Lock()
defer dbMutex.Unlock()
pdus = make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent)
if receipts, ok := d.associatedPDUs[serverName]; ok {
for receipt := range receipts {
if event, ok := d.pendingPDUs[receipt]; ok {
pdus[receipt] = event
}
}
}
return pdus, nil
}
func (d *fakeDatabase) GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error) {
dbMutex.Lock()
defer dbMutex.Unlock()
edus = make(map[*shared.Receipt]*gomatrixserverlib.EDU)
if receipts, ok := d.associatedEDUs[serverName]; ok {
for receipt := range receipts {
if event, ok := d.pendingEDUs[receipt]; ok {
edus[receipt] = event
}
}
}
return edus, nil
}
func (d *fakeDatabase) AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error {
dbMutex.Lock()
defer dbMutex.Unlock()
if _, ok := d.pendingPDUs[receipt]; ok {
if _, ok := d.associatedPDUs[serverName]; !ok {
d.associatedPDUs[serverName] = make(map[*shared.Receipt]struct{})
}
d.associatedPDUs[serverName][receipt] = struct{}{}
return nil
} else {
return errors.New("PDU doesn't exist")
}
}
func (d *fakeDatabase) AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error {
dbMutex.Lock()
defer dbMutex.Unlock()
if _, ok := d.pendingEDUs[receipt]; ok {
if _, ok := d.associatedEDUs[serverName]; !ok {
d.associatedEDUs[serverName] = make(map[*shared.Receipt]struct{})
}
d.associatedEDUs[serverName][receipt] = struct{}{}
return nil
} else {
return errors.New("EDU doesn't exist")
}
}
func (d *fakeDatabase) CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error {
dbMutex.Lock()
defer dbMutex.Unlock()
if pdus, ok := d.associatedPDUs[serverName]; ok {
for _, receipt := range receipts {
delete(pdus, receipt)
}
}
return nil
}
func (d *fakeDatabase) CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error {
dbMutex.Lock()
defer dbMutex.Unlock()
if edus, ok := d.associatedEDUs[serverName]; ok {
for _, receipt := range receipts {
delete(edus, receipt)
}
}
return nil
}
func (d *fakeDatabase) GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) {
dbMutex.Lock()
defer dbMutex.Unlock()
var count int64
if pdus, ok := d.associatedPDUs[serverName]; ok {
count = int64(len(pdus))
}
return count, nil
}
func (d *fakeDatabase) GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) {
dbMutex.Lock()
defer dbMutex.Unlock()
var count int64
if edus, ok := d.associatedEDUs[serverName]; ok {
count = int64(len(edus))
}
return count, nil
}
func (d *fakeDatabase) GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) {
dbMutex.Lock()
defer dbMutex.Unlock()
servers := []gomatrixserverlib.ServerName{}
for server := range d.pendingPDUServers {
servers = append(servers, server)
}
return servers, nil
}
func (d *fakeDatabase) GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) {
dbMutex.Lock()
defer dbMutex.Unlock()
servers := []gomatrixserverlib.ServerName{}
for server := range d.pendingEDUServers {
servers = append(servers, server)
}
return servers, nil
}
func (d *fakeDatabase) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error {
dbMutex.Lock()
defer dbMutex.Unlock()
d.blacklistedServers[serverName] = struct{}{}
return nil
}
func (d *fakeDatabase) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error {
dbMutex.Lock()
defer dbMutex.Unlock()
delete(d.blacklistedServers, serverName)
return nil
}
func (d *fakeDatabase) RemoveAllServersFromBlacklist() error {
dbMutex.Lock()
defer dbMutex.Unlock()
d.blacklistedServers = make(map[gomatrixserverlib.ServerName]struct{})
return nil
}
func (d *fakeDatabase) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) {
dbMutex.Lock()
defer dbMutex.Unlock()
isBlacklisted := false
if _, ok := d.blacklistedServers[serverName]; ok {
isBlacklisted = true
}
return isBlacklisted, nil
}
type stubFederationRoomServerAPI struct {
rsapi.FederationRoomserverAPI
}
func (r *stubFederationRoomServerAPI) QueryServerBannedFromRoom(ctx context.Context, req *rsapi.QueryServerBannedFromRoomRequest, res *rsapi.QueryServerBannedFromRoomResponse) error {
res.Banned = false
return nil
}
type stubFederationClient struct {
api.FederationClient
shouldTxSucceed bool
txCount atomic.Uint32
}
func (f *stubFederationClient) SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) {
var result error
if !f.shouldTxSucceed {
result = fmt.Errorf("transaction failed")
}
f.txCount.Add(1)
return gomatrixserverlib.RespSend{}, result
}
func createDatabase() storage.Database {
return &fakeDatabase{
pendingPDUServers: make(map[gomatrixserverlib.ServerName]struct{}),
pendingEDUServers: make(map[gomatrixserverlib.ServerName]struct{}),
blacklistedServers: make(map[gomatrixserverlib.ServerName]struct{}),
pendingPDUs: make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent),
pendingEDUs: make(map[*shared.Receipt]*gomatrixserverlib.EDU),
associatedPDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}),
}
}
func mustCreateEvent(t *testing.T) *gomatrixserverlib.HeaderedEvent {
t.Helper()
content := `{"type":"m.room.message"}`
ev, err := gomatrixserverlib.NewEventFromTrustedJSON([]byte(content), false, gomatrixserverlib.RoomVersionV10)
if err != nil {
t.Fatalf("failed to create event: %v", err)
}
return ev.Headered(gomatrixserverlib.RoomVersionV10)
}
func testSetup(failuresUntilBlacklist uint32, shouldTxSucceed bool) (storage.Database, *stubFederationClient, *OutgoingQueues) {
db := createDatabase()
fc := &stubFederationClient{
shouldTxSucceed: shouldTxSucceed,
txCount: *atomic.NewUint32(0),
}
rs := &stubFederationRoomServerAPI{}
stats := &statistics.Statistics{
DB: db,
FailuresUntilBlacklist: failuresUntilBlacklist,
}
signingInfo := &SigningInfo{
KeyID: "ed25519:auto",
PrivateKey: test.PrivateKeyA,
ServerName: "localhost",
}
queues := NewOutgoingQueues(db, process.NewProcessContext(), false, "localhost", fc, rs, stats, signingInfo)
return db, fc, queues
}
func TestSendTransactionOnSuccessRemovedFromDB(t *testing.T) {
ctx := context.Background()
failuresUntilBlacklist := uint32(16)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues := testSetup(failuresUntilBlacklist, true)
ev := mustCreateEvent(t)
err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination})
assert.NoError(t, err)
check := func(log poll.LogT) poll.Result {
if fc.txCount.Load() >= 1 {
data, err := db.GetPendingPDUs(ctx, destination, 100)
assert.NoError(t, err)
if len(data) == 0 {
return poll.Success()
}
return poll.Continue("waiting for event to be removed from database")
}
return poll.Continue("waiting for more send attempts before checking database")
}
poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
}
func TestSendTransactionOnFailStoredInDB(t *testing.T) {
ctx := context.Background()
failuresUntilBlacklist := uint32(16)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues := testSetup(failuresUntilBlacklist, false)
ev := mustCreateEvent(t)
err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination})
assert.NoError(t, err)
check := func(log poll.LogT) poll.Result {
// Wait for 2 backoff attempts to ensure there was adequate time to attempt sending
if fc.txCount.Load() >= 2 {
data, err := db.GetPendingPDUs(ctx, destination, 100)
assert.NoError(t, err)
if len(data) == 1 {
return poll.Success()
}
return poll.Continue("waiting for event to be added to database")
}
return poll.Continue("waiting for more send attempts before checking database")
}
poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
}
func TestSendTransactionMultipleFailuresBlacklisted(t *testing.T) {
ctx := context.Background()
failuresUntilBlacklist := uint32(2)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues := testSetup(failuresUntilBlacklist, false)
ev := mustCreateEvent(t)
err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination})
assert.NoError(t, err)
check := func(log poll.LogT) poll.Result {
if fc.txCount.Load() >= failuresUntilBlacklist {
data, err := db.GetPendingPDUs(ctx, destination, 100)
assert.NoError(t, err)
if len(data) == 1 {
if val, _ := db.IsServerBlacklisted(destination); val {
return poll.Success()
}
return poll.Continue("waiting for server to be blacklisted")
}
return poll.Continue("waiting for event to be added to database")
}
return poll.Continue("waiting for more send attempts before checking database")
}
poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
}
func TestRetryServerSendsSuccessfully(t *testing.T) {
ctx := context.Background()
failuresUntilBlacklist := uint32(1)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues := testSetup(failuresUntilBlacklist, false)
ev := mustCreateEvent(t)
err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination})
assert.NoError(t, err)
checkBlacklisted := func(log poll.LogT) poll.Result {
if fc.txCount.Load() >= failuresUntilBlacklist {
data, err := db.GetPendingPDUs(ctx, destination, 100)
assert.NoError(t, err)
if len(data) == 1 {
if val, _ := db.IsServerBlacklisted(destination); val {
return poll.Success()
}
return poll.Continue("waiting for server to be blacklisted")
}
return poll.Continue("waiting for event to be added to database")
}
return poll.Continue("waiting for more send attempts before checking database")
}
poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
fc.shouldTxSucceed = true
db.RemoveServerFromBlacklist(destination)
queues.RetryServer(destination)
checkRetry := func(log poll.LogT) poll.Result {
data, err := db.GetPendingPDUs(ctx, destination, 100)
assert.NoError(t, err)
if len(data) == 0 {
return poll.Success()
}
return poll.Continue("waiting for event to be removed from database")
}
poll.WaitOn(t, checkRetry, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
}

View file

@ -95,8 +95,8 @@ func (s *ServerStatistics) cancel() {
// we will unblacklist it.
func (s *ServerStatistics) Success() {
s.cancel()
s.successCounter.Inc()
s.backoffCount.Store(0)
s.successCounter.Inc()
if s.statistics.DB != nil {
if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil {
logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName)
@ -174,6 +174,12 @@ func (s *ServerStatistics) Blacklisted() bool {
return s.blacklisted.Load()
}
// RemoveBlacklist removes the blacklisted status from the server.
func (s *ServerStatistics) RemoveBlacklist() {
s.cancel()
s.backoffCount.Store(0)
}
// SuccessCount returns the number of successful requests. This is
// usually useful in constructing transaction IDs.
func (s *ServerStatistics) SuccessCount() uint32 {

2
go.mod
View file

@ -50,6 +50,7 @@ require (
golang.org/x/term v0.0.0-20220919170432-7a66f970e087
gopkg.in/h2non/bimg.v1 v1.1.9
gopkg.in/yaml.v2 v2.4.0
gotest.tools/v3 v3.0.3
nhooyr.io/websocket v1.8.7
)
@ -128,7 +129,6 @@ require (
gopkg.in/macaroon.v2 v2.1.0 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
gotest.tools/v3 v3.0.3 // indirect
)
go 1.18