Associate EDU with destinations in one transaction

This commit is contained in:
Till Faelligen 2022-10-21 10:38:06 +02:00
parent b26caac3f3
commit 58047f55b5
No known key found for this signature in database
GPG key ID: 3DF82D8AB9211D4E
6 changed files with 49 additions and 36 deletions

View file

@ -108,19 +108,6 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share
logrus.Errorf("attempt to send nil EDU with destination %q", oq.destination) logrus.Errorf("attempt to send nil EDU with destination %q", oq.destination)
return return
} }
// Create a database entry that associates the given PDU NID with
// this destination queue. We'll then be able to retrieve the PDU
// later.
if err := oq.db.AssociateEDUWithDestination(
oq.process.Context(),
oq.destination, // the destination server name
receipt, // NIDs from federationapi_queue_json table
event.Type,
nil, // this will use the default expireEDUTypes map
); err != nil {
logrus.WithError(err).Errorf("failed to associate EDU with destination %q", oq.destination)
return
}
// Check if the destination is blacklisted. If it isn't then wake // Check if the destination is blacklisted. If it isn't then wake
// up the queue. // up the queue.
if !oq.statistics.Blacklisted() { if !oq.statistics.Blacklisted() {

View file

@ -338,9 +338,25 @@ func (oqs *OutgoingQueues) SendEDU(
for destination := range destmap { for destination := range destmap {
if queue := oqs.getQueue(destination); queue != nil { if queue := oqs.getQueue(destination); queue != nil {
queue.sendEDU(e, nid) queue.sendEDU(e, nid)
} else {
delete(destmap, destination)
} }
} }
// Create a database entry that associates the given PDU NID with
// this destination queue. We'll then be able to retrieve the PDU
// later.
if err := oqs.db.AssociateEDUWithDestinations(
oqs.process.Context(),
destmap, // the destination server name
nid, // NIDs from federationapi_queue_json table
e.Type,
nil, // this will use the default expireEDUTypes map
); err != nil {
logrus.WithError(err).Errorf("failed to associate EDU with destinations")
return err
}
return nil return nil
} }

View file

@ -177,15 +177,18 @@ func (d *fakeDatabase) AssociatePDUWithDestinations(ctx context.Context, destina
} }
} }
func (d *fakeDatabase) AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error { func (d *fakeDatabase) AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error {
d.dbMutex.Lock() d.dbMutex.Lock()
defer d.dbMutex.Unlock() defer d.dbMutex.Unlock()
if _, ok := d.pendingEDUs[receipt]; ok { if _, ok := d.pendingEDUs[receipt]; ok {
if _, ok := d.associatedEDUs[serverName]; !ok { for destination := range destinations {
d.associatedEDUs[serverName] = make(map[*shared.Receipt]struct{}) if _, ok := d.associatedEDUs[destination]; !ok {
d.associatedEDUs[destination] = make(map[*shared.Receipt]struct{})
}
d.associatedEDUs[destination][receipt] = struct{}{}
} }
d.associatedEDUs[serverName][receipt] = struct{}{}
return nil return nil
} else { } else {
return errors.New("EDU doesn't exist") return errors.New("EDU doesn't exist")
@ -870,13 +873,15 @@ func TestSendEDUBatches(t *testing.T) {
<-pc.WaitForShutdown() <-pc.WaitForShutdown()
}() }()
destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}}
// Populate database with > maxEDUsPerTransaction // Populate database with > maxEDUsPerTransaction
eduMultiplier := uint32(3) eduMultiplier := uint32(3)
for i := 0; i < maxEDUsPerTransaction*int(eduMultiplier); i++ { for i := 0; i < maxEDUsPerTransaction*int(eduMultiplier); i++ {
ev := mustCreateEDU(t) ev := mustCreateEDU(t)
ephemeralJSON, _ := json.Marshal(ev) ephemeralJSON, _ := json.Marshal(ev)
nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON)) nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON))
db.AssociateEDUWithDestination(pc.Context(), destination, nid, ev.Type, nil) err := db.AssociateEDUWithDestinations(pc.Context(), destinations, nid, ev.Type, nil)
assert.NoError(t, err, "failed to associate EDU with destinations")
} }
ev := mustCreateEDU(t) ev := mustCreateEDU(t)
@ -929,7 +934,8 @@ func TestSendPDUAndEDUBatches(t *testing.T) {
ev := mustCreateEDU(t) ev := mustCreateEDU(t)
ephemeralJSON, _ := json.Marshal(ev) ephemeralJSON, _ := json.Marshal(ev)
nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON)) nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON))
db.AssociateEDUWithDestination(pc.Context(), destination, nid, ev.Type, nil) err := db.AssociateEDUWithDestinations(pc.Context(), destinations, nid, ev.Type, nil)
assert.NoError(t, err, "failed to associate EDU with destinations")
} }
ev := mustCreateEDU(t) ev := mustCreateEDU(t)
@ -993,6 +999,7 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) {
t.Parallel() t.Parallel()
failuresUntilBlacklist := uint32(1) failuresUntilBlacklist := uint32(1)
destination := gomatrixserverlib.ServerName("remotehost") destination := gomatrixserverlib.ServerName("remotehost")
destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}}
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, dbType, true) db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, dbType, true)
// NOTE : These defers aren't called if go test is killed so the dbs may not get cleaned up. // NOTE : These defers aren't called if go test is killed so the dbs may not get cleaned up.
@ -1014,7 +1021,8 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) {
edu := mustCreateEDU(t) edu := mustCreateEDU(t)
ephemeralJSON, _ := json.Marshal(edu) ephemeralJSON, _ := json.Marshal(edu)
nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON)) nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON))
db.AssociateEDUWithDestination(pc.Context(), destination, nid, edu.Type, nil) err = db.AssociateEDUWithDestinations(pc.Context(), destinations, nid, edu.Type, nil)
assert.NoError(t, err, "failed to associate EDU with destinations")
checkBlacklisted := func(log poll.LogT) poll.Result { checkBlacklisted := func(log poll.LogT) poll.Result {
if fc.txCount.Load() == failuresUntilBlacklist { if fc.txCount.Load() == failuresUntilBlacklist {

View file

@ -40,7 +40,7 @@ type Database interface {
GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error) GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error)
AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt) error AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt) error
AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error
CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error

View file

@ -38,9 +38,9 @@ var defaultExpireEDUTypes = map[string]time.Duration{
// AssociateEDUWithDestination creates an association that the // AssociateEDUWithDestination creates an association that the
// destination queues will use to determine which JSON blobs to send // destination queues will use to determine which JSON blobs to send
// to which servers. // to which servers.
func (d *Database) AssociateEDUWithDestination( func (d *Database) AssociateEDUWithDestinations(
ctx context.Context, ctx context.Context,
serverName gomatrixserverlib.ServerName, destinations map[gomatrixserverlib.ServerName]struct{},
receipt *Receipt, receipt *Receipt,
eduType string, eduType string,
expireEDUTypes map[string]time.Duration, expireEDUTypes map[string]time.Duration,
@ -59,17 +59,18 @@ func (d *Database) AssociateEDUWithDestination(
expiresAt = 0 expiresAt = 0
} }
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if err := d.FederationQueueEDUs.InsertQueueEDU( var err error
ctx, // context for destination := range destinations {
txn, // SQL transaction err = d.FederationQueueEDUs.InsertQueueEDU(
eduType, // EDU type for coalescing ctx, // context
serverName, // destination server name txn, // SQL transaction
receipt.nid, // NID from the federationapi_queue_json table eduType, // EDU type for coalescing
expiresAt, // The timestamp this EDU will expire destination, // destination server name
); err != nil { receipt.nid, // NID from the federationapi_queue_json table
return fmt.Errorf("InsertQueueEDU: %w", err) expiresAt, // The timestamp this EDU will expire
)
} }
return nil return err
}) })
} }

View file

@ -35,6 +35,7 @@ func TestExpireEDUs(t *testing.T) {
} }
ctx := context.Background() ctx := context.Background()
destinations := map[gomatrixserverlib.ServerName]struct{}{"localhost": {}}
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateFederationDatabase(t, dbType) db, close := mustCreateFederationDatabase(t, dbType)
defer close() defer close()
@ -43,7 +44,7 @@ func TestExpireEDUs(t *testing.T) {
receipt, err := db.StoreJSON(ctx, "{}") receipt, err := db.StoreJSON(ctx, "{}")
assert.NoError(t, err) assert.NoError(t, err)
err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, gomatrixserverlib.MReceipt, expireEDUTypes) err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, gomatrixserverlib.MReceipt, expireEDUTypes)
assert.NoError(t, err) assert.NoError(t, err)
} }
// add data without expiry // add data without expiry
@ -51,7 +52,7 @@ func TestExpireEDUs(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
// m.read_marker gets the default expiry of 24h, so won't be deleted further down in this test // m.read_marker gets the default expiry of 24h, so won't be deleted further down in this test
err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, "m.read_marker", expireEDUTypes) err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, "m.read_marker", expireEDUTypes)
assert.NoError(t, err) assert.NoError(t, err)
// Delete expired EDUs // Delete expired EDUs
@ -67,7 +68,7 @@ func TestExpireEDUs(t *testing.T) {
receipt, err = db.StoreJSON(ctx, "{}") receipt, err = db.StoreJSON(ctx, "{}")
assert.NoError(t, err) assert.NoError(t, err)
err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, gomatrixserverlib.MDirectToDevice, expireEDUTypes) err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, gomatrixserverlib.MDirectToDevice, expireEDUTypes)
assert.NoError(t, err) assert.NoError(t, err)
err = db.DeleteExpiredEDUs(ctx) err = db.DeleteExpiredEDUs(ctx)