From 58047f55b501d3804da27af9ae5cb47e0fb46640 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Fri, 21 Oct 2022 10:38:06 +0200 Subject: [PATCH] Associate EDU with destinations in one transaction --- federationapi/queue/destinationqueue.go | 13 ---------- federationapi/queue/queue.go | 16 +++++++++++++ federationapi/queue/queue_test.go | 22 +++++++++++------ federationapi/storage/interface.go | 2 +- federationapi/storage/shared/storage_edus.go | 25 ++++++++++---------- federationapi/storage/storage_test.go | 7 +++--- 6 files changed, 49 insertions(+), 36 deletions(-) diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index 6b2d640aa..256359ad0 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -108,19 +108,6 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share logrus.Errorf("attempt to send nil EDU with destination %q", oq.destination) return } - // Create a database entry that associates the given PDU NID with - // this destination queue. We'll then be able to retrieve the PDU - // later. - if err := oq.db.AssociateEDUWithDestination( - oq.process.Context(), - oq.destination, // the destination server name - receipt, // NIDs from federationapi_queue_json table - event.Type, - nil, // this will use the default expireEDUTypes map - ); err != nil { - logrus.WithError(err).Errorf("failed to associate EDU with destination %q", oq.destination) - return - } // Check if the destination is blacklisted. If it isn't then wake // up the queue. if !oq.statistics.Blacklisted() { diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go index 7219be44f..fb6d45a90 100644 --- a/federationapi/queue/queue.go +++ b/federationapi/queue/queue.go @@ -338,9 +338,25 @@ func (oqs *OutgoingQueues) SendEDU( for destination := range destmap { if queue := oqs.getQueue(destination); queue != nil { queue.sendEDU(e, nid) + } else { + delete(destmap, destination) } } + // Create a database entry that associates the given PDU NID with + // this destination queue. We'll then be able to retrieve the PDU + // later. + if err := oqs.db.AssociateEDUWithDestinations( + oqs.process.Context(), + destmap, // the destination server name + nid, // NIDs from federationapi_queue_json table + e.Type, + nil, // this will use the default expireEDUTypes map + ); err != nil { + logrus.WithError(err).Errorf("failed to associate EDU with destinations") + return err + } + return nil } diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go index 65b9d39fd..00bacd598 100644 --- a/federationapi/queue/queue_test.go +++ b/federationapi/queue/queue_test.go @@ -177,15 +177,18 @@ func (d *fakeDatabase) AssociatePDUWithDestinations(ctx context.Context, destina } } -func (d *fakeDatabase) AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error { +func (d *fakeDatabase) AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error { d.dbMutex.Lock() defer d.dbMutex.Unlock() if _, ok := d.pendingEDUs[receipt]; ok { - if _, ok := d.associatedEDUs[serverName]; !ok { - d.associatedEDUs[serverName] = make(map[*shared.Receipt]struct{}) + for destination := range destinations { + if _, ok := d.associatedEDUs[destination]; !ok { + d.associatedEDUs[destination] = make(map[*shared.Receipt]struct{}) + } + d.associatedEDUs[destination][receipt] = struct{}{} } - d.associatedEDUs[serverName][receipt] = struct{}{} + return nil } else { return errors.New("EDU doesn't exist") @@ -870,13 +873,15 @@ func TestSendEDUBatches(t *testing.T) { <-pc.WaitForShutdown() }() + destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} // Populate database with > maxEDUsPerTransaction eduMultiplier := uint32(3) for i := 0; i < maxEDUsPerTransaction*int(eduMultiplier); i++ { ev := mustCreateEDU(t) ephemeralJSON, _ := json.Marshal(ev) nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON)) - db.AssociateEDUWithDestination(pc.Context(), destination, nid, ev.Type, nil) + err := db.AssociateEDUWithDestinations(pc.Context(), destinations, nid, ev.Type, nil) + assert.NoError(t, err, "failed to associate EDU with destinations") } ev := mustCreateEDU(t) @@ -929,7 +934,8 @@ func TestSendPDUAndEDUBatches(t *testing.T) { ev := mustCreateEDU(t) ephemeralJSON, _ := json.Marshal(ev) nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON)) - db.AssociateEDUWithDestination(pc.Context(), destination, nid, ev.Type, nil) + err := db.AssociateEDUWithDestinations(pc.Context(), destinations, nid, ev.Type, nil) + assert.NoError(t, err, "failed to associate EDU with destinations") } ev := mustCreateEDU(t) @@ -993,6 +999,7 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(1) destination := gomatrixserverlib.ServerName("remotehost") + destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, dbType, true) // NOTE : These defers aren't called if go test is killed so the dbs may not get cleaned up. @@ -1014,7 +1021,8 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) { edu := mustCreateEDU(t) ephemeralJSON, _ := json.Marshal(edu) nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON)) - db.AssociateEDUWithDestination(pc.Context(), destination, nid, edu.Type, nil) + err = db.AssociateEDUWithDestinations(pc.Context(), destinations, nid, edu.Type, nil) + assert.NoError(t, err, "failed to associate EDU with destinations") checkBlacklisted := func(log poll.LogT) poll.Result { if fc.txCount.Load() == failuresUntilBlacklist { diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index fbfe0672a..09098cd1e 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -40,7 +40,7 @@ type Database interface { GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error) AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt) error - AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error + AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error diff --git a/federationapi/storage/shared/storage_edus.go b/federationapi/storage/shared/storage_edus.go index e0c740c11..c796d2f8f 100644 --- a/federationapi/storage/shared/storage_edus.go +++ b/federationapi/storage/shared/storage_edus.go @@ -38,9 +38,9 @@ var defaultExpireEDUTypes = map[string]time.Duration{ // AssociateEDUWithDestination creates an association that the // destination queues will use to determine which JSON blobs to send // to which servers. -func (d *Database) AssociateEDUWithDestination( +func (d *Database) AssociateEDUWithDestinations( ctx context.Context, - serverName gomatrixserverlib.ServerName, + destinations map[gomatrixserverlib.ServerName]struct{}, receipt *Receipt, eduType string, expireEDUTypes map[string]time.Duration, @@ -59,17 +59,18 @@ func (d *Database) AssociateEDUWithDestination( expiresAt = 0 } return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - if err := d.FederationQueueEDUs.InsertQueueEDU( - ctx, // context - txn, // SQL transaction - eduType, // EDU type for coalescing - serverName, // destination server name - receipt.nid, // NID from the federationapi_queue_json table - expiresAt, // The timestamp this EDU will expire - ); err != nil { - return fmt.Errorf("InsertQueueEDU: %w", err) + var err error + for destination := range destinations { + err = d.FederationQueueEDUs.InsertQueueEDU( + ctx, // context + txn, // SQL transaction + eduType, // EDU type for coalescing + destination, // destination server name + receipt.nid, // NID from the federationapi_queue_json table + expiresAt, // The timestamp this EDU will expire + ) } - return nil + return err }) } diff --git a/federationapi/storage/storage_test.go b/federationapi/storage/storage_test.go index 3b0268e55..6272fd2b1 100644 --- a/federationapi/storage/storage_test.go +++ b/federationapi/storage/storage_test.go @@ -35,6 +35,7 @@ func TestExpireEDUs(t *testing.T) { } ctx := context.Background() + destinations := map[gomatrixserverlib.ServerName]struct{}{"localhost": {}} test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateFederationDatabase(t, dbType) defer close() @@ -43,7 +44,7 @@ func TestExpireEDUs(t *testing.T) { receipt, err := db.StoreJSON(ctx, "{}") assert.NoError(t, err) - err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, gomatrixserverlib.MReceipt, expireEDUTypes) + err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, gomatrixserverlib.MReceipt, expireEDUTypes) assert.NoError(t, err) } // add data without expiry @@ -51,7 +52,7 @@ func TestExpireEDUs(t *testing.T) { assert.NoError(t, err) // m.read_marker gets the default expiry of 24h, so won't be deleted further down in this test - err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, "m.read_marker", expireEDUTypes) + err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, "m.read_marker", expireEDUTypes) assert.NoError(t, err) // Delete expired EDUs @@ -67,7 +68,7 @@ func TestExpireEDUs(t *testing.T) { receipt, err = db.StoreJSON(ctx, "{}") assert.NoError(t, err) - err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, gomatrixserverlib.MDirectToDevice, expireEDUTypes) + err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, gomatrixserverlib.MDirectToDevice, expireEDUTypes) assert.NoError(t, err) err = db.DeleteExpiredEDUs(ctx)