diff --git a/federationsender/queue/destinationqueue.go b/federationsender/queue/destinationqueue.go index d4e28a9c4..a1cc7992a 100644 --- a/federationsender/queue/destinationqueue.go +++ b/federationsender/queue/destinationqueue.go @@ -69,6 +69,10 @@ type destinationQueue struct { // If the queue is empty then it starts a background goroutine to // start sending events to that destination. func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, receipt *shared.Receipt) { + if event == nil { + log.Errorf("attempt to send nil PDU with destination %q", oq.destination) + return + } // Create a transaction ID. We'll either do this if we don't have // one made up yet, or if we've exceeded the number of maximum // events allowed in a single tranaction. We'll reset the counter @@ -116,6 +120,10 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re // If the queue is empty then it starts a background goroutine to // start sending events to that destination. func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *shared.Receipt) { + if event == nil { + log.Errorf("attempt to send nil EDU with destination %q", oq.destination) + return + } // Create a database entry that associates the given PDU NID with // this destination queue. We'll then be able to retrieve the PDU // later. @@ -370,6 +378,9 @@ func (oq *destinationQueue) nextTransaction( // Go through PDUs that we retrieved from the database, if any, // and add them into the transaction. for _, pdu := range pdus { + if pdu.pdu == nil { + continue + } // Append the JSON of the event, since this is a json.RawMessage type in the // gomatrixserverlib.Transaction struct t.PDUs = append(t.PDUs, pdu.pdu.JSON()) @@ -378,6 +389,9 @@ func (oq *destinationQueue) nextTransaction( // Do the same for pending EDUS in the queue. for _, edu := range edus { + if edu.edu == nil { + continue + } t.EDUs = append(t.EDUs, *edu.edu) eduReceipts = append(pduReceipts, edu.receipt) } diff --git a/federationsender/storage/shared/storage_edus.go b/federationsender/storage/shared/storage_edus.go index 86fee1a37..2b9e2622e 100644 --- a/federationsender/storage/shared/storage_edus.go +++ b/federationsender/storage/shared/storage_edus.go @@ -52,42 +52,36 @@ func (d *Database) GetPendingEDUs( ctx context.Context, serverName gomatrixserverlib.ServerName, limit int, -) ( - edus map[*Receipt]*gomatrixserverlib.EDU, - err error, -) { - edus = make(map[*Receipt]*gomatrixserverlib.EDU) - err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - nids, err := d.FederationSenderQueueEDUs.SelectQueueEDUs(ctx, txn, serverName, limit) - if err != nil { - return fmt.Errorf("SelectQueueEDUs: %w", err) - } +) (map[*Receipt]*gomatrixserverlib.EDU, error) { + edus := make(map[*Receipt]*gomatrixserverlib.EDU) + nids, err := d.FederationSenderQueueEDUs.SelectQueueEDUs(ctx, nil, serverName, limit) + if err != nil { + return nil, fmt.Errorf("SelectQueueEDUs: %w", err) + } - retrieve := make([]int64, 0, len(nids)) - for _, nid := range nids { - if edu, ok := d.Cache.GetFederationSenderQueuedEDU(nid); ok { - edus[&Receipt{nid}] = edu - } else { - retrieve = append(retrieve, nid) - } + retrieve := make([]int64, 0, len(nids)) + for _, nid := range nids { + if edu, ok := d.Cache.GetFederationSenderQueuedEDU(nid); ok { + edus[&Receipt{nid}] = edu + } else { + retrieve = append(retrieve, nid) } + } - blobs, err := d.FederationSenderQueueJSON.SelectQueueJSON(ctx, txn, retrieve) - if err != nil { - return fmt.Errorf("SelectQueueJSON: %w", err) + blobs, err := d.FederationSenderQueueJSON.SelectQueueJSON(ctx, nil, retrieve) + if err != nil { + return nil, fmt.Errorf("SelectQueueJSON: %w", err) + } + + for nid, blob := range blobs { + var event gomatrixserverlib.EDU + if err := json.Unmarshal(blob, &event); err != nil { + return nil, fmt.Errorf("json.Unmarshal: %w", err) } + edus[&Receipt{nid}] = &event + } - for nid, blob := range blobs { - var event gomatrixserverlib.EDU - if err := json.Unmarshal(blob, &event); err != nil { - return fmt.Errorf("json.Unmarshal: %w", err) - } - edus[&Receipt{nid}] = &event - } - - return nil - }) - return + return edus, nil } // CleanEDUs cleans up all specified EDUs. This is done when a diff --git a/federationsender/storage/shared/storage_pdus.go b/federationsender/storage/shared/storage_pdus.go index bc298a905..a9c4c447b 100644 --- a/federationsender/storage/shared/storage_pdus.go +++ b/federationsender/storage/shared/storage_pdus.go @@ -53,48 +53,42 @@ func (d *Database) GetPendingPDUs( ctx context.Context, serverName gomatrixserverlib.ServerName, limit int, -) ( - events map[*Receipt]*gomatrixserverlib.HeaderedEvent, - err error, -) { +) (map[*Receipt]*gomatrixserverlib.HeaderedEvent, error) { // Strictly speaking this doesn't need to be using the writer // since we are only performing selects, but since we don't have // a guarantee of transactional isolation, it's actually useful // to know in SQLite mode that nothing else is trying to modify // the database. - events = make(map[*Receipt]*gomatrixserverlib.HeaderedEvent) - err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - nids, err := d.FederationSenderQueuePDUs.SelectQueuePDUs(ctx, txn, serverName, limit) - if err != nil { - return fmt.Errorf("SelectQueuePDUs: %w", err) - } + events := make(map[*Receipt]*gomatrixserverlib.HeaderedEvent) + nids, err := d.FederationSenderQueuePDUs.SelectQueuePDUs(ctx, nil, serverName, limit) + if err != nil { + return nil, fmt.Errorf("SelectQueuePDUs: %w", err) + } - retrieve := make([]int64, 0, len(nids)) - for _, nid := range nids { - if event, ok := d.Cache.GetFederationSenderQueuedPDU(nid); ok { - events[&Receipt{nid}] = event - } else { - retrieve = append(retrieve, nid) - } + retrieve := make([]int64, 0, len(nids)) + for _, nid := range nids { + if event, ok := d.Cache.GetFederationSenderQueuedPDU(nid); ok { + events[&Receipt{nid}] = event + } else { + retrieve = append(retrieve, nid) } + } - blobs, err := d.FederationSenderQueueJSON.SelectQueueJSON(ctx, txn, retrieve) - if err != nil { - return fmt.Errorf("SelectQueueJSON: %w", err) + blobs, err := d.FederationSenderQueueJSON.SelectQueueJSON(ctx, nil, retrieve) + if err != nil { + return nil, fmt.Errorf("SelectQueueJSON: %w", err) + } + + for nid, blob := range blobs { + var event gomatrixserverlib.HeaderedEvent + if err := json.Unmarshal(blob, &event); err != nil { + return nil, fmt.Errorf("json.Unmarshal: %w", err) } + events[&Receipt{nid}] = &event + d.Cache.StoreFederationSenderQueuedPDU(nid, &event) + } - for nid, blob := range blobs { - var event gomatrixserverlib.HeaderedEvent - if err := json.Unmarshal(blob, &event); err != nil { - return fmt.Errorf("json.Unmarshal: %w", err) - } - events[&Receipt{nid}] = &event - d.Cache.StoreFederationSenderQueuedPDU(nid, &event) - } - - return nil - }) - return + return events, nil } // CleanTransactionPDUs cleans up all associated events for a