- Update the FederationSender Config to use CosmosDB (#9)

- Implement the tables to use Cosmos
- Update the Storage to use Cosmos
This commit is contained in:
alexfca 2021-05-28 15:00:15 +10:00 committed by GitHub
parent 3ca96b13b3
commit db08aa6250
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 1782 additions and 665 deletions

View file

@ -202,7 +202,7 @@ federation_sender:
listen: http://localhost:7775 listen: http://localhost:7775
connect: http://localhost:7775 connect: http://localhost:7775
database: database:
connection_string: file:federationsender.db connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=test.criticalarc.com;"
max_open_conns: 10 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1

View file

@ -17,54 +17,90 @@ package cosmosdb
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
const blacklistSchema = ` // const blacklistSchema = `
CREATE TABLE IF NOT EXISTS federationsender_blacklist ( // CREATE TABLE IF NOT EXISTS federationsender_blacklist (
-- The blacklisted server name // -- The blacklisted server name
server_name TEXT NOT NULL, // server_name TEXT NOT NULL,
UNIQUE (server_name) // UNIQUE (server_name)
); // );
` // `
const insertBlacklistSQL = "" + type BlacklistCosmos struct {
"INSERT INTO federationsender_blacklist (server_name) VALUES ($1)" + ServerName string `json:"server_name"`
" ON CONFLICT DO NOTHING"
const selectBlacklistSQL = "" +
"SELECT server_name FROM federationsender_blacklist WHERE server_name = $1"
const deleteBlacklistSQL = "" +
"DELETE FROM federationsender_blacklist WHERE server_name = $1"
type blacklistStatements struct {
db *sql.DB
insertBlacklistStmt *sql.Stmt
selectBlacklistStmt *sql.Stmt
deleteBlacklistStmt *sql.Stmt
} }
func NewSQLiteBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) { type BlacklistCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
Blacklist BlacklistCosmos `json:"mx_federationsender_blacklist"`
}
// const insertBlacklistSQL = "" +
// "INSERT INTO federationsender_blacklist (server_name) VALUES ($1)" +
// " ON CONFLICT DO NOTHING"
// const selectBlacklistSQL = "" +
// "SELECT server_name FROM federationsender_blacklist WHERE server_name = $1"
// const deleteBlacklistSQL = "" +
// "DELETE FROM federationsender_blacklist WHERE server_name = $1"
type blacklistStatements struct {
db *Database
// insertBlacklistStmt *sql.Stmt
// selectBlacklistStmt *sql.Stmt
// deleteBlacklistStmt *sql.Stmt
tableName string
}
func getBlacklist(s *blacklistStatements, ctx context.Context, pk string, docId string) (*BlacklistCosmosData, error) {
response := BlacklistCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
ctx,
pk,
docId,
&response)
if response.Id == "" {
return nil, nil
}
return &response, err
}
func deleteBlacklist(s *blacklistStatements, ctx context.Context, dbData BlacklistCosmosData) error {
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData.Id,
options)
if err != nil {
return err
}
return err
}
func NewCosmosDBBlacklistTable(db *Database) (s *blacklistStatements, err error) {
s = &blacklistStatements{ s = &blacklistStatements{
db: db, db: db,
} }
_, err = db.Exec(blacklistSchema) s.tableName = "blacklists"
if err != nil {
return
}
if s.insertBlacklistStmt, err = db.Prepare(insertBlacklistSQL); err != nil {
return
}
if s.selectBlacklistStmt, err = db.Prepare(selectBlacklistSQL); err != nil {
return
}
if s.deleteBlacklistStmt, err = db.Prepare(deleteBlacklistSQL); err != nil {
return
}
return return
} }
@ -73,8 +109,40 @@ func NewSQLiteBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) {
func (s *blacklistStatements) InsertBlacklist( func (s *blacklistStatements) InsertBlacklist(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt)
_, err := stmt.ExecContext(ctx, serverName) // "INSERT INTO federationsender_blacklist (server_name) VALUES ($1)" +
// " ON CONFLICT DO NOTHING"
// stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (server_name)
docId := fmt.Sprintf("%s", serverName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
data := BlacklistCosmos{
ServerName: string(serverName),
}
dbData := &BlacklistCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
Blacklist: data,
}
// _, err := stmt.ExecContext(ctx, serverName)
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
_, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
return err return err
} }
@ -84,16 +152,24 @@ func (s *blacklistStatements) InsertBlacklist(
func (s *blacklistStatements) SelectBlacklist( func (s *blacklistStatements) SelectBlacklist(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (bool, error) { ) (bool, error) {
stmt := sqlutil.TxStmt(txn, s.selectBlacklistStmt) // "SELECT server_name FROM federationsender_blacklist WHERE server_name = $1"
res, err := stmt.QueryContext(ctx, serverName)
// stmt := sqlutil.TxStmt(txn, s.selectBlacklistStmt)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (server_name)
docId := fmt.Sprintf("%s", serverName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
// res, err := stmt.QueryContext(ctx, serverName)
res, err := getBlacklist(s, ctx, pk, cosmosDocId)
if err != nil { if err != nil {
return false, err return false, err
} }
defer res.Close() // nolint:errcheck
// The query will return the server name if the server is blacklisted, and // The query will return the server name if the server is blacklisted, and
// will return no rows if not. By calling Next, we find out if a row was // will return no rows if not. By calling Next, we find out if a row was
// returned or not - we don't care about the value itself. // returned or not - we don't care about the value itself.
return res.Next(), nil return res != nil, nil
} }
// updateRoom updates the last_event_id for the room. selectRoomForUpdate should // updateRoom updates the last_event_id for the room. selectRoomForUpdate should
@ -101,7 +177,18 @@ func (s *blacklistStatements) SelectBlacklist(
func (s *blacklistStatements) DeleteBlacklist( func (s *blacklistStatements) DeleteBlacklist(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt) // "DELETE FROM federationsender_blacklist WHERE server_name = $1"
_, err := stmt.ExecContext(ctx, serverName)
// stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (server_name)
docId := fmt.Sprintf("%s", serverName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
// _, err := stmt.ExecContext(ctx, serverName)
res, err := getBlacklist(s, ctx, pk, cosmosDocId)
if(res != nil) {
_ = deleteBlacklist(s, ctx, *res)
}
return err return err
} }

View file

@ -17,90 +17,198 @@ package cosmosdb
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"time" "time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/federationsender/types" "github.com/matrix-org/dendrite/federationsender/types"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
const inboundPeeksSchema = ` // const inboundPeeksSchema = `
CREATE TABLE IF NOT EXISTS federationsender_inbound_peeks ( // CREATE TABLE IF NOT EXISTS federationsender_inbound_peeks (
room_id TEXT NOT NULL, // room_id TEXT NOT NULL,
server_name TEXT NOT NULL, // server_name TEXT NOT NULL,
peek_id TEXT NOT NULL, // peek_id TEXT NOT NULL,
creation_ts INTEGER NOT NULL, // creation_ts INTEGER NOT NULL,
renewed_ts INTEGER NOT NULL, // renewed_ts INTEGER NOT NULL,
renewal_interval INTEGER NOT NULL, // renewal_interval INTEGER NOT NULL,
UNIQUE (room_id, server_name, peek_id) // UNIQUE (room_id, server_name, peek_id)
); // );
` // `
const insertInboundPeekSQL = "" + type InboundPeekCosmos struct {
"INSERT INTO federationsender_inbound_peeks (room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval) VALUES ($1, $2, $3, $4, $5, $6)" RoomID string `json:"room_id"`
ServerName string `json:"server_name"`
const selectInboundPeekSQL = "" + PeekID string `json:"peek_id"`
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" CreationTimestamp int64 `json:"creation_ts"`
RenewedTimestamp int64 `json:"renewed_ts"`
const selectInboundPeeksSQL = "" + RenewalInterval int64 `json:"renewal_interval"`
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1"
const renewInboundPeekSQL = "" +
"UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
const deleteInboundPeekSQL = "" +
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2"
const deleteInboundPeeksSQL = "" +
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1"
type inboundPeeksStatements struct {
db *sql.DB
insertInboundPeekStmt *sql.Stmt
selectInboundPeekStmt *sql.Stmt
selectInboundPeeksStmt *sql.Stmt
renewInboundPeekStmt *sql.Stmt
deleteInboundPeekStmt *sql.Stmt
deleteInboundPeeksStmt *sql.Stmt
} }
func NewSQLiteInboundPeeksTable(db *sql.DB) (s *inboundPeeksStatements, err error) { type InboundPeekCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
InboundPeek InboundPeekCosmos `json:"mx_federationsender_inbound_peek"`
}
// const insertInboundPeekSQL = "" +
// "INSERT INTO federationsender_inbound_peeks (room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval) VALUES ($1, $2, $3, $4, $5, $6)"
// const selectInboundPeekSQL = "" +
// "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
// "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1"
const selectInboundPeeksSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_federationsender_inbound_peek.room_id = @x2"
// const renewInboundPeekSQL = "" +
// "UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
// "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2"
const deleteInboundPeekSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_federationsender_inbound_peek.room_id = @x2" +
"and c.mx_federationsender_inbound_peek.server_name = @x3"
// "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1"
const deleteInboundPeeksSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_federationsender_inbound_peek.room_id = @x2"
type inboundPeeksStatements struct {
db *Database
// insertInboundPeekStmt *sql.Stmt
// selectInboundPeekStmt *sql.Stmt
selectInboundPeeksStmt string
// renewInboundPeekStmt string
deleteInboundPeekStmt string
deleteInboundPeeksStmt string
tableName string
}
func queryInboundPeek(s *inboundPeeksStatements, ctx context.Context, qry string, params map[string]interface{}) ([]InboundPeekCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []InboundPeekCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func getInboundPeek(s *inboundPeeksStatements, ctx context.Context, pk string, docId string) (*InboundPeekCosmosData, error) {
response := InboundPeekCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
ctx,
pk,
docId,
&response)
if response.Id == "" {
return nil, nil
}
return &response, err
}
func setInboundPeek(s *inboundPeeksStatements, ctx context.Context, inboundPeek InboundPeekCosmosData) (*InboundPeekCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(inboundPeek.Pk, inboundPeek.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
inboundPeek.Id,
&inboundPeek,
optionsReplace)
return &inboundPeek, ex
}
func deleteInboundPeek(s *inboundPeeksStatements, ctx context.Context, dbData InboundPeekCosmosData) error {
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData.Id,
options)
if err != nil {
return err
}
return err
}
func NewCosmosDBInboundPeeksTable(db *Database) (s *inboundPeeksStatements, err error) {
s = &inboundPeeksStatements{ s = &inboundPeeksStatements{
db: db, db: db,
} }
_, err = db.Exec(inboundPeeksSchema) s.selectInboundPeeksStmt = selectInboundPeeksSQL
if err != nil { s.deleteInboundPeeksStmt = deleteInboundPeeksSQL
return s.deleteInboundPeekStmt = deleteInboundPeekSQL
} s.tableName = "inbound_peeks"
if s.insertInboundPeekStmt, err = db.Prepare(insertInboundPeekSQL); err != nil {
return
}
if s.selectInboundPeekStmt, err = db.Prepare(selectInboundPeekSQL); err != nil {
return
}
if s.selectInboundPeeksStmt, err = db.Prepare(selectInboundPeeksSQL); err != nil {
return
}
if s.renewInboundPeekStmt, err = db.Prepare(renewInboundPeekSQL); err != nil {
return
}
if s.deleteInboundPeeksStmt, err = db.Prepare(deleteInboundPeeksSQL); err != nil {
return
}
if s.deleteInboundPeekStmt, err = db.Prepare(deleteInboundPeekSQL); err != nil {
return
}
return return
} }
func (s *inboundPeeksStatements) InsertInboundPeek( func (s *inboundPeeksStatements) InsertInboundPeek(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64, ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64,
) (err error) { ) (err error) {
// "INSERT INTO federationsender_inbound_peeks (room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval) VALUES ($1, $2, $3, $4, $5, $6)"
nowMilli := time.Now().UnixNano() / int64(time.Millisecond) nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
stmt := sqlutil.TxStmt(txn, s.insertInboundPeekStmt) // stmt := sqlutil.TxStmt(txn, s.insertInboundPeekStmt)
_, err = stmt.ExecContext(ctx, roomID, serverName, peekID, nowMilli, nowMilli, renewalInterval)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (room_id, server_name, peek_id)
docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
data := InboundPeekCosmos{
RoomID: roomID,
ServerName: string(serverName),
PeekID: peekID,
CreationTimestamp: nowMilli,
RenewedTimestamp: nowMilli,
RenewalInterval: renewalInterval,
}
dbData := &InboundPeekCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
InboundPeek: data,
}
// _, err = stmt.ExecContext(ctx, roomID, serverName, peekID, nowMilli, nowMilli, renewalInterval)
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
return return
} }
@ -108,26 +216,58 @@ func (s *inboundPeeksStatements) RenewInboundPeek(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64, ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64,
) (err error) { ) (err error) {
nowMilli := time.Now().UnixNano() / int64(time.Millisecond) nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
_, err = sqlutil.TxStmt(txn, s.renewInboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID) // "UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
// _, err = sqlutil.TxStmt(txn, s.renewInboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (room_id, server_name, peek_id)
docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
// _, err = sqlutil.TxStmt(txn, s.renewInboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID)
res, err := getInboundPeek(s, ctx, pk, cosmosDocId)
if err != nil {
return
}
if res == nil {
return
}
res.InboundPeek.RenewedTimestamp = nowMilli
res.InboundPeek.RenewalInterval = renewalInterval
_, err = setInboundPeek(s, ctx, *res)
return return
} }
func (s *inboundPeeksStatements) SelectInboundPeek( func (s *inboundPeeksStatements) SelectInboundPeek(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string,
) (*types.InboundPeek, error) { ) (*types.InboundPeek, error) {
row := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryRowContext(ctx, roomID)
inboundPeek := types.InboundPeek{} // "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
err := row.Scan( var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
&inboundPeek.RoomID, // UNIQUE (room_id, server_name, peek_id)
&inboundPeek.ServerName, docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID)
&inboundPeek.PeekID, cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
&inboundPeek.CreationTimestamp, pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
&inboundPeek.RenewedTimestamp,
&inboundPeek.RenewalInterval, // row := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryRowContext(ctx, roomID)
) row, err := getInboundPeek(s, ctx, pk, cosmosDocId)
if err == sql.ErrNoRows {
if row == nil {
return nil, nil return nil, nil
} }
inboundPeek := types.InboundPeek{}
inboundPeek.RoomID = row.InboundPeek.RoomID
inboundPeek.ServerName = gomatrixserverlib.ServerName(row.InboundPeek.ServerName)
inboundPeek.PeekID = row.InboundPeek.PeekID
inboundPeek.CreationTimestamp = row.InboundPeek.CreationTimestamp
inboundPeek.RenewedTimestamp = row.InboundPeek.RenewedTimestamp
inboundPeek.RenewalInterval = row.InboundPeek.RenewalInterval
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -137,40 +277,87 @@ func (s *inboundPeeksStatements) SelectInboundPeek(
func (s *inboundPeeksStatements) SelectInboundPeeks( func (s *inboundPeeksStatements) SelectInboundPeeks(
ctx context.Context, txn *sql.Tx, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
) (inboundPeeks []types.InboundPeek, err error) { ) (inboundPeeks []types.InboundPeek, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryContext(ctx, roomID) // "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomID,
}
// rows, err := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryContext(ctx, roomID)
rows, err := queryInboundPeek(s, ctx, s.selectInboundPeeksStmt, params)
if err != nil { if err != nil {
return return
} }
defer internal.CloseAndLogIfError(ctx, rows, "SelectInboundPeeks: rows.close() failed")
for rows.Next() { for _, item := range rows {
inboundPeek := types.InboundPeek{} inboundPeek := types.InboundPeek{}
if err = rows.Scan( inboundPeek.RoomID = item.InboundPeek.RoomID
&inboundPeek.RoomID, inboundPeek.ServerName = gomatrixserverlib.ServerName(item.InboundPeek.ServerName)
&inboundPeek.ServerName, inboundPeek.PeekID = item.InboundPeek.PeekID
&inboundPeek.PeekID, inboundPeek.CreationTimestamp = item.InboundPeek.CreationTimestamp
&inboundPeek.CreationTimestamp, inboundPeek.RenewedTimestamp = item.InboundPeek.RenewedTimestamp
&inboundPeek.RenewedTimestamp, inboundPeek.RenewalInterval = item.InboundPeek.RenewalInterval
&inboundPeek.RenewalInterval,
); err != nil {
return
}
inboundPeeks = append(inboundPeeks, inboundPeek) inboundPeeks = append(inboundPeeks, inboundPeek)
} }
return inboundPeeks, rows.Err() return inboundPeeks, nil
} }
func (s *inboundPeeksStatements) DeleteInboundPeek( func (s *inboundPeeksStatements) DeleteInboundPeek(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string,
) (err error) { ) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteInboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID) // "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomID,
"@x3": serverName,
}
// _, err = sqlutil.TxStmt(txn, s.deleteInboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID)
rows, err := queryInboundPeek(s, ctx, s.deleteInboundPeekStmt, params)
if err != nil {
return
}
for _, item := range rows {
err = deleteInboundPeek(s, ctx, item)
if err != nil {
return
}
}
return return
} }
func (s *inboundPeeksStatements) DeleteInboundPeeks( func (s *inboundPeeksStatements) DeleteInboundPeeks(
ctx context.Context, txn *sql.Tx, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
) (err error) { ) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteInboundPeeksStmt).ExecContext(ctx, roomID) // "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomID,
}
// _, err = sqlutil.TxStmt(txn, s.deleteInboundPeeksStmt).ExecContext(ctx, roomID)
rows, err := queryInboundPeek(s, ctx, s.deleteInboundPeekStmt, params)
if err != nil {
return
}
for _, item := range rows {
err = deleteInboundPeek(s, ctx, item)
if err != nil {
return
}
}
return return
} }

View file

@ -18,87 +18,155 @@ package cosmosdb
import ( import (
"context" "context"
"database/sql" "database/sql"
"strings" "fmt"
"time"
"github.com/matrix-org/dendrite/federationsender/types" "github.com/matrix-org/dendrite/federationsender/types"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
const joinedHostsSchema = ` // const joinedHostsSchema = `
-- The joined_hosts table stores a list of m.room.member event ids in the // -- The joined_hosts table stores a list of m.room.member event ids in the
-- current state for each room where the membership is "join". // -- current state for each room where the membership is "join".
-- There will be an entry for every user that is joined to the room. // -- There will be an entry for every user that is joined to the room.
CREATE TABLE IF NOT EXISTS federationsender_joined_hosts ( // CREATE TABLE IF NOT EXISTS federationsender_joined_hosts (
-- The string ID of the room. // -- The string ID of the room.
room_id TEXT NOT NULL, // room_id TEXT NOT NULL,
-- The event ID of the m.room.member join event. // -- The event ID of the m.room.member join event.
event_id TEXT NOT NULL, // event_id TEXT NOT NULL,
-- The domain part of the user ID the m.room.member event is for. // -- The domain part of the user ID the m.room.member event is for.
server_name TEXT NOT NULL // server_name TEXT NOT NULL
); // );
CREATE UNIQUE INDEX IF NOT EXISTS federatonsender_joined_hosts_event_id_idx // CREATE UNIQUE INDEX IF NOT EXISTS federatonsender_joined_hosts_event_id_idx
ON federationsender_joined_hosts (event_id); // ON federationsender_joined_hosts (event_id);
CREATE INDEX IF NOT EXISTS federatonsender_joined_hosts_room_id_idx // CREATE INDEX IF NOT EXISTS federatonsender_joined_hosts_room_id_idx
ON federationsender_joined_hosts (room_id) // ON federationsender_joined_hosts (room_id)
` // `
const insertJoinedHostsSQL = "" + type JoinedHostCosmos struct {
"INSERT OR IGNORE INTO federationsender_joined_hosts (room_id, event_id, server_name)" + RoomID string `json:"room_id"`
" VALUES ($1, $2, $3)" EventID string `json:"event_id"`
ServerName string `json:"server_name"`
const deleteJoinedHostsSQL = "" +
"DELETE FROM federationsender_joined_hosts WHERE event_id = $1"
const deleteJoinedHostsForRoomSQL = "" +
"DELETE FROM federationsender_joined_hosts WHERE room_id = $1"
const selectJoinedHostsSQL = "" +
"SELECT event_id, server_name FROM federationsender_joined_hosts" +
" WHERE room_id = $1"
const selectAllJoinedHostsSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts"
const selectJoinedHostsForRoomsSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)"
type joinedHostsStatements struct {
db *sql.DB
insertJoinedHostsStmt *sql.Stmt
deleteJoinedHostsStmt *sql.Stmt
deleteJoinedHostsForRoomStmt *sql.Stmt
selectJoinedHostsStmt *sql.Stmt
selectAllJoinedHostsStmt *sql.Stmt
// selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic
} }
func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { type JoinedHostCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
JoinedHost JoinedHostCosmos `json:"mx_federationsender_joined_host"`
}
// const insertJoinedHostsSQL = "" +
// "INSERT OR IGNORE INTO federationsender_joined_hosts (room_id, event_id, server_name)" +
// " VALUES ($1, $2, $3)"
// "DELETE FROM federationsender_joined_hosts WHERE event_id = $1"
const deleteJoinedHostsSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_federationsender_joined_host.event_id = @x2 "
// "DELETE FROM federationsender_joined_hosts WHERE room_id = $1"
const deleteJoinedHostsForRoomSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_federationsender_joined_host.room_id = @x2 "
// "SELECT event_id, server_name FROM federationsender_joined_hosts" +
// " WHERE room_id = $1"
const selectJoinedHostsSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_federationsender_joined_host.room_id = @x2 "
// "SELECT DISTINCT server_name FROM federationsender_joined_hosts"
const selectAllJoinedHostsSQL = "" +
"select distinct c.mx_federationsender_joined_host.server_name from c where c._cn = @x1 "
// "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)"
const selectJoinedHostsForRoomsSQL = "" +
"select distinct c.mx_federationsender_joined_host.server_name from c where c._cn = @x1 " +
"and ARRAY_CONTAINS(@x2, c.mx_federationsender_joined_host.room_id) "
type joinedHostsStatements struct {
db *Database
// insertJoinedHostsStmt *sql.Stmt
deleteJoinedHostsStmt string
deleteJoinedHostsForRoomStmt string
selectJoinedHostsStmt string
selectAllJoinedHostsStmt string
// selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic
tableName string
}
func queryJoinedHostDistinct(s *joinedHostsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]JoinedHostCosmos, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []JoinedHostCosmos
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func queryJoinedHost(s *joinedHostsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]JoinedHostCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []JoinedHostCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func deleteJoinedHost(s *joinedHostsStatements, ctx context.Context, dbData JoinedHostCosmosData) error {
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData.Id,
options)
if err != nil {
return err
}
return err
}
func NewCosmosDBJoinedHostsTable(db *Database) (s *joinedHostsStatements, err error) {
s = &joinedHostsStatements{ s = &joinedHostsStatements{
db: db, db: db,
} }
_, err = db.Exec(joinedHostsSchema) s.deleteJoinedHostsStmt = deleteJoinedHostsSQL
if err != nil { s.deleteJoinedHostsForRoomStmt = deleteJoinedHostsForRoomSQL
return s.selectJoinedHostsStmt = selectJoinedHostsSQL
} s.selectAllJoinedHostsStmt = selectAllJoinedHostsSQL
if s.insertJoinedHostsStmt, err = db.Prepare(insertJoinedHostsSQL); err != nil { s.tableName = "joined_hosts"
return
}
if s.deleteJoinedHostsStmt, err = db.Prepare(deleteJoinedHostsSQL); err != nil {
return
}
if s.deleteJoinedHostsForRoomStmt, err = s.db.Prepare(deleteJoinedHostsForRoomSQL); err != nil {
return
}
if s.selectJoinedHostsStmt, err = db.Prepare(selectJoinedHostsSQL); err != nil {
return
}
if s.selectAllJoinedHostsStmt, err = db.Prepare(selectAllJoinedHostsSQL); err != nil {
return
}
return return
} }
@ -108,8 +176,43 @@ func (s *joinedHostsStatements) InsertJoinedHosts(
roomID, eventID string, roomID, eventID string,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt)
_, err := stmt.ExecContext(ctx, roomID, eventID, serverName) // "INSERT OR IGNORE INTO federationsender_joined_hosts (room_id, event_id, server_name)" +
// " VALUES ($1, $2, $3)"
// stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// CREATE UNIQUE INDEX IF NOT EXISTS federatonsender_joined_hosts_event_id_idx
// ON federationsender_joined_hosts (event_id);
docId := fmt.Sprintf("%s", eventID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
data := JoinedHostCosmos{
EventID: eventID,
RoomID: roomID,
ServerName: string(serverName),
}
dbData := &JoinedHostCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
JoinedHost: data,
}
// _, err := stmt.ExecContext(ctx, roomID, eventID, serverName)
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
_, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
return err return err
} }
@ -117,103 +220,146 @@ func (s *joinedHostsStatements) DeleteJoinedHosts(
ctx context.Context, txn *sql.Tx, eventIDs []string, ctx context.Context, txn *sql.Tx, eventIDs []string,
) error { ) error {
for _, eventID := range eventIDs { for _, eventID := range eventIDs {
stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt) // "DELETE FROM federationsender_joined_hosts WHERE event_id = $1"
if _, err := stmt.ExecContext(ctx, eventID); err != nil {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": eventID,
}
// stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt)
rows, err := queryJoinedHost(s, ctx, s.deleteJoinedHostsStmt, params)
for _, item := range rows {
if err = deleteJoinedHost(s, ctx, item); err != nil {
return err return err
} }
} }
}
return nil return nil
} }
func (s *joinedHostsStatements) DeleteJoinedHostsForRoom( func (s *joinedHostsStatements) DeleteJoinedHostsForRoom(
ctx context.Context, txn *sql.Tx, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsForRoomStmt) // "DELETE FROM federationsender_joined_hosts WHERE room_id = $1"
_, err := stmt.ExecContext(ctx, roomID)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomID,
}
// stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsForRoomStmt)
rows, err := queryJoinedHost(s, ctx, s.deleteJoinedHostsStmt, params)
// _, err := stmt.ExecContext(ctx, roomID)
for _, item := range rows {
if err = deleteJoinedHost(s, ctx, item); err != nil {
return err
}
}
return err return err
} }
func (s *joinedHostsStatements) SelectJoinedHostsWithTx( func (s *joinedHostsStatements) SelectJoinedHostsWithTx(
ctx context.Context, txn *sql.Tx, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
) ([]types.JoinedHost, error) { ) ([]types.JoinedHost, error) {
stmt := sqlutil.TxStmt(txn, s.selectJoinedHostsStmt) // "SELECT event_id, server_name FROM federationsender_joined_hosts" +
return joinedHostsFromStmt(ctx, stmt, roomID) // " WHERE room_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomID,
}
// stmt := sqlutil.TxStmt(txn, s.selectJoinedHostsStmt)
rows, err := queryJoinedHost(s, ctx, s.deleteJoinedHostsStmt, params)
if err != nil {
return nil, err
}
return rowsToJoinedHosts(&rows), nil
} }
func (s *joinedHostsStatements) SelectJoinedHosts( func (s *joinedHostsStatements) SelectJoinedHosts(
ctx context.Context, roomID string, ctx context.Context, roomID string,
) ([]types.JoinedHost, error) { ) ([]types.JoinedHost, error) {
return joinedHostsFromStmt(ctx, s.selectJoinedHostsStmt, roomID) return s.SelectJoinedHostsWithTx(ctx, nil, roomID)
} }
func (s *joinedHostsStatements) SelectAllJoinedHosts( func (s *joinedHostsStatements) SelectAllJoinedHosts(
ctx context.Context, ctx context.Context,
) ([]gomatrixserverlib.ServerName, error) { ) ([]gomatrixserverlib.ServerName, error) {
rows, err := s.selectAllJoinedHostsStmt.QueryContext(ctx) // "SELECT DISTINCT server_name FROM federationsender_joined_hosts"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
}
// rows, err := s.selectAllJoinedHostsStmt.QueryContext(ctx)
rows, err := queryJoinedHostDistinct(s, ctx, s.selectAllJoinedHostsStmt, params)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectAllJoinedHosts: rows.close() failed")
var result []gomatrixserverlib.ServerName var result []gomatrixserverlib.ServerName
for rows.Next() { for _, item := range rows {
var serverName string var serverName string
if err = rows.Scan(&serverName); err != nil { serverName = item.ServerName
return nil, err
}
result = append(result, gomatrixserverlib.ServerName(serverName)) result = append(result, gomatrixserverlib.ServerName(serverName))
} }
return result, rows.Err() return result, err
} }
func (s *joinedHostsStatements) SelectJoinedHostsForRooms( func (s *joinedHostsStatements) SelectJoinedHostsForRooms(
ctx context.Context, roomIDs []string, ctx context.Context, roomIDs []string,
) ([]gomatrixserverlib.ServerName, error) { ) ([]gomatrixserverlib.ServerName, error) {
iRoomIDs := make([]interface{}, len(roomIDs)) // iRoomIDs := make([]interface{}, len(roomIDs))
for i := range roomIDs { // for i := range roomIDs {
iRoomIDs[i] = roomIDs[i] // iRoomIDs[i] = roomIDs[i]
// }
// "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)"
// sql := strings.Replace(selectJoinedHostsForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomIDs,
} }
sql := strings.Replace(selectJoinedHostsForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1) // rows, err := s.db.QueryContext(ctx, sql, iRoomIDs...)
rows, err := s.db.QueryContext(ctx, sql, iRoomIDs...) rows, err := queryJoinedHostDistinct(s, ctx, s.selectAllJoinedHostsStmt, params)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedHostsForRoomsStmt: rows.close() failed")
var result []gomatrixserverlib.ServerName var result []gomatrixserverlib.ServerName
for rows.Next() { for _, item := range rows {
var serverName string var serverName string
if err = rows.Scan(&serverName); err != nil { serverName = item.ServerName
return nil, err
}
result = append(result, gomatrixserverlib.ServerName(serverName)) result = append(result, gomatrixserverlib.ServerName(serverName))
} }
return result, rows.Err()
}
func joinedHostsFromStmt(
ctx context.Context, stmt *sql.Stmt, roomID string,
) ([]types.JoinedHost, error) {
rows, err := stmt.QueryContext(ctx, roomID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "joinedHostsFromStmt: rows.close() failed")
var result []types.JoinedHost
for rows.Next() {
var eventID, serverName string
if err = rows.Scan(&eventID, &serverName); err != nil {
return nil, err
}
result = append(result, types.JoinedHost{
MemberEventID: eventID,
ServerName: gomatrixserverlib.ServerName(serverName),
})
}
return result, nil return result, nil
} }
func rowsToJoinedHosts(rows *[]JoinedHostCosmosData) []types.JoinedHost {
var result []types.JoinedHost
if rows == nil {
return result
}
for _, item := range *rows {
result = append(result, types.JoinedHost{
MemberEventID: item.JoinedHost.EventID,
ServerName: gomatrixserverlib.ServerName(item.JoinedHost.ServerName),
})
}
return result
}

View file

@ -17,160 +17,350 @@ package cosmosdb
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"time" "time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/federationsender/types" "github.com/matrix-org/dendrite/federationsender/types"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
const outboundPeeksSchema = ` // const outboundPeeksSchema = `
CREATE TABLE IF NOT EXISTS federationsender_outbound_peeks ( // CREATE TABLE IF NOT EXISTS federationsender_outbound_peeks (
room_id TEXT NOT NULL, // room_id TEXT NOT NULL,
server_name TEXT NOT NULL, // server_name TEXT NOT NULL,
peek_id TEXT NOT NULL, // peek_id TEXT NOT NULL,
creation_ts INTEGER NOT NULL, // creation_ts INTEGER NOT NULL,
renewed_ts INTEGER NOT NULL, // renewed_ts INTEGER NOT NULL,
renewal_interval INTEGER NOT NULL, // renewal_interval INTEGER NOT NULL,
UNIQUE (room_id, server_name, peek_id) // UNIQUE (room_id, server_name, peek_id)
); // );
` // `
const insertOutboundPeekSQL = "" + type OutboundPeekCosmos struct {
"INSERT INTO federationsender_outbound_peeks (room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval) VALUES ($1, $2, $3, $4, $5, $6)" RoomID string `json:"room_id"`
ServerName string `json:"server_name"`
const selectOutboundPeekSQL = "" + PeekID string `json:"peek_id"`
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" CreationTimestamp int64 `json:"creation_ts"`
RenewedTimestamp int64 `json:"renewed_ts"`
const selectOutboundPeeksSQL = "" + RenewalInterval int64 `json:"renewal_interval"`
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1"
const renewOutboundPeekSQL = "" +
"UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
const deleteOutboundPeekSQL = "" +
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2"
const deleteOutboundPeeksSQL = "" +
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1"
type outboundPeeksStatements struct {
db *sql.DB
insertOutboundPeekStmt *sql.Stmt
selectOutboundPeekStmt *sql.Stmt
selectOutboundPeeksStmt *sql.Stmt
renewOutboundPeekStmt *sql.Stmt
deleteOutboundPeekStmt *sql.Stmt
deleteOutboundPeeksStmt *sql.Stmt
} }
func NewSQLiteOutboundPeeksTable(db *sql.DB) (s *outboundPeeksStatements, err error) { type OutboundPeekCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
OutboundPeek OutboundPeekCosmos `json:"mx_federationsender_outbound_peek"`
}
// const insertOutboundPeekSQL = "" +
// "INSERT INTO federationsender_outbound_peeks (room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval) VALUES ($1, $2, $3, $4, $5, $6)"
// "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1"
const selectOutboundPeeksSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_federationsender_outbound_peek.room_id = @x2"
// const renewOutboundPeekSQL = "" +
// "UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
// "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2"
const deleteOutboundPeekSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_federationsender_outbound_peek.room_id = @x2" +
"and c.mx_federationsender_outbound_peek.server_name = @x3"
// "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1"
const deleteOutboundPeeksSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_federationsender_outbound_peek.room_id = @x2"
type outboundPeeksStatements struct {
db *Database
// insertOutboundPeekStmt *sql.Stmt
// selectOutboundPeekStmt *sql.Stmt
selectOutboundPeeksStmt string
// renewOutboundPeekStmt *sql.Stmt
deleteOutboundPeekStmt string
deleteOutboundPeeksStmt string
tableName string
}
func queryOutboundPeek(s *outboundPeeksStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OutboundPeekCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []OutboundPeekCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func getOutboundPeek(s *outboundPeeksStatements, ctx context.Context, pk string, docId string) (*OutboundPeekCosmosData, error) {
response := OutboundPeekCosmosData{}
err := cosmosdbapi.GetDocumentOrNil(
s.db.connection,
s.db.cosmosConfig,
ctx,
pk,
docId,
&response)
if response.Id == "" {
return nil, nil
}
return &response, err
}
func setOutboundPeek(s *outboundPeeksStatements, ctx context.Context, outboundPeek OutboundPeekCosmosData) (*OutboundPeekCosmosData, error) {
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(outboundPeek.Pk, outboundPeek.ETag)
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
outboundPeek.Id,
&outboundPeek,
optionsReplace)
return &outboundPeek, ex
}
func deleteOutboundPeek(s *outboundPeeksStatements, ctx context.Context, dbData OutboundPeekCosmosData) error {
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData.Id,
options)
if err != nil {
return err
}
return err
}
func NewCosmosDBOutboundPeeksTable(db *Database) (s *outboundPeeksStatements, err error) {
s = &outboundPeeksStatements{ s = &outboundPeeksStatements{
db: db, db: db,
} }
_, err = db.Exec(outboundPeeksSchema) s.selectOutboundPeeksStmt = selectOutboundPeeksSQL
if err != nil { s.deleteOutboundPeeksStmt = deleteOutboundPeeksSQL
return s.deleteOutboundPeekStmt = deleteOutboundPeekSQL
} s.tableName = "outbound_peeks"
if s.insertOutboundPeekStmt, err = db.Prepare(insertOutboundPeekSQL); err != nil {
return
}
if s.selectOutboundPeekStmt, err = db.Prepare(selectOutboundPeekSQL); err != nil {
return
}
if s.selectOutboundPeeksStmt, err = db.Prepare(selectOutboundPeeksSQL); err != nil {
return
}
if s.renewOutboundPeekStmt, err = db.Prepare(renewOutboundPeekSQL); err != nil {
return
}
if s.deleteOutboundPeeksStmt, err = db.Prepare(deleteOutboundPeeksSQL); err != nil {
return
}
if s.deleteOutboundPeekStmt, err = db.Prepare(deleteOutboundPeekSQL); err != nil {
return
}
return return
} }
func (s *outboundPeeksStatements) InsertOutboundPeek( func (s *outboundPeeksStatements) InsertOutboundPeek(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64, ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64,
) (err error) { ) (err error) {
// "INSERT INTO federationsender_outbound_peeks (room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval) VALUES ($1, $2, $3, $4, $5, $6)"
// stmt := sqlutil.TxStmt(txn, s.insertOutboundPeekStmt)
nowMilli := time.Now().UnixNano() / int64(time.Millisecond) nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
stmt := sqlutil.TxStmt(txn, s.insertOutboundPeekStmt) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
_, err = stmt.ExecContext(ctx, roomID, serverName, peekID, nowMilli, nowMilli, renewalInterval) // UNIQUE (room_id, server_name, peek_id)
docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
data := OutboundPeekCosmos{
RoomID: roomID,
ServerName: string(serverName),
PeekID: peekID,
CreationTimestamp: nowMilli,
RenewedTimestamp: nowMilli,
RenewalInterval: renewalInterval,
}
dbData := &OutboundPeekCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
OutboundPeek: data,
}
// _, err = stmt.ExecContext(ctx, roomID, serverName, peekID, nowMilli, nowMilli, renewalInterval)
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
return return
} }
func (s *outboundPeeksStatements) RenewOutboundPeek( func (s *outboundPeeksStatements) RenewOutboundPeek(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64, ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64,
) (err error) { ) (err error) {
// "UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
nowMilli := time.Now().UnixNano() / int64(time.Millisecond) nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
_, err = sqlutil.TxStmt(txn, s.renewOutboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (room_id, server_name, peek_id)
docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
// _, err = sqlutil.TxStmt(txn, s.renewOutboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID)
res, err := getOutboundPeek(s, ctx, pk, cosmosDocId)
if err != nil {
return
}
if res == nil {
return
}
res.OutboundPeek.RenewedTimestamp = nowMilli
res.OutboundPeek.RenewalInterval = renewalInterval
_, err = setOutboundPeek(s, ctx, *res)
return return
} }
func (s *outboundPeeksStatements) SelectOutboundPeek( func (s *outboundPeeksStatements) SelectOutboundPeek(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string,
) (*types.OutboundPeek, error) { ) (*types.OutboundPeek, error) {
row := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryRowContext(ctx, roomID)
outboundPeek := types.OutboundPeek{} // "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
err := row.Scan(
&outboundPeek.RoomID, var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
&outboundPeek.ServerName, // UNIQUE (room_id, server_name, peek_id)
&outboundPeek.PeekID, docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID)
&outboundPeek.CreationTimestamp, cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
&outboundPeek.RenewedTimestamp, pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
&outboundPeek.RenewalInterval,
) // row := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryRowContext(ctx, roomID)
if err == sql.ErrNoRows { row, err := getOutboundPeek(s, ctx, pk, cosmosDocId)
return nil, nil
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
if row == nil {
return nil, nil
}
outboundPeek := types.OutboundPeek{}
outboundPeek.RoomID = row.OutboundPeek.RoomID
outboundPeek.ServerName = gomatrixserverlib.ServerName(row.OutboundPeek.ServerName)
outboundPeek.PeekID = row.OutboundPeek.PeekID
outboundPeek.CreationTimestamp = row.OutboundPeek.CreationTimestamp
outboundPeek.RenewedTimestamp = row.OutboundPeek.RenewedTimestamp
outboundPeek.RenewalInterval = row.OutboundPeek.RenewalInterval
return &outboundPeek, nil return &outboundPeek, nil
} }
func (s *outboundPeeksStatements) SelectOutboundPeeks( func (s *outboundPeeksStatements) SelectOutboundPeeks(
ctx context.Context, txn *sql.Tx, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
) (outboundPeeks []types.OutboundPeek, err error) { ) (outboundPeeks []types.OutboundPeek, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryContext(ctx, roomID)
// "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1"
if err != nil { if err != nil {
return return
} }
defer internal.CloseAndLogIfError(ctx, rows, "SelectOutboundPeeks: rows.close() failed") var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomID,
}
for rows.Next() { // rows, err := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryContext(ctx, roomID)
outboundPeek := types.OutboundPeek{} rows, err := queryOutboundPeek(s, ctx, s.selectOutboundPeeksStmt, params)
if err = rows.Scan(
&outboundPeek.RoomID, if err != nil {
&outboundPeek.ServerName,
&outboundPeek.PeekID,
&outboundPeek.CreationTimestamp,
&outboundPeek.RenewedTimestamp,
&outboundPeek.RenewalInterval,
); err != nil {
return return
} }
for _, item := range rows {
outboundPeek := types.OutboundPeek{}
outboundPeek.RoomID = item.OutboundPeek.RoomID
outboundPeek.ServerName = gomatrixserverlib.ServerName(item.OutboundPeek.ServerName)
outboundPeek.PeekID = item.OutboundPeek.PeekID
outboundPeek.CreationTimestamp = item.OutboundPeek.CreationTimestamp
outboundPeek.RenewedTimestamp = item.OutboundPeek.RenewedTimestamp
outboundPeek.RenewalInterval = item.OutboundPeek.RenewalInterval
outboundPeeks = append(outboundPeeks, outboundPeek) outboundPeeks = append(outboundPeeks, outboundPeek)
} }
return outboundPeeks, rows.Err() return outboundPeeks, nil
} }
func (s *outboundPeeksStatements) DeleteOutboundPeek( func (s *outboundPeeksStatements) DeleteOutboundPeek(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string,
) (err error) { ) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteOutboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID)
// "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomID,
"@x3": serverName,
}
// _, err = sqlutil.TxStmt(txn, s.deleteOutboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID)
rows, err := queryOutboundPeek(s, ctx, s.deleteOutboundPeekStmt, params)
if err != nil {
return
}
for _, item := range rows {
err = deleteOutboundPeek(s, ctx, item)
if err != nil {
return
}
}
return return
} }
func (s *outboundPeeksStatements) DeleteOutboundPeeks( func (s *outboundPeeksStatements) DeleteOutboundPeeks(
ctx context.Context, txn *sql.Tx, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
) (err error) { ) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteOutboundPeeksStmt).ExecContext(ctx, roomID)
// "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": roomID,
}
// _, err = sqlutil.TxStmt(txn, s.deleteOutboundPeeksStmt).ExecContext(ctx, roomID)
rows, err := queryOutboundPeek(s, ctx, s.deleteOutboundPeeksStmt, params)
if err != nil {
return
}
for _, item := range rows {
err = deleteOutboundPeek(s, ctx, item)
if err != nil {
return
}
}
return return
} }

View file

@ -18,82 +18,176 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"strings" "time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
const queueEDUsSchema = ` // const queueEDUsSchema = `
CREATE TABLE IF NOT EXISTS federationsender_queue_edus ( // CREATE TABLE IF NOT EXISTS federationsender_queue_edus (
-- The type of the event (informational). // -- The type of the event (informational).
edu_type TEXT NOT NULL, // edu_type TEXT NOT NULL,
-- The domain part of the user ID the EDU event is for. // -- The domain part of the user ID the EDU event is for.
server_name TEXT NOT NULL, // server_name TEXT NOT NULL,
-- The JSON NID from the federationsender_queue_edus_json table. // -- The JSON NID from the federationsender_queue_edus_json table.
json_nid BIGINT NOT NULL // json_nid BIGINT NOT NULL
); // );
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx // CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx
ON federationsender_queue_edus (json_nid, server_name); // ON federationsender_queue_edus (json_nid, server_name);
` // `
const insertQueueEDUSQL = "" + type QueueEDUCosmos struct {
"INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid)" + EDUType string `json:"edu_type"`
" VALUES ($1, $2, $3)" ServerName string `json:"server_name"`
JSONNID int64 `json:"json_nid"`
const deleteQueueEDUsSQL = "" +
"DELETE FROM federationsender_queue_edus WHERE server_name = $1 AND json_nid IN ($2)"
const selectQueueEDUSQL = "" +
"SELECT json_nid FROM federationsender_queue_edus" +
" WHERE server_name = $1" +
" LIMIT $2"
const selectQueueEDUReferenceJSONCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_edus" +
" WHERE json_nid = $1"
const selectQueueEDUCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_edus" +
" WHERE server_name = $1"
const selectQueueServerNamesSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_queue_edus"
type queueEDUsStatements struct {
db *sql.DB
insertQueueEDUStmt *sql.Stmt
selectQueueEDUStmt *sql.Stmt
selectQueueEDUReferenceJSONCountStmt *sql.Stmt
selectQueueEDUCountStmt *sql.Stmt
selectQueueEDUServerNamesStmt *sql.Stmt
} }
func NewSQLiteQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) { type QueueEDUCosmosNumber struct {
Number int64 `json:"number"`
}
type QueueEDUCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
QueueEDU QueueEDUCosmos `json:"mx_federationsender_queue_edu"`
}
// const insertQueueEDUSQL = "" +
// "INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid)" +
// " VALUES ($1, $2, $3)"
// "DELETE FROM federationsender_queue_edus WHERE server_name = $1 AND json_nid IN ($2)"
const deleteQueueEDUsSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_federationsender_queue_edu.server_name = @x2" +
"and ARRAY_CONTAINS(@x3, c.mx_federationsender_queue_edu.json_nid) "
// "SELECT json_nid FROM federationsender_queue_edus" +
// " WHERE server_name = $1" +
// " LIMIT $2"
const selectQueueEDUSQL = "" +
"select top @x3 * from c where c._cn = @x1 " +
"and c.mx_federationsender_queue_edu.server_name = @x2"
// "SELECT COUNT(*) FROM federationsender_queue_edus" +
// " WHERE json_nid = $1"
const selectQueueEDUReferenceJSONCountSQL = "" +
"select count(c._ts) as number from c where c._cn = @x1 " +
"and c.mx_federationsender_queue_edu.json_nid = @x2"
// "SELECT COUNT(*) FROM federationsender_queue_edus" +
// " WHERE server_name = $1"
const selectQueueEDUCountSQL = "" +
"select count(c._ts) as number from c where c._cn = @x1 " +
"and c.mx_federationsender_queue_edu.server_name = @x2"
// "SELECT DISTINCT server_name FROM federationsender_queue_edus"
const selectQueueServerNamesSQL = "" +
"select distinct c.mx_federationsender_queue_edu.server_name from c where c._cn = @x1 "
type queueEDUsStatements struct {
db *Database
// insertQueueEDUStmt *sql.Stmt
selectQueueEDUStmt string
selectQueueEDUReferenceJSONCountStmt string
selectQueueEDUCountStmt string
selectQueueEDUServerNamesStmt string
tableName string
}
func queryQueueEDUC(s *queueEDUsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]QueueEDUCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []QueueEDUCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func queryQueueEDUCDistinct(s *queueEDUsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]QueueEDUCosmos, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []QueueEDUCosmos
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func queryQueueEDUCNumber(s *queueEDUsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]QueueEDUCosmosNumber, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []QueueEDUCosmosNumber
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func deleteQueueEDUC(s *queueEDUsStatements, ctx context.Context, dbData QueueEDUCosmosData) error {
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData.Id,
options)
if err != nil {
return err
}
return err
}
func NewCosmosDBQueueEDUsTable(db *Database) (s *queueEDUsStatements, err error) {
s = &queueEDUsStatements{ s = &queueEDUsStatements{
db: db, db: db,
} }
_, err = db.Exec(queueEDUsSchema) s.selectQueueEDUStmt = selectQueueEDUSQL
if err != nil { s.selectQueueEDUReferenceJSONCountStmt = selectQueueEDUReferenceJSONCountSQL
return s.selectQueueEDUCountStmt = selectQueueEDUCountSQL
} s.selectQueueEDUServerNamesStmt = selectQueueServerNamesSQL
if s.insertQueueEDUStmt, err = db.Prepare(insertQueueEDUSQL); err != nil { s.tableName = "queue_edus"
return
}
if s.selectQueueEDUStmt, err = db.Prepare(selectQueueEDUSQL); err != nil {
return
}
if s.selectQueueEDUReferenceJSONCountStmt, err = db.Prepare(selectQueueEDUReferenceJSONCountSQL); err != nil {
return
}
if s.selectQueueEDUCountStmt, err = db.Prepare(selectQueueEDUCountSQL); err != nil {
return
}
if s.selectQueueEDUServerNamesStmt, err = db.Prepare(selectQueueServerNamesSQL); err != nil {
return
}
return return
} }
@ -104,13 +198,47 @@ func (s *queueEDUsStatements) InsertQueueEDU(
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
nid int64, nid int64,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt)
_, err := stmt.ExecContext( // "INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid)" +
// stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx
// ON federationsender_queue_edus (json_nid, server_name);
docId := fmt.Sprintf("%d_%s", nid, eduType)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
data := QueueEDUCosmos{
EDUType: eduType,
JSONNID: nid,
ServerName: string(serverName),
}
dbData := &QueueEDUCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
QueueEDU: data,
}
// _, err := stmt.ExecContext(
// ctx,
// eduType, // the EDU type
// serverName, // destination server name
// nid, // JSON blob NID
// )
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
_, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx, ctx,
eduType, // the EDU type s.db.cosmosConfig.DatabaseName,
serverName, // destination server name s.db.cosmosConfig.ContainerName,
nid, // JSON blob NID &dbData,
) options)
return err return err
} }
@ -119,20 +247,33 @@ func (s *queueEDUsStatements) DeleteQueueEDUs(
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
jsonNIDs []int64, jsonNIDs []int64,
) error { ) error {
deleteSQL := strings.Replace(deleteQueueEDUsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1)
deleteStmt, err := txn.Prepare(deleteSQL) // "DELETE FROM federationsender_queue_edus WHERE server_name = $1 AND json_nid IN ($2)"
// deleteSQL := strings.Replace(deleteQueueEDUsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": serverName,
"@x3": jsonNIDs,
}
// stmt := sqlutil.TxStmt(txn, deleteStmt)
// _, err = stmt.ExecContext(ctx, params...)
rows, err := queryQueueEDUC(s, ctx, deleteQueueEDUsSQL, params)
if err != nil { if err != nil {
return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err) return err
} }
params := make([]interface{}, len(jsonNIDs)+1) for _, item := range rows {
params[0] = serverName err = deleteQueueEDUC(s, ctx, item)
for k, v := range jsonNIDs { if err != nil {
params[k+1] = v return err
}
} }
stmt := sqlutil.TxStmt(txn, deleteStmt)
_, err = stmt.ExecContext(ctx, params...)
return err return err
} }
@ -141,18 +282,28 @@ func (s *queueEDUsStatements) SelectQueueEDUs(
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
limit int, limit int,
) ([]int64, error) { ) ([]int64, error) {
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUStmt)
rows, err := stmt.QueryContext(ctx, serverName, limit) // "SELECT json_nid FROM federationsender_queue_edus" +
// " WHERE server_name = $1" +
// " LIMIT $2"
// stmt := sqlutil.TxStmt(txn, s.selectQueueEDUStmt)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": serverName,
"@x3": limit,
}
// rows, err := stmt.QueryContext(ctx, serverName, limit)
rows, err := queryQueueEDUC(s, ctx, deleteQueueEDUsSQL, params)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
var result []int64 var result []int64
for rows.Next() { for _, item := range rows {
var nid int64 var nid int64
if err = rows.Scan(&nid); err != nil { nid = item.QueueEDU.JSONNID
return nil, err
}
result = append(result, nid) result = append(result, nid)
} }
return result, nil return result, nil
@ -162,11 +313,23 @@ func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount(
ctx context.Context, txn *sql.Tx, jsonNID int64, ctx context.Context, txn *sql.Tx, jsonNID int64,
) (int64, error) { ) (int64, error) {
var count int64 var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUReferenceJSONCountStmt)
err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count) // "SELECT COUNT(*) FROM federationsender_queue_edus" +
if err == sql.ErrNoRows { // " WHERE json_nid = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": jsonNID,
}
// stmt := sqlutil.TxStmt(txn, s.selectQueueEDUReferenceJSONCountStmt)
rows, err := queryQueueEDUCNumber(s, ctx, s.selectQueueEDUReferenceJSONCountStmt, params)
if len(rows) == 0 {
return -1, nil return -1, nil
} }
// err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count)
count = rows[0].Number
return count, err return count, err
} }
@ -174,34 +337,52 @@ func (s *queueEDUsStatements) SelectQueueEDUCount(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (int64, error) { ) (int64, error) {
var count int64 var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt)
err := stmt.QueryRowContext(ctx, serverName).Scan(&count) // "SELECT COUNT(*) FROM federationsender_queue_edus" +
if err == sql.ErrNoRows { // " WHERE server_name = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": serverName,
}
// stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt)
rows, err := queryQueueEDUCNumber(s, ctx, s.selectQueueEDUCountStmt, params)
if len(rows) == 0 {
// It's acceptable for there to be no rows referencing a given // It's acceptable for there to be no rows referencing a given
// JSON NID but it's not an error condition. Just return as if // JSON NID but it's not an error condition. Just return as if
// there's a zero count. // there's a zero count.
return 0, nil return 0, nil
} }
// err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
count = rows[0].Number
return count, err return count, err
} }
func (s *queueEDUsStatements) SelectQueueEDUServerNames( func (s *queueEDUsStatements) SelectQueueEDUServerNames(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
) ([]gomatrixserverlib.ServerName, error) { ) ([]gomatrixserverlib.ServerName, error) {
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUServerNamesStmt)
rows, err := stmt.QueryContext(ctx) // "SELECT DISTINCT server_name FROM federationsender_queue_edus"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
}
// stmt := sqlutil.TxStmt(txn, s.selectQueueEDUServerNamesStmt)
// rows, err := stmt.QueryContext(ctx)
rows, err := queryQueueEDUCDistinct(s, ctx, s.selectQueueEDUServerNamesStmt, params)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
var result []gomatrixserverlib.ServerName var result []gomatrixserverlib.ServerName
for rows.Next() { for _, item := range rows {
var serverName gomatrixserverlib.ServerName var serverName gomatrixserverlib.ServerName
if err = rows.Scan(&serverName); err != nil { serverName = gomatrixserverlib.ServerName(item.ServerName)
return nil, err
}
result = append(result, serverName) result = append(result, serverName)
} }
return result, rows.Err() return result, nil
} }

View file

@ -19,97 +19,205 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"strings" "time"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
const queueJSONSchema = ` // const queueJSONSchema = `
-- The queue_retry_json table contains event contents that // -- The queue_retry_json table contains event contents that
-- we failed to send. // -- we failed to send.
CREATE TABLE IF NOT EXISTS federationsender_queue_json ( // CREATE TABLE IF NOT EXISTS federationsender_queue_json (
-- The JSON NID. This allows the federationsender_queue_retry table to // -- The JSON NID. This allows the federationsender_queue_retry table to
-- cross-reference to find the JSON blob. // -- cross-reference to find the JSON blob.
json_nid INTEGER PRIMARY KEY AUTOINCREMENT, // json_nid INTEGER PRIMARY KEY AUTOINCREMENT,
-- The JSON body. Text so that we preserve UTF-8. // -- The JSON body. Text so that we preserve UTF-8.
json_body TEXT NOT NULL // json_body TEXT NOT NULL
); // );
` // `
const insertJSONSQL = "" + type QueueJSONCosmos struct {
"INSERT INTO federationsender_queue_json (json_body)" + JSONNID int64 `json:"json_nid"`
" VALUES ($1)" JSONBody []byte `json:"json_body"`
const deleteJSONSQL = "" +
"DELETE FROM federationsender_queue_json WHERE json_nid IN ($1)"
const selectJSONSQL = "" +
"SELECT json_nid, json_body FROM federationsender_queue_json" +
" WHERE json_nid IN ($1)"
type queueJSONStatements struct {
db *sql.DB
insertJSONStmt *sql.Stmt
//deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic
//selectJSONStmt *sql.Stmt - prepared at runtime due to variadic
} }
func NewSQLiteQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) { type QueueJSONCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
QueueJSON QueueJSONCosmos `json:"mx_federationsender_queue_json"`
}
// const insertJSONSQL = "" +
// "INSERT INTO federationsender_queue_json (json_body)" +
// " VALUES ($1)"
// "DELETE FROM federationsender_queue_json WHERE json_nid IN ($1)"
const deleteJSONSQL = "" +
"select * from c where c._cn = @x1 " +
"and ARRAY_CONTAINS(@x2, c.mx_federationsender_queue_json.json_nid) "
// "SELECT json_nid, json_body FROM federationsender_queue_json" +
// " WHERE json_nid IN ($1)"
const selectJSONSQL = "" +
"select * from c where c._cn = @x1 " +
"and ARRAY_CONTAINS(@x2, c.mx_federationsender_queue_json.json_nid) "
type queueJSONStatements struct {
db *Database
// insertJSONStmt *sql.Stmt
//deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic
//selectJSONStmt *sql.Stmt - prepared at runtime due to variadic
tableName string
}
func queryQueueJSON(s *queueJSONStatements, ctx context.Context, qry string, params map[string]interface{}) ([]QueueJSONCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []QueueJSONCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func deleteQueueJSON(s *queueJSONStatements, ctx context.Context, dbData QueueJSONCosmosData) error {
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData.Id,
options)
if err != nil {
return err
}
return err
}
func NewCosmosDBQueueJSONTable(db *Database) (s *queueJSONStatements, err error) {
s = &queueJSONStatements{ s = &queueJSONStatements{
db: db, db: db,
} }
_, err = db.Exec(queueJSONSchema) s.tableName = "queue_jsons"
if err != nil {
return
}
if s.insertJSONStmt, err = db.Prepare(insertJSONSQL); err != nil {
return
}
return return
} }
func (s *queueJSONStatements) InsertQueueJSON( func (s *queueJSONStatements) InsertQueueJSON(
ctx context.Context, txn *sql.Tx, json string, ctx context.Context, txn *sql.Tx, json string,
) (lastid int64, err error) { ) (lastid int64, err error) {
stmt := sqlutil.TxStmt(txn, s.insertJSONStmt)
res, err := stmt.ExecContext(ctx, json) // "INSERT INTO federationsender_queue_json (json_body)" +
// " VALUES ($1)"
// json_nid INTEGER PRIMARY KEY AUTOINCREMENT,
idSeq, err := GetNextQueueJSONNID(s, ctx)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// json_nid INTEGER PRIMARY KEY AUTOINCREMENT,
docId := fmt.Sprintf("%d", idSeq)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
//Convert to byte
jsonData := []byte(json)
data := QueueJSONCosmos{
JSONNID: idSeq,
JSONBody: jsonData,
}
dbData := &QueueJSONCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
QueueJSON: data,
}
// stmt := sqlutil.TxStmt(txn, s.insertJSONStmt)
// res, err := stmt.ExecContext(ctx, json)
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
if err != nil { if err != nil {
return 0, fmt.Errorf("stmt.QueryContext: %w", err) return 0, fmt.Errorf("stmt.QueryContext: %w", err)
} }
lastid, err = res.LastInsertId() lastid = idSeq
if err != nil {
return 0, fmt.Errorf("res.LastInsertId: %w", err)
}
return return
} }
func (s *queueJSONStatements) DeleteQueueJSON( func (s *queueJSONStatements) DeleteQueueJSON(
ctx context.Context, txn *sql.Tx, nids []int64, ctx context.Context, txn *sql.Tx, nids []int64,
) error { ) error {
deleteSQL := strings.Replace(deleteJSONSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
deleteStmt, err := txn.Prepare(deleteSQL) // "DELETE FROM federationsender_queue_json WHERE json_nid IN ($1)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": nids,
}
// deleteSQL := strings.Replace(deleteJSONSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
// deleteStmt, err := txn.Prepare(deleteSQL)
// stmt := sqlutil.TxStmt(txn, deleteStmt)
rows, err := queryQueueJSON(s, ctx, deleteJSONSQL, params)
if err != nil { if err != nil {
return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err) return err
} }
iNIDs := make([]interface{}, len(nids)) // iNIDs := make([]interface{}, len(nids))
for k, v := range nids { // for k, v := range nids {
iNIDs[k] = v // iNIDs[k] = v
} // }
stmt := sqlutil.TxStmt(txn, deleteStmt) for _, item := range rows {
_, err = stmt.ExecContext(ctx, iNIDs...) err = deleteQueueJSON(s, ctx, item)
}
return err return err
} }
func (s *queueJSONStatements) SelectQueueJSON( func (s *queueJSONStatements) SelectQueueJSON(
ctx context.Context, txn *sql.Tx, jsonNIDs []int64, ctx context.Context, txn *sql.Tx, jsonNIDs []int64,
) (map[int64][]byte, error) { ) (map[int64][]byte, error) {
selectSQL := strings.Replace(selectJSONSQL, "($1)", sqlutil.QueryVariadic(len(jsonNIDs)), 1)
selectStmt, err := txn.Prepare(selectSQL) // "SELECT json_nid, json_body FROM federationsender_queue_json" +
// " WHERE json_nid IN ($1)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": jsonNIDs,
}
// selectSQL := strings.Replace(selectJSONSQL, "($1)", sqlutil.QueryVariadic(len(jsonNIDs)), 1)
// selectStmt, err := txn.Prepare(selectSQL)
rows, err := queryQueueJSON(s, ctx, selectJSONSQL, params)
if err != nil { if err != nil {
return nil, fmt.Errorf("s.selectQueueJSON s.db.Prepare: %w", err) return nil, fmt.Errorf("s.selectQueueJSON stmt.QueryContext: %w", err)
} }
iNIDs := make([]interface{}, len(jsonNIDs)) iNIDs := make([]interface{}, len(jsonNIDs))
@ -118,18 +226,11 @@ func (s *queueJSONStatements) SelectQueueJSON(
} }
blobs := map[int64][]byte{} blobs := map[int64][]byte{}
stmt := sqlutil.TxStmt(txn, selectStmt) for _, item := range rows {
rows, err := stmt.QueryContext(ctx, iNIDs...)
if err != nil {
return nil, fmt.Errorf("s.selectQueueJSON stmt.QueryContext: %w", err)
}
defer internal.CloseAndLogIfError(ctx, rows, "selectJSON: rows.close() failed")
for rows.Next() {
var nid int64 var nid int64
var blob []byte var blob []byte
if err = rows.Scan(&nid, &blob); err != nil { nid = item.QueueJSON.JSONNID
return nil, fmt.Errorf("s.selectQueueJSON rows.Scan: %w", err) blob = item.QueueJSON.JSONBody
}
blobs[nid] = blob blobs[nid] = blob
} }
return blobs, err return blobs, err

View file

@ -0,0 +1,12 @@
package cosmosdb
import (
"context"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
)
func GetNextQueueJSONNID(s *queueJSONStatements, ctx context.Context) (int64, error) {
const docId = "json_nid_seq"
return cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 1)
}

View file

@ -19,96 +19,188 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"strings" "time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
const queuePDUsSchema = ` // const queuePDUsSchema = `
CREATE TABLE IF NOT EXISTS federationsender_queue_pdus ( // CREATE TABLE IF NOT EXISTS federationsender_queue_pdus (
-- The transaction ID that was generated before persisting the event. // -- The transaction ID that was generated before persisting the event.
transaction_id TEXT NOT NULL, // transaction_id TEXT NOT NULL,
-- The domain part of the user ID the m.room.member event is for. // -- The domain part of the user ID the m.room.member event is for.
server_name TEXT NOT NULL, // server_name TEXT NOT NULL,
-- The JSON NID from the federationsender_queue_pdus_json table. // -- The JSON NID from the federationsender_queue_pdus_json table.
json_nid BIGINT NOT NULL // json_nid BIGINT NOT NULL
); // );
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_pdus_pdus_json_nid_idx // CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_pdus_pdus_json_nid_idx
ON federationsender_queue_pdus (json_nid, server_name); // ON federationsender_queue_pdus (json_nid, server_name);
` // `
const insertQueuePDUSQL = "" + type QueuePDUCosmos struct {
"INSERT INTO federationsender_queue_pdus (transaction_id, server_name, json_nid)" + TransactionID string `json:"transaction_id"`
" VALUES ($1, $2, $3)" ServerName string `json:"server_name"`
JSONNID int64 `json:"json_nid"`
const deleteQueuePDUsSQL = "" +
"DELETE FROM federationsender_queue_pdus WHERE server_name = $1 AND json_nid IN ($2)"
const selectQueueNextTransactionIDSQL = "" +
"SELECT transaction_id FROM federationsender_queue_pdus" +
" WHERE server_name = $1" +
" ORDER BY transaction_id ASC" +
" LIMIT 1"
const selectQueuePDUsSQL = "" +
"SELECT json_nid FROM federationsender_queue_pdus" +
" WHERE server_name = $1" +
" LIMIT $2"
const selectQueuePDUsReferenceJSONCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
" WHERE json_nid = $1"
const selectQueuePDUsCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
" WHERE server_name = $1"
const selectQueuePDUsServerNamesSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_queue_pdus"
type queuePDUsStatements struct {
db *sql.DB
insertQueuePDUStmt *sql.Stmt
selectQueueNextTransactionIDStmt *sql.Stmt
selectQueuePDUsStmt *sql.Stmt
selectQueueReferenceJSONCountStmt *sql.Stmt
selectQueuePDUsCountStmt *sql.Stmt
selectQueueServerNamesStmt *sql.Stmt
// deleteQueuePDUsStmt *sql.Stmt - prepared at runtime due to variadic
} }
func NewSQLiteQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) { type QueuePDUCosmosNumber struct {
Number int64 `json:"number"`
}
type QueuePDUCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
QueuePDU QueuePDUCosmos `json:"mx_federationsender_queue_pdu"`
}
// const insertQueuePDUSQL = "" +
// "INSERT INTO federationsender_queue_pdus (transaction_id, server_name, json_nid)" +
// " VALUES ($1, $2, $3)"
// "DELETE FROM federationsender_queue_pdus WHERE server_name = $1 AND json_nid IN ($2)"
const deleteQueuePDUsSQL = "" +
"select * from c where c._cn = @x1 " +
"and c.mx_federationsender_queue_pdu.server_name = @x2 " +
"and ARRAY_CONTAINS(@x3, c.mx_federationsender_queue_pdu.json_nid) "
// "SELECT transaction_id FROM federationsender_queue_pdus" +
// " WHERE server_name = $1" +
// " ORDER BY transaction_id ASC" +
// " LIMIT 1"
const selectQueueNextTransactionIDSQL = "" +
"select top 1 * from c where c._cn = @x1 " +
"and c.mx_federationsender_queue_pdu.server_name = @x2 " +
"order by c.mx_federationsender_queue_pdu.transaction_id asc "
// "SELECT json_nid FROM federationsender_queue_pdus" +
// " WHERE server_name = $1" +
// " LIMIT $2"
const selectQueuePDUsSQL = "" +
"select top @x3 * from c where c._cn = @x1 " +
"and c.mx_federationsender_queue_pdu.server_name = @x2 "
// "SELECT COUNT(*) FROM federationsender_queue_pdus" +
// " WHERE json_nid = $1"
const selectQueuePDUsReferenceJSONCountSQL = "" +
"select count(c._ts) as number from c where c._cn = @x1 " +
"and c.mx_federationsender_queue_pdu.json_nid = @x2 "
// "SELECT COUNT(*) FROM federationsender_queue_pdus" +
// " WHERE server_name = $1"
const selectQueuePDUsCountSQL = "" +
"select count(c._ts) as number from c where c._cn = @x1 " +
"and c.mx_federationsender_queue_pdu.server_name = @x2 "
// "SELECT DISTINCT server_name FROM federationsender_queue_pdus"
const selectQueuePDUsServerNamesSQL = "" +
"select distinct c.mx_federationsender_queue_pdu.server_name from c where c._cn = @x1 "
type queuePDUsStatements struct {
db *Database
// insertQueuePDUStmt *sql.Stmt
selectQueueNextTransactionIDStmt string
selectQueuePDUsStmt string
selectQueueReferenceJSONCountStmt string
selectQueuePDUsCountStmt string
selectQueueServerNamesStmt string
// deleteQueuePDUsStmt *sql.Stmt - prepared at runtime due to variadic
tableName string
}
func queryQueuePDU(s *queuePDUsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]QueuePDUCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []QueuePDUCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func queryQueuePDUDistinct(s *queuePDUsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]QueuePDUCosmos, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []QueuePDUCosmos
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func queryQueuePDUNumber(s *queuePDUsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]QueuePDUCosmosNumber, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []QueuePDUCosmosNumber
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func deleteQueuePDU(s *queuePDUsStatements, ctx context.Context, dbData QueuePDUCosmosData) error {
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
dbData.Id,
options)
if err != nil {
return err
}
return err
}
func NewCosmosDBQueuePDUsTable(db *Database) (s *queuePDUsStatements, err error) {
s = &queuePDUsStatements{ s = &queuePDUsStatements{
db: db, db: db,
} }
_, err = db.Exec(queuePDUsSchema) s.selectQueueNextTransactionIDStmt = selectQueueNextTransactionIDSQL
if err != nil { s.selectQueuePDUsStmt = selectQueuePDUsSQL
return s.selectQueueReferenceJSONCountStmt = selectQueuePDUsReferenceJSONCountSQL
} s.selectQueuePDUsCountStmt = selectQueuePDUsCountSQL
if s.insertQueuePDUStmt, err = db.Prepare(insertQueuePDUSQL); err != nil { s.selectQueueServerNamesStmt = selectQueuePDUsServerNamesSQL
return s.tableName = "queue_pdus"
}
//if s.deleteQueuePDUsStmt, err = db.Prepare(deleteQueuePDUsSQL); err != nil {
// return
//}
if s.selectQueueNextTransactionIDStmt, err = db.Prepare(selectQueueNextTransactionIDSQL); err != nil {
return
}
if s.selectQueuePDUsStmt, err = db.Prepare(selectQueuePDUsSQL); err != nil {
return
}
if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueuePDUsReferenceJSONCountSQL); err != nil {
return
}
if s.selectQueuePDUsCountStmt, err = db.Prepare(selectQueuePDUsCountSQL); err != nil {
return
}
if s.selectQueueServerNamesStmt, err = db.Prepare(selectQueuePDUsServerNamesSQL); err != nil {
return
}
return return
} }
@ -119,13 +211,47 @@ func (s *queuePDUsStatements) InsertQueuePDU(
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
nid int64, nid int64,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.insertQueuePDUStmt)
_, err := stmt.ExecContext( // "INSERT INTO federationsender_queue_pdus (transaction_id, server_name, json_nid)" +
// " VALUES ($1, $2, $3)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_pdus_pdus_json_nid_idx
// ON federationsender_queue_pdus (json_nid, server_name);
docId := fmt.Sprintf("%d_%s", nid, serverName)
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
data := QueuePDUCosmos{
JSONNID: nid,
ServerName: string(serverName),
TransactionID: string(transactionID),
}
dbData := &QueuePDUCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
QueuePDU: data,
}
// stmt := sqlutil.TxStmt(txn, s.insertQueuePDUStmt)
// _, err := stmt.ExecContext(
// ctx,
// transactionID, // the transaction ID that we initially attempted
// serverName, // destination server name
// nid, // JSON blob NID
// )
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
_, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx, ctx,
transactionID, // the transaction ID that we initially attempted s.db.cosmosConfig.DatabaseName,
serverName, // destination server name s.db.cosmosConfig.ContainerName,
nid, // JSON blob NID &dbData,
) options)
return err return err
} }
@ -134,20 +260,31 @@ func (s *queuePDUsStatements) DeleteQueuePDUs(
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
jsonNIDs []int64, jsonNIDs []int64,
) error { ) error {
deleteSQL := strings.Replace(deleteQueuePDUsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1)
deleteStmt, err := txn.Prepare(deleteSQL) // "DELETE FROM federationsender_queue_pdus WHERE server_name = $1 AND json_nid IN ($2)"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": serverName,
"@x3": jsonNIDs,
}
// deleteSQL := strings.Replace(deleteQueuePDUsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1)
// deleteStmt, err := txn.Prepare(deleteSQL)
rows, err := queryQueuePDU(s, ctx, deleteQueuePDUsSQL, params)
if err != nil { if err != nil {
return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err) return err
} }
params := make([]interface{}, len(jsonNIDs)+1) for _, item := range rows {
params[0] = serverName // stmt := sqlutil.TxStmt(txn, deleteStmt)
for k, v := range jsonNIDs { err = deleteQueuePDU(s, ctx, item)
params[k+1] = v if err != nil {
return err
}
} }
stmt := sqlutil.TxStmt(txn, deleteStmt)
_, err = stmt.ExecContext(ctx, params...)
return err return err
} }
@ -155,11 +292,30 @@ func (s *queuePDUsStatements) SelectQueuePDUNextTransactionID(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (gomatrixserverlib.TransactionID, error) { ) (gomatrixserverlib.TransactionID, error) {
var transactionID gomatrixserverlib.TransactionID var transactionID gomatrixserverlib.TransactionID
stmt := sqlutil.TxStmt(txn, s.selectQueueNextTransactionIDStmt)
err := stmt.QueryRowContext(ctx, serverName).Scan(&transactionID) // "SELECT transaction_id FROM federationsender_queue_pdus" +
if err == sql.ErrNoRows { // " WHERE server_name = $1" +
// " ORDER BY transaction_id ASC" +
// " LIMIT 1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": serverName,
}
// stmt := sqlutil.TxStmt(txn, s.selectQueueNextTransactionIDStmt)
rows, err := queryQueuePDU(s, ctx, s.selectQueueNextTransactionIDStmt, params)
if err != nil {
return "", err
}
if len(rows) == 0 {
return "", nil return "", nil
} }
// err := stmt.QueryRowContext(ctx, serverName).Scan(&transactionID)
transactionID = gomatrixserverlib.TransactionID(rows[0].QueuePDU.TransactionID)
return transactionID, err return transactionID, err
} }
@ -167,11 +323,28 @@ func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount(
ctx context.Context, txn *sql.Tx, jsonNID int64, ctx context.Context, txn *sql.Tx, jsonNID int64,
) (int64, error) { ) (int64, error) {
var count int64 var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueueReferenceJSONCountStmt)
err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count) // "SELECT COUNT(*) FROM federationsender_queue_pdus" +
if err == sql.ErrNoRows { // " WHERE json_nid = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": jsonNID,
}
// stmt := sqlutil.TxStmt(txn, s.selectQueueReferenceJSONCountStmt)
rows, err := queryQueuePDUNumber(s, ctx, s.selectQueueReferenceJSONCountStmt, params)
if err != nil {
return -1, err
}
if len(rows) == 0 {
return -1, nil return -1, nil
} }
// err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count)
count = rows[0].Number
return count, err return count, err
} }
@ -179,14 +352,31 @@ func (s *queuePDUsStatements) SelectQueuePDUCount(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (int64, error) { ) (int64, error) {
var count int64 var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt)
err := stmt.QueryRowContext(ctx, serverName).Scan(&count) // "SELECT COUNT(*) FROM federationsender_queue_pdus" +
if err == sql.ErrNoRows { // " WHERE server_name = $1"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": serverName,
}
// stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt)
rows, err := queryQueuePDUNumber(s, ctx, s.selectQueuePDUsCountStmt, params)
if err != nil {
return 0, err
}
if len(rows) == 0 {
// It's acceptable for there to be no rows referencing a given // It's acceptable for there to be no rows referencing a given
// JSON NID but it's not an error condition. Just return as if // JSON NID but it's not an error condition. Just return as if
// there's a zero count. // there's a zero count.
return 0, nil return 0, nil
} }
// err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
count = rows[0].Number
return count, err return count, err
} }
@ -195,41 +385,58 @@ func (s *queuePDUsStatements) SelectQueuePDUs(
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
limit int, limit int,
) ([]int64, error) { ) ([]int64, error) {
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsStmt)
rows, err := stmt.QueryContext(ctx, serverName, limit) // "SELECT json_nid FROM federationsender_queue_pdus" +
// " WHERE server_name = $1" +
// " LIMIT $2"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": serverName,
"@x3": limit,
}
// stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsStmt)
// rows, err := stmt.QueryContext(ctx, serverName, limit)
rows, err := queryQueuePDU(s, ctx, s.selectQueuePDUsStmt, params)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
var result []int64 var result []int64
for rows.Next() { for _, item := range rows {
var nid int64 var nid int64
if err = rows.Scan(&nid); err != nil { nid = item.QueuePDU.JSONNID
return nil, err
}
result = append(result, nid) result = append(result, nid)
} }
return result, rows.Err() return result, nil
} }
func (s *queuePDUsStatements) SelectQueuePDUServerNames( func (s *queuePDUsStatements) SelectQueuePDUServerNames(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
) ([]gomatrixserverlib.ServerName, error) { ) ([]gomatrixserverlib.ServerName, error) {
stmt := sqlutil.TxStmt(txn, s.selectQueueServerNamesStmt)
rows, err := stmt.QueryContext(ctx) // "SELECT DISTINCT server_name FROM federationsender_queue_pdus"
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
}
// stmt := sqlutil.TxStmt(txn, s.selectQueueServerNamesStmt)
// rows, err := stmt.QueryContext(ctx)
rows, err := queryQueuePDUDistinct(s, ctx, s.selectQueueServerNamesStmt, params)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
var result []gomatrixserverlib.ServerName var result []gomatrixserverlib.ServerName
for rows.Next() { for _, item := range rows {
var serverName gomatrixserverlib.ServerName var serverName gomatrixserverlib.ServerName
if err = rows.Scan(&serverName); err != nil { serverName = gomatrixserverlib.ServerName(item.ServerName)
return nil, err
}
result = append(result, serverName) result = append(result, serverName)
} }
return result, rows.Err() return result, nil
} }

View file

@ -16,68 +16,74 @@
package cosmosdb package cosmosdb
import ( import (
"database/sql" "github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"github.com/matrix-org/dendrite/federationsender/storage/shared" "github.com/matrix-org/dendrite/federationsender/storage/shared"
"github.com/matrix-org/dendrite/federationsender/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
) )
// Database stores information needed by the federation sender // Database stores information needed by the federation sender
type Database struct { type Database struct {
shared.Database shared.Database
sqlutil.PartitionOffsetStatements cosmosdbutil.PartitionOffsetStatements
db *sql.DB database cosmosdbutil.Database
writer sqlutil.Writer writer cosmosdbutil.Writer
connection cosmosdbapi.CosmosConnection
databaseName string
cosmosConfig cosmosdbapi.CosmosConfig
} }
// NewDatabase opens a new database // NewDatabase opens a new database
func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationSenderCache) (*Database, error) { func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationSenderCache) (*Database, error) {
conn := cosmosdbutil.GetCosmosConnection(&dbProperties.ConnectionString)
configCosmos := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString)
var d Database var d Database
d.connection = conn
d.cosmosConfig = configCosmos
d.databaseName = "federationsender"
d.database = cosmosdbutil.Database{
Connection: conn,
CosmosConfig: configCosmos,
DatabaseName: d.databaseName,
}
var err error var err error
if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err d.writer = cosmosdbutil.NewExclusiveWriterFake()
} joinedHosts, err := NewCosmosDBJoinedHostsTable(&d)
d.writer = sqlutil.NewExclusiveWriter()
joinedHosts, err := NewSQLiteJoinedHostsTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err
} }
queuePDUs, err := NewSQLiteQueuePDUsTable(d.db) queuePDUs, err := NewCosmosDBQueuePDUsTable(&d)
if err != nil { if err != nil {
return nil, err return nil, err
} }
queueEDUs, err := NewSQLiteQueueEDUsTable(d.db) queueEDUs, err := NewCosmosDBQueueEDUsTable(&d)
if err != nil { if err != nil {
return nil, err return nil, err
} }
queueJSON, err := NewSQLiteQueueJSONTable(d.db) queueJSON, err := NewCosmosDBQueueJSONTable(&d)
if err != nil { if err != nil {
return nil, err return nil, err
} }
blacklist, err := NewSQLiteBlacklistTable(d.db) blacklist, err := NewCosmosDBBlacklistTable(&d)
if err != nil { if err != nil {
return nil, err return nil, err
} }
outboundPeeks, err := NewSQLiteOutboundPeeksTable(d.db) outboundPeeks, err := NewCosmosDBOutboundPeeksTable(&d)
if err != nil { if err != nil {
return nil, err return nil, err
} }
inboundPeeks, err := NewSQLiteInboundPeeksTable(d.db) inboundPeeks, err := NewCosmosDBInboundPeeksTable(&d)
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrations()
deltas.LoadRemoveRoomsTable(m)
if err = m.RunDeltas(d.db, dbProperties); err != nil {
return nil, err
}
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: nil,
Cache: cache, Cache: cache,
Writer: d.writer, Writer: d.writer,
FederationSenderJoinedHosts: joinedHosts, FederationSenderJoinedHosts: joinedHosts,
@ -88,7 +94,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationS
FederationSenderOutboundPeeks: outboundPeeks, FederationSenderOutboundPeeks: outboundPeeks,
FederationSenderInboundPeeks: inboundPeeks, FederationSenderInboundPeeks: inboundPeeks,
} }
if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "federationsender"); err != nil { if err = d.PartitionOffsetStatements.Prepare(&d.database, d.writer, "federationsender"); err != nil {
return nil, err return nil, err
} }
return &d, nil return &d, nil