From 339ea3d71121b6d64ccc264c41501fe1cf2c1315 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 30 Jun 2020 13:27:49 +0100 Subject: [PATCH] Put things into database (postgres for now) --- federationsender/queue/destinationqueue.go | 149 +++++++++------- federationsender/queue/queue.go | 10 +- federationsender/storage/interface.go | 6 +- .../storage/postgres/queue_json_table.go | 113 ++++++++++++ .../storage/postgres/queue_pdus_table.go | 141 +++++++++++++++ .../storage/postgres/queue_retry_table.go | 167 ------------------ federationsender/storage/postgres/storage.go | 103 ++++++++--- 7 files changed, 430 insertions(+), 259 deletions(-) create mode 100644 federationsender/storage/postgres/queue_json_table.go create mode 100644 federationsender/storage/postgres/queue_pdus_table.go delete mode 100644 federationsender/storage/postgres/queue_retry_table.go diff --git a/federationsender/queue/destinationqueue.go b/federationsender/queue/destinationqueue.go index d5144d562..35c59edff 100644 --- a/federationsender/queue/destinationqueue.go +++ b/federationsender/queue/destinationqueue.go @@ -30,28 +30,31 @@ import ( "go.uber.org/atomic" ) +const maxPDUsPerTransaction = 50 + // destinationQueue is a queue of events for a single destination. // It is responsible for sending the events to the destination and // ensures that only one request is in flight to a given destination // at a time. type destinationQueue struct { - db storage.Database - signing *SigningInfo - rsAPI api.RoomserverInternalAPI - client *gomatrixserverlib.FederationClient // federation client - origin gomatrixserverlib.ServerName // origin of requests - destination gomatrixserverlib.ServerName // destination of requests - running atomic.Bool // is the queue worker running? - backingOff atomic.Bool // true if we're backing off - statistics *types.ServerStatistics // statistics about this remote server - incomingPDUs chan *gomatrixserverlib.HeaderedEvent // PDUs to send - incomingEDUs chan *gomatrixserverlib.EDU // EDUs to send - incomingInvites chan *gomatrixserverlib.InviteV2Request // invites to send - transactionID gomatrixserverlib.TransactionID // last transaction ID - pendingPDUs []*gomatrixserverlib.HeaderedEvent // owned by backgroundSend - pendingEDUs []*gomatrixserverlib.EDU // owned by backgroundSend - pendingInvites []*gomatrixserverlib.InviteV2Request // owned by backgroundSend - retryServerCh chan bool // interrupts backoff + db storage.Database + signing *SigningInfo + rsAPI api.RoomserverInternalAPI + client *gomatrixserverlib.FederationClient // federation client + origin gomatrixserverlib.ServerName // origin of requests + destination gomatrixserverlib.ServerName // destination of requests + running atomic.Bool // is the queue worker running? + backingOff atomic.Bool // true if we're backing off + statistics *types.ServerStatistics // statistics about this remote server + incomingPDUs chan struct{} // signal that there are PDUs waiting + incomingInvites chan *gomatrixserverlib.InviteV2Request // invites to send + incomingEDUs chan *gomatrixserverlib.EDU // EDUs to send + transactionID gomatrixserverlib.TransactionID // last transaction ID + transactionCount int // how many events in this transaction so far + pendingPDUs []*gomatrixserverlib.HeaderedEvent // owned by backgroundSend + pendingEDUs []*gomatrixserverlib.EDU // owned by backgroundSend + pendingInvites []*gomatrixserverlib.InviteV2Request // owned by backgroundSend + retryServerCh chan bool // interrupts backoff } // retry will clear the blacklist state and attempt to send built up events to the server, @@ -81,15 +84,42 @@ func (oq *destinationQueue) retry() { // Send event adds the event to the pending queue for the destination. // If the queue is empty then it starts a background goroutine to // start sending events to that destination. -func (oq *destinationQueue) sendEvent(ev *gomatrixserverlib.HeaderedEvent) { +func (oq *destinationQueue) sendEvent(nid int64) { if oq.statistics.Blacklisted() { // If the destination is blacklisted then drop the event. return } + // Create a transaction ID. We'll either do this if we don't have + // one made up yet, or if we've exceeded the number of maximum + // events allowed in a single tranaction. We'll reset the counter + // when we do. + if oq.transactionID == "" || oq.transactionCount >= maxPDUsPerTransaction { + now := gomatrixserverlib.AsTimestamp(time.Now()) + oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount())) + oq.transactionCount = 0 + } + // Create a database entry that associates the given PDU NID with + // this destination queue. We'll then be able to retrieve the PDU + // later. + if err := oq.db.AssociatePDUWithDestination( + context.TODO(), + oq.transactionID, // the current transaction ID + oq.destination, // the destination server name + []int64{nid}, // NID from federationsender_queue_json table + ); err != nil { + log.WithError(err).Errorf("failed to associate PDU with ID %d with destination %q", oq.destination) + return + } + // We've successfully added a PDU to the transaction so increase + // the counter. + oq.transactionCount++ + // If the queue isn't running at this point then start it. if !oq.running.Load() { go oq.backgroundSend() } - oq.incomingPDUs <- ev + // Signal that we've sent a new PDU. This will cause the queue to + // wake up if it's asleep. + oq.incomingPDUs <- struct{}{} } // sendEDU adds the EDU event to the pending queue for the destination. @@ -131,21 +161,32 @@ func (oq *destinationQueue) backgroundSend() { defer oq.running.Store(false) for { + // For now we don't know the next transaction ID that we'll + // pluck from the database. + transactionID := gomatrixserverlib.TransactionID("") + + // Check to see if there are any pending PDUs in the database. + // If we haven't reached the PDU limit yet then retrieve those + // events so that they can be added into this transaction. + if len(oq.pendingPDUs) < maxPDUsPerTransaction { + txid, pdus, err := oq.db.GetNextTransactionPDUs( + context.TODO(), // context + oq.destination, // server name + maxPDUsPerTransaction-len(oq.pendingPDUs), // how many events to retrieve + ) + if err != nil { + log.WithError(err).Errorf("failed to get next transaction PDUs for server %q", oq.destination) + continue + } + transactionID = txid + oq.pendingPDUs = append(oq.pendingPDUs, pdus...) + } + // Wait either for incoming events, or until we hit an // idle timeout. select { - case pdu := <-oq.incomingPDUs: - // Ordering of PDUs is important so we add them to the end - // of the queue and they will all be added to transactions - // in order. - oq.pendingPDUs = append(oq.pendingPDUs, pdu) - // If there are any more things waiting in the channel queue - // then read them. This is safe because we guarantee only - // having one goroutine per destination queue, so the channel - // isn't being consumed anywhere else. - for len(oq.incomingPDUs) > 0 { - oq.pendingPDUs = append(oq.pendingPDUs, <-oq.incomingPDUs) - } + case <-oq.incomingPDUs: + // There are new PDUs waiting in the database. case edu := <-oq.incomingEDUs: // Likewise for EDUs, although we should probably not try // too hard with some EDUs (like typing notifications) after @@ -202,40 +243,20 @@ func (oq *destinationQueue) backgroundSend() { // If we have pending PDUs or EDUs then construct a transaction. if numPDUs > 0 || numEDUs > 0 { - // Generate a transaction ID. - if oq.transactionID == "" { + // If we haven't got a transaction ID then we should generate + // one. Ideally we'd know this already because something queued + // in the database would give us one, but if we're dealing with + // EDUs alone, we won't go via the database so we'll make one. + if transactionID == "" { now := gomatrixserverlib.AsTimestamp(time.Now()) - oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount())) + transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount())) } // Try sending the next transaction and see what happens. - transaction, terr := oq.nextTransaction(oq.transactionID, oq.pendingPDUs, oq.pendingEDUs, oq.statistics.SuccessCount()) + transaction, terr := oq.nextTransaction(transactionID, oq.pendingPDUs, oq.pendingEDUs, oq.statistics.SuccessCount()) if terr != nil { // We failed to send the transaction. - giveUp := oq.statistics.Failure() - // TODO: commit the transaction to the database - if terr = oq.db.StoreFailedPDUs( - context.TODO(), - oq.transactionID, - oq.destination, - oq.pendingPDUs, - ); terr != nil { - // We failed to persist the events to the database for some - // reason, so we'll keep them in memory for now. Hopefully - // it's a temporary condition but log it. - logrus.WithError(terr).Errorf("Failed to persist failed sends for server %q to database", oq.destination) - } else { - // Reallocate so that the underlying arrays can be GC'd, as - // opposed to growing forever. - for i := 0; i < numPDUs; i++ { - oq.pendingPDUs[i] = nil - } - oq.pendingPDUs = append( - []*gomatrixserverlib.HeaderedEvent{}, - oq.pendingPDUs[numPDUs:]..., - ) - } - if giveUp { + if giveUp := oq.statistics.Failure(); giveUp { // It's been suggested that we should give up because // the backoff has exceeded a maximum allowable value. return @@ -261,6 +282,14 @@ func (oq *destinationQueue) backgroundSend() { []*gomatrixserverlib.EDU{}, oq.pendingEDUs[numEDUs:]..., ) + // Clean up the transaction in the database. + if err := oq.db.CleanTransactionPDUs( + context.TODO(), + oq.destination, + transactionID, + ); err != nil { + log.WithError(err).Errorf("failed to clean transaction %q for server %q", transactionID, oq.destination) + } } } @@ -308,8 +337,6 @@ func (oq *destinationQueue) nextTransaction( t.Destination = oq.destination t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now()) - oq.transactionID = t.TransactionID - for _, pdu := range pendingPDUs { // Append the JSON of the event, since this is a json.RawMessage type in the // gomatrixserverlib.Transaction struct diff --git a/federationsender/queue/queue.go b/federationsender/queue/queue.go index bb1ed2258..66ab55786 100644 --- a/federationsender/queue/queue.go +++ b/federationsender/queue/queue.go @@ -15,6 +15,7 @@ package queue import ( + "context" "crypto/ed25519" "fmt" "sync" @@ -86,7 +87,7 @@ func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *d destination: destination, client: oqs.client, statistics: oqs.statistics.ForServer(destination), - incomingPDUs: make(chan *gomatrixserverlib.HeaderedEvent, 128), + incomingPDUs: make(chan struct{}, 128), incomingEDUs: make(chan *gomatrixserverlib.EDU, 128), incomingInvites: make(chan *gomatrixserverlib.InviteV2Request, 128), retryServerCh: make(chan bool), @@ -120,8 +121,13 @@ func (oqs *OutgoingQueues) SendEvent( "destinations": destinations, "event": ev.EventID(), }).Info("Sending event") + nid, err := oqs.db.StoreJSON(context.TODO(), ev.JSON()) + if err != nil { + return fmt.Errorf("sendevent: oqs.db.StoreJSON: %w", err) + } + for _, destination := range destinations { - oqs.getQueue(destination).sendEvent(ev) + oqs.getQueue(destination).sendEvent(nid) } return nil diff --git a/federationsender/storage/interface.go b/federationsender/storage/interface.go index d063f3c53..973eb474f 100644 --- a/federationsender/storage/interface.go +++ b/federationsender/storage/interface.go @@ -26,6 +26,8 @@ type Database interface { internal.PartitionStorer UpdateRoom(ctx context.Context, roomID, oldEventID, newEventID string, addHosts []types.JoinedHost, removeHosts []string) (joinedHosts []types.JoinedHost, err error) GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) - GetFailedPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName) ([]*gomatrixserverlib.HeaderedEvent, error) - StoreFailedPDUs(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, pdus []*gomatrixserverlib.HeaderedEvent) error + StoreJSON(ctx context.Context, js []byte) (int64, error) + AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nids []int64) error + GetNextTransactionPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (gomatrixserverlib.TransactionID, []*gomatrixserverlib.HeaderedEvent, error) + CleanTransactionPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, transactionID gomatrixserverlib.TransactionID) error } diff --git a/federationsender/storage/postgres/queue_json_table.go b/federationsender/storage/postgres/queue_json_table.go new file mode 100644 index 000000000..2852713ce --- /dev/null +++ b/federationsender/storage/postgres/queue_json_table.go @@ -0,0 +1,113 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +const queueJSONSchema = ` +-- The queue_retry_json table contains event contents that +-- we failed to send. +CREATE TABLE IF NOT EXISTS federationsender_queue_retry_json ( + -- The JSON NID. This allows the federationsender_queue_retry table to + -- cross-reference to find the JSON blob. + json_nid BIGSERIAL, + -- The JSON body. Text so that we preserve UTF-8. + json_body TEXT NOT NULL +); +` + +const insertJSONSQL = "" + + "INSERT INTO federationsender_queue_retry_json (json_body)" + + " VALUES ($1)" + + " ON CONFLICT DO NOTHING" + +const deleteJSONSQL = "" + + "DELETE FROM federationsender_queue_retry_json WHERE json_nid = ANY($1)" + +const selectJSONSQL = "" + + "SELECT json_nid, json_body FROM federationsender_queue_retry_json" + + " WHERE json_nid = ANY($1)" + +type queueJSONStatements struct { + insertJSONStmt *sql.Stmt + deleteJSONStmt *sql.Stmt + selectJSONStmt *sql.Stmt +} + +func (s *queueJSONStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(queueJSONSchema) + if err != nil { + return + } + if s.insertJSONStmt, err = db.Prepare(insertJSONSQL); err != nil { + return + } + if s.deleteJSONStmt, err = db.Prepare(deleteJSONSQL); err != nil { + return + } + if s.selectJSONStmt, err = db.Prepare(selectJSONSQL); err != nil { + return + } + return +} + +func (s *queueJSONStatements) insertQueueJSON( + ctx context.Context, txn *sql.Tx, json string, +) (int64, error) { + stmt := sqlutil.TxStmt(txn, s.insertJSONStmt) + res, err := stmt.ExecContext(ctx, json) + if err != nil { + return 0, err + } + lastid, err := res.LastInsertId() + return lastid, err +} + +func (s *queueJSONStatements) deleteQueueJSON( + ctx context.Context, txn *sql.Tx, eventIDs []string, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteJSONStmt) + _, err := stmt.ExecContext(ctx, pq.StringArray(eventIDs)) + return err +} + +func (s *queueJSONStatements) selectJSON( + ctx context.Context, txn *sql.Tx, jsonNIDs []int64, +) (map[int64][]byte, error) { + blobs := map[int64][]byte{} + stmt := sqlutil.TxStmt(txn, s.selectJSONStmt) + rows, err := stmt.QueryContext(ctx, pq.Int64Array(jsonNIDs)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectJSON: rows.close() failed") + for rows.Next() { + var nid int64 + var blob []byte + if err = rows.Scan(&nid, &blob); err != nil { + return nil, err + } + blobs[nid] = blob + } + return blobs, err +} diff --git a/federationsender/storage/postgres/queue_pdus_table.go b/federationsender/storage/postgres/queue_pdus_table.go new file mode 100644 index 000000000..609fd554c --- /dev/null +++ b/federationsender/storage/postgres/queue_pdus_table.go @@ -0,0 +1,141 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const queueSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_queue_pdus ( + -- The transaction ID that was generated before persisting the event. + transaction_id TEXT NOT NULL, + -- The domain part of the user ID the m.room.member event is for. + server_name TEXT NOT NULL, + -- The JSON NID from the federationsender_queue_json table. + json_nid BIGINT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_pdus_event_id_idx + ON federationsender_queue (event_id, server_name); +` + +const insertQueueSQL = "" + + "INSERT INTO federationsender_queue (transaction_id, server_name, json_nid)" + + " VALUES ($1, $2, $3)" + +const deleteQueueTransactionSQL = "" + + "DELETE FROM federationsender_queue WHERE server_name = $1 AND transaction_id = $2" + +const selectQueueNextTransactionIDSQL = "" + + "SELECT transaction_id FROM federationsender_queue" + + " WHERE server_name = $1" + + " ORDER BY transaction_id ASC" + + " LIMIT 1" + +const selectQueuePDUsByTransactionSQL = "" + + "SELECT json_nid FROM federationsender_queue" + + " WHERE server_name = $1 AND transaction_id = $2" + + " LIMIT 50" + +type queueStatements struct { + insertQueueStmt *sql.Stmt + deleteQueueTransactionStmt *sql.Stmt + selectQueueNextTransactionIDStmt *sql.Stmt + selectQueuePDUsByTransactionStmt *sql.Stmt +} + +func (s *queueStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(queueSchema) + if err != nil { + return + } + if s.insertQueueStmt, err = db.Prepare(insertQueueSQL); err != nil { + return + } + if s.deleteQueueTransactionStmt, err = db.Prepare(deleteQueueTransactionSQL); err != nil { + return + } + if s.selectQueueNextTransactionIDStmt, err = db.Prepare(selectQueueNextTransactionIDSQL); err != nil { + return + } + if s.selectQueuePDUsByTransactionStmt, err = db.Prepare(selectQueuePDUsByTransactionSQL); err != nil { + return + } + return +} + +func (s *queueStatements) insertQueuePDU( + ctx context.Context, + txn *sql.Tx, + transactionID gomatrixserverlib.TransactionID, + serverName gomatrixserverlib.ServerName, + nid int64, +) error { + stmt := sqlutil.TxStmt(txn, s.insertQueueStmt) + _, err := stmt.ExecContext( + ctx, + transactionID, // the transaction ID that we initially attempted + serverName, // destination server name + nid, // JSON blob NID + ) + return err +} + +func (s *queueStatements) deleteQueueTransaction( + ctx context.Context, txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + transactionID gomatrixserverlib.TransactionID, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteQueueTransactionStmt) + _, err := stmt.ExecContext(ctx, serverName, transactionID) + return err +} + +func (s *queueStatements) selectQueueNextTransactionID( + ctx context.Context, txn *sql.Tx, serverName, sendType string, +) (string, error) { + var transactionID string + stmt := sqlutil.TxStmt(txn, s.selectQueueNextTransactionIDStmt) + err := stmt.QueryRowContext(ctx, serverName).Scan(&transactionID) + return transactionID, err +} + +func (s *queueStatements) selectQueuePDUs( + ctx context.Context, txn *sql.Tx, serverName string, transactionID string, limit int, +) ([]int64, error) { + stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsByTransactionStmt) + rows, err := stmt.QueryContext(ctx, serverName, transactionID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") + var result []int64 + for rows.Next() { + var nid int64 + if err = rows.Scan(&nid); err != nil { + return nil, err + } + result = append(result, nid) + } + + return result, rows.Err() +} diff --git a/federationsender/storage/postgres/queue_retry_table.go b/federationsender/storage/postgres/queue_retry_table.go deleted file mode 100644 index b2a0e86fd..000000000 --- a/federationsender/storage/postgres/queue_retry_table.go +++ /dev/null @@ -1,167 +0,0 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "context" - "database/sql" - "encoding/json" - "fmt" - - "github.com/lib/pq" - "github.com/matrix-org/dendrite/federationsender/types" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" -) - -const queueRetrySchema = ` --- The queue_retry table contains events that we failed to --- send to a destination host, such that we can try them again --- later. -CREATE TABLE IF NOT EXISTS federationsender_queue_retry ( - -- The string ID of the room. - transaction_id TEXT NOT NULL, - -- The event type: "pdu", "invite", "send_to_device". - send_type TEXT NOT NULL, - -- The event ID of the m.room.member join event. - event_id TEXT NOT NULL, - -- The origin server TS of the event. - origin_server_ts BIGINT NOT NULL, - -- The domain part of the user ID the m.room.member event is for. - server_name TEXT NOT NULL, - -- The JSON body. - json_body BYTEA NOT NULL -); - -CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_retry_event_id_idx - ON federationsender_queue_retry (event_id, server_name); -` - -const insertRetrySQL = "" + - "INSERT INTO federationsender_queue_retry (transaction_id, send_type, event_id, origin_server_ts, server_name, json_body)" + - " VALUES ($1, $2, $3, $4, $5, $6)" - -const deleteRetrySQL = "" + - "DELETE FROM federationsender_queue_retry WHERE event_id = ANY($1)" - -const selectRetryNextTransactionIDSQL = "" + - "SELECT transaction_id FROM federationsender_queue_retry" + - " WHERE server_name = $1 AND send_type = $2" + - " ORDER BY transaction_id ASC" + - " LIMIT 1" - -const selectRetryPDUsByTransactionSQL = "" + - "SELECT event_id, server_name, origin_server_ts, json_body FROM federationsender_queue_retry" + - " WHERE server_name = $1 AND send_type = $2 AND transaction_id = $3" + - " LIMIT 50" - -type queueRetryStatements struct { - insertRetryStmt *sql.Stmt - deleteRetryStmt *sql.Stmt - selectRetryNextTransactionIDStmt *sql.Stmt - selectRetryPDUsByTransactionStmt *sql.Stmt -} - -func (s *queueRetryStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(queueRetrySchema) - if err != nil { - return - } - if s.insertRetryStmt, err = db.Prepare(insertRetrySQL); err != nil { - return - } - if s.deleteRetryStmt, err = db.Prepare(deleteRetrySQL); err != nil { - return - } - if s.selectRetryNextTransactionIDStmt, err = db.Prepare(selectRetryNextTransactionIDSQL); err != nil { - return - } - if s.selectRetryPDUsByTransactionStmt, err = db.Prepare(selectRetryPDUsByTransactionSQL); err != nil { - return - } - return -} - -func (s *queueRetryStatements) insertQueueRetry( - ctx context.Context, - txn *sql.Tx, - transactionID string, - sendtype string, - event gomatrixserverlib.Event, - serverName gomatrixserverlib.ServerName, -) error { - stmt := sqlutil.TxStmt(txn, s.insertRetryStmt) - _, err := stmt.ExecContext( - ctx, - transactionID, // the transaction ID that we initially attempted - sendtype, // either "pdu", "invite", "send_to_device" - event.EventID(), // the event ID - event.OriginServerTS(), // the event origin server TS - serverName, // destination server name - event.JSON(), // JSON body - ) - return err -} - -func (s *queueRetryStatements) deleteQueueRetry( - ctx context.Context, txn *sql.Tx, eventIDs []string, -) error { - stmt := sqlutil.TxStmt(txn, s.deleteRetryStmt) - _, err := stmt.ExecContext(ctx, pq.StringArray(eventIDs)) - return err -} - -func (s *queueRetryStatements) selectRetryNextTransactionID( - ctx context.Context, txn *sql.Tx, serverName, sendType string, -) (string, error) { - var transactionID string - stmt := sqlutil.TxStmt(txn, s.selectRetryNextTransactionIDStmt) - err := stmt.QueryRowContext(ctx, serverName, types.FailedEventTypePDU).Scan(&transactionID) - return transactionID, err -} - -func (s *queueRetryStatements) selectQueueRetryPDUs( - ctx context.Context, txn *sql.Tx, serverName string, transactionID string, -) ([]*gomatrixserverlib.HeaderedEvent, error) { - - stmt := sqlutil.TxStmt(txn, s.selectRetryPDUsByTransactionStmt) - rows, err := stmt.QueryContext(ctx, serverName, types.FailedEventTypePDU, transactionID) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "queueRetryFromStmt: rows.close() failed") - - var result []*gomatrixserverlib.HeaderedEvent - for rows.Next() { - var transactionID, eventID string - var originServerTS int64 - var jsonBody []byte - if err = rows.Scan(&transactionID, &eventID, &originServerTS, &jsonBody); err != nil { - return nil, err - } - var event gomatrixserverlib.HeaderedEvent - if err = json.Unmarshal(jsonBody, &event); err != nil { - return nil, fmt.Errorf("json.Unmarshal: %w", err) - } - if event.EventID() != eventID { - return nil, fmt.Errorf("event ID %q doesn't match expected %q", event.EventID(), eventID) - } - result = append(result, &event) - } - - return result, rows.Err() -} diff --git a/federationsender/storage/postgres/storage.go b/federationsender/storage/postgres/storage.go index 602ba862c..7a3d5384c 100644 --- a/federationsender/storage/postgres/storage.go +++ b/federationsender/storage/postgres/storage.go @@ -18,6 +18,7 @@ package postgres import ( "context" "database/sql" + "encoding/json" "fmt" "github.com/matrix-org/dendrite/federationsender/types" @@ -29,7 +30,8 @@ import ( type Database struct { joinedHostsStatements roomStatements - queueRetryStatements + queueStatements + queueJSONStatements sqlutil.PartitionOffsetStatements db *sql.DB } @@ -58,7 +60,11 @@ func (d *Database) prepare() error { return err } - if err = d.queueRetryStatements.prepare(d.db); err != nil { + if err = d.queueStatements.prepare(d.db); err != nil { + return err + } + + if err = d.queueJSONStatements.prepare(d.db); err != nil { return err } @@ -128,44 +134,87 @@ func (d *Database) GetJoinedHosts( return d.selectJoinedHosts(ctx, roomID) } -// GetFailedPDUs retrieves PDUs that we have failed to send on -// a specific destination queue. -func (d *Database) GetFailedPDUs( - ctx context.Context, - serverName gomatrixserverlib.ServerName, -) ([]*gomatrixserverlib.HeaderedEvent, error) { - transactionID, err := d.selectRetryNextTransactionID(ctx, nil, string(serverName), types.FailedEventTypePDU) +// StoreJSON adds a JSON blob into the queue JSON table and returns +// a NID. The NID will then be used when inserting the per-destination +// metadata entries. +func (d *Database) StoreJSON( + ctx context.Context, js []byte, +) (int64, error) { + res, err := d.insertJSONStmt.ExecContext(ctx, js) if err != nil { - return nil, fmt.Errorf("d.selectRetryNextTransactionID: %w", err) + return 0, fmt.Errorf("d.insertRetryJSONStmt: %w", err) } - - events, err := d.selectQueueRetryPDUs(ctx, nil, string(serverName), transactionID) + nid, err := res.LastInsertId() if err != nil { - return nil, fmt.Errorf("d.selectQueueRetryPDUs: %w", err) + return 0, fmt.Errorf("res.LastInsertID: %w", err) } - return events, nil + return nid, nil } -// StoreFailedPDUs stores PDUs that we have failed to send on -// a specific destination queue. -func (d *Database) StoreFailedPDUs( +// AssociatePDUWithDestination creates an association that the +// destination queues will use to determine which JSON blobs to send +// to which servers. +func (d *Database) AssociatePDUWithDestination( ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, - pdus []*gomatrixserverlib.HeaderedEvent, + nids []int64, ) error { - for _, pdu := range pdus { - if _, err := d.insertRetryStmt.ExecContext( - ctx, - string(transactionID), // transaction ID - types.FailedEventTypePDU, // type of event that was queued - pdu.EventID(), // event ID - pdu.OriginServerTS(), // event origin server TS - string(serverName), // destination server name - pdu.JSON(), // JSON body + for _, nid := range nids { + if err := d.insertQueuePDU( + ctx, // context + nil, // SQL transaction + transactionID, // transaction ID + serverName, // destination server name + nid, // NID from the federationsender_queue_json table ); err != nil { return fmt.Errorf("d.insertQueueRetryStmt.ExecContext: %w", err) } } return nil } + +// GetNextTransactionPDUs retrieves events from the database for +// the next pending transaction, up to the limit specified. +func (d *Database) GetNextTransactionPDUs( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + limit int, +) (gomatrixserverlib.TransactionID, []*gomatrixserverlib.HeaderedEvent, error) { + transactionID, err := d.selectQueueNextTransactionID(ctx, nil, string(serverName), types.FailedEventTypePDU) + if err != nil { + return "", nil, fmt.Errorf("d.selectRetryNextTransactionID: %w", err) + } + + nids, err := d.selectQueuePDUs(ctx, nil, string(serverName), transactionID, limit) + if err != nil { + return "", nil, fmt.Errorf("d.selectQueueRetryPDUs: %w", err) + } + + blobs, err := d.selectJSON(ctx, nil, nids) + if err != nil { + return "", nil, fmt.Errorf("d.selectJSON: %w", err) + } + + var events []*gomatrixserverlib.HeaderedEvent + for _, blob := range blobs { + var event gomatrixserverlib.HeaderedEvent + if err := json.Unmarshal(blob, &event); err != nil { + return "", nil, fmt.Errorf("json.Unmarshal: %w", err) + } + events = append(events, &event) + } + + return gomatrixserverlib.TransactionID(transactionID), events, nil +} + +// CleanTransactionPDUs cleans up all associated events for a +// given transaction. This is done when the transaction was sent +// successfully. +func (d *Database) CleanTransactionPDUs( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + transactionID gomatrixserverlib.TransactionID, +) error { + return d.deleteQueueTransaction(ctx, nil, serverName, transactionID) +}