diff --git a/federationsender/storage/postgres/storage.go b/federationsender/storage/postgres/storage.go index c2763c16c..75b54bbcb 100644 --- a/federationsender/storage/postgres/storage.go +++ b/federationsender/storage/postgres/storage.go @@ -33,7 +33,7 @@ type Database struct { } // NewDatabase opens a new database -func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationEventCache) (*Database, error) { +func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationSenderCache) (*Database, error) { var d Database var err error if d.db, err = sqlutil.Open(dbProperties); err != nil { diff --git a/federationsender/storage/shared/storage.go b/federationsender/storage/shared/storage.go index 44f06f88f..af9d0d6a3 100644 --- a/federationsender/storage/shared/storage.go +++ b/federationsender/storage/shared/storage.go @@ -29,7 +29,7 @@ import ( type Database struct { DB *sql.DB - Cache caching.FederationEventCache + Cache caching.FederationSenderCache Writer sqlutil.Writer FederationSenderQueuePDUs tables.FederationSenderQueuePDUs FederationSenderQueueEDUs tables.FederationSenderQueueEDUs diff --git a/federationsender/storage/shared/storage_edus.go b/federationsender/storage/shared/storage_edus.go index 529b46aa9..ae1d15118 100644 --- a/federationsender/storage/shared/storage_edus.go +++ b/federationsender/storage/shared/storage_edus.go @@ -69,7 +69,16 @@ func (d *Database) GetNextTransactionEDUs( nids: nids, } - blobs, err := d.FederationSenderQueueJSON.SelectQueueJSON(ctx, txn, nids) + retrieve := make([]int64, 0, len(nids)) + for _, nid := range nids { + if edu, ok := d.Cache.GetFederationSenderQueuedEDU(nid); ok { + edus = append(edus, edu) + } else { + retrieve = append(retrieve, nid) + } + } + + blobs, err := d.FederationSenderQueueJSON.SelectQueueJSON(ctx, txn, retrieve) if err != nil { return fmt.Errorf("SelectQueueJSON: %w", err) } @@ -111,6 +120,7 @@ func (d *Database) CleanEDUs( } if count == 0 { deleteNIDs = append(deleteNIDs, nid) + d.Cache.EvictFederationSenderQueuedEDU(nid) } } diff --git a/federationsender/storage/shared/storage_pdus.go b/federationsender/storage/shared/storage_pdus.go index 5a3e35a7b..09235a5ec 100644 --- a/federationsender/storage/shared/storage_pdus.go +++ b/federationsender/storage/shared/storage_pdus.go @@ -87,7 +87,7 @@ func (d *Database) GetNextTransactionPDUs( retrieve := make([]int64, 0, len(nids)) for _, nid := range nids { - if event, ok := d.Cache.GetFederationEvent(nid); ok { + if event, ok := d.Cache.GetFederationSenderQueuedPDU(nid); ok { events = append(events, event) } else { retrieve = append(retrieve, nid) @@ -105,7 +105,7 @@ func (d *Database) GetNextTransactionPDUs( return fmt.Errorf("json.Unmarshal: %w", err) } events = append(events, &event) - d.Cache.StoreFederationEvent(nid, &event) + d.Cache.StoreFederationSenderQueuedPDU(nid, &event) } return nil @@ -138,7 +138,7 @@ func (d *Database) CleanPDUs( } if count == 0 { deleteNIDs = append(deleteNIDs, nid) - d.Cache.EvictFederationEvent(nid) + d.Cache.EvictFederationSenderQueuedPDU(nid) } } diff --git a/federationsender/storage/sqlite3/storage.go b/federationsender/storage/sqlite3/storage.go index 308c515a4..e66d76909 100644 --- a/federationsender/storage/sqlite3/storage.go +++ b/federationsender/storage/sqlite3/storage.go @@ -35,7 +35,7 @@ type Database struct { } // NewDatabase opens a new database -func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationEventCache) (*Database, error) { +func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationSenderCache) (*Database, error) { var d Database var err error if d.db, err = sqlutil.Open(dbProperties); err != nil { diff --git a/federationsender/storage/storage.go b/federationsender/storage/storage.go index 69ea28222..5462c3523 100644 --- a/federationsender/storage/storage.go +++ b/federationsender/storage/storage.go @@ -26,7 +26,7 @@ import ( ) // NewDatabase opens a new database -func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationEventCache) (Database, error) { +func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationSenderCache) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): return sqlite3.NewDatabase(dbProperties, cache) diff --git a/federationsender/storage/storage_wasm.go b/federationsender/storage/storage_wasm.go index af93f7296..bc52bd9bb 100644 --- a/federationsender/storage/storage_wasm.go +++ b/federationsender/storage/storage_wasm.go @@ -23,7 +23,7 @@ import ( ) // NewDatabase opens a new database -func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationEventCache) (Database, error) { +func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationSenderCache) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): return sqlite3.NewDatabase(dbProperties, cache) diff --git a/internal/caching/cache_federationevents.go b/internal/caching/cache_federationevents.go index ebfbf1095..a48c11fd2 100644 --- a/internal/caching/cache_federationevents.go +++ b/internal/caching/cache_federationevents.go @@ -12,15 +12,19 @@ const ( FederationEventCacheMutable = true // to allow use of Unset only ) -// FederationEventCache contains the subset of functions needed for +// FederationSenderCache contains the subset of functions needed for // a federation event cache. -type FederationEventCache interface { - GetFederationEvent(eventNID int64) (event *gomatrixserverlib.HeaderedEvent, ok bool) - StoreFederationEvent(eventNID int64, event *gomatrixserverlib.HeaderedEvent) - EvictFederationEvent(eventNID int64) +type FederationSenderCache interface { + GetFederationSenderQueuedPDU(eventNID int64) (event *gomatrixserverlib.HeaderedEvent, ok bool) + StoreFederationSenderQueuedPDU(eventNID int64, event *gomatrixserverlib.HeaderedEvent) + EvictFederationSenderQueuedPDU(eventNID int64) + + GetFederationSenderQueuedEDU(eventNID int64) (event *gomatrixserverlib.EDU, ok bool) + StoreFederationSenderQueuedEDU(eventNID int64, event *gomatrixserverlib.EDU) + EvictFederationSenderQueuedEDU(eventNID int64) } -func (c Caches) GetFederationEvent(eventNID int64) (*gomatrixserverlib.HeaderedEvent, bool) { +func (c Caches) GetFederationSenderQueuedPDU(eventNID int64) (*gomatrixserverlib.HeaderedEvent, bool) { key := fmt.Sprintf("%d", eventNID) val, found := c.FederationEvents.Get(key) if found && val != nil { @@ -31,12 +35,33 @@ func (c Caches) GetFederationEvent(eventNID int64) (*gomatrixserverlib.HeaderedE return nil, false } -func (c Caches) StoreFederationEvent(eventNID int64, event *gomatrixserverlib.HeaderedEvent) { +func (c Caches) StoreFederationSenderQueuedPDU(eventNID int64, event *gomatrixserverlib.HeaderedEvent) { key := fmt.Sprintf("%d", eventNID) c.FederationEvents.Set(key, event) } -func (c Caches) EvictFederationEvent(eventNID int64) { +func (c Caches) EvictFederationSenderQueuedPDU(eventNID int64) { + key := fmt.Sprintf("%d", eventNID) + c.FederationEvents.Unset(key) +} + +func (c Caches) GetFederationSenderQueuedEDU(eventNID int64) (*gomatrixserverlib.EDU, bool) { + key := fmt.Sprintf("%d", eventNID) + val, found := c.FederationEvents.Get(key) + if found && val != nil { + if event, ok := val.(*gomatrixserverlib.EDU); ok { + return event, true + } + } + return nil, false +} + +func (c Caches) StoreFederationSenderQueuedEDU(eventNID int64, event *gomatrixserverlib.EDU) { + key := fmt.Sprintf("%d", eventNID) + c.FederationEvents.Set(key, event) +} + +func (c Caches) EvictFederationSenderQueuedEDU(eventNID int64) { key := fmt.Sprintf("%d", eventNID) c.FederationEvents.Unset(key) }