From b237f2d62d94d10dcb42e6a121f0a72b7fcc652e Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Tue, 22 Nov 2022 14:28:48 -0700 Subject: [PATCH] Add s&f database interactions --- .../postgres/queue_transactions_table.go | 155 ++++++++++++++++ .../postgres/transaction_json_table.go | 117 ++++++++++++ .../sqlite3/queue_transactions_table.go | 168 +++++++++++++++++ .../storage/sqlite3/transaction_json_table.go | 137 ++++++++++++++ federationapi/storage/tables/interface.go | 13 ++ .../tables/queue_transactions_table_test.go | 171 +++++++++++++++++ .../tables/transaction_json_table_test.go | 173 ++++++++++++++++++ 7 files changed, 934 insertions(+) create mode 100644 federationapi/storage/postgres/queue_transactions_table.go create mode 100644 federationapi/storage/postgres/transaction_json_table.go create mode 100644 federationapi/storage/sqlite3/queue_transactions_table.go create mode 100644 federationapi/storage/sqlite3/transaction_json_table.go create mode 100644 federationapi/storage/tables/queue_transactions_table_test.go create mode 100644 federationapi/storage/tables/transaction_json_table_test.go diff --git a/federationapi/storage/postgres/queue_transactions_table.go b/federationapi/storage/postgres/queue_transactions_table.go new file mode 100644 index 000000000..fe67ef9cd --- /dev/null +++ b/federationapi/storage/postgres/queue_transactions_table.go @@ -0,0 +1,155 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const queueTransactionsSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_queue_transactions ( + -- The transaction ID that was generated before persisting the event. + transaction_id TEXT NOT NULL, + -- The destination server that we will send the event to. + server_name TEXT NOT NULL, + -- The JSON NID from the federationsender_transaction_json table. + json_nid BIGINT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_transactions_transaction_json_nid_idx + ON federationsender_queue_transactions (json_nid, server_name); +CREATE INDEX IF NOT EXISTS federationsender_queue_transactions_json_nid_idx + ON federationsender_queue_transactions (json_nid); +CREATE INDEX IF NOT EXISTS federationsender_queue_transactions_server_name_idx + ON federationsender_queue_transactions (server_name); +` + +const insertQueueTransactionSQL = "" + + "INSERT INTO federationsender_queue_transactions (transaction_id, server_name, json_nid)" + + " VALUES ($1, $2, $3)" + +const deleteQueueTransactionsSQL = "" + + "DELETE FROM federationsender_queue_transactions WHERE server_name = $1 AND json_nid = ANY($2)" + +const selectQueueTransactionsSQL = "" + + "SELECT json_nid FROM federationsender_queue_transactions" + + " WHERE server_name = $1" + + " LIMIT $2" + +const selectQueueTransactionsCountSQL = "" + + "SELECT COUNT(*) FROM federationsender_queue_transactions" + + " WHERE server_name = $1" + +type queueTransactionsStatements struct { + db *sql.DB + insertQueueTransactionStmt *sql.Stmt + deleteQueueTransactionsStmt *sql.Stmt + selectQueueTransactionsStmt *sql.Stmt + selectQueueTransactionsCountStmt *sql.Stmt +} + +func NewPostgresQueueTransactionsTable(db *sql.DB) (s *queueTransactionsStatements, err error) { + s = &queueTransactionsStatements{ + db: db, + } + _, err = s.db.Exec(queueTransactionsSchema) + if err != nil { + return + } + if s.insertQueueTransactionStmt, err = s.db.Prepare(insertQueueTransactionSQL); err != nil { + return + } + if s.deleteQueueTransactionsStmt, err = s.db.Prepare(deleteQueueTransactionsSQL); err != nil { + return + } + if s.selectQueueTransactionsStmt, err = s.db.Prepare(selectQueueTransactionsSQL); err != nil { + return + } + if s.selectQueueTransactionsCountStmt, err = s.db.Prepare(selectQueueTransactionsCountSQL); err != nil { + return + } + return +} + +func (s *queueTransactionsStatements) InsertQueueTransaction( + ctx context.Context, + txn *sql.Tx, + transactionID gomatrixserverlib.TransactionID, + serverName gomatrixserverlib.ServerName, + nid int64, +) error { + stmt := sqlutil.TxStmt(txn, s.insertQueueTransactionStmt) + _, err := stmt.ExecContext( + ctx, + transactionID, // the transaction ID that we initially attempted + serverName, // destination server name + nid, // JSON blob NID + ) + return err +} + +func (s *queueTransactionsStatements) DeleteQueueTransactions( + ctx context.Context, txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + jsonNIDs []int64, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteQueueTransactionsStmt) + _, err := stmt.ExecContext(ctx, serverName, pq.Int64Array(jsonNIDs)) + return err +} + +func (s *queueTransactionsStatements) SelectQueueTransactions( + ctx context.Context, txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + limit int, +) ([]int64, error) { + stmt := sqlutil.TxStmt(txn, s.selectQueueTransactionsStmt) + rows, err := stmt.QueryContext(ctx, serverName, limit) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") + var result []int64 + for rows.Next() { + var nid int64 + if err = rows.Scan(&nid); err != nil { + return nil, err + } + result = append(result, nid) + } + + return result, rows.Err() +} + +func (s *queueTransactionsStatements) SelectQueueTransactionCount( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) (int64, error) { + var count int64 + stmt := sqlutil.TxStmt(txn, s.selectQueueTransactionsCountStmt) + err := stmt.QueryRowContext(ctx, serverName).Scan(&count) + if err == sql.ErrNoRows { + // It's acceptable for there to be no rows referencing a given + // JSON NID but it's not an error condition. Just return as if + // there's a zero count. + return 0, nil + } + return count, err +} diff --git a/federationapi/storage/postgres/transaction_json_table.go b/federationapi/storage/postgres/transaction_json_table.go new file mode 100644 index 000000000..507120edb --- /dev/null +++ b/federationapi/storage/postgres/transaction_json_table.go @@ -0,0 +1,117 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +const transactionJSONSchema = ` +-- The federationsender_transaction_json table contains event contents that +-- we are storing for future forwarding. +CREATE TABLE IF NOT EXISTS federationsender_transaction_json ( + -- The JSON NID. This allows cross-referencing to find the JSON blob. + json_nid BIGSERIAL, + -- The JSON body. Text so that we preserve UTF-8. + json_body TEXT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS federationsender_transaction_json_json_nid_idx + ON federationsender_transaction_json (json_nid); +` + +const insertTransactionJSONSQL = "" + + "INSERT INTO federationsender_transaction_json (json_body)" + + " VALUES ($1)" + + " RETURNING json_nid" + +const deleteTransactionJSONSQL = "" + + "DELETE FROM federationsender_transaction_json WHERE json_nid = ANY($1)" + +const selectTransactionJSONSQL = "" + + "SELECT json_nid, json_body FROM federationsender_transaction_json" + + " WHERE json_nid = ANY($1)" + +type transactionJSONStatements struct { + db *sql.DB + insertJSONStmt *sql.Stmt + deleteJSONStmt *sql.Stmt + selectJSONStmt *sql.Stmt +} + +func NewPostgresTransactionJSONTable(db *sql.DB) (s *transactionJSONStatements, err error) { + s = &transactionJSONStatements{ + db: db, + } + _, err = s.db.Exec(transactionJSONSchema) + if err != nil { + return + } + if s.insertJSONStmt, err = s.db.Prepare(insertTransactionJSONSQL); err != nil { + return + } + if s.deleteJSONStmt, err = s.db.Prepare(deleteTransactionJSONSQL); err != nil { + return + } + if s.selectJSONStmt, err = s.db.Prepare(selectTransactionJSONSQL); err != nil { + return + } + return +} + +func (s *transactionJSONStatements) InsertTransactionJSON( + ctx context.Context, txn *sql.Tx, json string, +) (int64, error) { + stmt := sqlutil.TxStmt(txn, s.insertJSONStmt) + var lastid int64 + if err := stmt.QueryRowContext(ctx, json).Scan(&lastid); err != nil { + return 0, err + } + return lastid, nil +} + +func (s *transactionJSONStatements) DeleteTransactionJSON( + ctx context.Context, txn *sql.Tx, nids []int64, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteJSONStmt) + _, err := stmt.ExecContext(ctx, pq.Int64Array(nids)) + return err +} + +func (s *transactionJSONStatements) SelectTransactionJSON( + ctx context.Context, txn *sql.Tx, jsonNIDs []int64, +) (map[int64][]byte, error) { + blobs := map[int64][]byte{} + stmt := sqlutil.TxStmt(txn, s.selectJSONStmt) + rows, err := stmt.QueryContext(ctx, pq.Int64Array(jsonNIDs)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectJSON: rows.close() failed") + for rows.Next() { + var nid int64 + var blob []byte + if err = rows.Scan(&nid, &blob); err != nil { + return nil, err + } + blobs[nid] = blob + } + return blobs, err +} diff --git a/federationapi/storage/sqlite3/queue_transactions_table.go b/federationapi/storage/sqlite3/queue_transactions_table.go new file mode 100644 index 000000000..e616abe78 --- /dev/null +++ b/federationapi/storage/sqlite3/queue_transactions_table.go @@ -0,0 +1,168 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const queueTransactionsSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_queue_transactions ( + -- The transaction ID that was generated before persisting the event. + transaction_id TEXT NOT NULL, + -- The domain part of the user ID the m.room.member event is for. + server_name TEXT NOT NULL, + -- The JSON NID from the federationsender_queue_transactions_json table. + json_nid BIGINT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_transactions_transaction_json_nid_idx + ON federationsender_queue_transactions (json_nid, server_name); +CREATE INDEX IF NOT EXISTS federationsender_queue_transactions_json_nid_idx + ON federationsender_queue_transactions (json_nid); +CREATE INDEX IF NOT EXISTS federationsender_queue_transactions_server_name_idx + ON federationsender_queue_transactions (server_name); +` + +const insertQueueTransactionSQL = "" + + "INSERT INTO federationsender_queue_transactions (transaction_id, server_name, json_nid)" + + " VALUES ($1, $2, $3)" + +const deleteQueueTransactionsSQL = "" + + "DELETE FROM federationsender_queue_transactions WHERE server_name = $1 AND json_nid IN ($2)" + +const selectQueueTransactionsSQL = "" + + "SELECT json_nid FROM federationsender_queue_transactions" + + " WHERE server_name = $1" + + " LIMIT $2" + +const selectQueueTransactionsCountSQL = "" + + "SELECT COUNT(*) FROM federationsender_queue_transactions" + + " WHERE server_name = $1" + +type queueTransactionsStatements struct { + db *sql.DB + insertQueueTransactionStmt *sql.Stmt + selectQueueTransactionsStmt *sql.Stmt + selectQueueTransactionsCountStmt *sql.Stmt + // deleteQueueTransactionsStmt *sql.Stmt - prepared at runtime due to variadic +} + +func NewSQLiteQueueTransactionsTable(db *sql.DB) (s *queueTransactionsStatements, err error) { + s = &queueTransactionsStatements{ + db: db, + } + _, err = db.Exec(queueTransactionsSchema) + if err != nil { + return + } + if s.insertQueueTransactionStmt, err = db.Prepare(insertQueueTransactionSQL); err != nil { + return + } + //if s.deleteQueueTransactionsStmt, err = db.Prepare(deleteQueueTransactionsSQL); err != nil { + // return + //} + if s.selectQueueTransactionsStmt, err = db.Prepare(selectQueueTransactionsSQL); err != nil { + return + } + if s.selectQueueTransactionsCountStmt, err = db.Prepare(selectQueueTransactionsCountSQL); err != nil { + return + } + return +} + +func (s *queueTransactionsStatements) InsertQueueTransaction( + ctx context.Context, + txn *sql.Tx, + transactionID gomatrixserverlib.TransactionID, + serverName gomatrixserverlib.ServerName, + nid int64, +) error { + stmt := sqlutil.TxStmt(txn, s.insertQueueTransactionStmt) + _, err := stmt.ExecContext( + ctx, + transactionID, // the transaction ID that we initially attempted + serverName, // destination server name + nid, // JSON blob NID + ) + return err +} + +func (s *queueTransactionsStatements) DeleteQueueTransactions( + ctx context.Context, txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + jsonNIDs []int64, +) error { + deleteSQL := strings.Replace(deleteQueueTransactionsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1) + deleteStmt, err := txn.Prepare(deleteSQL) + if err != nil { + return fmt.Errorf("s.deleteQueueTransactionJSON s.db.Prepare: %w", err) + } + + params := make([]interface{}, len(jsonNIDs)+1) + params[0] = serverName + for k, v := range jsonNIDs { + params[k+1] = v + } + + stmt := sqlutil.TxStmt(txn, deleteStmt) + _, err = stmt.ExecContext(ctx, params...) + return err +} + +func (s *queueTransactionsStatements) SelectQueueTransactions( + ctx context.Context, txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + limit int, +) ([]int64, error) { + stmt := sqlutil.TxStmt(txn, s.selectQueueTransactionsStmt) + rows, err := stmt.QueryContext(ctx, serverName, limit) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") + var result []int64 + for rows.Next() { + var nid int64 + if err = rows.Scan(&nid); err != nil { + return nil, err + } + result = append(result, nid) + } + + return result, rows.Err() +} + +func (s *queueTransactionsStatements) SelectQueueTransactionCount( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) (int64, error) { + var count int64 + stmt := sqlutil.TxStmt(txn, s.selectQueueTransactionsCountStmt) + err := stmt.QueryRowContext(ctx, serverName).Scan(&count) + if err == sql.ErrNoRows { + // It's acceptable for there to be no rows referencing a given + // JSON NID but it's not an error condition. Just return as if + // there's a zero count. + return 0, nil + } + return count, err +} diff --git a/federationapi/storage/sqlite3/transaction_json_table.go b/federationapi/storage/sqlite3/transaction_json_table.go new file mode 100644 index 000000000..30ad297ac --- /dev/null +++ b/federationapi/storage/sqlite3/transaction_json_table.go @@ -0,0 +1,137 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +const transactionJSONSchema = ` +-- The federationsender_transaction_json table contains event contents that +-- we are storing for future forwarding. +CREATE TABLE IF NOT EXISTS federationsender_transaction_json ( + -- The JSON NID. This allows cross-referencing to find the JSON blob. + json_nid INTEGER PRIMARY KEY AUTOINCREMENT, + -- The JSON body. Text so that we preserve UTF-8. + json_body TEXT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS federationsender_transaction_json_json_nid_idx + ON federationsender_transaction_json (json_nid); +` + +const insertTransactionJSONSQL = "" + + "INSERT INTO federationsender_transaction_json (json_body)" + + " VALUES ($1)" + +const deleteTransactionJSONSQL = "" + + "DELETE FROM federationsender_transaction_json WHERE json_nid IN ($1)" + +const selectTransactionJSONSQL = "" + + "SELECT json_nid, json_body FROM federationsender_transaction_json" + + " WHERE json_nid IN ($1)" + +type transactionJSONStatements struct { + db *sql.DB + insertJSONStmt *sql.Stmt + //deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic + //selectJSONStmt *sql.Stmt - prepared at runtime due to variadic +} + +func NewSQLiteTransactionJSONTable(db *sql.DB) (s *transactionJSONStatements, err error) { + s = &transactionJSONStatements{ + db: db, + } + _, err = db.Exec(transactionJSONSchema) + if err != nil { + return + } + if s.insertJSONStmt, err = db.Prepare(insertTransactionJSONSQL); err != nil { + return + } + return +} + +func (s *transactionJSONStatements) InsertTransactionJSON( + ctx context.Context, txn *sql.Tx, json string, +) (lastid int64, err error) { + stmt := sqlutil.TxStmt(txn, s.insertJSONStmt) + res, err := stmt.ExecContext(ctx, json) + if err != nil { + return 0, fmt.Errorf("stmt.QueryContext: %w", err) + } + lastid, err = res.LastInsertId() + if err != nil { + return 0, fmt.Errorf("res.LastInsertId: %w", err) + } + return +} + +func (s *transactionJSONStatements) DeleteTransactionJSON( + ctx context.Context, txn *sql.Tx, nids []int64, +) error { + deleteSQL := strings.Replace(deleteTransactionJSONSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) + deleteStmt, err := txn.Prepare(deleteSQL) + if err != nil { + return fmt.Errorf("s.deleteTransactionJSON s.db.Prepare: %w", err) + } + + iNIDs := make([]interface{}, len(nids)) + for k, v := range nids { + iNIDs[k] = v + } + + stmt := sqlutil.TxStmt(txn, deleteStmt) + _, err = stmt.ExecContext(ctx, iNIDs...) + return err +} + +func (s *transactionJSONStatements) SelectTransactionJSON( + ctx context.Context, txn *sql.Tx, jsonNIDs []int64, +) (map[int64][]byte, error) { + selectSQL := strings.Replace(selectTransactionJSONSQL, "($1)", sqlutil.QueryVariadic(len(jsonNIDs)), 1) + selectStmt, err := txn.Prepare(selectSQL) + if err != nil { + return nil, fmt.Errorf("s.selectTransactionJSON s.db.Prepare: %w", err) + } + + iNIDs := make([]interface{}, len(jsonNIDs)) + for k, v := range jsonNIDs { + iNIDs[k] = v + } + + blobs := map[int64][]byte{} + stmt := sqlutil.TxStmt(txn, selectStmt) + rows, err := stmt.QueryContext(ctx, iNIDs...) + if err != nil { + return nil, fmt.Errorf("s.selectTransactionJSON stmt.QueryContext: %w", err) + } + defer internal.CloseAndLogIfError(ctx, rows, "selectJSON: rows.close() failed") + for rows.Next() { + var nid int64 + var blob []byte + if err = rows.Scan(&nid, &blob); err != nil { + return nil, fmt.Errorf("s.selectTransactionJSON rows.Scan: %w", err) + } + blobs[nid] = blob + } + return blobs, err +} diff --git a/federationapi/storage/tables/interface.go b/federationapi/storage/tables/interface.go index 3c116a1d0..37c7bb299 100644 --- a/federationapi/storage/tables/interface.go +++ b/federationapi/storage/tables/interface.go @@ -51,6 +51,19 @@ type FederationQueueJSON interface { SelectQueueJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error) } +type FederationQueueTransactions interface { + InsertQueueTransaction(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error + DeleteQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error + SelectQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) + SelectQueueTransactionCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) +} + +type FederationTransactionJSON interface { + InsertTransactionJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error) + DeleteTransactionJSON(ctx context.Context, txn *sql.Tx, nids []int64) error + SelectTransactionJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error) +} + type FederationJoinedHosts interface { InsertJoinedHosts(ctx context.Context, txn *sql.Tx, roomID, eventID string, serverName gomatrixserverlib.ServerName) error DeleteJoinedHosts(ctx context.Context, txn *sql.Tx, eventIDs []string) error diff --git a/federationapi/storage/tables/queue_transactions_table_test.go b/federationapi/storage/tables/queue_transactions_table_test.go new file mode 100644 index 000000000..46d8a3bf3 --- /dev/null +++ b/federationapi/storage/tables/queue_transactions_table_test.go @@ -0,0 +1,171 @@ +package tables_test + +import ( + "context" + "database/sql" + "fmt" + "testing" + "time" + + "github.com/matrix-org/dendrite/federationapi/storage/postgres" + "github.com/matrix-org/dendrite/federationapi/storage/sqlite3" + "github.com/matrix-org/dendrite/federationapi/storage/tables" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +type QueueTransactionsDatabase struct { + DB *sql.DB + Writer sqlutil.Writer + Table tables.FederationQueueTransactions +} + +func mustCreateQueueTransactionsTable(t *testing.T, dbType test.DBType) (database QueueTransactionsDatabase, close func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + var tab tables.FederationQueueTransactions + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresQueueTransactionsTable(db) + assert.NoError(t, err) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSQLiteQueueTransactionsTable(db) + assert.NoError(t, err) + } + assert.NoError(t, err) + + database = QueueTransactionsDatabase{ + DB: db, + Writer: sqlutil.NewDummyWriter(), + Table: tab, + } + return database, close +} + +func TestShoudInsertQueueTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueTransactionsTable(t, dbType) + defer close() + + transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + serverName := gomatrixserverlib.ServerName("domain") + nid := int64(1) + err := db.Table.InsertQueueTransaction(ctx, nil, transactionID, serverName, nid) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + }) +} + +func TestShouldRetrieveInsertedQueueTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueTransactionsTable(t, dbType) + defer close() + + transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + serverName := gomatrixserverlib.ServerName("domain") + nid := int64(1) + + err := db.Table.InsertQueueTransaction(ctx, nil, transactionID, serverName, nid) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + retrievedNids, err := db.Table.SelectQueueTransactions(ctx, nil, serverName, 10) + if err != nil { + t.Fatalf("Failed retrieving transaction: %s", err.Error()) + } + + assert.Equal(t, retrievedNids[0], nid) + assert.Equal(t, len(retrievedNids), 1) + }) +} + +func TestShouldDeleteQueueTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueTransactionsTable(t, dbType) + defer close() + + transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + serverName := gomatrixserverlib.ServerName("domain") + nid := int64(1) + + err := db.Table.InsertQueueTransaction(ctx, nil, transactionID, serverName, nid) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { + err = db.Table.DeleteQueueTransactions(ctx, txn, serverName, []int64{nid}) + return err + }) + if err != nil { + t.Fatalf("Failed deleting transaction: %s", err.Error()) + } + + count, err := db.Table.SelectQueueTransactionCount(ctx, nil, serverName) + if err != nil { + t.Fatalf("Failed retrieving transaction count: %s", err.Error()) + } + assert.Equal(t, count, int64(0)) + }) +} + +func TestShouldDeleteOnlySpecifiedQueueTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueTransactionsTable(t, dbType) + defer close() + + transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + serverName := gomatrixserverlib.ServerName("domain") + nid := int64(1) + transactionID2 := gomatrixserverlib.TransactionID(fmt.Sprintf("%d2", time.Now().UnixNano())) + serverName2 := gomatrixserverlib.ServerName("domain2") + nid2 := int64(2) + transactionID3 := gomatrixserverlib.TransactionID(fmt.Sprintf("%d3", time.Now().UnixNano())) + + err := db.Table.InsertQueueTransaction(ctx, nil, transactionID, serverName, nid) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + err = db.Table.InsertQueueTransaction(ctx, nil, transactionID2, serverName2, nid) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + err = db.Table.InsertQueueTransaction(ctx, nil, transactionID3, serverName, nid2) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { + err = db.Table.DeleteQueueTransactions(ctx, txn, serverName, []int64{nid}) + return err + }) + if err != nil { + t.Fatalf("Failed deleting transaction: %s", err.Error()) + } + + count, err := db.Table.SelectQueueTransactionCount(ctx, nil, serverName) + if err != nil { + t.Fatalf("Failed retrieving transaction count: %s", err.Error()) + } + assert.Equal(t, count, int64(1)) + + count, err = db.Table.SelectQueueTransactionCount(ctx, nil, serverName2) + if err != nil { + t.Fatalf("Failed retrieving transaction count: %s", err.Error()) + } + assert.Equal(t, count, int64(1)) + }) +} diff --git a/federationapi/storage/tables/transaction_json_table_test.go b/federationapi/storage/tables/transaction_json_table_test.go new file mode 100644 index 000000000..6ebeff508 --- /dev/null +++ b/federationapi/storage/tables/transaction_json_table_test.go @@ -0,0 +1,173 @@ +package tables_test + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "testing" + "time" + + "github.com/matrix-org/dendrite/federationapi/storage/postgres" + "github.com/matrix-org/dendrite/federationapi/storage/sqlite3" + "github.com/matrix-org/dendrite/federationapi/storage/tables" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +const ( + testOrigin = gomatrixserverlib.ServerName("kaer.morhen") + testDestination = gomatrixserverlib.ServerName("white.orchard") +) + +func mustCreateTransaction(userID gomatrixserverlib.UserID) gomatrixserverlib.Transaction { + txn := gomatrixserverlib.Transaction{} + txn.PDUs = []json.RawMessage{ + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`), + } + txn.Origin = testOrigin + txn.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + txn.Destination = testDestination + + return txn +} + +type TransactionJSONDatabase struct { + DB *sql.DB + Writer sqlutil.Writer + Table tables.FederationTransactionJSON +} + +func mustCreateTransactionJSONTable(t *testing.T, dbType test.DBType) (database TransactionJSONDatabase, close func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + var tab tables.FederationTransactionJSON + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresTransactionJSONTable(db) + assert.NoError(t, err) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSQLiteTransactionJSONTable(db) + assert.NoError(t, err) + } + assert.NoError(t, err) + + database = TransactionJSONDatabase{ + DB: db, + Writer: sqlutil.NewDummyWriter(), + Table: tab, + } + return database, close +} + +func TestShoudInsertTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateTransactionJSONTable(t, dbType) + defer close() + + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + if err != nil { + t.Fatalf("Invalid userID: %s", err.Error()) + } + transaction := mustCreateTransaction(*userID) + tx, err := json.Marshal(transaction) + if err != nil { + t.Fatalf("Invalid transaction: %s", err.Error()) + } + + _, err = db.Table.InsertTransactionJSON(ctx, nil, string(tx)) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + }) +} + +func TestShouldRetrieveInsertedTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateTransactionJSONTable(t, dbType) + defer close() + + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + if err != nil { + t.Fatalf("Invalid userID: %s", err.Error()) + } + transaction := mustCreateTransaction(*userID) + tx, err := json.Marshal(transaction) + if err != nil { + t.Fatalf("Invalid transaction: %s", err.Error()) + } + + nid, err := db.Table.InsertTransactionJSON(ctx, nil, string(tx)) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + var storedJSON map[int64][]byte + _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { + storedJSON, err = db.Table.SelectTransactionJSON(ctx, txn, []int64{nid}) + return err + }) + if err != nil { + t.Fatalf("Failed retrieving transaction: %s", err.Error()) + } + + assert.Equal(t, len(storedJSON), 1) + + var storedTx gomatrixserverlib.Transaction + json.Unmarshal(storedJSON[1], &storedTx) + + assert.Equal(t, transaction, storedTx) + }) +} + +func TestShouldDeleteTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateTransactionJSONTable(t, dbType) + defer close() + + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + if err != nil { + t.Fatalf("Invalid userID: %s", err.Error()) + } + transaction := mustCreateTransaction(*userID) + tx, err := json.Marshal(transaction) + if err != nil { + t.Fatalf("Invalid transaction: %s", err.Error()) + } + + nid, err := db.Table.InsertTransactionJSON(ctx, nil, string(tx)) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + storedJSON := map[int64][]byte{} + _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { + err = db.Table.DeleteTransactionJSON(ctx, txn, []int64{nid}) + return err + }) + if err != nil { + t.Fatalf("Failed deleting transaction: %s", err.Error()) + } + + storedJSON = map[int64][]byte{} + _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { + storedJSON, err = db.Table.SelectTransactionJSON(ctx, txn, []int64{nid}) + return err + }) + if err != nil { + t.Fatalf("Failed retrieving transaction: %s", err.Error()) + } + + assert.Equal(t, len(storedJSON), 0) + }) +}