From b26caac3f312ee316b32e88021b331924477412d Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Fri, 21 Oct 2022 10:30:01 +0200 Subject: [PATCH] Associate PDU with destinations in one transaction --- federationapi/queue/destinationqueue.go | 12 ------ federationapi/queue/queue.go | 15 +++++++ federationapi/queue/queue_test.go | 41 +++++++++++--------- federationapi/storage/interface.go | 5 ++- federationapi/storage/shared/storage_pdus.go | 24 ++++++------ 5 files changed, 53 insertions(+), 44 deletions(-) diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index 768ed1f2b..6b2d640aa 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -76,18 +76,6 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re return } - // Create a database entry that associates the given PDU NID with - // this destination queue. We'll then be able to retrieve the PDU - // later. - if err := oq.db.AssociatePDUWithDestination( - oq.process.Context(), - "", // TODO: remove this, as we don't need to persist the transaction ID - oq.destination, // the destination server name - receipt, // NIDs from federationapi_queue_json table - ); err != nil { - logrus.WithError(err).Errorf("failed to associate PDU %q with destination %q", event.EventID(), oq.destination) - return - } // Check if the destination is blacklisted. If it isn't then wake // up the queue. if !oq.statistics.Blacklisted() { diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go index 68f789e37..7219be44f 100644 --- a/federationapi/queue/queue.go +++ b/federationapi/queue/queue.go @@ -24,6 +24,7 @@ import ( "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" @@ -249,9 +250,23 @@ func (oqs *OutgoingQueues) SendEvent( for destination := range destmap { if queue := oqs.getQueue(destination); queue != nil { queue.sendEvent(ev, nid) + } else { + delete(destmap, destination) } } + // Create a database entry that associates the given PDU NID with + // this destinations queue. We'll then be able to retrieve the PDU + // later. + if err := oqs.db.AssociatePDUWithDestinations( + oqs.process.Context(), + destmap, + nid, // NIDs from federationapi_queue_json table + ); err != nil { + logrus.WithError(err).Errorf("failed to associate PDUs %q with destinations", nid) + return err + } + return nil } diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go index 40419b91f..65b9d39fd 100644 --- a/federationapi/queue/queue_test.go +++ b/federationapi/queue/queue_test.go @@ -25,6 +25,10 @@ import ( "go.uber.org/atomic" "gotest.tools/v3/poll" + "github.com/matrix-org/gomatrixserverlib" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" @@ -34,9 +38,6 @@ import ( "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" - "github.com/matrix-org/gomatrixserverlib" - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" ) func mustCreateFederationDatabase(t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *process.ProcessContext, func()) { @@ -158,15 +159,18 @@ func (d *fakeDatabase) GetPendingEDUs(ctx context.Context, serverName gomatrixse return edus, nil } -func (d *fakeDatabase) AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error { +func (d *fakeDatabase) AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt) error { d.dbMutex.Lock() defer d.dbMutex.Unlock() if _, ok := d.pendingPDUs[receipt]; ok { - if _, ok := d.associatedPDUs[serverName]; !ok { - d.associatedPDUs[serverName] = make(map[*shared.Receipt]struct{}) + for destination := range destinations { + if _, ok := d.associatedPDUs[destination]; !ok { + d.associatedPDUs[destination] = make(map[*shared.Receipt]struct{}) + } + d.associatedPDUs[destination][receipt] = struct{}{} } - d.associatedPDUs[serverName][receipt] = struct{}{} + return nil } else { return errors.New("PDU doesn't exist") @@ -365,6 +369,7 @@ func TestSendPDUOnSuccessRemovedFromDB(t *testing.T) { }() ev := mustCreatePDU(t) + t.Logf("Queues: %+v", queues) err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) assert.NoError(t, err) @@ -821,15 +826,15 @@ func TestSendPDUBatches(t *testing.T) { <-pc.WaitForShutdown() }() + destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} // Populate database with > maxPDUsPerTransaction pduMultiplier := uint32(3) for i := 0; i < maxPDUsPerTransaction*int(pduMultiplier); i++ { ev := mustCreatePDU(t) headeredJSON, _ := json.Marshal(ev) nid, _ := db.StoreJSON(pc.Context(), string(headeredJSON)) - now := gomatrixserverlib.AsTimestamp(time.Now()) - transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, i)) - db.AssociatePDUWithDestination(pc.Context(), transactionID, destination, nid) + err := db.AssociatePDUWithDestinations(pc.Context(), destinations, nid) + assert.NoError(t, err, "failed to associate PDU with destinations") } ev := mustCreatePDU(t) @@ -907,16 +912,17 @@ func TestSendPDUAndEDUBatches(t *testing.T) { <-pc.WaitForShutdown() }() + destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} // Populate database with > maxEDUsPerTransaction multiplier := uint32(3) - for i := 0; i < maxPDUsPerTransaction*int(multiplier)+1; i++ { ev := mustCreatePDU(t) headeredJSON, _ := json.Marshal(ev) nid, _ := db.StoreJSON(pc.Context(), string(headeredJSON)) - now := gomatrixserverlib.AsTimestamp(time.Now()) - transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, i)) - db.AssociatePDUWithDestination(pc.Context(), transactionID, destination, nid) + t.Logf("DB: %+v", db) + t.Logf("Destinations: %+v", destinations) + err := db.AssociatePDUWithDestinations(pc.Context(), destinations, nid) + assert.NoError(t, err, "failed to associate PDU with destinations") } for i := 0; i < maxEDUsPerTransaction*int(multiplier); i++ { @@ -960,13 +966,12 @@ func TestExternalFailureBackoffDoesntStartQueue(t *testing.T) { dest := queues.getQueue(destination) queues.statistics.ForServer(destination).Failure() - + destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} ev := mustCreatePDU(t) headeredJSON, _ := json.Marshal(ev) nid, _ := db.StoreJSON(pc.Context(), string(headeredJSON)) - now := gomatrixserverlib.AsTimestamp(time.Now()) - transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, 1)) - db.AssociatePDUWithDestination(pc.Context(), transactionID, destination, nid) + err := db.AssociatePDUWithDestinations(pc.Context(), destinations, nid) + assert.NoError(t, err, "failed to associate PDU with destinations") pollEnd := time.Now().Add(3 * time.Second) runningCheck := func(log poll.LogT) poll.Result { diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index b8109b432..fbfe0672a 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -18,9 +18,10 @@ import ( "context" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/federationapi/storage/shared" "github.com/matrix-org/dendrite/federationapi/types" - "github.com/matrix-org/gomatrixserverlib" ) type Database interface { @@ -38,7 +39,7 @@ type Database interface { GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error) GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error) - AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error + AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt) error AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error diff --git a/federationapi/storage/shared/storage_pdus.go b/federationapi/storage/shared/storage_pdus.go index 5a12c388a..dc37d7507 100644 --- a/federationapi/storage/shared/storage_pdus.go +++ b/federationapi/storage/shared/storage_pdus.go @@ -27,23 +27,23 @@ import ( // AssociatePDUWithDestination creates an association that the // destination queues will use to determine which JSON blobs to send // to which servers. -func (d *Database) AssociatePDUWithDestination( +func (d *Database) AssociatePDUWithDestinations( ctx context.Context, - transactionID gomatrixserverlib.TransactionID, - serverName gomatrixserverlib.ServerName, + destinations map[gomatrixserverlib.ServerName]struct{}, receipt *Receipt, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - if err := d.FederationQueuePDUs.InsertQueuePDU( - ctx, // context - txn, // SQL transaction - transactionID, // transaction ID - serverName, // destination server name - receipt.nid, // NID from the federationapi_queue_json table - ); err != nil { - return fmt.Errorf("InsertQueuePDU: %w", err) + var err error + for destination := range destinations { + err = d.FederationQueuePDUs.InsertQueuePDU( + ctx, // context + txn, // SQL transaction + "", // transaction ID + destination, // destination server name + receipt.nid, // NID from the federationapi_queue_json table + ) } - return nil + return err }) }