From dabb304d99f2e6c1e4984c732fdbb2c8b599148f Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 20 Jul 2020 14:12:20 +0100 Subject: [PATCH] Deduplicate FS database, add some EDU persistence groundwork --- federationsender/queue/queue.go | 2 +- federationsender/storage/interface.go | 2 +- .../storage/postgres/joined_hosts_table.go | 26 +- .../storage/postgres/queue_edus_table.go | 179 +++++++++++ .../storage/postgres/queue_json_table.go | 20 +- .../storage/postgres/queue_pdus_table.go | 62 ++-- .../storage/postgres/room_table.go | 21 +- federationsender/storage/postgres/storage.go | 278 +++-------------- federationsender/storage/shared/storage.go | 231 ++++++++++++++ .../storage/sqlite3/joined_hosts_table.go | 16 +- .../storage/sqlite3/queue_edus_table.go | 179 +++++++++++ .../storage/sqlite3/queue_json_table.go | 12 +- .../storage/sqlite3/queue_pdus_table.go | 28 +- .../storage/sqlite3/room_table.go | 12 +- federationsender/storage/sqlite3/storage.go | 282 ++---------------- federationsender/storage/tables/interface.go | 47 +++ 16 files changed, 818 insertions(+), 579 deletions(-) create mode 100644 federationsender/storage/postgres/queue_edus_table.go create mode 100644 federationsender/storage/shared/storage.go create mode 100644 federationsender/storage/sqlite3/queue_edus_table.go create mode 100644 federationsender/storage/tables/interface.go diff --git a/federationsender/queue/queue.go b/federationsender/queue/queue.go index 46c9fddbe..812267e63 100644 --- a/federationsender/queue/queue.go +++ b/federationsender/queue/queue.go @@ -61,7 +61,7 @@ func NewOutgoingQueues( queues: map[gomatrixserverlib.ServerName]*destinationQueue{}, } // Look up which servers we have pending items for and then rehydrate those queues. - if serverNames, err := db.GetPendingServerNames(context.Background()); err == nil { + if serverNames, err := db.GetPendingPDUServerNames(context.Background()); err == nil { for _, serverName := range serverNames { queues.getQueue(serverName).wakeQueueIfNeeded() } diff --git a/federationsender/storage/interface.go b/federationsender/storage/interface.go index 6fff35186..a24158033 100644 --- a/federationsender/storage/interface.go +++ b/federationsender/storage/interface.go @@ -32,5 +32,5 @@ type Database interface { GetNextTransactionPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (gomatrixserverlib.TransactionID, []*gomatrixserverlib.HeaderedEvent, error) CleanTransactionPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, transactionID gomatrixserverlib.TransactionID) error GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) - GetPendingServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) + GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) } diff --git a/federationsender/storage/postgres/joined_hosts_table.go b/federationsender/storage/postgres/joined_hosts_table.go index 2612e7e08..af0a52581 100644 --- a/federationsender/storage/postgres/joined_hosts_table.go +++ b/federationsender/storage/postgres/joined_hosts_table.go @@ -61,33 +61,37 @@ const selectAllJoinedHostsSQL = "" + "SELECT DISTINCT server_name FROM federationsender_joined_hosts" type joinedHostsStatements struct { + db *sql.DB insertJoinedHostsStmt *sql.Stmt deleteJoinedHostsStmt *sql.Stmt selectJoinedHostsStmt *sql.Stmt selectAllJoinedHostsStmt *sql.Stmt } -func (s *joinedHostsStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(joinedHostsSchema) +func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { + s = &joinedHostsStatements{ + db: db, + } + _, err = s.db.Exec(joinedHostsSchema) if err != nil { return } - if s.insertJoinedHostsStmt, err = db.Prepare(insertJoinedHostsSQL); err != nil { + if s.insertJoinedHostsStmt, err = s.db.Prepare(insertJoinedHostsSQL); err != nil { return } - if s.deleteJoinedHostsStmt, err = db.Prepare(deleteJoinedHostsSQL); err != nil { + if s.deleteJoinedHostsStmt, err = s.db.Prepare(deleteJoinedHostsSQL); err != nil { return } - if s.selectJoinedHostsStmt, err = db.Prepare(selectJoinedHostsSQL); err != nil { + if s.selectJoinedHostsStmt, err = s.db.Prepare(selectJoinedHostsSQL); err != nil { return } - if s.selectAllJoinedHostsStmt, err = db.Prepare(selectAllJoinedHostsSQL); err != nil { + if s.selectAllJoinedHostsStmt, err = s.db.Prepare(selectAllJoinedHostsSQL); err != nil { return } return } -func (s *joinedHostsStatements) insertJoinedHosts( +func (s *joinedHostsStatements) InsertJoinedHosts( ctx context.Context, txn *sql.Tx, roomID, eventID string, @@ -98,7 +102,7 @@ func (s *joinedHostsStatements) insertJoinedHosts( return err } -func (s *joinedHostsStatements) deleteJoinedHosts( +func (s *joinedHostsStatements) DeleteJoinedHosts( ctx context.Context, txn *sql.Tx, eventIDs []string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt) @@ -106,20 +110,20 @@ func (s *joinedHostsStatements) deleteJoinedHosts( return err } -func (s *joinedHostsStatements) selectJoinedHostsWithTx( +func (s *joinedHostsStatements) SelectJoinedHostsWithTx( ctx context.Context, txn *sql.Tx, roomID string, ) ([]types.JoinedHost, error) { stmt := sqlutil.TxStmt(txn, s.selectJoinedHostsStmt) return joinedHostsFromStmt(ctx, stmt, roomID) } -func (s *joinedHostsStatements) selectJoinedHosts( +func (s *joinedHostsStatements) SelectJoinedHosts( ctx context.Context, roomID string, ) ([]types.JoinedHost, error) { return joinedHostsFromStmt(ctx, s.selectJoinedHostsStmt, roomID) } -func (s *joinedHostsStatements) selectAllJoinedHosts( +func (s *joinedHostsStatements) SelectAllJoinedHosts( ctx context.Context, ) ([]gomatrixserverlib.ServerName, error) { rows, err := s.selectAllJoinedHostsStmt.QueryContext(ctx) diff --git a/federationsender/storage/postgres/queue_edus_table.go b/federationsender/storage/postgres/queue_edus_table.go new file mode 100644 index 000000000..b531907e5 --- /dev/null +++ b/federationsender/storage/postgres/queue_edus_table.go @@ -0,0 +1,179 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const queueEDUsSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_queue_edus ( + -- The type of the event (informational). + edu_type TEXT NOT NULL, + -- The domain part of the user ID the EDU event is for. + server_name TEXT NOT NULL, + -- The JSON NID from the federationsender_queue_edus_json table. + json_nid BIGINT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx + ON federationsender_queue_edus (json_nid, server_name); +` + +const insertQueueEDUSQL = "" + + "INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid)" + + " VALUES ($1, $2, $3)" + +const selectQueueEDUSQL = "" + + "SELECT json_nid FROM federationsender_queue_edus" + + " WHERE server_name = $1" + +const selectQueueEDUReferenceJSONCountSQL = "" + + "SELECT COUNT(*) FROM federationsender_queue_edus" + + " WHERE json_nid = $1" + +const selectQueueEDUCountSQL = "" + + "SELECT COUNT(*) FROM federationsender_queue_edus" + + " WHERE server_name = $1" + +const selectQueueServerNamesSQL = "" + + "SELECT DISTINCT server_name FROM federationsender_queue_edus" + +type queueEDUsStatements struct { + db *sql.DB + insertQueueEDUStmt *sql.Stmt + selectQueueEDUStmt *sql.Stmt + selectQueueEDUReferenceJSONCountStmt *sql.Stmt + selectQueueEDUCountStmt *sql.Stmt + selectQueueEDUServerNamesStmt *sql.Stmt +} + +func NewPostgresQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) { + s = &queueEDUsStatements{ + db: db, + } + _, err = s.db.Exec(queueEDUsSchema) + if err != nil { + return + } + if s.insertQueueEDUStmt, err = s.db.Prepare(insertQueueEDUSQL); err != nil { + return + } + if s.selectQueueEDUStmt, err = s.db.Prepare(selectQueueEDUSQL); err != nil { + return + } + if s.selectQueueEDUReferenceJSONCountStmt, err = s.db.Prepare(selectQueueEDUReferenceJSONCountSQL); err != nil { + return + } + if s.selectQueueEDUCountStmt, err = s.db.Prepare(selectQueueEDUCountSQL); err != nil { + return + } + if s.selectQueueEDUServerNamesStmt, err = s.db.Prepare(selectQueueServerNamesSQL); err != nil { + return + } + return +} + +func (s *queueEDUsStatements) InsertQueueEDU( + ctx context.Context, + txn *sql.Tx, + userID, deviceID string, + serverName gomatrixserverlib.ServerName, + nid int64, +) error { + stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt) + _, err := stmt.ExecContext( + ctx, + userID, // destination user ID + deviceID, // destination device ID + serverName, // destination server name + nid, // JSON blob NID + ) + return err +} + +func (s *queueEDUsStatements) SelectQueueEDU( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) ([]int64, error) { + stmt := sqlutil.TxStmt(txn, s.selectQueueEDUStmt) + rows, err := stmt.QueryContext(ctx) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") + var result []int64 + for rows.Next() { + var nid int64 + if err = rows.Scan(&nid); err != nil { + return nil, err + } + result = append(result, nid) + } + return result, nil +} + +func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount( + ctx context.Context, txn *sql.Tx, jsonNID int64, +) (int64, error) { + var count int64 + stmt := sqlutil.TxStmt(txn, s.selectQueueEDUReferenceJSONCountStmt) + err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count) + if err == sql.ErrNoRows { + return -1, nil + } + return count, err +} + +func (s *queueEDUsStatements) SelectQueueEDUCount( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) (int64, error) { + var count int64 + stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt) + err := stmt.QueryRowContext(ctx, serverName).Scan(&count) + if err == sql.ErrNoRows { + // It's acceptable for there to be no rows referencing a given + // JSON NID but it's not an error condition. Just return as if + // there's a zero count. + return 0, nil + } + return count, err +} + +func (s *queueEDUsStatements) SelectQueueEDUServerNames( + ctx context.Context, txn *sql.Tx, +) ([]gomatrixserverlib.ServerName, error) { + stmt := sqlutil.TxStmt(txn, s.selectQueueEDUServerNamesStmt) + rows, err := stmt.QueryContext(ctx) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") + var result []gomatrixserverlib.ServerName + for rows.Next() { + var serverName gomatrixserverlib.ServerName + if err = rows.Scan(&serverName); err != nil { + return nil, err + } + result = append(result, serverName) + } + + return result, rows.Err() +} diff --git a/federationsender/storage/postgres/queue_json_table.go b/federationsender/storage/postgres/queue_json_table.go index eac2ea988..853073741 100644 --- a/federationsender/storage/postgres/queue_json_table.go +++ b/federationsender/storage/postgres/queue_json_table.go @@ -48,29 +48,33 @@ const selectJSONSQL = "" + " WHERE json_nid = ANY($1)" type queueJSONStatements struct { + db *sql.DB insertJSONStmt *sql.Stmt deleteJSONStmt *sql.Stmt selectJSONStmt *sql.Stmt } -func (s *queueJSONStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(queueJSONSchema) +func NewPostgresQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) { + s = &queueJSONStatements{ + db: db, + } + _, err = s.db.Exec(queueJSONSchema) if err != nil { return } - if s.insertJSONStmt, err = db.Prepare(insertJSONSQL); err != nil { + if s.insertJSONStmt, err = s.db.Prepare(insertJSONSQL); err != nil { return } - if s.deleteJSONStmt, err = db.Prepare(deleteJSONSQL); err != nil { + if s.deleteJSONStmt, err = s.db.Prepare(deleteJSONSQL); err != nil { return } - if s.selectJSONStmt, err = db.Prepare(selectJSONSQL); err != nil { + if s.selectJSONStmt, err = s.db.Prepare(selectJSONSQL); err != nil { return } return } -func (s *queueJSONStatements) insertQueueJSON( +func (s *queueJSONStatements) InsertQueueJSON( ctx context.Context, txn *sql.Tx, json string, ) (int64, error) { stmt := sqlutil.TxStmt(txn, s.insertJSONStmt) @@ -81,7 +85,7 @@ func (s *queueJSONStatements) insertQueueJSON( return lastid, nil } -func (s *queueJSONStatements) deleteQueueJSON( +func (s *queueJSONStatements) DeleteQueueJSON( ctx context.Context, txn *sql.Tx, nids []int64, ) error { stmt := sqlutil.TxStmt(txn, s.deleteJSONStmt) @@ -89,7 +93,7 @@ func (s *queueJSONStatements) deleteQueueJSON( return err } -func (s *queueJSONStatements) selectQueueJSON( +func (s *queueJSONStatements) SelectQueueJSON( ctx context.Context, txn *sql.Tx, jsonNIDs []int64, ) (map[int64][]byte, error) { blobs := map[int64][]byte{} diff --git a/federationsender/storage/postgres/queue_pdus_table.go b/federationsender/storage/postgres/queue_pdus_table.go index dab6003e9..1740487ee 100644 --- a/federationsender/storage/postgres/queue_pdus_table.go +++ b/federationsender/storage/postgres/queue_pdus_table.go @@ -44,7 +44,7 @@ const insertQueuePDUSQL = "" + const deleteQueueTransactionPDUsSQL = "" + "DELETE FROM federationsender_queue_pdus WHERE server_name = $1 AND transaction_id = $2" -const selectQueueNextTransactionIDSQL = "" + +const selectQueuePDUNextTransactionIDSQL = "" + "SELECT transaction_id FROM federationsender_queue_pdus" + " WHERE server_name = $1" + " ORDER BY transaction_id ASC" + @@ -55,7 +55,7 @@ const selectQueuePDUsByTransactionSQL = "" + " WHERE server_name = $1 AND transaction_id = $2" + " LIMIT $3" -const selectQueueReferenceJSONCountSQL = "" + +const selectQueuePDUReferenceJSONCountSQL = "" + "SELECT COUNT(*) FROM federationsender_queue_pdus" + " WHERE json_nid = $1" @@ -63,49 +63,53 @@ const selectQueuePDUsCountSQL = "" + "SELECT COUNT(*) FROM federationsender_queue_pdus" + " WHERE server_name = $1" -const selectQueueServerNamesSQL = "" + +const selectQueuePDUServerNamesSQL = "" + "SELECT DISTINCT server_name FROM federationsender_queue_pdus" type queuePDUsStatements struct { - insertQueuePDUStmt *sql.Stmt - deleteQueueTransactionPDUsStmt *sql.Stmt - selectQueueNextTransactionIDStmt *sql.Stmt - selectQueuePDUsByTransactionStmt *sql.Stmt - selectQueueReferenceJSONCountStmt *sql.Stmt - selectQueuePDUsCountStmt *sql.Stmt - selectQueueServerNamesStmt *sql.Stmt + db *sql.DB + insertQueuePDUStmt *sql.Stmt + deleteQueueTransactionPDUsStmt *sql.Stmt + selectQueuePDUNextTransactionIDStmt *sql.Stmt + selectQueuePDUsByTransactionStmt *sql.Stmt + selectQueuePDUReferenceJSONCountStmt *sql.Stmt + selectQueuePDUsCountStmt *sql.Stmt + selectQueuePDUServerNamesStmt *sql.Stmt } -func (s *queuePDUsStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(queuePDUsSchema) +func NewPostgresQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) { + s = &queuePDUsStatements{ + db: db, + } + _, err = s.db.Exec(queuePDUsSchema) if err != nil { return } - if s.insertQueuePDUStmt, err = db.Prepare(insertQueuePDUSQL); err != nil { + if s.insertQueuePDUStmt, err = s.db.Prepare(insertQueuePDUSQL); err != nil { return } - if s.deleteQueueTransactionPDUsStmt, err = db.Prepare(deleteQueueTransactionPDUsSQL); err != nil { + if s.deleteQueueTransactionPDUsStmt, err = s.db.Prepare(deleteQueueTransactionPDUsSQL); err != nil { return } - if s.selectQueueNextTransactionIDStmt, err = db.Prepare(selectQueueNextTransactionIDSQL); err != nil { + if s.selectQueuePDUNextTransactionIDStmt, err = s.db.Prepare(selectQueuePDUNextTransactionIDSQL); err != nil { return } - if s.selectQueuePDUsByTransactionStmt, err = db.Prepare(selectQueuePDUsByTransactionSQL); err != nil { + if s.selectQueuePDUsByTransactionStmt, err = s.db.Prepare(selectQueuePDUsByTransactionSQL); err != nil { return } - if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueueReferenceJSONCountSQL); err != nil { + if s.selectQueuePDUReferenceJSONCountStmt, err = s.db.Prepare(selectQueuePDUReferenceJSONCountSQL); err != nil { return } - if s.selectQueuePDUsCountStmt, err = db.Prepare(selectQueuePDUsCountSQL); err != nil { + if s.selectQueuePDUsCountStmt, err = s.db.Prepare(selectQueuePDUsCountSQL); err != nil { return } - if s.selectQueueServerNamesStmt, err = db.Prepare(selectQueueServerNamesSQL); err != nil { + if s.selectQueuePDUServerNamesStmt, err = s.db.Prepare(selectQueuePDUServerNamesSQL); err != nil { return } return } -func (s *queuePDUsStatements) insertQueuePDU( +func (s *queuePDUsStatements) InsertQueuePDU( ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, @@ -122,7 +126,7 @@ func (s *queuePDUsStatements) insertQueuePDU( return err } -func (s *queuePDUsStatements) deleteQueueTransaction( +func (s *queuePDUsStatements) DeleteQueuePDUTransaction( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, transactionID gomatrixserverlib.TransactionID, @@ -132,11 +136,11 @@ func (s *queuePDUsStatements) deleteQueueTransaction( return err } -func (s *queuePDUsStatements) selectQueueNextTransactionID( +func (s *queuePDUsStatements) SelectQueuePDUNextTransactionID( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ) (gomatrixserverlib.TransactionID, error) { var transactionID gomatrixserverlib.TransactionID - stmt := sqlutil.TxStmt(txn, s.selectQueueNextTransactionIDStmt) + stmt := sqlutil.TxStmt(txn, s.selectQueuePDUNextTransactionIDStmt) err := stmt.QueryRowContext(ctx, serverName).Scan(&transactionID) if err == sql.ErrNoRows { return "", nil @@ -144,11 +148,11 @@ func (s *queuePDUsStatements) selectQueueNextTransactionID( return transactionID, err } -func (s *queuePDUsStatements) selectQueueReferenceJSONCount( +func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount( ctx context.Context, txn *sql.Tx, jsonNID int64, ) (int64, error) { var count int64 - stmt := sqlutil.TxStmt(txn, s.selectQueueReferenceJSONCountStmt) + stmt := sqlutil.TxStmt(txn, s.selectQueuePDUReferenceJSONCountStmt) err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count) if err == sql.ErrNoRows { // It's acceptable for there to be no rows referencing a given @@ -159,7 +163,7 @@ func (s *queuePDUsStatements) selectQueueReferenceJSONCount( return count, err } -func (s *queuePDUsStatements) selectQueuePDUCount( +func (s *queuePDUsStatements) SelectQueuePDUCount( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ) (int64, error) { var count int64 @@ -174,7 +178,7 @@ func (s *queuePDUsStatements) selectQueuePDUCount( return count, err } -func (s *queuePDUsStatements) selectQueuePDUs( +func (s *queuePDUsStatements) SelectQueuePDUs( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, transactionID gomatrixserverlib.TransactionID, @@ -198,10 +202,10 @@ func (s *queuePDUsStatements) selectQueuePDUs( return result, rows.Err() } -func (s *queuePDUsStatements) selectQueueServerNames( +func (s *queuePDUsStatements) SelectQueuePDUServerNames( ctx context.Context, txn *sql.Tx, ) ([]gomatrixserverlib.ServerName, error) { - stmt := sqlutil.TxStmt(txn, s.selectQueueServerNamesStmt) + stmt := sqlutil.TxStmt(txn, s.selectQueuePDUServerNamesStmt) rows, err := stmt.QueryContext(ctx) if err != nil { return nil, err diff --git a/federationsender/storage/postgres/room_table.go b/federationsender/storage/postgres/room_table.go index e5266c635..8d3ed20ff 100644 --- a/federationsender/storage/postgres/room_table.go +++ b/federationsender/storage/postgres/room_table.go @@ -43,24 +43,27 @@ const updateRoomSQL = "" + "UPDATE federationsender_rooms SET last_event_id = $2 WHERE room_id = $1" type roomStatements struct { + db *sql.DB insertRoomStmt *sql.Stmt selectRoomForUpdateStmt *sql.Stmt updateRoomStmt *sql.Stmt } -func (s *roomStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(roomSchema) +func NewPostgresRoomsTable(db *sql.DB) (s *roomStatements, err error) { + s = &roomStatements{ + db: db, + } + _, err = s.db.Exec(roomSchema) if err != nil { return } - - if s.insertRoomStmt, err = db.Prepare(insertRoomSQL); err != nil { + if s.insertRoomStmt, err = s.db.Prepare(insertRoomSQL); err != nil { return } - if s.selectRoomForUpdateStmt, err = db.Prepare(selectRoomForUpdateSQL); err != nil { + if s.selectRoomForUpdateStmt, err = s.db.Prepare(selectRoomForUpdateSQL); err != nil { return } - if s.updateRoomStmt, err = db.Prepare(updateRoomSQL); err != nil { + if s.updateRoomStmt, err = s.db.Prepare(updateRoomSQL); err != nil { return } return @@ -68,7 +71,7 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) { // insertRoom inserts the room if it didn't already exist. // If the room didn't exist then last_event_id is set to the empty string. -func (s *roomStatements) insertRoom( +func (s *roomStatements) InsertRoom( ctx context.Context, txn *sql.Tx, roomID string, ) error { _, err := sqlutil.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID) @@ -78,7 +81,7 @@ func (s *roomStatements) insertRoom( // selectRoomForUpdate locks the row for the room and returns the last_event_id. // The row must already exist in the table. Callers can ensure that the row // exists by calling insertRoom first. -func (s *roomStatements) selectRoomForUpdate( +func (s *roomStatements) SelectRoomForUpdate( ctx context.Context, txn *sql.Tx, roomID string, ) (string, error) { var lastEventID string @@ -92,7 +95,7 @@ func (s *roomStatements) selectRoomForUpdate( // updateRoom updates the last_event_id for the room. selectRoomForUpdate should // have already been called earlier within the transaction. -func (s *roomStatements) updateRoom( +func (s *roomStatements) UpdateRoom( ctx context.Context, txn *sql.Tx, roomID, lastEventID string, ) error { stmt := sqlutil.TxStmt(txn, s.updateRoomStmt) diff --git a/federationsender/storage/postgres/storage.go b/federationsender/storage/postgres/storage.go index 1535ebdf1..66388bfe4 100644 --- a/federationsender/storage/postgres/storage.go +++ b/federationsender/storage/postgres/storage.go @@ -16,266 +16,56 @@ package postgres import ( - "context" "database/sql" - "encoding/json" - "fmt" - "github.com/matrix-org/dendrite/federationsender/types" + "github.com/matrix-org/dendrite/federationsender/storage/shared" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" ) // Database stores information needed by the federation sender type Database struct { - joinedHostsStatements - roomStatements - queuePDUsStatements - queueJSONStatements + shared.Database sqlutil.PartitionOffsetStatements db *sql.DB } // NewDatabase opens a new database func NewDatabase(dataSourceName string, dbProperties sqlutil.DbProperties) (*Database, error) { - var result Database + var d Database var err error - if result.db, err = sqlutil.Open("postgres", dataSourceName, dbProperties); err != nil { + if d.db, err = sqlutil.Open("postgres", dataSourceName, dbProperties); err != nil { return nil, err } - if err = result.prepare(); err != nil { - return nil, err - } - return &result, nil -} - -func (d *Database) prepare() error { - var err error - - if err = d.joinedHostsStatements.prepare(d.db); err != nil { - return err - } - - if err = d.roomStatements.prepare(d.db); err != nil { - return err - } - - if err = d.queuePDUsStatements.prepare(d.db); err != nil { - return err - } - - if err = d.queueJSONStatements.prepare(d.db); err != nil { - return err - } - - return d.PartitionOffsetStatements.Prepare(d.db, "federationsender") -} - -// UpdateRoom updates the joined hosts for a room and returns what the joined -// hosts were before the update, or nil if this was a duplicate message. -// This is called when we receive a message from kafka, so we pass in -// oldEventID and newEventID to check that we haven't missed any messages or -// this isn't a duplicate message. -func (d *Database) UpdateRoom( - ctx context.Context, - roomID, oldEventID, newEventID string, - addHosts []types.JoinedHost, - removeHosts []string, -) (joinedHosts []types.JoinedHost, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - err = d.insertRoom(ctx, txn, roomID) - if err != nil { - return err - } - - lastSentEventID, err := d.selectRoomForUpdate(ctx, txn, roomID) - if err != nil { - return err - } - - if lastSentEventID == newEventID { - // We've handled this message before, so let's just ignore it. - // We can only get a duplicate for the last message we processed, - // so its enough just to compare the newEventID with lastSentEventID - return nil - } - - if lastSentEventID != "" && lastSentEventID != oldEventID { - return types.EventIDMismatchError{ - DatabaseID: lastSentEventID, RoomServerID: oldEventID, - } - } - - joinedHosts, err = d.selectJoinedHostsWithTx(ctx, txn, roomID) - if err != nil { - return err - } - - for _, add := range addHosts { - err = d.insertJoinedHosts(ctx, txn, roomID, add.MemberEventID, add.ServerName) - if err != nil { - return err - } - } - if err = d.deleteJoinedHosts(ctx, txn, removeHosts); err != nil { - return err - } - return d.updateRoom(ctx, txn, roomID, newEventID) - }) - return -} - -// GetJoinedHosts returns the currently joined hosts for room, -// as known to federationserver. -// Returns an error if something goes wrong. -func (d *Database) GetJoinedHosts( - ctx context.Context, roomID string, -) ([]types.JoinedHost, error) { - return d.selectJoinedHosts(ctx, roomID) -} - -// GetAllJoinedHosts returns the currently joined hosts for -// all rooms known to the federation sender. -// Returns an error if something goes wrong. -func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { - return d.selectAllJoinedHosts(ctx) -} - -// StoreJSON adds a JSON blob into the queue JSON table and returns -// a NID. The NID will then be used when inserting the per-destination -// metadata entries. -func (d *Database) StoreJSON( - ctx context.Context, js string, -) (int64, error) { - nid, err := d.insertQueueJSON(ctx, nil, js) + joinedHosts, err := NewPostgresJoinedHostsTable(d.db) if err != nil { - return 0, fmt.Errorf("d.insertQueueJSON: %w", err) + return nil, err } - return nid, nil -} - -// AssociatePDUWithDestination creates an association that the -// destination queues will use to determine which JSON blobs to send -// to which servers. -func (d *Database) AssociatePDUWithDestination( - ctx context.Context, - transactionID gomatrixserverlib.TransactionID, - serverName gomatrixserverlib.ServerName, - nids []int64, -) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - for _, nid := range nids { - if err := d.insertQueuePDU( - ctx, // context - txn, // SQL transaction - transactionID, // transaction ID - serverName, // destination server name - nid, // NID from the federationsender_queue_json table - ); err != nil { - return fmt.Errorf("d.insertQueueRetryStmt.ExecContext: %w", err) - } - } - return nil - }) -} - -// GetNextTransactionPDUs retrieves events from the database for -// the next pending transaction, up to the limit specified. -func (d *Database) GetNextTransactionPDUs( - ctx context.Context, - serverName gomatrixserverlib.ServerName, - limit int, -) ( - transactionID gomatrixserverlib.TransactionID, - events []*gomatrixserverlib.HeaderedEvent, - err error, -) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - transactionID, err = d.selectQueueNextTransactionID(ctx, txn, serverName) - if err != nil { - return fmt.Errorf("d.selectQueueNextTransactionID: %w", err) - } - - if transactionID == "" { - return nil - } - - nids, err := d.selectQueuePDUs(ctx, txn, serverName, transactionID, limit) - if err != nil { - return fmt.Errorf("d.selectQueuePDUs: %w", err) - } - - blobs, err := d.selectQueueJSON(ctx, txn, nids) - if err != nil { - return fmt.Errorf("d.selectJSON: %w", err) - } - - for _, blob := range blobs { - var event gomatrixserverlib.HeaderedEvent - if err := json.Unmarshal(blob, &event); err != nil { - return fmt.Errorf("json.Unmarshal: %w", err) - } - events = append(events, &event) - } - - return nil - }) - return -} - -// CleanTransactionPDUs cleans up all associated events for a -// given transaction. This is done when the transaction was sent -// successfully. -func (d *Database) CleanTransactionPDUs( - ctx context.Context, - serverName gomatrixserverlib.ServerName, - transactionID gomatrixserverlib.TransactionID, -) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - nids, err := d.selectQueuePDUs(ctx, txn, serverName, transactionID, 50) - if err != nil { - return fmt.Errorf("d.selectQueuePDUs: %w", err) - } - - if err = d.deleteQueueTransaction(ctx, txn, serverName, transactionID); err != nil { - return fmt.Errorf("d.deleteQueueTransaction: %w", err) - } - - var count int64 - var deleteNIDs []int64 - for _, nid := range nids { - count, err = d.selectQueueReferenceJSONCount(ctx, txn, nid) - if err != nil { - return fmt.Errorf("d.selectQueueReferenceJSONCount: %w", err) - } - if count == 0 { - deleteNIDs = append(deleteNIDs, nid) - } - } - - if len(deleteNIDs) > 0 { - if err = d.deleteQueueJSON(ctx, txn, deleteNIDs); err != nil { - return fmt.Errorf("d.deleteQueueJSON: %w", err) - } - } - - return nil - }) -} - -// GetPendingPDUCount returns the number of PDUs waiting to be -// sent for a given servername. -func (d *Database) GetPendingPDUCount( - ctx context.Context, - serverName gomatrixserverlib.ServerName, -) (int64, error) { - return d.selectQueuePDUCount(ctx, nil, serverName) -} - -// GetPendingServerNames returns the server names that have PDUs -// waiting to be sent. -func (d *Database) GetPendingServerNames( - ctx context.Context, -) ([]gomatrixserverlib.ServerName, error) { - return d.selectQueueServerNames(ctx, nil) + queuePDUs, err := NewPostgresQueuePDUsTable(d.db) + if err != nil { + return nil, err + } + queueEDUs, err := NewPostgresQueueEDUsTable(d.db) + if err != nil { + return nil, err + } + queueJSON, err := NewPostgresQueueJSONTable(d.db) + if err != nil { + return nil, err + } + rooms, err := NewPostgresRoomsTable(d.db) + if err != nil { + return nil, err + } + d.Database = shared.Database{ + DB: d.db, + FederationSenderJoinedHosts: joinedHosts, + FederationSenderQueuePDUs: queuePDUs, + FederationSenderQueueEDUs: queueEDUs, + FederationSenderQueueJSON: queueJSON, + FederationSenderRooms: rooms, + } + if err = d.PartitionOffsetStatements.Prepare(d.db, "federationsender"); err != nil { + return nil, err + } + return &d, nil } diff --git a/federationsender/storage/shared/storage.go b/federationsender/storage/shared/storage.go new file mode 100644 index 000000000..e5ac3876b --- /dev/null +++ b/federationsender/storage/shared/storage.go @@ -0,0 +1,231 @@ +package shared + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + + "github.com/matrix-org/dendrite/federationsender/storage/tables" + "github.com/matrix-org/dendrite/federationsender/types" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +type Database struct { + DB *sql.DB + FederationSenderQueuePDUs tables.FederationSenderQueuePDUs + FederationSenderQueueEDUs tables.FederationSenderQueueEDUs + FederationSenderQueueJSON tables.FederationSenderQueueJSON + FederationSenderJoinedHosts tables.FederationSenderJoinedHosts + FederationSenderRooms tables.FederationSenderRooms +} + +// UpdateRoom updates the joined hosts for a room and returns what the joined +// hosts were before the update, or nil if this was a duplicate message. +// This is called when we receive a message from kafka, so we pass in +// oldEventID and newEventID to check that we haven't missed any messages or +// this isn't a duplicate message. +func (d *Database) UpdateRoom( + ctx context.Context, + roomID, oldEventID, newEventID string, + addHosts []types.JoinedHost, + removeHosts []string, +) (joinedHosts []types.JoinedHost, err error) { + err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + err = d.FederationSenderRooms.InsertRoom(ctx, txn, roomID) + if err != nil { + return err + } + + lastSentEventID, err := d.FederationSenderRooms.SelectRoomForUpdate(ctx, txn, roomID) + if err != nil { + return err + } + + if lastSentEventID == newEventID { + // We've handled this message before, so let's just ignore it. + // We can only get a duplicate for the last message we processed, + // so its enough just to compare the newEventID with lastSentEventID + return nil + } + + if lastSentEventID != "" && lastSentEventID != oldEventID { + return types.EventIDMismatchError{ + DatabaseID: lastSentEventID, RoomServerID: oldEventID, + } + } + + joinedHosts, err = d.FederationSenderJoinedHosts.SelectJoinedHostsWithTx(ctx, txn, roomID) + if err != nil { + return err + } + + for _, add := range addHosts { + err = d.FederationSenderJoinedHosts.InsertJoinedHosts(ctx, txn, roomID, add.MemberEventID, add.ServerName) + if err != nil { + return err + } + } + if err = d.FederationSenderJoinedHosts.DeleteJoinedHosts(ctx, txn, removeHosts); err != nil { + return err + } + return d.FederationSenderRooms.UpdateRoom(ctx, txn, roomID, newEventID) + }) + return +} + +// GetJoinedHosts returns the currently joined hosts for room, +// as known to federationserver. +// Returns an error if something goes wrong. +func (d *Database) GetJoinedHosts( + ctx context.Context, roomID string, +) ([]types.JoinedHost, error) { + return d.FederationSenderJoinedHosts.SelectJoinedHosts(ctx, roomID) +} + +// GetAllJoinedHosts returns the currently joined hosts for +// all rooms known to the federation sender. +// Returns an error if something goes wrong. +func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { + return d.FederationSenderJoinedHosts.SelectAllJoinedHosts(ctx) +} + +// StoreJSON adds a JSON blob into the queue JSON table and returns +// a NID. The NID will then be used when inserting the per-destination +// metadata entries. +func (d *Database) StoreJSON( + ctx context.Context, js string, +) (int64, error) { + nid, err := d.FederationSenderQueueJSON.InsertQueueJSON(ctx, nil, js) + if err != nil { + return 0, fmt.Errorf("d.insertQueueJSON: %w", err) + } + return nid, nil +} + +// AssociatePDUWithDestination creates an association that the +// destination queues will use to determine which JSON blobs to send +// to which servers. +func (d *Database) AssociatePDUWithDestination( + ctx context.Context, + transactionID gomatrixserverlib.TransactionID, + serverName gomatrixserverlib.ServerName, + nids []int64, +) error { + return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + for _, nid := range nids { + if err := d.FederationSenderQueuePDUs.InsertQueuePDU( + ctx, // context + txn, // SQL transaction + transactionID, // transaction ID + serverName, // destination server name + nid, // NID from the federationsender_queue_json table + ); err != nil { + return fmt.Errorf("d.insertQueueRetryStmt.ExecContext: %w", err) + } + } + return nil + }) +} + +// GetNextTransactionPDUs retrieves events from the database for +// the next pending transaction, up to the limit specified. +func (d *Database) GetNextTransactionPDUs( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + limit int, +) ( + transactionID gomatrixserverlib.TransactionID, + events []*gomatrixserverlib.HeaderedEvent, + err error, +) { + err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + transactionID, err = d.FederationSenderQueuePDUs.SelectQueuePDUNextTransactionID(ctx, txn, serverName) + if err != nil { + return fmt.Errorf("d.selectQueueNextTransactionID: %w", err) + } + + if transactionID == "" { + return nil + } + + nids, err := d.FederationSenderQueuePDUs.SelectQueuePDUs(ctx, txn, serverName, transactionID, limit) + if err != nil { + return fmt.Errorf("d.selectQueuePDUs: %w", err) + } + + blobs, err := d.FederationSenderQueueJSON.SelectQueueJSON(ctx, txn, nids) + if err != nil { + return fmt.Errorf("d.selectJSON: %w", err) + } + + for _, blob := range blobs { + var event gomatrixserverlib.HeaderedEvent + if err := json.Unmarshal(blob, &event); err != nil { + return fmt.Errorf("json.Unmarshal: %w", err) + } + events = append(events, &event) + } + + return nil + }) + return +} + +// CleanTransactionPDUs cleans up all associated events for a +// given transaction. This is done when the transaction was sent +// successfully. +func (d *Database) CleanTransactionPDUs( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + transactionID gomatrixserverlib.TransactionID, +) error { + return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + nids, err := d.FederationSenderQueuePDUs.SelectQueuePDUs(ctx, txn, serverName, transactionID, 50) + if err != nil { + return fmt.Errorf("d.selectQueuePDUs: %w", err) + } + + if err = d.FederationSenderQueuePDUs.DeleteQueuePDUTransaction(ctx, txn, serverName, transactionID); err != nil { + return fmt.Errorf("d.deleteQueueTransaction: %w", err) + } + + var count int64 + var deleteNIDs []int64 + for _, nid := range nids { + count, err = d.FederationSenderQueuePDUs.SelectQueuePDUReferenceJSONCount(ctx, txn, nid) + if err != nil { + return fmt.Errorf("d.selectQueueReferenceJSONCount: %w", err) + } + if count == 0 { + deleteNIDs = append(deleteNIDs, nid) + } + } + + if len(deleteNIDs) > 0 { + if err = d.FederationSenderQueueJSON.DeleteQueueJSON(ctx, txn, deleteNIDs); err != nil { + return fmt.Errorf("d.deleteQueueJSON: %w", err) + } + } + + return nil + }) +} + +// GetPendingPDUCount returns the number of PDUs waiting to be +// sent for a given servername. +func (d *Database) GetPendingPDUCount( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) (int64, error) { + return d.FederationSenderQueuePDUs.SelectQueuePDUCount(ctx, nil, serverName) +} + +// GetPendingServerNames returns the server names that have PDUs +// waiting to be sent. +func (d *Database) GetPendingPDUServerNames( + ctx context.Context, +) ([]gomatrixserverlib.ServerName, error) { + return d.FederationSenderQueuePDUs.SelectQueuePDUServerNames(ctx, nil) +} diff --git a/federationsender/storage/sqlite3/joined_hosts_table.go b/federationsender/storage/sqlite3/joined_hosts_table.go index fd9ffedc1..4338e8182 100644 --- a/federationsender/storage/sqlite3/joined_hosts_table.go +++ b/federationsender/storage/sqlite3/joined_hosts_table.go @@ -60,13 +60,17 @@ const selectAllJoinedHostsSQL = "" + "SELECT DISTINCT server_name FROM federationsender_joined_hosts" type joinedHostsStatements struct { + db *sql.DB insertJoinedHostsStmt *sql.Stmt deleteJoinedHostsStmt *sql.Stmt selectJoinedHostsStmt *sql.Stmt selectAllJoinedHostsStmt *sql.Stmt } -func (s *joinedHostsStatements) prepare(db *sql.DB) (err error) { +func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { + s = &joinedHostsStatements{ + db: db, + } _, err = db.Exec(joinedHostsSchema) if err != nil { return @@ -86,7 +90,7 @@ func (s *joinedHostsStatements) prepare(db *sql.DB) (err error) { return } -func (s *joinedHostsStatements) insertJoinedHosts( +func (s *joinedHostsStatements) InsertJoinedHosts( ctx context.Context, txn *sql.Tx, roomID, eventID string, @@ -97,7 +101,7 @@ func (s *joinedHostsStatements) insertJoinedHosts( return err } -func (s *joinedHostsStatements) deleteJoinedHosts( +func (s *joinedHostsStatements) DeleteJoinedHosts( ctx context.Context, txn *sql.Tx, eventIDs []string, ) error { for _, eventID := range eventIDs { @@ -109,20 +113,20 @@ func (s *joinedHostsStatements) deleteJoinedHosts( return nil } -func (s *joinedHostsStatements) selectJoinedHostsWithTx( +func (s *joinedHostsStatements) SelectJoinedHostsWithTx( ctx context.Context, txn *sql.Tx, roomID string, ) ([]types.JoinedHost, error) { stmt := sqlutil.TxStmt(txn, s.selectJoinedHostsStmt) return joinedHostsFromStmt(ctx, stmt, roomID) } -func (s *joinedHostsStatements) selectJoinedHosts( +func (s *joinedHostsStatements) SelectJoinedHosts( ctx context.Context, roomID string, ) ([]types.JoinedHost, error) { return joinedHostsFromStmt(ctx, s.selectJoinedHostsStmt, roomID) } -func (s *joinedHostsStatements) selectAllJoinedHosts( +func (s *joinedHostsStatements) SelectAllJoinedHosts( ctx context.Context, ) ([]gomatrixserverlib.ServerName, error) { rows, err := s.selectAllJoinedHostsStmt.QueryContext(ctx) diff --git a/federationsender/storage/sqlite3/queue_edus_table.go b/federationsender/storage/sqlite3/queue_edus_table.go new file mode 100644 index 000000000..46b44c047 --- /dev/null +++ b/federationsender/storage/sqlite3/queue_edus_table.go @@ -0,0 +1,179 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const queueEDUsSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_queue_edus ( + -- The type of the event (informational). + edu_type TEXT NOT NULL, + -- The domain part of the user ID the EDU event is for. + server_name TEXT NOT NULL, + -- The JSON NID from the federationsender_queue_edus_json table. + json_nid BIGINT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx + ON federationsender_queue_edus (json_nid, server_name); +` + +const insertQueueEDUSQL = "" + + "INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid)" + + " VALUES ($1, $2, $3)" + +const selectQueueEDUSQL = "" + + "SELECT json_nid FROM federationsender_queue_edus" + + " WHERE server_name = $1" + +const selectQueueEDUReferenceJSONCountSQL = "" + + "SELECT COUNT(*) FROM federationsender_queue_edus" + + " WHERE json_nid = $1" + +const selectQueueEDUCountSQL = "" + + "SELECT COUNT(*) FROM federationsender_queue_edus" + + " WHERE server_name = $1" + +const selectQueueServerNamesSQL = "" + + "SELECT DISTINCT server_name FROM federationsender_queue_edus" + +type queueEDUsStatements struct { + db *sql.DB + insertQueueEDUStmt *sql.Stmt + selectQueueEDUStmt *sql.Stmt + selectQueueEDUReferenceJSONCountStmt *sql.Stmt + selectQueueEDUCountStmt *sql.Stmt + selectQueueEDUServerNamesStmt *sql.Stmt +} + +func NewSQLiteQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) { + s = &queueEDUsStatements{ + db: db, + } + _, err = db.Exec(queueEDUsSchema) + if err != nil { + return + } + if s.insertQueueEDUStmt, err = db.Prepare(insertQueueEDUSQL); err != nil { + return + } + if s.selectQueueEDUStmt, err = db.Prepare(selectQueueEDUSQL); err != nil { + return + } + if s.selectQueueEDUReferenceJSONCountStmt, err = db.Prepare(selectQueueEDUReferenceJSONCountSQL); err != nil { + return + } + if s.selectQueueEDUCountStmt, err = db.Prepare(selectQueueEDUCountSQL); err != nil { + return + } + if s.selectQueueEDUServerNamesStmt, err = db.Prepare(selectQueueServerNamesSQL); err != nil { + return + } + return +} + +func (s *queueEDUsStatements) InsertQueueEDU( + ctx context.Context, + txn *sql.Tx, + userID, deviceID string, + serverName gomatrixserverlib.ServerName, + nid int64, +) error { + stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt) + _, err := stmt.ExecContext( + ctx, + userID, // destination user ID + deviceID, // destination device ID + serverName, // destination server name + nid, // JSON blob NID + ) + return err +} + +func (s *queueEDUsStatements) SelectQueueEDU( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) ([]int64, error) { + stmt := sqlutil.TxStmt(txn, s.selectQueueEDUStmt) + rows, err := stmt.QueryContext(ctx) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") + var result []int64 + for rows.Next() { + var nid int64 + if err = rows.Scan(&nid); err != nil { + return nil, err + } + result = append(result, nid) + } + return result, nil +} + +func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount( + ctx context.Context, txn *sql.Tx, jsonNID int64, +) (int64, error) { + var count int64 + stmt := sqlutil.TxStmt(txn, s.selectQueueEDUReferenceJSONCountStmt) + err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count) + if err == sql.ErrNoRows { + return -1, nil + } + return count, err +} + +func (s *queueEDUsStatements) SelectQueueEDUCount( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) (int64, error) { + var count int64 + stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt) + err := stmt.QueryRowContext(ctx, serverName).Scan(&count) + if err == sql.ErrNoRows { + // It's acceptable for there to be no rows referencing a given + // JSON NID but it's not an error condition. Just return as if + // there's a zero count. + return 0, nil + } + return count, err +} + +func (s *queueEDUsStatements) SelectQueueEDUServerNames( + ctx context.Context, txn *sql.Tx, +) ([]gomatrixserverlib.ServerName, error) { + stmt := sqlutil.TxStmt(txn, s.selectQueueEDUServerNamesStmt) + rows, err := stmt.QueryContext(ctx) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") + var result []gomatrixserverlib.ServerName + for rows.Next() { + var serverName gomatrixserverlib.ServerName + if err = rows.Scan(&serverName); err != nil { + return nil, err + } + result = append(result, serverName) + } + + return result, rows.Err() +} diff --git a/federationsender/storage/sqlite3/queue_json_table.go b/federationsender/storage/sqlite3/queue_json_table.go index 01b7160db..95e6cd206 100644 --- a/federationsender/storage/sqlite3/queue_json_table.go +++ b/federationsender/storage/sqlite3/queue_json_table.go @@ -49,12 +49,16 @@ const selectJSONSQL = "" + " WHERE json_nid IN ($1)" type queueJSONStatements struct { + db *sql.DB insertJSONStmt *sql.Stmt //deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic //selectJSONStmt *sql.Stmt - prepared at runtime due to variadic } -func (s *queueJSONStatements) prepare(db *sql.DB) (err error) { +func NewSQLiteQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) { + s = &queueJSONStatements{ + db: db, + } _, err = db.Exec(queueJSONSchema) if err != nil { return @@ -65,7 +69,7 @@ func (s *queueJSONStatements) prepare(db *sql.DB) (err error) { return } -func (s *queueJSONStatements) insertQueueJSON( +func (s *queueJSONStatements) InsertQueueJSON( ctx context.Context, txn *sql.Tx, json string, ) (int64, error) { stmt := sqlutil.TxStmt(txn, s.insertJSONStmt) @@ -80,7 +84,7 @@ func (s *queueJSONStatements) insertQueueJSON( return lastid, nil } -func (s *queueJSONStatements) deleteQueueJSON( +func (s *queueJSONStatements) DeleteQueueJSON( ctx context.Context, txn *sql.Tx, nids []int64, ) error { deleteSQL := strings.Replace(deleteJSONSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) @@ -99,7 +103,7 @@ func (s *queueJSONStatements) deleteQueueJSON( return err } -func (s *queueJSONStatements) selectQueueJSON( +func (s *queueJSONStatements) SelectQueueJSON( ctx context.Context, txn *sql.Tx, jsonNIDs []int64, ) (map[int64][]byte, error) { selectSQL := strings.Replace(selectJSONSQL, "($1)", sqlutil.QueryVariadic(len(jsonNIDs)), 1) diff --git a/federationsender/storage/sqlite3/queue_pdus_table.go b/federationsender/storage/sqlite3/queue_pdus_table.go index 33eef91ed..de278c4ef 100644 --- a/federationsender/storage/sqlite3/queue_pdus_table.go +++ b/federationsender/storage/sqlite3/queue_pdus_table.go @@ -56,7 +56,7 @@ const selectQueuePDUsByTransactionSQL = "" + " WHERE server_name = $1 AND transaction_id = $2" + " LIMIT $3" -const selectQueueReferenceJSONCountSQL = "" + +const selectQueuePDUsReferenceJSONCountSQL = "" + "SELECT COUNT(*) FROM federationsender_queue_pdus" + " WHERE json_nid = $1" @@ -64,10 +64,11 @@ const selectQueuePDUsCountSQL = "" + "SELECT COUNT(*) FROM federationsender_queue_pdus" + " WHERE server_name = $1" -const selectQueueServerNamesSQL = "" + +const selectQueuePDUsServerNamesSQL = "" + "SELECT DISTINCT server_name FROM federationsender_queue_pdus" type queuePDUsStatements struct { + db *sql.DB insertQueuePDUStmt *sql.Stmt deleteQueueTransactionPDUsStmt *sql.Stmt selectQueueNextTransactionIDStmt *sql.Stmt @@ -77,7 +78,10 @@ type queuePDUsStatements struct { selectQueueServerNamesStmt *sql.Stmt } -func (s *queuePDUsStatements) prepare(db *sql.DB) (err error) { +func NewSQLiteQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) { + s = &queuePDUsStatements{ + db: db, + } _, err = db.Exec(queuePDUsSchema) if err != nil { return @@ -94,19 +98,19 @@ func (s *queuePDUsStatements) prepare(db *sql.DB) (err error) { if s.selectQueuePDUsByTransactionStmt, err = db.Prepare(selectQueuePDUsByTransactionSQL); err != nil { return } - if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueueReferenceJSONCountSQL); err != nil { + if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueuePDUsReferenceJSONCountSQL); err != nil { return } if s.selectQueuePDUsCountStmt, err = db.Prepare(selectQueuePDUsCountSQL); err != nil { return } - if s.selectQueueServerNamesStmt, err = db.Prepare(selectQueueServerNamesSQL); err != nil { + if s.selectQueueServerNamesStmt, err = db.Prepare(selectQueuePDUsServerNamesSQL); err != nil { return } return } -func (s *queuePDUsStatements) insertQueuePDU( +func (s *queuePDUsStatements) InsertQueuePDU( ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, @@ -123,7 +127,7 @@ func (s *queuePDUsStatements) insertQueuePDU( return err } -func (s *queuePDUsStatements) deleteQueueTransaction( +func (s *queuePDUsStatements) DeleteQueuePDUTransaction( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, transactionID gomatrixserverlib.TransactionID, @@ -133,7 +137,7 @@ func (s *queuePDUsStatements) deleteQueueTransaction( return err } -func (s *queuePDUsStatements) selectQueueNextTransactionID( +func (s *queuePDUsStatements) SelectQueuePDUNextTransactionID( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ) (gomatrixserverlib.TransactionID, error) { var transactionID gomatrixserverlib.TransactionID @@ -145,7 +149,7 @@ func (s *queuePDUsStatements) selectQueueNextTransactionID( return transactionID, err } -func (s *queuePDUsStatements) selectQueueReferenceJSONCount( +func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount( ctx context.Context, txn *sql.Tx, jsonNID int64, ) (int64, error) { var count int64 @@ -157,7 +161,7 @@ func (s *queuePDUsStatements) selectQueueReferenceJSONCount( return count, err } -func (s *queuePDUsStatements) selectQueuePDUCount( +func (s *queuePDUsStatements) SelectQueuePDUCount( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ) (int64, error) { var count int64 @@ -172,7 +176,7 @@ func (s *queuePDUsStatements) selectQueuePDUCount( return count, err } -func (s *queuePDUsStatements) selectQueuePDUs( +func (s *queuePDUsStatements) SelectQueuePDUs( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, transactionID gomatrixserverlib.TransactionID, @@ -196,7 +200,7 @@ func (s *queuePDUsStatements) selectQueuePDUs( return result, rows.Err() } -func (s *queuePDUsStatements) selectQueueServerNames( +func (s *queuePDUsStatements) SelectQueuePDUServerNames( ctx context.Context, txn *sql.Tx, ) ([]gomatrixserverlib.ServerName, error) { stmt := sqlutil.TxStmt(txn, s.selectQueueServerNamesStmt) diff --git a/federationsender/storage/sqlite3/room_table.go b/federationsender/storage/sqlite3/room_table.go index ca0c4d0b6..0710ccca3 100644 --- a/federationsender/storage/sqlite3/room_table.go +++ b/federationsender/storage/sqlite3/room_table.go @@ -43,12 +43,16 @@ const updateRoomSQL = "" + "UPDATE federationsender_rooms SET last_event_id = $2 WHERE room_id = $1" type roomStatements struct { + db *sql.DB insertRoomStmt *sql.Stmt selectRoomForUpdateStmt *sql.Stmt updateRoomStmt *sql.Stmt } -func (s *roomStatements) prepare(db *sql.DB) (err error) { +func NewSQLiteRoomsTable(db *sql.DB) (s *roomStatements, err error) { + s = &roomStatements{ + db: db, + } _, err = db.Exec(roomSchema) if err != nil { return @@ -68,7 +72,7 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) { // insertRoom inserts the room if it didn't already exist. // If the room didn't exist then last_event_id is set to the empty string. -func (s *roomStatements) insertRoom( +func (s *roomStatements) InsertRoom( ctx context.Context, txn *sql.Tx, roomID string, ) error { _, err := sqlutil.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID) @@ -78,7 +82,7 @@ func (s *roomStatements) insertRoom( // selectRoomForUpdate locks the row for the room and returns the last_event_id. // The row must already exist in the table. Callers can ensure that the row // exists by calling insertRoom first. -func (s *roomStatements) selectRoomForUpdate( +func (s *roomStatements) SelectRoomForUpdate( ctx context.Context, txn *sql.Tx, roomID string, ) (string, error) { var lastEventID string @@ -92,7 +96,7 @@ func (s *roomStatements) selectRoomForUpdate( // updateRoom updates the last_event_id for the room. selectRoomForUpdate should // have already been called earlier within the transaction. -func (s *roomStatements) updateRoom( +func (s *roomStatements) UpdateRoom( ctx context.Context, txn *sql.Tx, roomID, lastEventID string, ) error { stmt := sqlutil.TxStmt(txn, s.updateRoomStmt) diff --git a/federationsender/storage/sqlite3/storage.go b/federationsender/storage/sqlite3/storage.go index b23a2dbe6..a24b6c354 100644 --- a/federationsender/storage/sqlite3/storage.go +++ b/federationsender/storage/sqlite3/storage.go @@ -16,283 +16,65 @@ package sqlite3 import ( - "context" "database/sql" - "encoding/json" - "fmt" _ "github.com/mattn/go-sqlite3" - "github.com/matrix-org/dendrite/federationsender/types" + "github.com/matrix-org/dendrite/federationsender/storage/shared" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" ) // Database stores information needed by the federation sender type Database struct { - joinedHostsStatements - roomStatements - queuePDUsStatements - queueJSONStatements + shared.Database sqlutil.PartitionOffsetStatements db *sql.DB queuePDUsWriter *sqlutil.TransactionWriter + queueEDUsWriter *sqlutil.TransactionWriter queueJSONWriter *sqlutil.TransactionWriter } // NewDatabase opens a new database func NewDatabase(dataSourceName string) (*Database, error) { - var result Database + var d Database var err error cs, err := sqlutil.ParseFileURI(dataSourceName) if err != nil { return nil, err } - if result.db, err = sqlutil.Open(sqlutil.SQLiteDriverName(), cs, nil); err != nil { + if d.db, err = sqlutil.Open(sqlutil.SQLiteDriverName(), cs, nil); err != nil { return nil, err } - if err = result.prepare(); err != nil { - return nil, err - } - return &result, nil -} - -func (d *Database) prepare() error { - var err error - - if err = d.joinedHostsStatements.prepare(d.db); err != nil { - return err - } - - if err = d.roomStatements.prepare(d.db); err != nil { - return err - } - - if err = d.queuePDUsStatements.prepare(d.db); err != nil { - return err - } - - if err = d.queueJSONStatements.prepare(d.db); err != nil { - return err - } - - d.queuePDUsWriter = sqlutil.NewTransactionWriter() - d.queueJSONWriter = sqlutil.NewTransactionWriter() - - return d.PartitionOffsetStatements.Prepare(d.db, "federationsender") -} - -// UpdateRoom updates the joined hosts for a room and returns what the joined -// hosts were before the update, or nil if this was a duplicate message. -// This is called when we receive a message from kafka, so we pass in -// oldEventID and newEventID to check that we haven't missed any messages or -// this isn't a duplicate message. -func (d *Database) UpdateRoom( - ctx context.Context, - roomID, oldEventID, newEventID string, - addHosts []types.JoinedHost, - removeHosts []string, -) (joinedHosts []types.JoinedHost, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - err = d.insertRoom(ctx, txn, roomID) - if err != nil { - return err - } - - lastSentEventID, err := d.selectRoomForUpdate(ctx, txn, roomID) - if err != nil { - return err - } - - if lastSentEventID == newEventID { - // We've handled this message before, so let's just ignore it. - // We can only get a duplicate for the last message we processed, - // so its enough just to compare the newEventID with lastSentEventID - return nil - } - - if lastSentEventID != "" && lastSentEventID != oldEventID { - return types.EventIDMismatchError{ - DatabaseID: lastSentEventID, RoomServerID: oldEventID, - } - } - - joinedHosts, err = d.selectJoinedHostsWithTx(ctx, txn, roomID) - if err != nil { - return err - } - - for _, add := range addHosts { - err = d.insertJoinedHosts(ctx, txn, roomID, add.MemberEventID, add.ServerName) - if err != nil { - return err - } - } - if err = d.deleteJoinedHosts(ctx, txn, removeHosts); err != nil { - return err - } - return d.updateRoom(ctx, txn, roomID, newEventID) - }) - return -} - -// GetJoinedHosts returns the currently joined hosts for room, -// as known to federationserver. -// Returns an error if something goes wrong. -func (d *Database) GetJoinedHosts( - ctx context.Context, roomID string, -) ([]types.JoinedHost, error) { - return d.selectJoinedHosts(ctx, roomID) -} - -// GetAllJoinedHosts returns the currently joined hosts for -// all rooms known to the federation sender. -// Returns an error if something goes wrong. -func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { - return d.selectAllJoinedHosts(ctx) -} - -// StoreJSON adds a JSON blob into the queue JSON table and returns -// a NID. The NID will then be used when inserting the per-destination -// metadata entries. -func (d *Database) StoreJSON( - ctx context.Context, js string, -) (nid int64, err error) { - err = d.queueJSONWriter.Do(d.db, func(txn *sql.Tx) error { - n, e := d.insertQueueJSON(ctx, nil, js) - if e != nil { - return fmt.Errorf("d.insertQueueJSON: %w", e) - } - nid = n - return nil - }) - return -} - -// AssociatePDUWithDestination creates an association that the -// destination queues will use to determine which JSON blobs to send -// to which servers. -func (d *Database) AssociatePDUWithDestination( - ctx context.Context, - transactionID gomatrixserverlib.TransactionID, - serverName gomatrixserverlib.ServerName, - nids []int64, -) error { - return d.queuePDUsWriter.Do(d.db, func(txn *sql.Tx) error { - for _, nid := range nids { - if err := d.insertQueuePDU( - ctx, // context - txn, // SQL transaction - transactionID, // transaction ID - serverName, // destination server name - nid, // NID from the federationsender_queue_json table - ); err != nil { - return fmt.Errorf("d.insertQueueRetryStmt.ExecContext: %w", err) - } - } - return nil - }) -} - -// GetNextTransactionPDUs retrieves events from the database for -// the next pending transaction, up to the limit specified. -func (d *Database) GetNextTransactionPDUs( - ctx context.Context, - serverName gomatrixserverlib.ServerName, - limit int, -) ( - transactionID gomatrixserverlib.TransactionID, - events []*gomatrixserverlib.HeaderedEvent, - err error, -) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - transactionID, err = d.selectQueueNextTransactionID(ctx, txn, serverName) - if err != nil { - return fmt.Errorf("d.selectQueueNextTransactionID: %w", err) - } - - if transactionID == "" { - return nil - } - - nids, err := d.selectQueuePDUs(ctx, txn, serverName, transactionID, limit) - if err != nil { - return fmt.Errorf("d.selectQueuePDUs: %w", err) - } - - blobs, err := d.selectQueueJSON(ctx, txn, nids) - if err != nil { - return fmt.Errorf("d.selectJSON: %w", err) - } - - for _, blob := range blobs { - var event gomatrixserverlib.HeaderedEvent - if err := json.Unmarshal(blob, &event); err != nil { - return fmt.Errorf("json.Unmarshal: %w", err) - } - events = append(events, &event) - } - - return nil - }) - return -} - -// CleanTransactionPDUs cleans up all associated events for a -// given transaction. This is done when the transaction was sent -// successfully. -func (d *Database) CleanTransactionPDUs( - ctx context.Context, - serverName gomatrixserverlib.ServerName, - transactionID gomatrixserverlib.TransactionID, -) error { - var deleteNIDs []int64 - nids, err := d.selectQueuePDUs(ctx, nil, serverName, transactionID, 50) + joinedHosts, err := NewSQLiteJoinedHostsTable(d.db) if err != nil { - return fmt.Errorf("d.selectQueuePDUs: %w", err) + return nil, err } - if err = d.queuePDUsWriter.Do(d.db, func(txn *sql.Tx) error { - if err = d.deleteQueueTransaction(ctx, txn, serverName, transactionID); err != nil { - return fmt.Errorf("d.deleteQueueTransaction: %w", err) - } - return nil - }); err != nil { - return err + rooms, err := NewSQLiteRoomsTable(d.db) + if err != nil { + return nil, err } - var count int64 - for _, nid := range nids { - count, err = d.selectQueueReferenceJSONCount(ctx, nil, nid) - if err != nil { - return fmt.Errorf("d.selectQueueReferenceJSONCount: %w", err) - } - if count == 0 { - deleteNIDs = append(deleteNIDs, nid) - } + queuePDUs, err := NewSQLiteQueuePDUsTable(d.db) + if err != nil { + return nil, err } - if len(deleteNIDs) > 0 { - err = d.queueJSONWriter.Do(d.db, func(txn *sql.Tx) error { - if err = d.deleteQueueJSON(ctx, txn, deleteNIDs); err != nil { - return fmt.Errorf("d.deleteQueueJSON: %w", err) - } - return nil - }) + queueEDUs, err := NewSQLiteQueueEDUsTable(d.db) + if err != nil { + return nil, err } - return err -} - -// GetPendingPDUCount returns the number of PDUs waiting to be -// sent for a given servername. -func (d *Database) GetPendingPDUCount( - ctx context.Context, - serverName gomatrixserverlib.ServerName, -) (int64, error) { - return d.selectQueuePDUCount(ctx, nil, serverName) -} - -// GetPendingServerNames returns the server names that have PDUs -// waiting to be sent. -func (d *Database) GetPendingServerNames( - ctx context.Context, -) ([]gomatrixserverlib.ServerName, error) { - return d.selectQueueServerNames(ctx, nil) + queueJSON, err := NewSQLiteQueueJSONTable(d.db) + if err != nil { + return nil, err + } + d.Database = shared.Database{ + DB: d.db, + FederationSenderJoinedHosts: joinedHosts, + FederationSenderQueuePDUs: queuePDUs, + FederationSenderQueueEDUs: queueEDUs, + FederationSenderQueueJSON: queueJSON, + FederationSenderRooms: rooms, + } + if err = d.PartitionOffsetStatements.Prepare(d.db, "federationsender"); err != nil { + return nil, err + } + return &d, nil } diff --git a/federationsender/storage/tables/interface.go b/federationsender/storage/tables/interface.go new file mode 100644 index 000000000..e4155f0c2 --- /dev/null +++ b/federationsender/storage/tables/interface.go @@ -0,0 +1,47 @@ +package tables + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/federationsender/types" + "github.com/matrix-org/gomatrixserverlib" +) + +type FederationSenderQueuePDUs interface { + InsertQueuePDU(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error + DeleteQueuePDUTransaction(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, transactionID gomatrixserverlib.TransactionID) error + SelectQueuePDUNextTransactionID(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (gomatrixserverlib.TransactionID, error) + SelectQueuePDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error) + SelectQueuePDUCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) + SelectQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, transactionID gomatrixserverlib.TransactionID, limit int) ([]int64, error) + SelectQueuePDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error) +} + +type FederationSenderQueueEDUs interface { + InsertQueueEDU(ctx context.Context, txn *sql.Tx, userID, deviceID string, serverName gomatrixserverlib.ServerName, nid int64) error + SelectQueueEDU(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) ([]int64, error) + SelectQueueEDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error) + SelectQueueEDUCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) + SelectQueueEDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error) +} + +type FederationSenderQueueJSON interface { + InsertQueueJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error) + DeleteQueueJSON(ctx context.Context, txn *sql.Tx, nids []int64) error + SelectQueueJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error) +} + +type FederationSenderJoinedHosts interface { + InsertJoinedHosts(ctx context.Context, txn *sql.Tx, roomID, eventID string, serverName gomatrixserverlib.ServerName) error + DeleteJoinedHosts(ctx context.Context, txn *sql.Tx, eventIDs []string) error + SelectJoinedHostsWithTx(ctx context.Context, txn *sql.Tx, roomID string) ([]types.JoinedHost, error) + SelectJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) + SelectAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) +} + +type FederationSenderRooms interface { + InsertRoom(ctx context.Context, txn *sql.Tx, roomID string) error + SelectRoomForUpdate(ctx context.Context, txn *sql.Tx, roomID string) (string, error) + UpdateRoom(ctx context.Context, txn *sql.Tx, roomID, lastEventID string) error +}