Remove name ghosting when referring to receipts

This commit is contained in:
Devon Hudson 2023-01-19 14:30:18 -07:00
parent b3af289136
commit 3f10d42c1b
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
9 changed files with 59 additions and 54 deletions

View file

@ -70,7 +70,7 @@ type destinationQueue struct {
// Send event adds the event to the pending queue for the destination.
// If the queue is empty then it starts a background goroutine to
// start sending events to that destination.
func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, nid *receipt.Receipt) {
func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, dbReceipt *receipt.Receipt) {
if event == nil {
logrus.Errorf("attempt to send nil PDU with destination %q", oq.destination)
return
@ -84,8 +84,8 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, ni
oq.pendingMutex.Lock()
if len(oq.pendingPDUs) < maxPDUsInMemory {
oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{
pdu: event,
receipt: nid,
pdu: event,
eventReceipt: dbReceipt,
})
} else {
oq.overflowed.Store(true)
@ -101,7 +101,7 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, ni
// sendEDU adds the EDU event to the pending queue for the destination.
// If the queue is empty then it starts a background goroutine to
// start sending events to that destination.
func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, nid *receipt.Receipt) {
func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, dbReceipt *receipt.Receipt) {
if event == nil {
logrus.Errorf("attempt to send nil EDU with destination %q", oq.destination)
return
@ -115,8 +115,8 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, nid *receipt.R
oq.pendingMutex.Lock()
if len(oq.pendingEDUs) < maxEDUsInMemory {
oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{
edu: event,
receipt: nid,
edu: event,
eventReceipt: dbReceipt,
})
} else {
oq.overflowed.Store(true)
@ -210,10 +210,10 @@ func (oq *destinationQueue) getPendingFromDatabase() {
gotPDUs := map[string]struct{}{}
gotEDUs := map[string]struct{}{}
for _, pdu := range oq.pendingPDUs {
gotPDUs[pdu.receipt.String()] = struct{}{}
gotPDUs[pdu.eventReceipt.String()] = struct{}{}
}
for _, edu := range oq.pendingEDUs {
gotEDUs[edu.receipt.String()] = struct{}{}
gotEDUs[edu.eventReceipt.String()] = struct{}{}
}
overflowed := false
@ -518,7 +518,7 @@ func (oq *destinationQueue) createTransaction(
// Append the JSON of the event, since this is a json.RawMessage type in the
// gomatrixserverlib.Transaction struct
t.PDUs = append(t.PDUs, pdu.pdu.JSON())
pduReceipts = append(pduReceipts, pdu.receipt)
pduReceipts = append(pduReceipts, pdu.eventReceipt)
}
// Do the same for pending EDUS in the queue.
@ -528,7 +528,7 @@ func (oq *destinationQueue) createTransaction(
continue
}
t.EDUs = append(t.EDUs, *edu.edu)
eduReceipts = append(eduReceipts, edu.receipt)
eduReceipts = append(eduReceipts, edu.eventReceipt)
}
return t, pduReceipts, eduReceipts

View file

@ -138,13 +138,13 @@ func NewOutgoingQueues(
}
type queuedPDU struct {
receipt *receipt.Receipt
pdu *gomatrixserverlib.HeaderedEvent
dbReceipt *receipt.Receipt
pdu *gomatrixserverlib.HeaderedEvent
}
type queuedEDU struct {
receipt *receipt.Receipt
edu *gomatrixserverlib.EDU
dbReceipt *receipt.Receipt
edu *gomatrixserverlib.EDU
}
func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *destinationQueue {

View file

@ -40,8 +40,8 @@ type Database interface {
GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent, err error)
GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*receipt.Receipt]*gomatrixserverlib.EDU, err error)
AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *receipt.Receipt) error
AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *receipt.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error
AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, dbReceipt *receipt.Receipt) error
AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, dbReceipt *receipt.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error
CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*receipt.Receipt) error
CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*receipt.Receipt) error

View file

@ -20,6 +20,11 @@ package receipt
import "fmt"
// Receipt is a wrapper type used to represent a nid that corresponds to a unique row entry
// in some database table.
// The internal nid value cannot be modified after a Receipt has been created.
// This guarantees a receipt will always refer to the same table entry that it was created
// to represent.
type Receipt struct {
nid int64
}

View file

@ -42,7 +42,7 @@ var defaultExpireEDUTypes = map[string]time.Duration{
func (d *Database) AssociateEDUWithDestinations(
ctx context.Context,
destinations map[gomatrixserverlib.ServerName]struct{},
receipt *receipt.Receipt,
dbReceipt *receipt.Receipt,
eduType string,
expireEDUTypes map[string]time.Duration,
) error {
@ -63,12 +63,12 @@ func (d *Database) AssociateEDUWithDestinations(
var err error
for destination := range destinations {
err = d.FederationQueueEDUs.InsertQueueEDU(
ctx, // context
txn, // SQL transaction
eduType, // EDU type for coalescing
destination, // destination server name
receipt.GetNID(), // NID from the federationapi_queue_json table
expiresAt, // The timestamp this EDU will expire
ctx, // context
txn, // SQL transaction
eduType, // EDU type for coalescing
destination, // destination server name
dbReceipt.GetNID(), // NID from the federationapi_queue_json table
expiresAt, // The timestamp this EDU will expire
)
}
return err

View file

@ -31,17 +31,17 @@ import (
func (d *Database) AssociatePDUWithDestinations(
ctx context.Context,
destinations map[gomatrixserverlib.ServerName]struct{},
receipt *receipt.Receipt,
dbReceipt *receipt.Receipt,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var err error
for destination := range destinations {
err = d.FederationQueuePDUs.InsertQueuePDU(
ctx, // context
txn, // SQL transaction
"", // transaction ID
destination, // destination server name
receipt.GetNID(), // NID from the federationapi_queue_json table
ctx, // context
txn, // SQL transaction
"", // transaction ID
destination, // destination server name
dbReceipt.GetNID(), // NID from the federationapi_queue_json table
)
}
return err

View file

@ -23,7 +23,7 @@ import (
type Database interface {
StoreTransaction(ctx context.Context, txn gomatrixserverlib.Transaction) (*receipt.Receipt, error)
AssociateTransactionWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.UserID]struct{}, transactionID gomatrixserverlib.TransactionID, receipt *receipt.Receipt) error
AssociateTransactionWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.UserID]struct{}, transactionID gomatrixserverlib.TransactionID, dbReceipt *receipt.Receipt) error
CleanTransactions(ctx context.Context, userID gomatrixserverlib.UserID, receipts []*receipt.Receipt) error
GetTransaction(ctx context.Context, userID gomatrixserverlib.UserID) (*gomatrixserverlib.Transaction, *receipt.Receipt, error)
GetTransactionCount(ctx context.Context, userID gomatrixserverlib.UserID) (int64, error)

View file

@ -55,15 +55,15 @@ func (d *Database) StoreTransaction(
return nil, fmt.Errorf("d.insertQueueJSON: %w", err)
}
receipt := receipt.NewReceipt(nid)
return &receipt, nil
newReceipt := receipt.NewReceipt(nid)
return &newReceipt, nil
}
func (d *Database) AssociateTransactionWithDestinations(
ctx context.Context,
destinations map[gomatrixserverlib.UserID]struct{},
transactionID gomatrixserverlib.TransactionID,
receipt *receipt.Receipt,
dbReceipt *receipt.Receipt,
) error {
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var lastErr error
@ -74,7 +74,7 @@ func (d *Database) AssociateTransactionWithDestinations(
txn,
transactionID,
destination.Domain(),
receipt.GetNID(),
dbReceipt.GetNID(),
)
if err != nil {
lastErr = fmt.Errorf("d.insertQueueEntry: %w", err)
@ -92,8 +92,8 @@ func (d *Database) CleanTransactions(
receipts []*receipt.Receipt,
) error {
nids := make([]int64, len(receipts))
for i, receipt := range receipts {
nids[i] = receipt.GetNID()
for i, dbReceipt := range receipts {
nids[i] = dbReceipt.GetNID()
}
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
@ -145,8 +145,8 @@ func (d *Database) GetTransaction(
return nil, nil, fmt.Errorf("Unmarshal transaction: %w", err)
}
receipt := receipt.NewReceipt(firstNID)
return transaction, &receipt, nil
newReceipt := receipt.NewReceipt(firstNID)
return transaction, &newReceipt, nil
}
func (d *Database) GetTransactionCount(

View file

@ -97,9 +97,9 @@ func (d *InMemoryFederationDatabase) GetPendingPDUs(
pduCount := 0
pdus = make(map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent)
if receipts, ok := d.associatedPDUs[serverName]; ok {
for nid := range receipts {
if event, ok := d.pendingPDUs[nid]; ok {
pdus[nid] = event
for dbReceipt := range receipts {
if event, ok := d.pendingPDUs[dbReceipt]; ok {
pdus[dbReceipt] = event
pduCount++
if pduCount == limit {
break
@ -121,9 +121,9 @@ func (d *InMemoryFederationDatabase) GetPendingEDUs(
eduCount := 0
edus = make(map[*receipt.Receipt]*gomatrixserverlib.EDU)
if receipts, ok := d.associatedEDUs[serverName]; ok {
for nid := range receipts {
if event, ok := d.pendingEDUs[nid]; ok {
edus[nid] = event
for dbReceipt := range receipts {
if event, ok := d.pendingEDUs[dbReceipt]; ok {
edus[dbReceipt] = event
eduCount++
if eduCount == limit {
break
@ -137,17 +137,17 @@ func (d *InMemoryFederationDatabase) GetPendingEDUs(
func (d *InMemoryFederationDatabase) AssociatePDUWithDestinations(
ctx context.Context,
destinations map[gomatrixserverlib.ServerName]struct{},
nid *receipt.Receipt,
dbReceipt *receipt.Receipt,
) error {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
if _, ok := d.pendingPDUs[nid]; ok {
if _, ok := d.pendingPDUs[dbReceipt]; ok {
for destination := range destinations {
if _, ok := d.associatedPDUs[destination]; !ok {
d.associatedPDUs[destination] = make(map[*receipt.Receipt]struct{})
}
d.associatedPDUs[destination][nid] = struct{}{}
d.associatedPDUs[destination][dbReceipt] = struct{}{}
}
return nil
@ -159,19 +159,19 @@ func (d *InMemoryFederationDatabase) AssociatePDUWithDestinations(
func (d *InMemoryFederationDatabase) AssociateEDUWithDestinations(
ctx context.Context,
destinations map[gomatrixserverlib.ServerName]struct{},
nid *receipt.Receipt,
dbReceipt *receipt.Receipt,
eduType string,
expireEDUTypes map[string]time.Duration,
) error {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
if _, ok := d.pendingEDUs[nid]; ok {
if _, ok := d.pendingEDUs[dbReceipt]; ok {
for destination := range destinations {
if _, ok := d.associatedEDUs[destination]; !ok {
d.associatedEDUs[destination] = make(map[*receipt.Receipt]struct{})
}
d.associatedEDUs[destination][nid] = struct{}{}
d.associatedEDUs[destination][dbReceipt] = struct{}{}
}
return nil
@ -189,8 +189,8 @@ func (d *InMemoryFederationDatabase) CleanPDUs(
defer d.dbMutex.Unlock()
if pdus, ok := d.associatedPDUs[serverName]; ok {
for _, nid := range receipts {
delete(pdus, nid)
for _, dbReceipt := range receipts {
delete(pdus, dbReceipt)
}
}
@ -206,8 +206,8 @@ func (d *InMemoryFederationDatabase) CleanEDUs(
defer d.dbMutex.Unlock()
if edus, ok := d.associatedEDUs[serverName]; ok {
for _, nid := range receipts {
delete(edus, nid)
for _, dbReceipt := range receipts {
delete(edus, dbReceipt)
}
}