Add s&f database interactions

This commit is contained in:
Devon Hudson 2022-11-22 14:28:48 -07:00
parent b9d5fd942f
commit b237f2d62d
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
7 changed files with 934 additions and 0 deletions

View file

@ -0,0 +1,155 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const queueTransactionsSchema = `
CREATE TABLE IF NOT EXISTS federationsender_queue_transactions (
-- The transaction ID that was generated before persisting the event.
transaction_id TEXT NOT NULL,
-- The destination server that we will send the event to.
server_name TEXT NOT NULL,
-- The JSON NID from the federationsender_transaction_json table.
json_nid BIGINT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_transactions_transaction_json_nid_idx
ON federationsender_queue_transactions (json_nid, server_name);
CREATE INDEX IF NOT EXISTS federationsender_queue_transactions_json_nid_idx
ON federationsender_queue_transactions (json_nid);
CREATE INDEX IF NOT EXISTS federationsender_queue_transactions_server_name_idx
ON federationsender_queue_transactions (server_name);
`
const insertQueueTransactionSQL = "" +
"INSERT INTO federationsender_queue_transactions (transaction_id, server_name, json_nid)" +
" VALUES ($1, $2, $3)"
const deleteQueueTransactionsSQL = "" +
"DELETE FROM federationsender_queue_transactions WHERE server_name = $1 AND json_nid = ANY($2)"
const selectQueueTransactionsSQL = "" +
"SELECT json_nid FROM federationsender_queue_transactions" +
" WHERE server_name = $1" +
" LIMIT $2"
const selectQueueTransactionsCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_transactions" +
" WHERE server_name = $1"
type queueTransactionsStatements struct {
db *sql.DB
insertQueueTransactionStmt *sql.Stmt
deleteQueueTransactionsStmt *sql.Stmt
selectQueueTransactionsStmt *sql.Stmt
selectQueueTransactionsCountStmt *sql.Stmt
}
func NewPostgresQueueTransactionsTable(db *sql.DB) (s *queueTransactionsStatements, err error) {
s = &queueTransactionsStatements{
db: db,
}
_, err = s.db.Exec(queueTransactionsSchema)
if err != nil {
return
}
if s.insertQueueTransactionStmt, err = s.db.Prepare(insertQueueTransactionSQL); err != nil {
return
}
if s.deleteQueueTransactionsStmt, err = s.db.Prepare(deleteQueueTransactionsSQL); err != nil {
return
}
if s.selectQueueTransactionsStmt, err = s.db.Prepare(selectQueueTransactionsSQL); err != nil {
return
}
if s.selectQueueTransactionsCountStmt, err = s.db.Prepare(selectQueueTransactionsCountSQL); err != nil {
return
}
return
}
func (s *queueTransactionsStatements) InsertQueueTransaction(
ctx context.Context,
txn *sql.Tx,
transactionID gomatrixserverlib.TransactionID,
serverName gomatrixserverlib.ServerName,
nid int64,
) error {
stmt := sqlutil.TxStmt(txn, s.insertQueueTransactionStmt)
_, err := stmt.ExecContext(
ctx,
transactionID, // the transaction ID that we initially attempted
serverName, // destination server name
nid, // JSON blob NID
)
return err
}
func (s *queueTransactionsStatements) DeleteQueueTransactions(
ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
jsonNIDs []int64,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteQueueTransactionsStmt)
_, err := stmt.ExecContext(ctx, serverName, pq.Int64Array(jsonNIDs))
return err
}
func (s *queueTransactionsStatements) SelectQueueTransactions(
ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
limit int,
) ([]int64, error) {
stmt := sqlutil.TxStmt(txn, s.selectQueueTransactionsStmt)
rows, err := stmt.QueryContext(ctx, serverName, limit)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
var result []int64
for rows.Next() {
var nid int64
if err = rows.Scan(&nid); err != nil {
return nil, err
}
result = append(result, nid)
}
return result, rows.Err()
}
func (s *queueTransactionsStatements) SelectQueueTransactionCount(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (int64, error) {
var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueueTransactionsCountStmt)
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
if err == sql.ErrNoRows {
// It's acceptable for there to be no rows referencing a given
// JSON NID but it's not an error condition. Just return as if
// there's a zero count.
return 0, nil
}
return count, err
}

View file

@ -0,0 +1,117 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
const transactionJSONSchema = `
-- The federationsender_transaction_json table contains event contents that
-- we are storing for future forwarding.
CREATE TABLE IF NOT EXISTS federationsender_transaction_json (
-- The JSON NID. This allows cross-referencing to find the JSON blob.
json_nid BIGSERIAL,
-- The JSON body. Text so that we preserve UTF-8.
json_body TEXT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_transaction_json_json_nid_idx
ON federationsender_transaction_json (json_nid);
`
const insertTransactionJSONSQL = "" +
"INSERT INTO federationsender_transaction_json (json_body)" +
" VALUES ($1)" +
" RETURNING json_nid"
const deleteTransactionJSONSQL = "" +
"DELETE FROM federationsender_transaction_json WHERE json_nid = ANY($1)"
const selectTransactionJSONSQL = "" +
"SELECT json_nid, json_body FROM federationsender_transaction_json" +
" WHERE json_nid = ANY($1)"
type transactionJSONStatements struct {
db *sql.DB
insertJSONStmt *sql.Stmt
deleteJSONStmt *sql.Stmt
selectJSONStmt *sql.Stmt
}
func NewPostgresTransactionJSONTable(db *sql.DB) (s *transactionJSONStatements, err error) {
s = &transactionJSONStatements{
db: db,
}
_, err = s.db.Exec(transactionJSONSchema)
if err != nil {
return
}
if s.insertJSONStmt, err = s.db.Prepare(insertTransactionJSONSQL); err != nil {
return
}
if s.deleteJSONStmt, err = s.db.Prepare(deleteTransactionJSONSQL); err != nil {
return
}
if s.selectJSONStmt, err = s.db.Prepare(selectTransactionJSONSQL); err != nil {
return
}
return
}
func (s *transactionJSONStatements) InsertTransactionJSON(
ctx context.Context, txn *sql.Tx, json string,
) (int64, error) {
stmt := sqlutil.TxStmt(txn, s.insertJSONStmt)
var lastid int64
if err := stmt.QueryRowContext(ctx, json).Scan(&lastid); err != nil {
return 0, err
}
return lastid, nil
}
func (s *transactionJSONStatements) DeleteTransactionJSON(
ctx context.Context, txn *sql.Tx, nids []int64,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteJSONStmt)
_, err := stmt.ExecContext(ctx, pq.Int64Array(nids))
return err
}
func (s *transactionJSONStatements) SelectTransactionJSON(
ctx context.Context, txn *sql.Tx, jsonNIDs []int64,
) (map[int64][]byte, error) {
blobs := map[int64][]byte{}
stmt := sqlutil.TxStmt(txn, s.selectJSONStmt)
rows, err := stmt.QueryContext(ctx, pq.Int64Array(jsonNIDs))
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectJSON: rows.close() failed")
for rows.Next() {
var nid int64
var blob []byte
if err = rows.Scan(&nid, &blob); err != nil {
return nil, err
}
blobs[nid] = blob
}
return blobs, err
}

View file

@ -0,0 +1,168 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const queueTransactionsSchema = `
CREATE TABLE IF NOT EXISTS federationsender_queue_transactions (
-- The transaction ID that was generated before persisting the event.
transaction_id TEXT NOT NULL,
-- The domain part of the user ID the m.room.member event is for.
server_name TEXT NOT NULL,
-- The JSON NID from the federationsender_queue_transactions_json table.
json_nid BIGINT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_transactions_transaction_json_nid_idx
ON federationsender_queue_transactions (json_nid, server_name);
CREATE INDEX IF NOT EXISTS federationsender_queue_transactions_json_nid_idx
ON federationsender_queue_transactions (json_nid);
CREATE INDEX IF NOT EXISTS federationsender_queue_transactions_server_name_idx
ON federationsender_queue_transactions (server_name);
`
const insertQueueTransactionSQL = "" +
"INSERT INTO federationsender_queue_transactions (transaction_id, server_name, json_nid)" +
" VALUES ($1, $2, $3)"
const deleteQueueTransactionsSQL = "" +
"DELETE FROM federationsender_queue_transactions WHERE server_name = $1 AND json_nid IN ($2)"
const selectQueueTransactionsSQL = "" +
"SELECT json_nid FROM federationsender_queue_transactions" +
" WHERE server_name = $1" +
" LIMIT $2"
const selectQueueTransactionsCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_transactions" +
" WHERE server_name = $1"
type queueTransactionsStatements struct {
db *sql.DB
insertQueueTransactionStmt *sql.Stmt
selectQueueTransactionsStmt *sql.Stmt
selectQueueTransactionsCountStmt *sql.Stmt
// deleteQueueTransactionsStmt *sql.Stmt - prepared at runtime due to variadic
}
func NewSQLiteQueueTransactionsTable(db *sql.DB) (s *queueTransactionsStatements, err error) {
s = &queueTransactionsStatements{
db: db,
}
_, err = db.Exec(queueTransactionsSchema)
if err != nil {
return
}
if s.insertQueueTransactionStmt, err = db.Prepare(insertQueueTransactionSQL); err != nil {
return
}
//if s.deleteQueueTransactionsStmt, err = db.Prepare(deleteQueueTransactionsSQL); err != nil {
// return
//}
if s.selectQueueTransactionsStmt, err = db.Prepare(selectQueueTransactionsSQL); err != nil {
return
}
if s.selectQueueTransactionsCountStmt, err = db.Prepare(selectQueueTransactionsCountSQL); err != nil {
return
}
return
}
func (s *queueTransactionsStatements) InsertQueueTransaction(
ctx context.Context,
txn *sql.Tx,
transactionID gomatrixserverlib.TransactionID,
serverName gomatrixserverlib.ServerName,
nid int64,
) error {
stmt := sqlutil.TxStmt(txn, s.insertQueueTransactionStmt)
_, err := stmt.ExecContext(
ctx,
transactionID, // the transaction ID that we initially attempted
serverName, // destination server name
nid, // JSON blob NID
)
return err
}
func (s *queueTransactionsStatements) DeleteQueueTransactions(
ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
jsonNIDs []int64,
) error {
deleteSQL := strings.Replace(deleteQueueTransactionsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1)
deleteStmt, err := txn.Prepare(deleteSQL)
if err != nil {
return fmt.Errorf("s.deleteQueueTransactionJSON s.db.Prepare: %w", err)
}
params := make([]interface{}, len(jsonNIDs)+1)
params[0] = serverName
for k, v := range jsonNIDs {
params[k+1] = v
}
stmt := sqlutil.TxStmt(txn, deleteStmt)
_, err = stmt.ExecContext(ctx, params...)
return err
}
func (s *queueTransactionsStatements) SelectQueueTransactions(
ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
limit int,
) ([]int64, error) {
stmt := sqlutil.TxStmt(txn, s.selectQueueTransactionsStmt)
rows, err := stmt.QueryContext(ctx, serverName, limit)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
var result []int64
for rows.Next() {
var nid int64
if err = rows.Scan(&nid); err != nil {
return nil, err
}
result = append(result, nid)
}
return result, rows.Err()
}
func (s *queueTransactionsStatements) SelectQueueTransactionCount(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (int64, error) {
var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueueTransactionsCountStmt)
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
if err == sql.ErrNoRows {
// It's acceptable for there to be no rows referencing a given
// JSON NID but it's not an error condition. Just return as if
// there's a zero count.
return 0, nil
}
return count, err
}

View file

@ -0,0 +1,137 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
const transactionJSONSchema = `
-- The federationsender_transaction_json table contains event contents that
-- we are storing for future forwarding.
CREATE TABLE IF NOT EXISTS federationsender_transaction_json (
-- The JSON NID. This allows cross-referencing to find the JSON blob.
json_nid INTEGER PRIMARY KEY AUTOINCREMENT,
-- The JSON body. Text so that we preserve UTF-8.
json_body TEXT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_transaction_json_json_nid_idx
ON federationsender_transaction_json (json_nid);
`
const insertTransactionJSONSQL = "" +
"INSERT INTO federationsender_transaction_json (json_body)" +
" VALUES ($1)"
const deleteTransactionJSONSQL = "" +
"DELETE FROM federationsender_transaction_json WHERE json_nid IN ($1)"
const selectTransactionJSONSQL = "" +
"SELECT json_nid, json_body FROM federationsender_transaction_json" +
" WHERE json_nid IN ($1)"
type transactionJSONStatements struct {
db *sql.DB
insertJSONStmt *sql.Stmt
//deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic
//selectJSONStmt *sql.Stmt - prepared at runtime due to variadic
}
func NewSQLiteTransactionJSONTable(db *sql.DB) (s *transactionJSONStatements, err error) {
s = &transactionJSONStatements{
db: db,
}
_, err = db.Exec(transactionJSONSchema)
if err != nil {
return
}
if s.insertJSONStmt, err = db.Prepare(insertTransactionJSONSQL); err != nil {
return
}
return
}
func (s *transactionJSONStatements) InsertTransactionJSON(
ctx context.Context, txn *sql.Tx, json string,
) (lastid int64, err error) {
stmt := sqlutil.TxStmt(txn, s.insertJSONStmt)
res, err := stmt.ExecContext(ctx, json)
if err != nil {
return 0, fmt.Errorf("stmt.QueryContext: %w", err)
}
lastid, err = res.LastInsertId()
if err != nil {
return 0, fmt.Errorf("res.LastInsertId: %w", err)
}
return
}
func (s *transactionJSONStatements) DeleteTransactionJSON(
ctx context.Context, txn *sql.Tx, nids []int64,
) error {
deleteSQL := strings.Replace(deleteTransactionJSONSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
deleteStmt, err := txn.Prepare(deleteSQL)
if err != nil {
return fmt.Errorf("s.deleteTransactionJSON s.db.Prepare: %w", err)
}
iNIDs := make([]interface{}, len(nids))
for k, v := range nids {
iNIDs[k] = v
}
stmt := sqlutil.TxStmt(txn, deleteStmt)
_, err = stmt.ExecContext(ctx, iNIDs...)
return err
}
func (s *transactionJSONStatements) SelectTransactionJSON(
ctx context.Context, txn *sql.Tx, jsonNIDs []int64,
) (map[int64][]byte, error) {
selectSQL := strings.Replace(selectTransactionJSONSQL, "($1)", sqlutil.QueryVariadic(len(jsonNIDs)), 1)
selectStmt, err := txn.Prepare(selectSQL)
if err != nil {
return nil, fmt.Errorf("s.selectTransactionJSON s.db.Prepare: %w", err)
}
iNIDs := make([]interface{}, len(jsonNIDs))
for k, v := range jsonNIDs {
iNIDs[k] = v
}
blobs := map[int64][]byte{}
stmt := sqlutil.TxStmt(txn, selectStmt)
rows, err := stmt.QueryContext(ctx, iNIDs...)
if err != nil {
return nil, fmt.Errorf("s.selectTransactionJSON stmt.QueryContext: %w", err)
}
defer internal.CloseAndLogIfError(ctx, rows, "selectJSON: rows.close() failed")
for rows.Next() {
var nid int64
var blob []byte
if err = rows.Scan(&nid, &blob); err != nil {
return nil, fmt.Errorf("s.selectTransactionJSON rows.Scan: %w", err)
}
blobs[nid] = blob
}
return blobs, err
}

View file

@ -51,6 +51,19 @@ type FederationQueueJSON interface {
SelectQueueJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error)
}
type FederationQueueTransactions interface {
InsertQueueTransaction(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error
DeleteQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error
SelectQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error)
SelectQueueTransactionCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error)
}
type FederationTransactionJSON interface {
InsertTransactionJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error)
DeleteTransactionJSON(ctx context.Context, txn *sql.Tx, nids []int64) error
SelectTransactionJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error)
}
type FederationJoinedHosts interface {
InsertJoinedHosts(ctx context.Context, txn *sql.Tx, roomID, eventID string, serverName gomatrixserverlib.ServerName) error
DeleteJoinedHosts(ctx context.Context, txn *sql.Tx, eventIDs []string) error

View file

@ -0,0 +1,171 @@
package tables_test
import (
"context"
"database/sql"
"fmt"
"testing"
"time"
"github.com/matrix-org/dendrite/federationapi/storage/postgres"
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
"github.com/matrix-org/dendrite/federationapi/storage/tables"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
"github.com/stretchr/testify/assert"
)
type QueueTransactionsDatabase struct {
DB *sql.DB
Writer sqlutil.Writer
Table tables.FederationQueueTransactions
}
func mustCreateQueueTransactionsTable(t *testing.T, dbType test.DBType) (database QueueTransactionsDatabase, close func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
assert.NoError(t, err)
var tab tables.FederationQueueTransactions
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresQueueTransactionsTable(db)
assert.NoError(t, err)
case test.DBTypeSQLite:
tab, err = sqlite3.NewSQLiteQueueTransactionsTable(db)
assert.NoError(t, err)
}
assert.NoError(t, err)
database = QueueTransactionsDatabase{
DB: db,
Writer: sqlutil.NewDummyWriter(),
Table: tab,
}
return database, close
}
func TestShoudInsertQueueTransaction(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateQueueTransactionsTable(t, dbType)
defer close()
transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
serverName := gomatrixserverlib.ServerName("domain")
nid := int64(1)
err := db.Table.InsertQueueTransaction(ctx, nil, transactionID, serverName, nid)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
})
}
func TestShouldRetrieveInsertedQueueTransaction(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateQueueTransactionsTable(t, dbType)
defer close()
transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
serverName := gomatrixserverlib.ServerName("domain")
nid := int64(1)
err := db.Table.InsertQueueTransaction(ctx, nil, transactionID, serverName, nid)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
retrievedNids, err := db.Table.SelectQueueTransactions(ctx, nil, serverName, 10)
if err != nil {
t.Fatalf("Failed retrieving transaction: %s", err.Error())
}
assert.Equal(t, retrievedNids[0], nid)
assert.Equal(t, len(retrievedNids), 1)
})
}
func TestShouldDeleteQueueTransaction(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateQueueTransactionsTable(t, dbType)
defer close()
transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
serverName := gomatrixserverlib.ServerName("domain")
nid := int64(1)
err := db.Table.InsertQueueTransaction(ctx, nil, transactionID, serverName, nid)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
_ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
err = db.Table.DeleteQueueTransactions(ctx, txn, serverName, []int64{nid})
return err
})
if err != nil {
t.Fatalf("Failed deleting transaction: %s", err.Error())
}
count, err := db.Table.SelectQueueTransactionCount(ctx, nil, serverName)
if err != nil {
t.Fatalf("Failed retrieving transaction count: %s", err.Error())
}
assert.Equal(t, count, int64(0))
})
}
func TestShouldDeleteOnlySpecifiedQueueTransaction(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateQueueTransactionsTable(t, dbType)
defer close()
transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
serverName := gomatrixserverlib.ServerName("domain")
nid := int64(1)
transactionID2 := gomatrixserverlib.TransactionID(fmt.Sprintf("%d2", time.Now().UnixNano()))
serverName2 := gomatrixserverlib.ServerName("domain2")
nid2 := int64(2)
transactionID3 := gomatrixserverlib.TransactionID(fmt.Sprintf("%d3", time.Now().UnixNano()))
err := db.Table.InsertQueueTransaction(ctx, nil, transactionID, serverName, nid)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
err = db.Table.InsertQueueTransaction(ctx, nil, transactionID2, serverName2, nid)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
err = db.Table.InsertQueueTransaction(ctx, nil, transactionID3, serverName, nid2)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
_ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
err = db.Table.DeleteQueueTransactions(ctx, txn, serverName, []int64{nid})
return err
})
if err != nil {
t.Fatalf("Failed deleting transaction: %s", err.Error())
}
count, err := db.Table.SelectQueueTransactionCount(ctx, nil, serverName)
if err != nil {
t.Fatalf("Failed retrieving transaction count: %s", err.Error())
}
assert.Equal(t, count, int64(1))
count, err = db.Table.SelectQueueTransactionCount(ctx, nil, serverName2)
if err != nil {
t.Fatalf("Failed retrieving transaction count: %s", err.Error())
}
assert.Equal(t, count, int64(1))
})
}

View file

@ -0,0 +1,173 @@
package tables_test
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"testing"
"time"
"github.com/matrix-org/dendrite/federationapi/storage/postgres"
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
"github.com/matrix-org/dendrite/federationapi/storage/tables"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
"github.com/stretchr/testify/assert"
)
const (
testOrigin = gomatrixserverlib.ServerName("kaer.morhen")
testDestination = gomatrixserverlib.ServerName("white.orchard")
)
func mustCreateTransaction(userID gomatrixserverlib.UserID) gomatrixserverlib.Transaction {
txn := gomatrixserverlib.Transaction{}
txn.PDUs = []json.RawMessage{
[]byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`),
}
txn.Origin = testOrigin
txn.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
txn.Destination = testDestination
return txn
}
type TransactionJSONDatabase struct {
DB *sql.DB
Writer sqlutil.Writer
Table tables.FederationTransactionJSON
}
func mustCreateTransactionJSONTable(t *testing.T, dbType test.DBType) (database TransactionJSONDatabase, close func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
assert.NoError(t, err)
var tab tables.FederationTransactionJSON
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresTransactionJSONTable(db)
assert.NoError(t, err)
case test.DBTypeSQLite:
tab, err = sqlite3.NewSQLiteTransactionJSONTable(db)
assert.NoError(t, err)
}
assert.NoError(t, err)
database = TransactionJSONDatabase{
DB: db,
Writer: sqlutil.NewDummyWriter(),
Table: tab,
}
return database, close
}
func TestShoudInsertTransaction(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateTransactionJSONTable(t, dbType)
defer close()
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
if err != nil {
t.Fatalf("Invalid userID: %s", err.Error())
}
transaction := mustCreateTransaction(*userID)
tx, err := json.Marshal(transaction)
if err != nil {
t.Fatalf("Invalid transaction: %s", err.Error())
}
_, err = db.Table.InsertTransactionJSON(ctx, nil, string(tx))
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
})
}
func TestShouldRetrieveInsertedTransaction(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateTransactionJSONTable(t, dbType)
defer close()
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
if err != nil {
t.Fatalf("Invalid userID: %s", err.Error())
}
transaction := mustCreateTransaction(*userID)
tx, err := json.Marshal(transaction)
if err != nil {
t.Fatalf("Invalid transaction: %s", err.Error())
}
nid, err := db.Table.InsertTransactionJSON(ctx, nil, string(tx))
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
var storedJSON map[int64][]byte
_ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
storedJSON, err = db.Table.SelectTransactionJSON(ctx, txn, []int64{nid})
return err
})
if err != nil {
t.Fatalf("Failed retrieving transaction: %s", err.Error())
}
assert.Equal(t, len(storedJSON), 1)
var storedTx gomatrixserverlib.Transaction
json.Unmarshal(storedJSON[1], &storedTx)
assert.Equal(t, transaction, storedTx)
})
}
func TestShouldDeleteTransaction(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateTransactionJSONTable(t, dbType)
defer close()
userID, err := gomatrixserverlib.NewUserID("@local:domain", false)
if err != nil {
t.Fatalf("Invalid userID: %s", err.Error())
}
transaction := mustCreateTransaction(*userID)
tx, err := json.Marshal(transaction)
if err != nil {
t.Fatalf("Invalid transaction: %s", err.Error())
}
nid, err := db.Table.InsertTransactionJSON(ctx, nil, string(tx))
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
storedJSON := map[int64][]byte{}
_ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
err = db.Table.DeleteTransactionJSON(ctx, txn, []int64{nid})
return err
})
if err != nil {
t.Fatalf("Failed deleting transaction: %s", err.Error())
}
storedJSON = map[int64][]byte{}
_ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error {
storedJSON, err = db.Table.SelectTransactionJSON(ctx, txn, []int64{nid})
return err
})
if err != nil {
t.Fatalf("Failed retrieving transaction: %s", err.Error())
}
assert.Equal(t, len(storedJSON), 0)
})
}