dendrite/federationapi/storage/postgres/queue_transactions_table.go
2022-11-22 14:28:48 -07:00

156 lines
4.9 KiB
Go

// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const queueTransactionsSchema = `
CREATE TABLE IF NOT EXISTS federationsender_queue_transactions (
-- The transaction ID that was generated before persisting the event.
transaction_id TEXT NOT NULL,
-- The destination server that we will send the event to.
server_name TEXT NOT NULL,
-- The JSON NID from the federationsender_transaction_json table.
json_nid BIGINT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_transactions_transaction_json_nid_idx
ON federationsender_queue_transactions (json_nid, server_name);
CREATE INDEX IF NOT EXISTS federationsender_queue_transactions_json_nid_idx
ON federationsender_queue_transactions (json_nid);
CREATE INDEX IF NOT EXISTS federationsender_queue_transactions_server_name_idx
ON federationsender_queue_transactions (server_name);
`
const insertQueueTransactionSQL = "" +
"INSERT INTO federationsender_queue_transactions (transaction_id, server_name, json_nid)" +
" VALUES ($1, $2, $3)"
const deleteQueueTransactionsSQL = "" +
"DELETE FROM federationsender_queue_transactions WHERE server_name = $1 AND json_nid = ANY($2)"
const selectQueueTransactionsSQL = "" +
"SELECT json_nid FROM federationsender_queue_transactions" +
" WHERE server_name = $1" +
" LIMIT $2"
const selectQueueTransactionsCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_transactions" +
" WHERE server_name = $1"
type queueTransactionsStatements struct {
db *sql.DB
insertQueueTransactionStmt *sql.Stmt
deleteQueueTransactionsStmt *sql.Stmt
selectQueueTransactionsStmt *sql.Stmt
selectQueueTransactionsCountStmt *sql.Stmt
}
func NewPostgresQueueTransactionsTable(db *sql.DB) (s *queueTransactionsStatements, err error) {
s = &queueTransactionsStatements{
db: db,
}
_, err = s.db.Exec(queueTransactionsSchema)
if err != nil {
return
}
if s.insertQueueTransactionStmt, err = s.db.Prepare(insertQueueTransactionSQL); err != nil {
return
}
if s.deleteQueueTransactionsStmt, err = s.db.Prepare(deleteQueueTransactionsSQL); err != nil {
return
}
if s.selectQueueTransactionsStmt, err = s.db.Prepare(selectQueueTransactionsSQL); err != nil {
return
}
if s.selectQueueTransactionsCountStmt, err = s.db.Prepare(selectQueueTransactionsCountSQL); err != nil {
return
}
return
}
func (s *queueTransactionsStatements) InsertQueueTransaction(
ctx context.Context,
txn *sql.Tx,
transactionID gomatrixserverlib.TransactionID,
serverName gomatrixserverlib.ServerName,
nid int64,
) error {
stmt := sqlutil.TxStmt(txn, s.insertQueueTransactionStmt)
_, err := stmt.ExecContext(
ctx,
transactionID, // the transaction ID that we initially attempted
serverName, // destination server name
nid, // JSON blob NID
)
return err
}
func (s *queueTransactionsStatements) DeleteQueueTransactions(
ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
jsonNIDs []int64,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteQueueTransactionsStmt)
_, err := stmt.ExecContext(ctx, serverName, pq.Int64Array(jsonNIDs))
return err
}
func (s *queueTransactionsStatements) SelectQueueTransactions(
ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
limit int,
) ([]int64, error) {
stmt := sqlutil.TxStmt(txn, s.selectQueueTransactionsStmt)
rows, err := stmt.QueryContext(ctx, serverName, limit)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
var result []int64
for rows.Next() {
var nid int64
if err = rows.Scan(&nid); err != nil {
return nil, err
}
result = append(result, nid)
}
return result, rows.Err()
}
func (s *queueTransactionsStatements) SelectQueueTransactionCount(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (int64, error) {
var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueueTransactionsCountStmt)
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
if err == sql.ErrNoRows {
// It's acceptable for there to be no rows referencing a given
// JSON NID but it's not an error condition. Just return as if
// there's a zero count.
return 0, nil
}
return count, err
}