diff --git a/federationapi/storage/shared/storage_edus.go b/federationapi/storage/shared/storage_edus.go index b62e5d9c5..ce9632ed3 100644 --- a/federationapi/storage/shared/storage_edus.go +++ b/federationapi/storage/shared/storage_edus.go @@ -110,6 +110,7 @@ func (d *Database) GetPendingEDUs( return fmt.Errorf("json.Unmarshal: %w", err) } edus[&Receipt{nid}] = &event + d.Cache.StoreFederationQueuedEDU(nid, &event) } return nil @@ -177,20 +178,18 @@ func (d *Database) GetPendingEDUServerNames( return d.FederationQueueEDUs.SelectQueueEDUServerNames(ctx, nil) } -// DeleteExpiredEDUs deletes expired EDUs +// DeleteExpiredEDUs deletes expired EDUs and evicts them from the cache. func (d *Database) DeleteExpiredEDUs(ctx context.Context) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + var jsonNIDs []int64 + err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) (err error) { expiredBefore := gomatrixserverlib.AsTimestamp(time.Now()) - jsonNIDs, err := d.FederationQueueEDUs.SelectExpiredEDUs(ctx, txn, expiredBefore) + jsonNIDs, err = d.FederationQueueEDUs.SelectExpiredEDUs(ctx, txn, expiredBefore) if err != nil { return err } if len(jsonNIDs) == 0 { return nil } - for i := range jsonNIDs { - d.Cache.EvictFederationQueuedEDU(jsonNIDs[i]) - } if err = d.FederationQueueJSON.DeleteQueueJSON(ctx, txn, jsonNIDs); err != nil { return err @@ -198,4 +197,14 @@ func (d *Database) DeleteExpiredEDUs(ctx context.Context) error { return d.FederationQueueEDUs.DeleteExpiredEDUs(ctx, txn, expiredBefore) }) + + if err != nil { + return err + } + + for i := range jsonNIDs { + d.Cache.EvictFederationQueuedEDU(jsonNIDs[i]) + } + + return nil }