From 3f10d42c1bb096c5b39ee2163cc5f1aed883197d Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Thu, 19 Jan 2023 14:30:18 -0700 Subject: [PATCH] Remove name ghosting when referring to receipts --- federationapi/queue/destinationqueue.go | 20 ++++++------ federationapi/queue/queue.go | 8 ++--- federationapi/storage/interface.go | 4 +-- .../storage/shared/receipt/receipt.go | 5 +++ federationapi/storage/shared/storage_edus.go | 14 ++++---- federationapi/storage/shared/storage_pdus.go | 12 +++---- relayapi/storage/interface.go | 2 +- relayapi/storage/shared/storage.go | 16 +++++----- test/memory_federation_db.go | 32 +++++++++---------- 9 files changed, 59 insertions(+), 54 deletions(-) diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index 67b9f6253..971b74d29 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -70,7 +70,7 @@ type destinationQueue struct { // Send event adds the event to the pending queue for the destination. // If the queue is empty then it starts a background goroutine to // start sending events to that destination. -func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, nid *receipt.Receipt) { +func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, dbReceipt *receipt.Receipt) { if event == nil { logrus.Errorf("attempt to send nil PDU with destination %q", oq.destination) return @@ -84,8 +84,8 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, ni oq.pendingMutex.Lock() if len(oq.pendingPDUs) < maxPDUsInMemory { oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{ - pdu: event, - receipt: nid, + pdu: event, + eventReceipt: dbReceipt, }) } else { oq.overflowed.Store(true) @@ -101,7 +101,7 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, ni // sendEDU adds the EDU event to the pending queue for the destination. // If the queue is empty then it starts a background goroutine to // start sending events to that destination. -func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, nid *receipt.Receipt) { +func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, dbReceipt *receipt.Receipt) { if event == nil { logrus.Errorf("attempt to send nil EDU with destination %q", oq.destination) return @@ -115,8 +115,8 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, nid *receipt.R oq.pendingMutex.Lock() if len(oq.pendingEDUs) < maxEDUsInMemory { oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{ - edu: event, - receipt: nid, + edu: event, + eventReceipt: dbReceipt, }) } else { oq.overflowed.Store(true) @@ -210,10 +210,10 @@ func (oq *destinationQueue) getPendingFromDatabase() { gotPDUs := map[string]struct{}{} gotEDUs := map[string]struct{}{} for _, pdu := range oq.pendingPDUs { - gotPDUs[pdu.receipt.String()] = struct{}{} + gotPDUs[pdu.eventReceipt.String()] = struct{}{} } for _, edu := range oq.pendingEDUs { - gotEDUs[edu.receipt.String()] = struct{}{} + gotEDUs[edu.eventReceipt.String()] = struct{}{} } overflowed := false @@ -518,7 +518,7 @@ func (oq *destinationQueue) createTransaction( // Append the JSON of the event, since this is a json.RawMessage type in the // gomatrixserverlib.Transaction struct t.PDUs = append(t.PDUs, pdu.pdu.JSON()) - pduReceipts = append(pduReceipts, pdu.receipt) + pduReceipts = append(pduReceipts, pdu.eventReceipt) } // Do the same for pending EDUS in the queue. @@ -528,7 +528,7 @@ func (oq *destinationQueue) createTransaction( continue } t.EDUs = append(t.EDUs, *edu.edu) - eduReceipts = append(eduReceipts, edu.receipt) + eduReceipts = append(eduReceipts, edu.eventReceipt) } return t, pduReceipts, eduReceipts diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go index dcb303e79..5d6b8d44c 100644 --- a/federationapi/queue/queue.go +++ b/federationapi/queue/queue.go @@ -138,13 +138,13 @@ func NewOutgoingQueues( } type queuedPDU struct { - receipt *receipt.Receipt - pdu *gomatrixserverlib.HeaderedEvent + dbReceipt *receipt.Receipt + pdu *gomatrixserverlib.HeaderedEvent } type queuedEDU struct { - receipt *receipt.Receipt - edu *gomatrixserverlib.EDU + dbReceipt *receipt.Receipt + edu *gomatrixserverlib.EDU } func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *destinationQueue { diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index b73bd2841..14201e882 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -40,8 +40,8 @@ type Database interface { GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent, err error) GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*receipt.Receipt]*gomatrixserverlib.EDU, err error) - AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *receipt.Receipt) error - AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *receipt.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error + AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, dbReceipt *receipt.Receipt) error + AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, dbReceipt *receipt.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*receipt.Receipt) error CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*receipt.Receipt) error diff --git a/federationapi/storage/shared/receipt/receipt.go b/federationapi/storage/shared/receipt/receipt.go index 33a71a1a8..b347269c1 100644 --- a/federationapi/storage/shared/receipt/receipt.go +++ b/federationapi/storage/shared/receipt/receipt.go @@ -20,6 +20,11 @@ package receipt import "fmt" +// Receipt is a wrapper type used to represent a nid that corresponds to a unique row entry +// in some database table. +// The internal nid value cannot be modified after a Receipt has been created. +// This guarantees a receipt will always refer to the same table entry that it was created +// to represent. type Receipt struct { nid int64 } diff --git a/federationapi/storage/shared/storage_edus.go b/federationapi/storage/shared/storage_edus.go index ae7e884b8..cff1ade6f 100644 --- a/federationapi/storage/shared/storage_edus.go +++ b/federationapi/storage/shared/storage_edus.go @@ -42,7 +42,7 @@ var defaultExpireEDUTypes = map[string]time.Duration{ func (d *Database) AssociateEDUWithDestinations( ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, - receipt *receipt.Receipt, + dbReceipt *receipt.Receipt, eduType string, expireEDUTypes map[string]time.Duration, ) error { @@ -63,12 +63,12 @@ func (d *Database) AssociateEDUWithDestinations( var err error for destination := range destinations { err = d.FederationQueueEDUs.InsertQueueEDU( - ctx, // context - txn, // SQL transaction - eduType, // EDU type for coalescing - destination, // destination server name - receipt.GetNID(), // NID from the federationapi_queue_json table - expiresAt, // The timestamp this EDU will expire + ctx, // context + txn, // SQL transaction + eduType, // EDU type for coalescing + destination, // destination server name + dbReceipt.GetNID(), // NID from the federationapi_queue_json table + expiresAt, // The timestamp this EDU will expire ) } return err diff --git a/federationapi/storage/shared/storage_pdus.go b/federationapi/storage/shared/storage_pdus.go index 0f5844520..854e00553 100644 --- a/federationapi/storage/shared/storage_pdus.go +++ b/federationapi/storage/shared/storage_pdus.go @@ -31,17 +31,17 @@ import ( func (d *Database) AssociatePDUWithDestinations( ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, - receipt *receipt.Receipt, + dbReceipt *receipt.Receipt, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { var err error for destination := range destinations { err = d.FederationQueuePDUs.InsertQueuePDU( - ctx, // context - txn, // SQL transaction - "", // transaction ID - destination, // destination server name - receipt.GetNID(), // NID from the federationapi_queue_json table + ctx, // context + txn, // SQL transaction + "", // transaction ID + destination, // destination server name + dbReceipt.GetNID(), // NID from the federationapi_queue_json table ) } return err diff --git a/relayapi/storage/interface.go b/relayapi/storage/interface.go index d39b89aae..c754c21a3 100644 --- a/relayapi/storage/interface.go +++ b/relayapi/storage/interface.go @@ -23,7 +23,7 @@ import ( type Database interface { StoreTransaction(ctx context.Context, txn gomatrixserverlib.Transaction) (*receipt.Receipt, error) - AssociateTransactionWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.UserID]struct{}, transactionID gomatrixserverlib.TransactionID, receipt *receipt.Receipt) error + AssociateTransactionWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.UserID]struct{}, transactionID gomatrixserverlib.TransactionID, dbReceipt *receipt.Receipt) error CleanTransactions(ctx context.Context, userID gomatrixserverlib.UserID, receipts []*receipt.Receipt) error GetTransaction(ctx context.Context, userID gomatrixserverlib.UserID) (*gomatrixserverlib.Transaction, *receipt.Receipt, error) GetTransactionCount(ctx context.Context, userID gomatrixserverlib.UserID) (int64, error) diff --git a/relayapi/storage/shared/storage.go b/relayapi/storage/shared/storage.go index 4b7588b90..040ae0c06 100644 --- a/relayapi/storage/shared/storage.go +++ b/relayapi/storage/shared/storage.go @@ -55,15 +55,15 @@ func (d *Database) StoreTransaction( return nil, fmt.Errorf("d.insertQueueJSON: %w", err) } - receipt := receipt.NewReceipt(nid) - return &receipt, nil + newReceipt := receipt.NewReceipt(nid) + return &newReceipt, nil } func (d *Database) AssociateTransactionWithDestinations( ctx context.Context, destinations map[gomatrixserverlib.UserID]struct{}, transactionID gomatrixserverlib.TransactionID, - receipt *receipt.Receipt, + dbReceipt *receipt.Receipt, ) error { err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { var lastErr error @@ -74,7 +74,7 @@ func (d *Database) AssociateTransactionWithDestinations( txn, transactionID, destination.Domain(), - receipt.GetNID(), + dbReceipt.GetNID(), ) if err != nil { lastErr = fmt.Errorf("d.insertQueueEntry: %w", err) @@ -92,8 +92,8 @@ func (d *Database) CleanTransactions( receipts []*receipt.Receipt, ) error { nids := make([]int64, len(receipts)) - for i, receipt := range receipts { - nids[i] = receipt.GetNID() + for i, dbReceipt := range receipts { + nids[i] = dbReceipt.GetNID() } err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -145,8 +145,8 @@ func (d *Database) GetTransaction( return nil, nil, fmt.Errorf("Unmarshal transaction: %w", err) } - receipt := receipt.NewReceipt(firstNID) - return transaction, &receipt, nil + newReceipt := receipt.NewReceipt(firstNID) + return transaction, &newReceipt, nil } func (d *Database) GetTransactionCount( diff --git a/test/memory_federation_db.go b/test/memory_federation_db.go index 74eb2e7f7..99ec2abdf 100644 --- a/test/memory_federation_db.go +++ b/test/memory_federation_db.go @@ -97,9 +97,9 @@ func (d *InMemoryFederationDatabase) GetPendingPDUs( pduCount := 0 pdus = make(map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent) if receipts, ok := d.associatedPDUs[serverName]; ok { - for nid := range receipts { - if event, ok := d.pendingPDUs[nid]; ok { - pdus[nid] = event + for dbReceipt := range receipts { + if event, ok := d.pendingPDUs[dbReceipt]; ok { + pdus[dbReceipt] = event pduCount++ if pduCount == limit { break @@ -121,9 +121,9 @@ func (d *InMemoryFederationDatabase) GetPendingEDUs( eduCount := 0 edus = make(map[*receipt.Receipt]*gomatrixserverlib.EDU) if receipts, ok := d.associatedEDUs[serverName]; ok { - for nid := range receipts { - if event, ok := d.pendingEDUs[nid]; ok { - edus[nid] = event + for dbReceipt := range receipts { + if event, ok := d.pendingEDUs[dbReceipt]; ok { + edus[dbReceipt] = event eduCount++ if eduCount == limit { break @@ -137,17 +137,17 @@ func (d *InMemoryFederationDatabase) GetPendingEDUs( func (d *InMemoryFederationDatabase) AssociatePDUWithDestinations( ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, - nid *receipt.Receipt, + dbReceipt *receipt.Receipt, ) error { d.dbMutex.Lock() defer d.dbMutex.Unlock() - if _, ok := d.pendingPDUs[nid]; ok { + if _, ok := d.pendingPDUs[dbReceipt]; ok { for destination := range destinations { if _, ok := d.associatedPDUs[destination]; !ok { d.associatedPDUs[destination] = make(map[*receipt.Receipt]struct{}) } - d.associatedPDUs[destination][nid] = struct{}{} + d.associatedPDUs[destination][dbReceipt] = struct{}{} } return nil @@ -159,19 +159,19 @@ func (d *InMemoryFederationDatabase) AssociatePDUWithDestinations( func (d *InMemoryFederationDatabase) AssociateEDUWithDestinations( ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, - nid *receipt.Receipt, + dbReceipt *receipt.Receipt, eduType string, expireEDUTypes map[string]time.Duration, ) error { d.dbMutex.Lock() defer d.dbMutex.Unlock() - if _, ok := d.pendingEDUs[nid]; ok { + if _, ok := d.pendingEDUs[dbReceipt]; ok { for destination := range destinations { if _, ok := d.associatedEDUs[destination]; !ok { d.associatedEDUs[destination] = make(map[*receipt.Receipt]struct{}) } - d.associatedEDUs[destination][nid] = struct{}{} + d.associatedEDUs[destination][dbReceipt] = struct{}{} } return nil @@ -189,8 +189,8 @@ func (d *InMemoryFederationDatabase) CleanPDUs( defer d.dbMutex.Unlock() if pdus, ok := d.associatedPDUs[serverName]; ok { - for _, nid := range receipts { - delete(pdus, nid) + for _, dbReceipt := range receipts { + delete(pdus, dbReceipt) } } @@ -206,8 +206,8 @@ func (d *InMemoryFederationDatabase) CleanEDUs( defer d.dbMutex.Unlock() if edus, ok := d.associatedEDUs[serverName]; ok { - for _, nid := range receipts { - delete(edus, nid) + for _, dbReceipt := range receipts { + delete(edus, dbReceipt) } }