- Update the FederationSender Config to use CosmosDB (#9)

- Implement the tables to use Cosmos
- Update the Storage to use Cosmos
This commit is contained in:
alexfca 2021-05-28 15:00:15 +10:00 committed by GitHub
parent 3ca96b13b3
commit db08aa6250
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 1782 additions and 665 deletions

View file

@ -202,7 +202,7 @@ federation_sender:
listen: http://localhost:7775
connect: http://localhost:7775
database:
connection_string: file:federationsender.db
connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=test.criticalarc.com;"
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1

View file

@ -17,54 +17,90 @@ package cosmosdb
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const blacklistSchema = `
CREATE TABLE IF NOT EXISTS federationsender_blacklist (
-- The blacklisted server name
server_name TEXT NOT NULL,
UNIQUE (server_name)
);
`
// const blacklistSchema = `
// CREATE TABLE IF NOT EXISTS federationsender_blacklist (
// -- The blacklisted server name
// server_name TEXT NOT NULL,
// UNIQUE (server_name)
// );
// `
const insertBlacklistSQL = "" +
"INSERT INTO federationsender_blacklist (server_name) VALUES ($1)" +
" ON CONFLICT DO NOTHING"
const selectBlacklistSQL = "" +
"SELECT server_name FROM federationsender_blacklist WHERE server_name = $1"
const deleteBlacklistSQL = "" +
"DELETE FROM federationsender_blacklist WHERE server_name = $1"
type blacklistStatements struct {
db *sql.DB
insertBlacklistStmt *sql.Stmt
selectBlacklistStmt *sql.Stmt
deleteBlacklistStmt *sql.Stmt
type BlacklistCosmos struct {
ServerName string `json:"server_name"`
}
func NewSQLiteBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) {
type BlacklistCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
Blacklist BlacklistCosmos `json:"mx_federationsender_blacklist"`
}
// const insertBlacklistSQL = "" +
// "INSERT INTO federationsender_blacklist (server_name) VALUES ($1)" +
// " ON CONFLICT DO NOTHING"
// const selectBlacklistSQL = "" +
// "SELECT server_name FROM federationsender_blacklist WHERE server_name = $1"
// const deleteBlacklistSQL = "" +
// "DELETE FROM federationsender_blacklist WHERE server_name = $1"
type blacklistStatements struct {
db *Database
// insertBlacklistStmt *sql.Stmt
// selectBlacklistStmt *sql.Stmt
// deleteBlacklistStmt *sql.Stmt
tableName string
}
func getBlacklist(s *blacklistStatements, ctx context.Context, pk string, docId string) (*BlacklistCosmosData, error) {
response := BlacklistCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
ctx,
pk,
docId,
&response)
if response.Id == "" {
return nil, nil
}
return &response, err
}
func deleteBlacklist(s *blacklistStatements, ctx context.Context, dbData BlacklistCosmosData) error {
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData.Id,
options)
if err != nil {
return err
}
return err
}
func NewCosmosDBBlacklistTable(db *Database) (s *blacklistStatements, err error) {
s = &blacklistStatements{
db: db,
}
_, err = db.Exec(blacklistSchema)
if err != nil {
return
}
if s.insertBlacklistStmt, err = db.Prepare(insertBlacklistSQL); err != nil {
return
}
if s.selectBlacklistStmt, err = db.Prepare(selectBlacklistSQL); err != nil {
return
}
if s.deleteBlacklistStmt, err = db.Prepare(deleteBlacklistSQL); err != nil {
return
}
s.tableName = "blacklists"
return
}
@ -73,8 +109,40 @@ func NewSQLiteBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) {
func (s *blacklistStatements) InsertBlacklist(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt)
_, err := stmt.ExecContext(ctx, serverName)
// "INSERT INTO federationsender_blacklist (server_name) VALUES ($1)" +
// " ON CONFLICT DO NOTHING"
// stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (server_name)
docId := fmt.Sprintf("%s", serverName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
data := BlacklistCosmos{
ServerName: string(serverName),
}
dbData := &BlacklistCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
Blacklist: data,
}
// _, err := stmt.ExecContext(ctx, serverName)
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
_, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
return err
}
@ -84,16 +152,24 @@ func (s *blacklistStatements) InsertBlacklist(
func (s *blacklistStatements) SelectBlacklist(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (bool, error) {
stmt := sqlutil.TxStmt(txn, s.selectBlacklistStmt)
res, err := stmt.QueryContext(ctx, serverName)
// "SELECT server_name FROM federationsender_blacklist WHERE server_name = $1"
// stmt := sqlutil.TxStmt(txn, s.selectBlacklistStmt)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (server_name)
docId := fmt.Sprintf("%s", serverName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
// res, err := stmt.QueryContext(ctx, serverName)
res, err := getBlacklist(s, ctx, pk, cosmosDocId)
if err != nil {
return false, err
}
defer res.Close() // nolint:errcheck
// The query will return the server name if the server is blacklisted, and
// will return no rows if not. By calling Next, we find out if a row was
// returned or not - we don't care about the value itself.
return res.Next(), nil
return res != nil, nil
}
// updateRoom updates the last_event_id for the room. selectRoomForUpdate should
@ -101,7 +177,18 @@ func (s *blacklistStatements) SelectBlacklist(
func (s *blacklistStatements) DeleteBlacklist(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt)
_, err := stmt.ExecContext(ctx, serverName)
// "DELETE FROM federationsender_blacklist WHERE server_name = $1"
// stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (server_name)
docId := fmt.Sprintf("%s", serverName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
// _, err := stmt.ExecContext(ctx, serverName)
res, err := getBlacklist(s, ctx, pk, cosmosDocId)
if(res != nil) {
_ = deleteBlacklist(s, ctx, *res)
}
return err
}

View file

@ -17,90 +17,198 @@ package cosmosdb
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/federationsender/types"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const inboundPeeksSchema = `
CREATE TABLE IF NOT EXISTS federationsender_inbound_peeks (
room_id TEXT NOT NULL,
server_name TEXT NOT NULL,
peek_id TEXT NOT NULL,
creation_ts INTEGER NOT NULL,
renewed_ts INTEGER NOT NULL,
renewal_interval INTEGER NOT NULL,
UNIQUE (room_id, server_name, peek_id)
);
`
// const inboundPeeksSchema = `
// CREATE TABLE IF NOT EXISTS federationsender_inbound_peeks (
// room_id TEXT NOT NULL,
// server_name TEXT NOT NULL,
// peek_id TEXT NOT NULL,
// creation_ts INTEGER NOT NULL,
// renewed_ts INTEGER NOT NULL,
// renewal_interval INTEGER NOT NULL,
// UNIQUE (room_id, server_name, peek_id)
// );
// `
const insertInboundPeekSQL = "" +
"INSERT INTO federationsender_inbound_peeks (room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval) VALUES ($1, $2, $3, $4, $5, $6)"
const selectInboundPeekSQL = "" +
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
const selectInboundPeeksSQL = "" +
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1"
const renewInboundPeekSQL = "" +
"UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
const deleteInboundPeekSQL = "" +
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2"
const deleteInboundPeeksSQL = "" +
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1"
type inboundPeeksStatements struct {
db *sql.DB
insertInboundPeekStmt *sql.Stmt
selectInboundPeekStmt *sql.Stmt
selectInboundPeeksStmt *sql.Stmt
renewInboundPeekStmt *sql.Stmt
deleteInboundPeekStmt *sql.Stmt
deleteInboundPeeksStmt *sql.Stmt
type InboundPeekCosmos struct {
RoomID string `json:"room_id"`
ServerName string `json:"server_name"`
PeekID string `json:"peek_id"`
CreationTimestamp int64 `json:"creation_ts"`
RenewedTimestamp int64 `json:"renewed_ts"`
RenewalInterval int64 `json:"renewal_interval"`
}
func NewSQLiteInboundPeeksTable(db *sql.DB) (s *inboundPeeksStatements, err error) {
type InboundPeekCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
InboundPeek InboundPeekCosmos `json:"mx_federationsender_inbound_peek"`
}
// const insertInboundPeekSQL = "" +
// "INSERT INTO federationsender_inbound_peeks (room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval) VALUES ($1, $2, $3, $4, $5, $6)"
// const selectInboundPeekSQL = "" +
// "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
// "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1"
const selectInboundPeeksSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_federationsender_inbound_peek.room_id = @x2"
// const renewInboundPeekSQL = "" +
// "UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
// "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2"
const deleteInboundPeekSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_federationsender_inbound_peek.room_id = @x2" +
"and c.mx_federationsender_inbound_peek.server_name = @x3"
// "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1"
const deleteInboundPeeksSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_federationsender_inbound_peek.room_id = @x2"
type inboundPeeksStatements struct {
db *Database
// insertInboundPeekStmt *sql.Stmt
// selectInboundPeekStmt *sql.Stmt
selectInboundPeeksStmt string
// renewInboundPeekStmt string
deleteInboundPeekStmt string
deleteInboundPeeksStmt string
tableName string
}
func queryInboundPeek(s *inboundPeeksStatements, ctx context.Context, qry string, params map[string]interface{}) ([]InboundPeekCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []InboundPeekCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func getInboundPeek(s *inboundPeeksStatements, ctx context.Context, pk string, docId string) (*InboundPeekCosmosData, error) {
response := InboundPeekCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
ctx,
pk,
docId,
&response)
if response.Id == "" {
return nil, nil
}
return &response, err
}
func setInboundPeek(s *inboundPeeksStatements, ctx context.Context, inboundPeek InboundPeekCosmosData) (*InboundPeekCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(inboundPeek.Pk, inboundPeek.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
inboundPeek.Id,
&inboundPeek,
optionsReplace)
return &inboundPeek, ex
}
func deleteInboundPeek(s *inboundPeeksStatements, ctx context.Context, dbData InboundPeekCosmosData) error {
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData.Id,
options)
if err != nil {
return err
}
return err
}
func NewCosmosDBInboundPeeksTable(db *Database) (s *inboundPeeksStatements, err error) {
s = &inboundPeeksStatements{
db: db,
}
_, err = db.Exec(inboundPeeksSchema)
if err != nil {
return
}
if s.insertInboundPeekStmt, err = db.Prepare(insertInboundPeekSQL); err != nil {
return
}
if s.selectInboundPeekStmt, err = db.Prepare(selectInboundPeekSQL); err != nil {
return
}
if s.selectInboundPeeksStmt, err = db.Prepare(selectInboundPeeksSQL); err != nil {
return
}
if s.renewInboundPeekStmt, err = db.Prepare(renewInboundPeekSQL); err != nil {
return
}
if s.deleteInboundPeeksStmt, err = db.Prepare(deleteInboundPeeksSQL); err != nil {
return
}
if s.deleteInboundPeekStmt, err = db.Prepare(deleteInboundPeekSQL); err != nil {
return
}
s.selectInboundPeeksStmt = selectInboundPeeksSQL
s.deleteInboundPeeksStmt = deleteInboundPeeksSQL
s.deleteInboundPeekStmt = deleteInboundPeekSQL
s.tableName = "inbound_peeks"
return
}
func (s *inboundPeeksStatements) InsertInboundPeek(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64,
) (err error) {
// "INSERT INTO federationsender_inbound_peeks (room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval) VALUES ($1, $2, $3, $4, $5, $6)"
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
stmt := sqlutil.TxStmt(txn, s.insertInboundPeekStmt)
_, err = stmt.ExecContext(ctx, roomID, serverName, peekID, nowMilli, nowMilli, renewalInterval)
// stmt := sqlutil.TxStmt(txn, s.insertInboundPeekStmt)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (room_id, server_name, peek_id)
docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
data := InboundPeekCosmos{
RoomID: roomID,
ServerName: string(serverName),
PeekID: peekID,
CreationTimestamp: nowMilli,
RenewedTimestamp: nowMilli,
RenewalInterval: renewalInterval,
}
dbData := &InboundPeekCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
InboundPeek: data,
}
// _, err = stmt.ExecContext(ctx, roomID, serverName, peekID, nowMilli, nowMilli, renewalInterval)
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
return
}
@ -108,26 +216,58 @@ func (s *inboundPeeksStatements) RenewInboundPeek(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64,
) (err error) {
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
_, err = sqlutil.TxStmt(txn, s.renewInboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID)
// "UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
// _, err = sqlutil.TxStmt(txn, s.renewInboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (room_id, server_name, peek_id)
docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
// _, err = sqlutil.TxStmt(txn, s.renewInboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID)
res, err := getInboundPeek(s, ctx, pk, cosmosDocId)
if err != nil {
return
}
if res == nil {
return
}
res.InboundPeek.RenewedTimestamp = nowMilli
res.InboundPeek.RenewalInterval = renewalInterval
_, err = setInboundPeek(s, ctx, *res)
return
}
func (s *inboundPeeksStatements) SelectInboundPeek(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string,
) (*types.InboundPeek, error) {
row := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryRowContext(ctx, roomID)
inboundPeek := types.InboundPeek{}
err := row.Scan(
&inboundPeek.RoomID,
&inboundPeek.ServerName,
&inboundPeek.PeekID,
&inboundPeek.CreationTimestamp,
&inboundPeek.RenewedTimestamp,
&inboundPeek.RenewalInterval,
)
if err == sql.ErrNoRows {
// "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (room_id, server_name, peek_id)
docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
// row := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryRowContext(ctx, roomID)
row, err := getInboundPeek(s, ctx, pk, cosmosDocId)
if row == nil {
return nil, nil
}
inboundPeek := types.InboundPeek{}
inboundPeek.RoomID = row.InboundPeek.RoomID
inboundPeek.ServerName = gomatrixserverlib.ServerName(row.InboundPeek.ServerName)
inboundPeek.PeekID = row.InboundPeek.PeekID
inboundPeek.CreationTimestamp = row.InboundPeek.CreationTimestamp
inboundPeek.RenewedTimestamp = row.InboundPeek.RenewedTimestamp
inboundPeek.RenewalInterval = row.InboundPeek.RenewalInterval
if err != nil {
return nil, err
}
@ -137,40 +277,87 @@ func (s *inboundPeeksStatements) SelectInboundPeek(
func (s *inboundPeeksStatements) SelectInboundPeeks(
ctx context.Context, txn *sql.Tx, roomID string,
) (inboundPeeks []types.InboundPeek, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryContext(ctx, roomID)
// "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomID,
}
// rows, err := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryContext(ctx, roomID)
rows, err := queryInboundPeek(s, ctx, s.selectInboundPeeksStmt, params)
if err != nil {
return
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectInboundPeeks: rows.close() failed")
for rows.Next() {
for _, item := range rows {
inboundPeek := types.InboundPeek{}
if err = rows.Scan(
&inboundPeek.RoomID,
&inboundPeek.ServerName,
&inboundPeek.PeekID,
&inboundPeek.CreationTimestamp,
&inboundPeek.RenewedTimestamp,
&inboundPeek.RenewalInterval,
); err != nil {
return
}
inboundPeek.RoomID = item.InboundPeek.RoomID
inboundPeek.ServerName = gomatrixserverlib.ServerName(item.InboundPeek.ServerName)
inboundPeek.PeekID = item.InboundPeek.PeekID
inboundPeek.CreationTimestamp = item.InboundPeek.CreationTimestamp
inboundPeek.RenewedTimestamp = item.InboundPeek.RenewedTimestamp
inboundPeek.RenewalInterval = item.InboundPeek.RenewalInterval
inboundPeeks = append(inboundPeeks, inboundPeek)
}
return inboundPeeks, rows.Err()
return inboundPeeks, nil
}
func (s *inboundPeeksStatements) DeleteInboundPeek(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteInboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID)
// "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomID,
"@x3": serverName,
}
// _, err = sqlutil.TxStmt(txn, s.deleteInboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID)
rows, err := queryInboundPeek(s, ctx, s.deleteInboundPeekStmt, params)
if err != nil {
return
}
for _, item := range rows {
err = deleteInboundPeek(s, ctx, item)
if err != nil {
return
}
}
return
}
func (s *inboundPeeksStatements) DeleteInboundPeeks(
ctx context.Context, txn *sql.Tx, roomID string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteInboundPeeksStmt).ExecContext(ctx, roomID)
// "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomID,
}
// _, err = sqlutil.TxStmt(txn, s.deleteInboundPeeksStmt).ExecContext(ctx, roomID)
rows, err := queryInboundPeek(s, ctx, s.deleteInboundPeekStmt, params)
if err != nil {
return
}
for _, item := range rows {
err = deleteInboundPeek(s, ctx, item)
if err != nil {
return
}
}
return
}

View file

@ -18,87 +18,155 @@ package cosmosdb
import (
"context"
"database/sql"
"strings"
"fmt"
"time"
"github.com/matrix-org/dendrite/federationsender/types"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/gomatrixserverlib"
)
const joinedHostsSchema = `
-- The joined_hosts table stores a list of m.room.member event ids in the
-- current state for each room where the membership is "join".
-- There will be an entry for every user that is joined to the room.
CREATE TABLE IF NOT EXISTS federationsender_joined_hosts (
-- The string ID of the room.
room_id TEXT NOT NULL,
-- The event ID of the m.room.member join event.
event_id TEXT NOT NULL,
-- The domain part of the user ID the m.room.member event is for.
server_name TEXT NOT NULL
);
// const joinedHostsSchema = `
// -- The joined_hosts table stores a list of m.room.member event ids in the
// -- current state for each room where the membership is "join".
// -- There will be an entry for every user that is joined to the room.
// CREATE TABLE IF NOT EXISTS federationsender_joined_hosts (
// -- The string ID of the room.
// room_id TEXT NOT NULL,
// -- The event ID of the m.room.member join event.
// event_id TEXT NOT NULL,
// -- The domain part of the user ID the m.room.member event is for.
// server_name TEXT NOT NULL
// );
CREATE UNIQUE INDEX IF NOT EXISTS federatonsender_joined_hosts_event_id_idx
ON federationsender_joined_hosts (event_id);
// CREATE UNIQUE INDEX IF NOT EXISTS federatonsender_joined_hosts_event_id_idx
// ON federationsender_joined_hosts (event_id);
CREATE INDEX IF NOT EXISTS federatonsender_joined_hosts_room_id_idx
ON federationsender_joined_hosts (room_id)
`
// CREATE INDEX IF NOT EXISTS federatonsender_joined_hosts_room_id_idx
// ON federationsender_joined_hosts (room_id)
// `
const insertJoinedHostsSQL = "" +
"INSERT OR IGNORE INTO federationsender_joined_hosts (room_id, event_id, server_name)" +
" VALUES ($1, $2, $3)"
const deleteJoinedHostsSQL = "" +
"DELETE FROM federationsender_joined_hosts WHERE event_id = $1"
const deleteJoinedHostsForRoomSQL = "" +
"DELETE FROM federationsender_joined_hosts WHERE room_id = $1"
const selectJoinedHostsSQL = "" +
"SELECT event_id, server_name FROM federationsender_joined_hosts" +
" WHERE room_id = $1"
const selectAllJoinedHostsSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts"
const selectJoinedHostsForRoomsSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)"
type joinedHostsStatements struct {
db *sql.DB
insertJoinedHostsStmt *sql.Stmt
deleteJoinedHostsStmt *sql.Stmt
deleteJoinedHostsForRoomStmt *sql.Stmt
selectJoinedHostsStmt *sql.Stmt
selectAllJoinedHostsStmt *sql.Stmt
// selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic
type JoinedHostCosmos struct {
RoomID string `json:"room_id"`
EventID string `json:"event_id"`
ServerName string `json:"server_name"`
}
func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
type JoinedHostCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
JoinedHost JoinedHostCosmos `json:"mx_federationsender_joined_host"`
}
// const insertJoinedHostsSQL = "" +
// "INSERT OR IGNORE INTO federationsender_joined_hosts (room_id, event_id, server_name)" +
// " VALUES ($1, $2, $3)"
// "DELETE FROM federationsender_joined_hosts WHERE event_id = $1"
const deleteJoinedHostsSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_federationsender_joined_host.event_id = @x2 "
// "DELETE FROM federationsender_joined_hosts WHERE room_id = $1"
const deleteJoinedHostsForRoomSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_federationsender_joined_host.room_id = @x2 "
// "SELECT event_id, server_name FROM federationsender_joined_hosts" +
// " WHERE room_id = $1"
const selectJoinedHostsSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_federationsender_joined_host.room_id = @x2 "
// "SELECT DISTINCT server_name FROM federationsender_joined_hosts"
const selectAllJoinedHostsSQL = "" +
"select distinct c.mx_federationsender_joined_host.server_name from c where c._cn = @x1 "
// "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)"
const selectJoinedHostsForRoomsSQL = "" +
"select distinct c.mx_federationsender_joined_host.server_name from c where c._cn = @x1 " +
"and ARRAY_CONTAINS(@x2, c.mx_federationsender_joined_host.room_id) "
type joinedHostsStatements struct {
db *Database
// insertJoinedHostsStmt *sql.Stmt
deleteJoinedHostsStmt string
deleteJoinedHostsForRoomStmt string
selectJoinedHostsStmt string
selectAllJoinedHostsStmt string
// selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic
tableName string
}
func queryJoinedHostDistinct(s *joinedHostsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]JoinedHostCosmos, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []JoinedHostCosmos
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func queryJoinedHost(s *joinedHostsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]JoinedHostCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []JoinedHostCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func deleteJoinedHost(s *joinedHostsStatements, ctx context.Context, dbData JoinedHostCosmosData) error {
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData.Id,
options)
if err != nil {
return err
}
return err
}
func NewCosmosDBJoinedHostsTable(db *Database) (s *joinedHostsStatements, err error) {
s = &joinedHostsStatements{
db: db,
}
_, err = db.Exec(joinedHostsSchema)
if err != nil {
return
}
if s.insertJoinedHostsStmt, err = db.Prepare(insertJoinedHostsSQL); err != nil {
return
}
if s.deleteJoinedHostsStmt, err = db.Prepare(deleteJoinedHostsSQL); err != nil {
return
}
if s.deleteJoinedHostsForRoomStmt, err = s.db.Prepare(deleteJoinedHostsForRoomSQL); err != nil {
return
}
if s.selectJoinedHostsStmt, err = db.Prepare(selectJoinedHostsSQL); err != nil {
return
}
if s.selectAllJoinedHostsStmt, err = db.Prepare(selectAllJoinedHostsSQL); err != nil {
return
}
s.deleteJoinedHostsStmt = deleteJoinedHostsSQL
s.deleteJoinedHostsForRoomStmt = deleteJoinedHostsForRoomSQL
s.selectJoinedHostsStmt = selectJoinedHostsSQL
s.selectAllJoinedHostsStmt = selectAllJoinedHostsSQL
s.tableName = "joined_hosts"
return
}
@ -108,8 +176,43 @@ func (s *joinedHostsStatements) InsertJoinedHosts(
roomID, eventID string,
serverName gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt)
_, err := stmt.ExecContext(ctx, roomID, eventID, serverName)
// "INSERT OR IGNORE INTO federationsender_joined_hosts (room_id, event_id, server_name)" +
// " VALUES ($1, $2, $3)"
// stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// CREATE UNIQUE INDEX IF NOT EXISTS federatonsender_joined_hosts_event_id_idx
// ON federationsender_joined_hosts (event_id);
docId := fmt.Sprintf("%s", eventID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
data := JoinedHostCosmos{
EventID: eventID,
RoomID: roomID,
ServerName: string(serverName),
}
dbData := &JoinedHostCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
JoinedHost: data,
}
// _, err := stmt.ExecContext(ctx, roomID, eventID, serverName)
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
_, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
return err
}
@ -117,9 +220,21 @@ func (s *joinedHostsStatements) DeleteJoinedHosts(
ctx context.Context, txn *sql.Tx, eventIDs []string,
) error {
for _, eventID := range eventIDs {
stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt)
if _, err := stmt.ExecContext(ctx, eventID); err != nil {
return err
// "DELETE FROM federationsender_joined_hosts WHERE event_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": eventID,
}
// stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt)
rows, err := queryJoinedHost(s, ctx, s.deleteJoinedHostsStmt, params)
for _, item := range rows {
if err = deleteJoinedHost(s, ctx, item); err != nil {
return err
}
}
}
return nil
@ -128,92 +243,123 @@ func (s *joinedHostsStatements) DeleteJoinedHosts(
func (s *joinedHostsStatements) DeleteJoinedHostsForRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsForRoomStmt)
_, err := stmt.ExecContext(ctx, roomID)
// "DELETE FROM federationsender_joined_hosts WHERE room_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomID,
}
// stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsForRoomStmt)
rows, err := queryJoinedHost(s, ctx, s.deleteJoinedHostsStmt, params)
// _, err := stmt.ExecContext(ctx, roomID)
for _, item := range rows {
if err = deleteJoinedHost(s, ctx, item); err != nil {
return err
}
}
return err
}
func (s *joinedHostsStatements) SelectJoinedHostsWithTx(
ctx context.Context, txn *sql.Tx, roomID string,
) ([]types.JoinedHost, error) {
stmt := sqlutil.TxStmt(txn, s.selectJoinedHostsStmt)
return joinedHostsFromStmt(ctx, stmt, roomID)
// "SELECT event_id, server_name FROM federationsender_joined_hosts" +
// " WHERE room_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomID,
}
// stmt := sqlutil.TxStmt(txn, s.selectJoinedHostsStmt)
rows, err := queryJoinedHost(s, ctx, s.deleteJoinedHostsStmt, params)
if err != nil {
return nil, err
}
return rowsToJoinedHosts(&rows), nil
}
func (s *joinedHostsStatements) SelectJoinedHosts(
ctx context.Context, roomID string,
) ([]types.JoinedHost, error) {
return joinedHostsFromStmt(ctx, s.selectJoinedHostsStmt, roomID)
return s.SelectJoinedHostsWithTx(ctx, nil, roomID)
}
func (s *joinedHostsStatements) SelectAllJoinedHosts(
ctx context.Context,
) ([]gomatrixserverlib.ServerName, error) {
rows, err := s.selectAllJoinedHostsStmt.QueryContext(ctx)
// "SELECT DISTINCT server_name FROM federationsender_joined_hosts"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
}
// rows, err := s.selectAllJoinedHostsStmt.QueryContext(ctx)
rows, err := queryJoinedHostDistinct(s, ctx, s.selectAllJoinedHostsStmt, params)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectAllJoinedHosts: rows.close() failed")
var result []gomatrixserverlib.ServerName
for rows.Next() {
for _, item := range rows {
var serverName string
if err = rows.Scan(&serverName); err != nil {
return nil, err
}
serverName = item.ServerName
result = append(result, gomatrixserverlib.ServerName(serverName))
}
return result, rows.Err()
return result, err
}
func (s *joinedHostsStatements) SelectJoinedHostsForRooms(
ctx context.Context, roomIDs []string,
) ([]gomatrixserverlib.ServerName, error) {
iRoomIDs := make([]interface{}, len(roomIDs))
for i := range roomIDs {
iRoomIDs[i] = roomIDs[i]
// iRoomIDs := make([]interface{}, len(roomIDs))
// for i := range roomIDs {
// iRoomIDs[i] = roomIDs[i]
// }
// "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)"
// sql := strings.Replace(selectJoinedHostsForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomIDs,
}
sql := strings.Replace(selectJoinedHostsForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1)
rows, err := s.db.QueryContext(ctx, sql, iRoomIDs...)
// rows, err := s.db.QueryContext(ctx, sql, iRoomIDs...)
rows, err := queryJoinedHostDistinct(s, ctx, s.selectAllJoinedHostsStmt, params)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedHostsForRoomsStmt: rows.close() failed")
var result []gomatrixserverlib.ServerName
for rows.Next() {
for _, item := range rows {
var serverName string
if err = rows.Scan(&serverName); err != nil {
return nil, err
}
serverName = item.ServerName
result = append(result, gomatrixserverlib.ServerName(serverName))
}
return result, rows.Err()
}
func joinedHostsFromStmt(
ctx context.Context, stmt *sql.Stmt, roomID string,
) ([]types.JoinedHost, error) {
rows, err := stmt.QueryContext(ctx, roomID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "joinedHostsFromStmt: rows.close() failed")
var result []types.JoinedHost
for rows.Next() {
var eventID, serverName string
if err = rows.Scan(&eventID, &serverName); err != nil {
return nil, err
}
result = append(result, types.JoinedHost{
MemberEventID: eventID,
ServerName: gomatrixserverlib.ServerName(serverName),
})
}
return result, nil
}
func rowsToJoinedHosts(rows *[]JoinedHostCosmosData) []types.JoinedHost {
var result []types.JoinedHost
if rows == nil {
return result
}
for _, item := range *rows {
result = append(result, types.JoinedHost{
MemberEventID: item.JoinedHost.EventID,
ServerName: gomatrixserverlib.ServerName(item.JoinedHost.ServerName),
})
}
return result
}

View file

@ -17,160 +17,350 @@ package cosmosdb
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/federationsender/types"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const outboundPeeksSchema = `
CREATE TABLE IF NOT EXISTS federationsender_outbound_peeks (
room_id TEXT NOT NULL,
server_name TEXT NOT NULL,
peek_id TEXT NOT NULL,
creation_ts INTEGER NOT NULL,
renewed_ts INTEGER NOT NULL,
renewal_interval INTEGER NOT NULL,
UNIQUE (room_id, server_name, peek_id)
);
`
// const outboundPeeksSchema = `
// CREATE TABLE IF NOT EXISTS federationsender_outbound_peeks (
// room_id TEXT NOT NULL,
// server_name TEXT NOT NULL,
// peek_id TEXT NOT NULL,
// creation_ts INTEGER NOT NULL,
// renewed_ts INTEGER NOT NULL,
// renewal_interval INTEGER NOT NULL,
// UNIQUE (room_id, server_name, peek_id)
// );
// `
const insertOutboundPeekSQL = "" +
"INSERT INTO federationsender_outbound_peeks (room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval) VALUES ($1, $2, $3, $4, $5, $6)"
const selectOutboundPeekSQL = "" +
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
const selectOutboundPeeksSQL = "" +
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1"
const renewOutboundPeekSQL = "" +
"UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
const deleteOutboundPeekSQL = "" +
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2"
const deleteOutboundPeeksSQL = "" +
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1"
type outboundPeeksStatements struct {
db *sql.DB
insertOutboundPeekStmt *sql.Stmt
selectOutboundPeekStmt *sql.Stmt
selectOutboundPeeksStmt *sql.Stmt
renewOutboundPeekStmt *sql.Stmt
deleteOutboundPeekStmt *sql.Stmt
deleteOutboundPeeksStmt *sql.Stmt
type OutboundPeekCosmos struct {
RoomID string `json:"room_id"`
ServerName string `json:"server_name"`
PeekID string `json:"peek_id"`
CreationTimestamp int64 `json:"creation_ts"`
RenewedTimestamp int64 `json:"renewed_ts"`
RenewalInterval int64 `json:"renewal_interval"`
}
func NewSQLiteOutboundPeeksTable(db *sql.DB) (s *outboundPeeksStatements, err error) {
type OutboundPeekCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
OutboundPeek OutboundPeekCosmos `json:"mx_federationsender_outbound_peek"`
}
// const insertOutboundPeekSQL = "" +
// "INSERT INTO federationsender_outbound_peeks (room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval) VALUES ($1, $2, $3, $4, $5, $6)"
// "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1"
const selectOutboundPeeksSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_federationsender_outbound_peek.room_id = @x2"
// const renewOutboundPeekSQL = "" +
// "UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
// "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2"
const deleteOutboundPeekSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_federationsender_outbound_peek.room_id = @x2" +
"and c.mx_federationsender_outbound_peek.server_name = @x3"
// "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1"
const deleteOutboundPeeksSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_federationsender_outbound_peek.room_id = @x2"
type outboundPeeksStatements struct {
db *Database
// insertOutboundPeekStmt *sql.Stmt
// selectOutboundPeekStmt *sql.Stmt
selectOutboundPeeksStmt string
// renewOutboundPeekStmt *sql.Stmt
deleteOutboundPeekStmt string
deleteOutboundPeeksStmt string
tableName string
}
func queryOutboundPeek(s *outboundPeeksStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OutboundPeekCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []OutboundPeekCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func getOutboundPeek(s *outboundPeeksStatements, ctx context.Context, pk string, docId string) (*OutboundPeekCosmosData, error) {
response := OutboundPeekCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
ctx,
pk,
docId,
&response)
if response.Id == "" {
return nil, nil
}
return &response, err
}
func setOutboundPeek(s *outboundPeeksStatements, ctx context.Context, outboundPeek OutboundPeekCosmosData) (*OutboundPeekCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(outboundPeek.Pk, outboundPeek.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
outboundPeek.Id,
&outboundPeek,
optionsReplace)
return &outboundPeek, ex
}
func deleteOutboundPeek(s *outboundPeeksStatements, ctx context.Context, dbData OutboundPeekCosmosData) error {
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData.Id,
options)
if err != nil {
return err
}
return err
}
func NewCosmosDBOutboundPeeksTable(db *Database) (s *outboundPeeksStatements, err error) {
s = &outboundPeeksStatements{
db: db,
}
_, err = db.Exec(outboundPeeksSchema)
if err != nil {
return
}
if s.insertOutboundPeekStmt, err = db.Prepare(insertOutboundPeekSQL); err != nil {
return
}
if s.selectOutboundPeekStmt, err = db.Prepare(selectOutboundPeekSQL); err != nil {
return
}
if s.selectOutboundPeeksStmt, err = db.Prepare(selectOutboundPeeksSQL); err != nil {
return
}
if s.renewOutboundPeekStmt, err = db.Prepare(renewOutboundPeekSQL); err != nil {
return
}
if s.deleteOutboundPeeksStmt, err = db.Prepare(deleteOutboundPeeksSQL); err != nil {
return
}
if s.deleteOutboundPeekStmt, err = db.Prepare(deleteOutboundPeekSQL); err != nil {
return
}
s.selectOutboundPeeksStmt = selectOutboundPeeksSQL
s.deleteOutboundPeeksStmt = deleteOutboundPeeksSQL
s.deleteOutboundPeekStmt = deleteOutboundPeekSQL
s.tableName = "outbound_peeks"
return
}
func (s *outboundPeeksStatements) InsertOutboundPeek(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64,
) (err error) {
// "INSERT INTO federationsender_outbound_peeks (room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval) VALUES ($1, $2, $3, $4, $5, $6)"
// stmt := sqlutil.TxStmt(txn, s.insertOutboundPeekStmt)
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
stmt := sqlutil.TxStmt(txn, s.insertOutboundPeekStmt)
_, err = stmt.ExecContext(ctx, roomID, serverName, peekID, nowMilli, nowMilli, renewalInterval)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (room_id, server_name, peek_id)
docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
data := OutboundPeekCosmos{
RoomID: roomID,
ServerName: string(serverName),
PeekID: peekID,
CreationTimestamp: nowMilli,
RenewedTimestamp: nowMilli,
RenewalInterval: renewalInterval,
}
dbData := &OutboundPeekCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
OutboundPeek: data,
}
// _, err = stmt.ExecContext(ctx, roomID, serverName, peekID, nowMilli, nowMilli, renewalInterval)
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
return
}
func (s *outboundPeeksStatements) RenewOutboundPeek(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64,
) (err error) {
// "UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
_, err = sqlutil.TxStmt(txn, s.renewOutboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (room_id, server_name, peek_id)
docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
// _, err = sqlutil.TxStmt(txn, s.renewOutboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID)
res, err := getOutboundPeek(s, ctx, pk, cosmosDocId)
if err != nil {
return
}
if res == nil {
return
}
res.OutboundPeek.RenewedTimestamp = nowMilli
res.OutboundPeek.RenewalInterval = renewalInterval
_, err = setOutboundPeek(s, ctx, *res)
return
}
func (s *outboundPeeksStatements) SelectOutboundPeek(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string,
) (*types.OutboundPeek, error) {
row := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryRowContext(ctx, roomID)
outboundPeek := types.OutboundPeek{}
err := row.Scan(
&outboundPeek.RoomID,
&outboundPeek.ServerName,
&outboundPeek.PeekID,
&outboundPeek.CreationTimestamp,
&outboundPeek.RenewedTimestamp,
&outboundPeek.RenewalInterval,
)
if err == sql.ErrNoRows {
return nil, nil
}
// "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (room_id, server_name, peek_id)
docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
// row := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryRowContext(ctx, roomID)
row, err := getOutboundPeek(s, ctx, pk, cosmosDocId)
if err != nil {
return nil, err
}
if row == nil {
return nil, nil
}
outboundPeek := types.OutboundPeek{}
outboundPeek.RoomID = row.OutboundPeek.RoomID
outboundPeek.ServerName = gomatrixserverlib.ServerName(row.OutboundPeek.ServerName)
outboundPeek.PeekID = row.OutboundPeek.PeekID
outboundPeek.CreationTimestamp = row.OutboundPeek.CreationTimestamp
outboundPeek.RenewedTimestamp = row.OutboundPeek.RenewedTimestamp
outboundPeek.RenewalInterval = row.OutboundPeek.RenewalInterval
return &outboundPeek, nil
}
func (s *outboundPeeksStatements) SelectOutboundPeeks(
ctx context.Context, txn *sql.Tx, roomID string,
) (outboundPeeks []types.OutboundPeek, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryContext(ctx, roomID)
// "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1"
if err != nil {
return
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectOutboundPeeks: rows.close() failed")
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomID,
}
for rows.Next() {
// rows, err := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryContext(ctx, roomID)
rows, err := queryOutboundPeek(s, ctx, s.selectOutboundPeeksStmt, params)
if err != nil {
return
}
for _, item := range rows {
outboundPeek := types.OutboundPeek{}
if err = rows.Scan(
&outboundPeek.RoomID,
&outboundPeek.ServerName,
&outboundPeek.PeekID,
&outboundPeek.CreationTimestamp,
&outboundPeek.RenewedTimestamp,
&outboundPeek.RenewalInterval,
); err != nil {
return
}
outboundPeek.RoomID = item.OutboundPeek.RoomID
outboundPeek.ServerName = gomatrixserverlib.ServerName(item.OutboundPeek.ServerName)
outboundPeek.PeekID = item.OutboundPeek.PeekID
outboundPeek.CreationTimestamp = item.OutboundPeek.CreationTimestamp
outboundPeek.RenewedTimestamp = item.OutboundPeek.RenewedTimestamp
outboundPeek.RenewalInterval = item.OutboundPeek.RenewalInterval
outboundPeeks = append(outboundPeeks, outboundPeek)
}
return outboundPeeks, rows.Err()
return outboundPeeks, nil
}
func (s *outboundPeeksStatements) DeleteOutboundPeek(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteOutboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID)
// "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomID,
"@x3": serverName,
}
// _, err = sqlutil.TxStmt(txn, s.deleteOutboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID)
rows, err := queryOutboundPeek(s, ctx, s.deleteOutboundPeekStmt, params)
if err != nil {
return
}
for _, item := range rows {
err = deleteOutboundPeek(s, ctx, item)
if err != nil {
return
}
}
return
}
func (s *outboundPeeksStatements) DeleteOutboundPeeks(
ctx context.Context, txn *sql.Tx, roomID string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteOutboundPeeksStmt).ExecContext(ctx, roomID)
// "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomID,
}
// _, err = sqlutil.TxStmt(txn, s.deleteOutboundPeeksStmt).ExecContext(ctx, roomID)
rows, err := queryOutboundPeek(s, ctx, s.deleteOutboundPeeksStmt, params)
if err != nil {
return
}
for _, item := range rows {
err = deleteOutboundPeek(s, ctx, item)
if err != nil {
return
}
}
return
}

View file

@ -18,82 +18,176 @@ import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const queueEDUsSchema = `
CREATE TABLE IF NOT EXISTS federationsender_queue_edus (
-- The type of the event (informational).
edu_type TEXT NOT NULL,
-- The domain part of the user ID the EDU event is for.
server_name TEXT NOT NULL,
-- The JSON NID from the federationsender_queue_edus_json table.
json_nid BIGINT NOT NULL
);
// const queueEDUsSchema = `
// CREATE TABLE IF NOT EXISTS federationsender_queue_edus (
// -- The type of the event (informational).
// edu_type TEXT NOT NULL,
// -- The domain part of the user ID the EDU event is for.
// server_name TEXT NOT NULL,
// -- The JSON NID from the federationsender_queue_edus_json table.
// json_nid BIGINT NOT NULL
// );
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx
ON federationsender_queue_edus (json_nid, server_name);
`
// CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx
// ON federationsender_queue_edus (json_nid, server_name);
// `
const insertQueueEDUSQL = "" +
"INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid)" +
" VALUES ($1, $2, $3)"
const deleteQueueEDUsSQL = "" +
"DELETE FROM federationsender_queue_edus WHERE server_name = $1 AND json_nid IN ($2)"
const selectQueueEDUSQL = "" +
"SELECT json_nid FROM federationsender_queue_edus" +
" WHERE server_name = $1" +
" LIMIT $2"
const selectQueueEDUReferenceJSONCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_edus" +
" WHERE json_nid = $1"
const selectQueueEDUCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_edus" +
" WHERE server_name = $1"
const selectQueueServerNamesSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_queue_edus"
type queueEDUsStatements struct {
db *sql.DB
insertQueueEDUStmt *sql.Stmt
selectQueueEDUStmt *sql.Stmt
selectQueueEDUReferenceJSONCountStmt *sql.Stmt
selectQueueEDUCountStmt *sql.Stmt
selectQueueEDUServerNamesStmt *sql.Stmt
type QueueEDUCosmos struct {
EDUType string `json:"edu_type"`
ServerName string `json:"server_name"`
JSONNID int64 `json:"json_nid"`
}
func NewSQLiteQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) {
type QueueEDUCosmosNumber struct {
Number int64 `json:"number"`
}
type QueueEDUCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
QueueEDU QueueEDUCosmos `json:"mx_federationsender_queue_edu"`
}
// const insertQueueEDUSQL = "" +
// "INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid)" +
// " VALUES ($1, $2, $3)"
// "DELETE FROM federationsender_queue_edus WHERE server_name = $1 AND json_nid IN ($2)"
const deleteQueueEDUsSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_federationsender_queue_edu.server_name = @x2" +
"and ARRAY_CONTAINS(@x3, c.mx_federationsender_queue_edu.json_nid) "
// "SELECT json_nid FROM federationsender_queue_edus" +
// " WHERE server_name = $1" +
// " LIMIT $2"
const selectQueueEDUSQL = "" +
"select top @x3 * from c where c._cn = @x1 " +
"and c.mx_federationsender_queue_edu.server_name = @x2"
// "SELECT COUNT(*) FROM federationsender_queue_edus" +
// " WHERE json_nid = $1"
const selectQueueEDUReferenceJSONCountSQL = "" +
"select count(c._ts) as number from c where c._cn = @x1 " +
"and c.mx_federationsender_queue_edu.json_nid = @x2"
// "SELECT COUNT(*) FROM federationsender_queue_edus" +
// " WHERE server_name = $1"
const selectQueueEDUCountSQL = "" +
"select count(c._ts) as number from c where c._cn = @x1 " +
"and c.mx_federationsender_queue_edu.server_name = @x2"
// "SELECT DISTINCT server_name FROM federationsender_queue_edus"
const selectQueueServerNamesSQL = "" +
"select distinct c.mx_federationsender_queue_edu.server_name from c where c._cn = @x1 "
type queueEDUsStatements struct {
db *Database
// insertQueueEDUStmt *sql.Stmt
selectQueueEDUStmt string
selectQueueEDUReferenceJSONCountStmt string
selectQueueEDUCountStmt string
selectQueueEDUServerNamesStmt string
tableName string
}
func queryQueueEDUC(s *queueEDUsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]QueueEDUCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []QueueEDUCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func queryQueueEDUCDistinct(s *queueEDUsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]QueueEDUCosmos, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []QueueEDUCosmos
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func queryQueueEDUCNumber(s *queueEDUsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]QueueEDUCosmosNumber, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []QueueEDUCosmosNumber
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func deleteQueueEDUC(s *queueEDUsStatements, ctx context.Context, dbData QueueEDUCosmosData) error {
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData.Id,
options)
if err != nil {
return err
}
return err
}
func NewCosmosDBQueueEDUsTable(db *Database) (s *queueEDUsStatements, err error) {
s = &queueEDUsStatements{
db: db,
}
_, err = db.Exec(queueEDUsSchema)
if err != nil {
return
}
if s.insertQueueEDUStmt, err = db.Prepare(insertQueueEDUSQL); err != nil {
return
}
if s.selectQueueEDUStmt, err = db.Prepare(selectQueueEDUSQL); err != nil {
return
}
if s.selectQueueEDUReferenceJSONCountStmt, err = db.Prepare(selectQueueEDUReferenceJSONCountSQL); err != nil {
return
}
if s.selectQueueEDUCountStmt, err = db.Prepare(selectQueueEDUCountSQL); err != nil {
return
}
if s.selectQueueEDUServerNamesStmt, err = db.Prepare(selectQueueServerNamesSQL); err != nil {
return
}
s.selectQueueEDUStmt = selectQueueEDUSQL
s.selectQueueEDUReferenceJSONCountStmt = selectQueueEDUReferenceJSONCountSQL
s.selectQueueEDUCountStmt = selectQueueEDUCountSQL
s.selectQueueEDUServerNamesStmt = selectQueueServerNamesSQL
s.tableName = "queue_edus"
return
}
@ -104,13 +198,47 @@ func (s *queueEDUsStatements) InsertQueueEDU(
serverName gomatrixserverlib.ServerName,
nid int64,
) error {
stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt)
_, err := stmt.ExecContext(
// "INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid)" +
// stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx
// ON federationsender_queue_edus (json_nid, server_name);
docId := fmt.Sprintf("%d_%s", nid, eduType)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
data := QueueEDUCosmos{
EDUType: eduType,
JSONNID: nid,
ServerName: string(serverName),
}
dbData := &QueueEDUCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
QueueEDU: data,
}
// _, err := stmt.ExecContext(
// ctx,
// eduType, // the EDU type
// serverName, // destination server name
// nid, // JSON blob NID
// )
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
_, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
eduType, // the EDU type
serverName, // destination server name
nid, // JSON blob NID
)
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
return err
}
@ -119,20 +247,33 @@ func (s *queueEDUsStatements) DeleteQueueEDUs(
serverName gomatrixserverlib.ServerName,
jsonNIDs []int64,
) error {
deleteSQL := strings.Replace(deleteQueueEDUsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1)
deleteStmt, err := txn.Prepare(deleteSQL)
// "DELETE FROM federationsender_queue_edus WHERE server_name = $1 AND json_nid IN ($2)"
// deleteSQL := strings.Replace(deleteQueueEDUsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": serverName,
"@x3": jsonNIDs,
}
// stmt := sqlutil.TxStmt(txn, deleteStmt)
// _, err = stmt.ExecContext(ctx, params...)
rows, err := queryQueueEDUC(s, ctx, deleteQueueEDUsSQL, params)
if err != nil {
return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err)
return err
}
params := make([]interface{}, len(jsonNIDs)+1)
params[0] = serverName
for k, v := range jsonNIDs {
params[k+1] = v
for _, item := range rows {
err = deleteQueueEDUC(s, ctx, item)
if err != nil {
return err
}
}
stmt := sqlutil.TxStmt(txn, deleteStmt)
_, err = stmt.ExecContext(ctx, params...)
return err
}
@ -141,18 +282,28 @@ func (s *queueEDUsStatements) SelectQueueEDUs(
serverName gomatrixserverlib.ServerName,
limit int,
) ([]int64, error) {
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUStmt)
rows, err := stmt.QueryContext(ctx, serverName, limit)
// "SELECT json_nid FROM federationsender_queue_edus" +
// " WHERE server_name = $1" +
// " LIMIT $2"
// stmt := sqlutil.TxStmt(txn, s.selectQueueEDUStmt)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": serverName,
"@x3": limit,
}
// rows, err := stmt.QueryContext(ctx, serverName, limit)
rows, err := queryQueueEDUC(s, ctx, deleteQueueEDUsSQL, params)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
var result []int64
for rows.Next() {
for _, item := range rows {
var nid int64
if err = rows.Scan(&nid); err != nil {
return nil, err
}
nid = item.QueueEDU.JSONNID
result = append(result, nid)
}
return result, nil
@ -162,11 +313,23 @@ func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount(
ctx context.Context, txn *sql.Tx, jsonNID int64,
) (int64, error) {
var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUReferenceJSONCountStmt)
err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count)
if err == sql.ErrNoRows {
// "SELECT COUNT(*) FROM federationsender_queue_edus" +
// " WHERE json_nid = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": jsonNID,
}
// stmt := sqlutil.TxStmt(txn, s.selectQueueEDUReferenceJSONCountStmt)
rows, err := queryQueueEDUCNumber(s, ctx, s.selectQueueEDUReferenceJSONCountStmt, params)
if len(rows) == 0 {
return -1, nil
}
// err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count)
count = rows[0].Number
return count, err
}
@ -174,34 +337,52 @@ func (s *queueEDUsStatements) SelectQueueEDUCount(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (int64, error) {
var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt)
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
if err == sql.ErrNoRows {
// "SELECT COUNT(*) FROM federationsender_queue_edus" +
// " WHERE server_name = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": serverName,
}
// stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt)
rows, err := queryQueueEDUCNumber(s, ctx, s.selectQueueEDUCountStmt, params)
if len(rows) == 0 {
// It's acceptable for there to be no rows referencing a given
// JSON NID but it's not an error condition. Just return as if
// there's a zero count.
return 0, nil
}
// err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
count = rows[0].Number
return count, err
}
func (s *queueEDUsStatements) SelectQueueEDUServerNames(
ctx context.Context, txn *sql.Tx,
) ([]gomatrixserverlib.ServerName, error) {
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUServerNamesStmt)
rows, err := stmt.QueryContext(ctx)
// "SELECT DISTINCT server_name FROM federationsender_queue_edus"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
}
// stmt := sqlutil.TxStmt(txn, s.selectQueueEDUServerNamesStmt)
// rows, err := stmt.QueryContext(ctx)
rows, err := queryQueueEDUCDistinct(s, ctx, s.selectQueueEDUServerNamesStmt, params)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
var result []gomatrixserverlib.ServerName
for rows.Next() {
for _, item := range rows {
var serverName gomatrixserverlib.ServerName
if err = rows.Scan(&serverName); err != nil {
return nil, err
}
serverName = gomatrixserverlib.ServerName(item.ServerName)
result = append(result, serverName)
}
return result, rows.Err()
return result, nil
}

View file

@ -19,97 +19,205 @@ import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
)
const queueJSONSchema = `
-- The queue_retry_json table contains event contents that
-- we failed to send.
CREATE TABLE IF NOT EXISTS federationsender_queue_json (
-- The JSON NID. This allows the federationsender_queue_retry table to
-- cross-reference to find the JSON blob.
json_nid INTEGER PRIMARY KEY AUTOINCREMENT,
-- The JSON body. Text so that we preserve UTF-8.
json_body TEXT NOT NULL
);
`
// const queueJSONSchema = `
// -- The queue_retry_json table contains event contents that
// -- we failed to send.
// CREATE TABLE IF NOT EXISTS federationsender_queue_json (
// -- The JSON NID. This allows the federationsender_queue_retry table to
// -- cross-reference to find the JSON blob.
// json_nid INTEGER PRIMARY KEY AUTOINCREMENT,
// -- The JSON body. Text so that we preserve UTF-8.
// json_body TEXT NOT NULL
// );
// `
const insertJSONSQL = "" +
"INSERT INTO federationsender_queue_json (json_body)" +
" VALUES ($1)"
const deleteJSONSQL = "" +
"DELETE FROM federationsender_queue_json WHERE json_nid IN ($1)"
const selectJSONSQL = "" +
"SELECT json_nid, json_body FROM federationsender_queue_json" +
" WHERE json_nid IN ($1)"
type queueJSONStatements struct {
db *sql.DB
insertJSONStmt *sql.Stmt
//deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic
//selectJSONStmt *sql.Stmt - prepared at runtime due to variadic
type QueueJSONCosmos struct {
JSONNID int64 `json:"json_nid"`
JSONBody []byte `json:"json_body"`
}
func NewSQLiteQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) {
type QueueJSONCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
QueueJSON QueueJSONCosmos `json:"mx_federationsender_queue_json"`
}
// const insertJSONSQL = "" +
// "INSERT INTO federationsender_queue_json (json_body)" +
// " VALUES ($1)"
// "DELETE FROM federationsender_queue_json WHERE json_nid IN ($1)"
const deleteJSONSQL = "" +
"select * from c where c._cn = @x1 " +
"and ARRAY_CONTAINS(@x2, c.mx_federationsender_queue_json.json_nid) "
// "SELECT json_nid, json_body FROM federationsender_queue_json" +
// " WHERE json_nid IN ($1)"
const selectJSONSQL = "" +
"select * from c where c._cn = @x1 " +
"and ARRAY_CONTAINS(@x2, c.mx_federationsender_queue_json.json_nid) "
type queueJSONStatements struct {
db *Database
// insertJSONStmt *sql.Stmt
//deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic
//selectJSONStmt *sql.Stmt - prepared at runtime due to variadic
tableName string
}
func queryQueueJSON(s *queueJSONStatements, ctx context.Context, qry string, params map[string]interface{}) ([]QueueJSONCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []QueueJSONCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func deleteQueueJSON(s *queueJSONStatements, ctx context.Context, dbData QueueJSONCosmosData) error {
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData.Id,
options)
if err != nil {
return err
}
return err
}
func NewCosmosDBQueueJSONTable(db *Database) (s *queueJSONStatements, err error) {
s = &queueJSONStatements{
db: db,
}
_, err = db.Exec(queueJSONSchema)
if err != nil {
return
}
if s.insertJSONStmt, err = db.Prepare(insertJSONSQL); err != nil {
return
}
s.tableName = "queue_jsons"
return
}
func (s *queueJSONStatements) InsertQueueJSON(
ctx context.Context, txn *sql.Tx, json string,
) (lastid int64, err error) {
stmt := sqlutil.TxStmt(txn, s.insertJSONStmt)
res, err := stmt.ExecContext(ctx, json)
// "INSERT INTO federationsender_queue_json (json_body)" +
// " VALUES ($1)"
// json_nid INTEGER PRIMARY KEY AUTOINCREMENT,
idSeq, err := GetNextQueueJSONNID(s, ctx)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// json_nid INTEGER PRIMARY KEY AUTOINCREMENT,
docId := fmt.Sprintf("%d", idSeq)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
//Convert to byte
jsonData := []byte(json)
data := QueueJSONCosmos{
JSONNID: idSeq,
JSONBody: jsonData,
}
dbData := &QueueJSONCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
QueueJSON: data,
}
// stmt := sqlutil.TxStmt(txn, s.insertJSONStmt)
// res, err := stmt.ExecContext(ctx, json)
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
if err != nil {
return 0, fmt.Errorf("stmt.QueryContext: %w", err)
}
lastid, err = res.LastInsertId()
if err != nil {
return 0, fmt.Errorf("res.LastInsertId: %w", err)
}
lastid = idSeq
return
}
func (s *queueJSONStatements) DeleteQueueJSON(
ctx context.Context, txn *sql.Tx, nids []int64,
) error {
deleteSQL := strings.Replace(deleteJSONSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
deleteStmt, err := txn.Prepare(deleteSQL)
// "DELETE FROM federationsender_queue_json WHERE json_nid IN ($1)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": nids,
}
// deleteSQL := strings.Replace(deleteJSONSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
// deleteStmt, err := txn.Prepare(deleteSQL)
// stmt := sqlutil.TxStmt(txn, deleteStmt)
rows, err := queryQueueJSON(s, ctx, deleteJSONSQL, params)
if err != nil {
return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err)
return err
}
iNIDs := make([]interface{}, len(nids))
for k, v := range nids {
iNIDs[k] = v
}
// iNIDs := make([]interface{}, len(nids))
// for k, v := range nids {
// iNIDs[k] = v
// }
stmt := sqlutil.TxStmt(txn, deleteStmt)
_, err = stmt.ExecContext(ctx, iNIDs...)
for _, item := range rows {
err = deleteQueueJSON(s, ctx, item)
}
return err
}
func (s *queueJSONStatements) SelectQueueJSON(
ctx context.Context, txn *sql.Tx, jsonNIDs []int64,
) (map[int64][]byte, error) {
selectSQL := strings.Replace(selectJSONSQL, "($1)", sqlutil.QueryVariadic(len(jsonNIDs)), 1)
selectStmt, err := txn.Prepare(selectSQL)
// "SELECT json_nid, json_body FROM federationsender_queue_json" +
// " WHERE json_nid IN ($1)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": jsonNIDs,
}
// selectSQL := strings.Replace(selectJSONSQL, "($1)", sqlutil.QueryVariadic(len(jsonNIDs)), 1)
// selectStmt, err := txn.Prepare(selectSQL)
rows, err := queryQueueJSON(s, ctx, selectJSONSQL, params)
if err != nil {
return nil, fmt.Errorf("s.selectQueueJSON s.db.Prepare: %w", err)
return nil, fmt.Errorf("s.selectQueueJSON stmt.QueryContext: %w", err)
}
iNIDs := make([]interface{}, len(jsonNIDs))
@ -118,18 +226,11 @@ func (s *queueJSONStatements) SelectQueueJSON(
}
blobs := map[int64][]byte{}
stmt := sqlutil.TxStmt(txn, selectStmt)
rows, err := stmt.QueryContext(ctx, iNIDs...)
if err != nil {
return nil, fmt.Errorf("s.selectQueueJSON stmt.QueryContext: %w", err)
}
defer internal.CloseAndLogIfError(ctx, rows, "selectJSON: rows.close() failed")
for rows.Next() {
for _, item := range rows {
var nid int64
var blob []byte
if err = rows.Scan(&nid, &blob); err != nil {
return nil, fmt.Errorf("s.selectQueueJSON rows.Scan: %w", err)
}
nid = item.QueueJSON.JSONNID
blob = item.QueueJSON.JSONBody
blobs[nid] = blob
}
return blobs, err

View file

@ -0,0 +1,12 @@
package cosmosdb
import (
"context"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
)
func GetNextQueueJSONNID(s *queueJSONStatements, ctx context.Context) (int64, error) {
const docId = "json_nid_seq"
return cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 1)
}

View file

@ -19,96 +19,188 @@ import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const queuePDUsSchema = `
CREATE TABLE IF NOT EXISTS federationsender_queue_pdus (
-- The transaction ID that was generated before persisting the event.
transaction_id TEXT NOT NULL,
-- The domain part of the user ID the m.room.member event is for.
server_name TEXT NOT NULL,
-- The JSON NID from the federationsender_queue_pdus_json table.
json_nid BIGINT NOT NULL
);
// const queuePDUsSchema = `
// CREATE TABLE IF NOT EXISTS federationsender_queue_pdus (
// -- The transaction ID that was generated before persisting the event.
// transaction_id TEXT NOT NULL,
// -- The domain part of the user ID the m.room.member event is for.
// server_name TEXT NOT NULL,
// -- The JSON NID from the federationsender_queue_pdus_json table.
// json_nid BIGINT NOT NULL
// );
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_pdus_pdus_json_nid_idx
ON federationsender_queue_pdus (json_nid, server_name);
`
// CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_pdus_pdus_json_nid_idx
// ON federationsender_queue_pdus (json_nid, server_name);
// `
const insertQueuePDUSQL = "" +
"INSERT INTO federationsender_queue_pdus (transaction_id, server_name, json_nid)" +
" VALUES ($1, $2, $3)"
const deleteQueuePDUsSQL = "" +
"DELETE FROM federationsender_queue_pdus WHERE server_name = $1 AND json_nid IN ($2)"
const selectQueueNextTransactionIDSQL = "" +
"SELECT transaction_id FROM federationsender_queue_pdus" +
" WHERE server_name = $1" +
" ORDER BY transaction_id ASC" +
" LIMIT 1"
const selectQueuePDUsSQL = "" +
"SELECT json_nid FROM federationsender_queue_pdus" +
" WHERE server_name = $1" +
" LIMIT $2"
const selectQueuePDUsReferenceJSONCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
" WHERE json_nid = $1"
const selectQueuePDUsCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
" WHERE server_name = $1"
const selectQueuePDUsServerNamesSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_queue_pdus"
type queuePDUsStatements struct {
db *sql.DB
insertQueuePDUStmt *sql.Stmt
selectQueueNextTransactionIDStmt *sql.Stmt
selectQueuePDUsStmt *sql.Stmt
selectQueueReferenceJSONCountStmt *sql.Stmt
selectQueuePDUsCountStmt *sql.Stmt
selectQueueServerNamesStmt *sql.Stmt
// deleteQueuePDUsStmt *sql.Stmt - prepared at runtime due to variadic
type QueuePDUCosmos struct {
TransactionID string `json:"transaction_id"`
ServerName string `json:"server_name"`
JSONNID int64 `json:"json_nid"`
}
func NewSQLiteQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) {
type QueuePDUCosmosNumber struct {
Number int64 `json:"number"`
}
type QueuePDUCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
QueuePDU QueuePDUCosmos `json:"mx_federationsender_queue_pdu"`
}
// const insertQueuePDUSQL = "" +
// "INSERT INTO federationsender_queue_pdus (transaction_id, server_name, json_nid)" +
// " VALUES ($1, $2, $3)"
// "DELETE FROM federationsender_queue_pdus WHERE server_name = $1 AND json_nid IN ($2)"
const deleteQueuePDUsSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_federationsender_queue_pdu.server_name = @x2 " +
"and ARRAY_CONTAINS(@x3, c.mx_federationsender_queue_pdu.json_nid) "
// "SELECT transaction_id FROM federationsender_queue_pdus" +
// " WHERE server_name = $1" +
// " ORDER BY transaction_id ASC" +
// " LIMIT 1"
const selectQueueNextTransactionIDSQL = "" +
"select top 1 * from c where c._cn = @x1 " +
"and c.mx_federationsender_queue_pdu.server_name = @x2 " +
"order by c.mx_federationsender_queue_pdu.transaction_id asc "
// "SELECT json_nid FROM federationsender_queue_pdus" +
// " WHERE server_name = $1" +
// " LIMIT $2"
const selectQueuePDUsSQL = "" +
"select top @x3 * from c where c._cn = @x1 " +
"and c.mx_federationsender_queue_pdu.server_name = @x2 "
// "SELECT COUNT(*) FROM federationsender_queue_pdus" +
// " WHERE json_nid = $1"
const selectQueuePDUsReferenceJSONCountSQL = "" +
"select count(c._ts) as number from c where c._cn = @x1 " +
"and c.mx_federationsender_queue_pdu.json_nid = @x2 "
// "SELECT COUNT(*) FROM federationsender_queue_pdus" +
// " WHERE server_name = $1"
const selectQueuePDUsCountSQL = "" +
"select count(c._ts) as number from c where c._cn = @x1 " +
"and c.mx_federationsender_queue_pdu.server_name = @x2 "
// "SELECT DISTINCT server_name FROM federationsender_queue_pdus"
const selectQueuePDUsServerNamesSQL = "" +
"select distinct c.mx_federationsender_queue_pdu.server_name from c where c._cn = @x1 "
type queuePDUsStatements struct {
db *Database
// insertQueuePDUStmt *sql.Stmt
selectQueueNextTransactionIDStmt string
selectQueuePDUsStmt string
selectQueueReferenceJSONCountStmt string
selectQueuePDUsCountStmt string
selectQueueServerNamesStmt string
// deleteQueuePDUsStmt *sql.Stmt - prepared at runtime due to variadic
tableName string
}
func queryQueuePDU(s *queuePDUsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]QueuePDUCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []QueuePDUCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func queryQueuePDUDistinct(s *queuePDUsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]QueuePDUCosmos, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []QueuePDUCosmos
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func queryQueuePDUNumber(s *queuePDUsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]QueuePDUCosmosNumber, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []QueuePDUCosmosNumber
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func deleteQueuePDU(s *queuePDUsStatements, ctx context.Context, dbData QueuePDUCosmosData) error {
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData.Id,
options)
if err != nil {
return err
}
return err
}
func NewCosmosDBQueuePDUsTable(db *Database) (s *queuePDUsStatements, err error) {
s = &queuePDUsStatements{
db: db,
}
_, err = db.Exec(queuePDUsSchema)
if err != nil {
return
}
if s.insertQueuePDUStmt, err = db.Prepare(insertQueuePDUSQL); err != nil {
return
}
//if s.deleteQueuePDUsStmt, err = db.Prepare(deleteQueuePDUsSQL); err != nil {
// return
//}
if s.selectQueueNextTransactionIDStmt, err = db.Prepare(selectQueueNextTransactionIDSQL); err != nil {
return
}
if s.selectQueuePDUsStmt, err = db.Prepare(selectQueuePDUsSQL); err != nil {
return
}
if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueuePDUsReferenceJSONCountSQL); err != nil {
return
}
if s.selectQueuePDUsCountStmt, err = db.Prepare(selectQueuePDUsCountSQL); err != nil {
return
}
if s.selectQueueServerNamesStmt, err = db.Prepare(selectQueuePDUsServerNamesSQL); err != nil {
return
}
s.selectQueueNextTransactionIDStmt = selectQueueNextTransactionIDSQL
s.selectQueuePDUsStmt = selectQueuePDUsSQL
s.selectQueueReferenceJSONCountStmt = selectQueuePDUsReferenceJSONCountSQL
s.selectQueuePDUsCountStmt = selectQueuePDUsCountSQL
s.selectQueueServerNamesStmt = selectQueuePDUsServerNamesSQL
s.tableName = "queue_pdus"
return
}
@ -119,13 +211,47 @@ func (s *queuePDUsStatements) InsertQueuePDU(
serverName gomatrixserverlib.ServerName,
nid int64,
) error {
stmt := sqlutil.TxStmt(txn, s.insertQueuePDUStmt)
_, err := stmt.ExecContext(
// "INSERT INTO federationsender_queue_pdus (transaction_id, server_name, json_nid)" +
// " VALUES ($1, $2, $3)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_pdus_pdus_json_nid_idx
// ON federationsender_queue_pdus (json_nid, server_name);
docId := fmt.Sprintf("%d_%s", nid, serverName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
data := QueuePDUCosmos{
JSONNID: nid,
ServerName: string(serverName),
TransactionID: string(transactionID),
}
dbData := &QueuePDUCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
QueuePDU: data,
}
// stmt := sqlutil.TxStmt(txn, s.insertQueuePDUStmt)
// _, err := stmt.ExecContext(
// ctx,
// transactionID, // the transaction ID that we initially attempted
// serverName, // destination server name
// nid, // JSON blob NID
// )
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
_, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
transactionID, // the transaction ID that we initially attempted
serverName, // destination server name
nid, // JSON blob NID
)
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
return err
}
@ -134,20 +260,31 @@ func (s *queuePDUsStatements) DeleteQueuePDUs(
serverName gomatrixserverlib.ServerName,
jsonNIDs []int64,
) error {
deleteSQL := strings.Replace(deleteQueuePDUsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1)
deleteStmt, err := txn.Prepare(deleteSQL)
// "DELETE FROM federationsender_queue_pdus WHERE server_name = $1 AND json_nid IN ($2)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": serverName,
"@x3": jsonNIDs,
}
// deleteSQL := strings.Replace(deleteQueuePDUsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1)
// deleteStmt, err := txn.Prepare(deleteSQL)
rows, err := queryQueuePDU(s, ctx, deleteQueuePDUsSQL, params)
if err != nil {
return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err)
return err
}
params := make([]interface{}, len(jsonNIDs)+1)
params[0] = serverName
for k, v := range jsonNIDs {
params[k+1] = v
for _, item := range rows {
// stmt := sqlutil.TxStmt(txn, deleteStmt)
err = deleteQueuePDU(s, ctx, item)
if err != nil {
return err
}
}
stmt := sqlutil.TxStmt(txn, deleteStmt)
_, err = stmt.ExecContext(ctx, params...)
return err
}
@ -155,11 +292,30 @@ func (s *queuePDUsStatements) SelectQueuePDUNextTransactionID(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (gomatrixserverlib.TransactionID, error) {
var transactionID gomatrixserverlib.TransactionID
stmt := sqlutil.TxStmt(txn, s.selectQueueNextTransactionIDStmt)
err := stmt.QueryRowContext(ctx, serverName).Scan(&transactionID)
if err == sql.ErrNoRows {
// "SELECT transaction_id FROM federationsender_queue_pdus" +
// " WHERE server_name = $1" +
// " ORDER BY transaction_id ASC" +
// " LIMIT 1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": serverName,
}
// stmt := sqlutil.TxStmt(txn, s.selectQueueNextTransactionIDStmt)
rows, err := queryQueuePDU(s, ctx, s.selectQueueNextTransactionIDStmt, params)
if err != nil {
return "", err
}
if len(rows) == 0 {
return "", nil
}
// err := stmt.QueryRowContext(ctx, serverName).Scan(&transactionID)
transactionID = gomatrixserverlib.TransactionID(rows[0].QueuePDU.TransactionID)
return transactionID, err
}
@ -167,11 +323,28 @@ func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount(
ctx context.Context, txn *sql.Tx, jsonNID int64,
) (int64, error) {
var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueueReferenceJSONCountStmt)
err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count)
if err == sql.ErrNoRows {
// "SELECT COUNT(*) FROM federationsender_queue_pdus" +
// " WHERE json_nid = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": jsonNID,
}
// stmt := sqlutil.TxStmt(txn, s.selectQueueReferenceJSONCountStmt)
rows, err := queryQueuePDUNumber(s, ctx, s.selectQueueReferenceJSONCountStmt, params)
if err != nil {
return -1, err
}
if len(rows) == 0 {
return -1, nil
}
// err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count)
count = rows[0].Number
return count, err
}
@ -179,14 +352,31 @@ func (s *queuePDUsStatements) SelectQueuePDUCount(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (int64, error) {
var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt)
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
if err == sql.ErrNoRows {
// "SELECT COUNT(*) FROM federationsender_queue_pdus" +
// " WHERE server_name = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": serverName,
}
// stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt)
rows, err := queryQueuePDUNumber(s, ctx, s.selectQueuePDUsCountStmt, params)
if err != nil {
return 0, err
}
if len(rows) == 0 {
// It's acceptable for there to be no rows referencing a given
// JSON NID but it's not an error condition. Just return as if
// there's a zero count.
return 0, nil
}
// err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
count = rows[0].Number
return count, err
}
@ -195,41 +385,58 @@ func (s *queuePDUsStatements) SelectQueuePDUs(
serverName gomatrixserverlib.ServerName,
limit int,
) ([]int64, error) {
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsStmt)
rows, err := stmt.QueryContext(ctx, serverName, limit)
// "SELECT json_nid FROM federationsender_queue_pdus" +
// " WHERE server_name = $1" +
// " LIMIT $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": serverName,
"@x3": limit,
}
// stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsStmt)
// rows, err := stmt.QueryContext(ctx, serverName, limit)
rows, err := queryQueuePDU(s, ctx, s.selectQueuePDUsStmt, params)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
var result []int64
for rows.Next() {
for _, item := range rows {
var nid int64
if err = rows.Scan(&nid); err != nil {
return nil, err
}
nid = item.QueuePDU.JSONNID
result = append(result, nid)
}
return result, rows.Err()
return result, nil
}
func (s *queuePDUsStatements) SelectQueuePDUServerNames(
ctx context.Context, txn *sql.Tx,
) ([]gomatrixserverlib.ServerName, error) {
stmt := sqlutil.TxStmt(txn, s.selectQueueServerNamesStmt)
rows, err := stmt.QueryContext(ctx)
// "SELECT DISTINCT server_name FROM federationsender_queue_pdus"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
}
// stmt := sqlutil.TxStmt(txn, s.selectQueueServerNamesStmt)
// rows, err := stmt.QueryContext(ctx)
rows, err := queryQueuePDUDistinct(s, ctx, s.selectQueueServerNamesStmt, params)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
var result []gomatrixserverlib.ServerName
for rows.Next() {
for _, item := range rows {
var serverName gomatrixserverlib.ServerName
if err = rows.Scan(&serverName); err != nil {
return nil, err
}
serverName = gomatrixserverlib.ServerName(item.ServerName)
result = append(result, serverName)
}
return result, rows.Err()
return result, nil
}

View file

@ -16,68 +16,74 @@
package cosmosdb
import (
"database/sql"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
_ "github.com/mattn/go-sqlite3"
"github.com/matrix-org/dendrite/federationsender/storage/shared"
"github.com/matrix-org/dendrite/federationsender/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
)
// Database stores information needed by the federation sender
type Database struct {
shared.Database
sqlutil.PartitionOffsetStatements
db *sql.DB
writer sqlutil.Writer
cosmosdbutil.PartitionOffsetStatements
database cosmosdbutil.Database
writer cosmosdbutil.Writer
connection cosmosdbapi.CosmosConnection
databaseName string
cosmosConfig cosmosdbapi.CosmosConfig
}
// NewDatabase opens a new database
func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationSenderCache) (*Database, error) {
conn := cosmosdbutil.GetCosmosConnection(&dbProperties.ConnectionString)
configCosmos := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString)
var d Database
d.connection = conn
d.cosmosConfig = configCosmos
d.databaseName = "federationsender"
d.database = cosmosdbutil.Database{
Connection: conn,
CosmosConfig: configCosmos,
DatabaseName: d.databaseName,
}
var err error
if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err
}
d.writer = sqlutil.NewExclusiveWriter()
joinedHosts, err := NewSQLiteJoinedHostsTable(d.db)
d.writer = cosmosdbutil.NewExclusiveWriterFake()
joinedHosts, err := NewCosmosDBJoinedHostsTable(&d)
if err != nil {
return nil, err
}
queuePDUs, err := NewSQLiteQueuePDUsTable(d.db)
queuePDUs, err := NewCosmosDBQueuePDUsTable(&d)
if err != nil {
return nil, err
}
queueEDUs, err := NewSQLiteQueueEDUsTable(d.db)
queueEDUs, err := NewCosmosDBQueueEDUsTable(&d)
if err != nil {
return nil, err
}
queueJSON, err := NewSQLiteQueueJSONTable(d.db)
queueJSON, err := NewCosmosDBQueueJSONTable(&d)
if err != nil {
return nil, err
}
blacklist, err := NewSQLiteBlacklistTable(d.db)
blacklist, err := NewCosmosDBBlacklistTable(&d)
if err != nil {
return nil, err
}
outboundPeeks, err := NewSQLiteOutboundPeeksTable(d.db)
outboundPeeks, err := NewCosmosDBOutboundPeeksTable(&d)
if err != nil {
return nil, err
}
inboundPeeks, err := NewSQLiteInboundPeeksTable(d.db)
inboundPeeks, err := NewCosmosDBInboundPeeksTable(&d)
if err != nil {
return nil, err
}
m := sqlutil.NewMigrations()
deltas.LoadRemoveRoomsTable(m)
if err = m.RunDeltas(d.db, dbProperties); err != nil {
return nil, err
}
d.Database = shared.Database{
DB: d.db,
DB: nil,
Cache: cache,
Writer: d.writer,
FederationSenderJoinedHosts: joinedHosts,
@ -88,7 +94,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationS
FederationSenderOutboundPeeks: outboundPeeks,
FederationSenderInboundPeeks: inboundPeeks,
}
if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "federationsender"); err != nil {
if err = d.PartitionOffsetStatements.Prepare(&d.database, d.writer, "federationsender"); err != nil {
return nil, err
}
return &d, nil