From db08aa625059a24586f74517678100409834aaca Mon Sep 17 00:00:00 2001 From: alexfca <75228224+alexfca@users.noreply.github.com> Date: Fri, 28 May 2021 15:00:15 +1000 Subject: [PATCH] - Update the FederationSender Config to use CosmosDB (#9) - Implement the tables to use Cosmos - Update the Storage to use Cosmos --- dendrite-config-cosmosdb.yaml | 2 +- .../storage/cosmosdb/blacklist_table.go | 179 +++++-- .../storage/cosmosdb/inbound_peeks_table.go | 373 +++++++++++---- .../storage/cosmosdb/joined_hosts_table.go | 384 ++++++++++----- .../storage/cosmosdb/outbound_peeks_table.go | 380 +++++++++++---- .../storage/cosmosdb/queue_edus_table.go | 383 +++++++++++---- .../storage/cosmosdb/queue_json_table.go | 233 ++++++--- .../cosmosdb/queue_json_table_json_nid_seq.go | 12 + .../storage/cosmosdb/queue_pdus_table.go | 447 +++++++++++++----- federationsender/storage/cosmosdb/storage.go | 54 ++- 10 files changed, 1782 insertions(+), 665 deletions(-) create mode 100644 federationsender/storage/cosmosdb/queue_json_table_json_nid_seq.go diff --git a/dendrite-config-cosmosdb.yaml b/dendrite-config-cosmosdb.yaml index c25f2ee9d..014a2ec88 100644 --- a/dendrite-config-cosmosdb.yaml +++ b/dendrite-config-cosmosdb.yaml @@ -202,7 +202,7 @@ federation_sender: listen: http://localhost:7775 connect: http://localhost:7775 database: - connection_string: file:federationsender.db + connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=test.criticalarc.com;" max_open_conns: 10 max_idle_conns: 2 conn_max_lifetime: -1 diff --git a/federationsender/storage/cosmosdb/blacklist_table.go b/federationsender/storage/cosmosdb/blacklist_table.go index f4488a8e8..71c1475c3 100644 --- a/federationsender/storage/cosmosdb/blacklist_table.go +++ b/federationsender/storage/cosmosdb/blacklist_table.go @@ -17,54 +17,90 @@ package cosmosdb import ( "context" "database/sql" + "fmt" + "time" + + "github.com/matrix-org/dendrite/internal/cosmosdbapi" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" ) -const blacklistSchema = ` -CREATE TABLE IF NOT EXISTS federationsender_blacklist ( - -- The blacklisted server name - server_name TEXT NOT NULL, - UNIQUE (server_name) -); -` +// const blacklistSchema = ` +// CREATE TABLE IF NOT EXISTS federationsender_blacklist ( +// -- The blacklisted server name +// server_name TEXT NOT NULL, +// UNIQUE (server_name) +// ); +// ` -const insertBlacklistSQL = "" + - "INSERT INTO federationsender_blacklist (server_name) VALUES ($1)" + - " ON CONFLICT DO NOTHING" - -const selectBlacklistSQL = "" + - "SELECT server_name FROM federationsender_blacklist WHERE server_name = $1" - -const deleteBlacklistSQL = "" + - "DELETE FROM federationsender_blacklist WHERE server_name = $1" - -type blacklistStatements struct { - db *sql.DB - insertBlacklistStmt *sql.Stmt - selectBlacklistStmt *sql.Stmt - deleteBlacklistStmt *sql.Stmt +type BlacklistCosmos struct { + ServerName string `json:"server_name"` } -func NewSQLiteBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) { +type BlacklistCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + Blacklist BlacklistCosmos `json:"mx_federationsender_blacklist"` +} + +// const insertBlacklistSQL = "" + +// "INSERT INTO federationsender_blacklist (server_name) VALUES ($1)" + +// " ON CONFLICT DO NOTHING" + +// const selectBlacklistSQL = "" + +// "SELECT server_name FROM federationsender_blacklist WHERE server_name = $1" + +// const deleteBlacklistSQL = "" + +// "DELETE FROM federationsender_blacklist WHERE server_name = $1" + +type blacklistStatements struct { + db *Database + // insertBlacklistStmt *sql.Stmt + // selectBlacklistStmt *sql.Stmt + // deleteBlacklistStmt *sql.Stmt + tableName string +} + +func getBlacklist(s *blacklistStatements, ctx context.Context, pk string, docId string) (*BlacklistCosmosData, error) { + response := BlacklistCosmosData{} + err := cosmosdbapi.GetDocumentOrNil( + s.db.connection, + s.db.cosmosConfig, + ctx, + pk, + docId, + &response) + + if response.Id == "" { + return nil, nil + } + + return &response, err +} + +func deleteBlacklist(s *blacklistStatements, ctx context.Context, dbData BlacklistCosmosData) error { + var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk) + var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + dbData.Id, + options) + + if err != nil { + return err + } + return err +} + +func NewCosmosDBBlacklistTable(db *Database) (s *blacklistStatements, err error) { s = &blacklistStatements{ db: db, } - _, err = db.Exec(blacklistSchema) - if err != nil { - return - } - - if s.insertBlacklistStmt, err = db.Prepare(insertBlacklistSQL); err != nil { - return - } - if s.selectBlacklistStmt, err = db.Prepare(selectBlacklistSQL); err != nil { - return - } - if s.deleteBlacklistStmt, err = db.Prepare(deleteBlacklistSQL); err != nil { - return - } + s.tableName = "blacklists" return } @@ -73,8 +109,40 @@ func NewSQLiteBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) { func (s *blacklistStatements) InsertBlacklist( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ) error { - stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt) - _, err := stmt.ExecContext(ctx, serverName) + + // "INSERT INTO federationsender_blacklist (server_name) VALUES ($1)" + + // " ON CONFLICT DO NOTHING" + + // stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt) + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // UNIQUE (server_name) + docId := fmt.Sprintf("%s", serverName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + data := BlacklistCosmos{ + ServerName: string(serverName), + } + + dbData := &BlacklistCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + Blacklist: data, + } + + // _, err := stmt.ExecContext(ctx, serverName) + + var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk) + _, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + &dbData, + options) + return err } @@ -84,16 +152,24 @@ func (s *blacklistStatements) InsertBlacklist( func (s *blacklistStatements) SelectBlacklist( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ) (bool, error) { - stmt := sqlutil.TxStmt(txn, s.selectBlacklistStmt) - res, err := stmt.QueryContext(ctx, serverName) + // "SELECT server_name FROM federationsender_blacklist WHERE server_name = $1" + + // stmt := sqlutil.TxStmt(txn, s.selectBlacklistStmt) + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // UNIQUE (server_name) + docId := fmt.Sprintf("%s", serverName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + // res, err := stmt.QueryContext(ctx, serverName) + res, err := getBlacklist(s, ctx, pk, cosmosDocId) if err != nil { return false, err } - defer res.Close() // nolint:errcheck // The query will return the server name if the server is blacklisted, and // will return no rows if not. By calling Next, we find out if a row was // returned or not - we don't care about the value itself. - return res.Next(), nil + return res != nil, nil } // updateRoom updates the last_event_id for the room. selectRoomForUpdate should @@ -101,7 +177,18 @@ func (s *blacklistStatements) SelectBlacklist( func (s *blacklistStatements) DeleteBlacklist( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ) error { - stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt) - _, err := stmt.ExecContext(ctx, serverName) + // "DELETE FROM federationsender_blacklist WHERE server_name = $1" + + // stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt) + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // UNIQUE (server_name) + docId := fmt.Sprintf("%s", serverName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + // _, err := stmt.ExecContext(ctx, serverName) + res, err := getBlacklist(s, ctx, pk, cosmosDocId) + if(res != nil) { + _ = deleteBlacklist(s, ctx, *res) + } return err } diff --git a/federationsender/storage/cosmosdb/inbound_peeks_table.go b/federationsender/storage/cosmosdb/inbound_peeks_table.go index 88d9b4a86..644e4ef15 100644 --- a/federationsender/storage/cosmosdb/inbound_peeks_table.go +++ b/federationsender/storage/cosmosdb/inbound_peeks_table.go @@ -17,90 +17,198 @@ package cosmosdb import ( "context" "database/sql" + "fmt" "time" + "github.com/matrix-org/dendrite/internal/cosmosdbapi" + "github.com/matrix-org/dendrite/federationsender/types" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" ) -const inboundPeeksSchema = ` -CREATE TABLE IF NOT EXISTS federationsender_inbound_peeks ( - room_id TEXT NOT NULL, - server_name TEXT NOT NULL, - peek_id TEXT NOT NULL, - creation_ts INTEGER NOT NULL, - renewed_ts INTEGER NOT NULL, - renewal_interval INTEGER NOT NULL, - UNIQUE (room_id, server_name, peek_id) -); -` +// const inboundPeeksSchema = ` +// CREATE TABLE IF NOT EXISTS federationsender_inbound_peeks ( +// room_id TEXT NOT NULL, +// server_name TEXT NOT NULL, +// peek_id TEXT NOT NULL, +// creation_ts INTEGER NOT NULL, +// renewed_ts INTEGER NOT NULL, +// renewal_interval INTEGER NOT NULL, +// UNIQUE (room_id, server_name, peek_id) +// ); +// ` -const insertInboundPeekSQL = "" + - "INSERT INTO federationsender_inbound_peeks (room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval) VALUES ($1, $2, $3, $4, $5, $6)" - -const selectInboundPeekSQL = "" + - "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" - -const selectInboundPeeksSQL = "" + - "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1" - -const renewInboundPeekSQL = "" + - "UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5" - -const deleteInboundPeekSQL = "" + - "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2" - -const deleteInboundPeeksSQL = "" + - "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1" - -type inboundPeeksStatements struct { - db *sql.DB - insertInboundPeekStmt *sql.Stmt - selectInboundPeekStmt *sql.Stmt - selectInboundPeeksStmt *sql.Stmt - renewInboundPeekStmt *sql.Stmt - deleteInboundPeekStmt *sql.Stmt - deleteInboundPeeksStmt *sql.Stmt +type InboundPeekCosmos struct { + RoomID string `json:"room_id"` + ServerName string `json:"server_name"` + PeekID string `json:"peek_id"` + CreationTimestamp int64 `json:"creation_ts"` + RenewedTimestamp int64 `json:"renewed_ts"` + RenewalInterval int64 `json:"renewal_interval"` } -func NewSQLiteInboundPeeksTable(db *sql.DB) (s *inboundPeeksStatements, err error) { +type InboundPeekCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + InboundPeek InboundPeekCosmos `json:"mx_federationsender_inbound_peek"` +} + +// const insertInboundPeekSQL = "" + +// "INSERT INTO federationsender_inbound_peeks (room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval) VALUES ($1, $2, $3, $4, $5, $6)" + +// const selectInboundPeekSQL = "" + +// "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" + +// "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1" +const selectInboundPeeksSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_federationsender_inbound_peek.room_id = @x2" + +// const renewInboundPeekSQL = "" + +// "UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5" + +// "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2" +const deleteInboundPeekSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_federationsender_inbound_peek.room_id = @x2" + + "and c.mx_federationsender_inbound_peek.server_name = @x3" + +// "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1" +const deleteInboundPeeksSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_federationsender_inbound_peek.room_id = @x2" + +type inboundPeeksStatements struct { + db *Database + // insertInboundPeekStmt *sql.Stmt + // selectInboundPeekStmt *sql.Stmt + selectInboundPeeksStmt string + // renewInboundPeekStmt string + deleteInboundPeekStmt string + deleteInboundPeeksStmt string + tableName string +} + +func queryInboundPeek(s *inboundPeeksStatements, ctx context.Context, qry string, params map[string]interface{}) ([]InboundPeekCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []InboundPeekCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, err + } + return response, nil +} + +func getInboundPeek(s *inboundPeeksStatements, ctx context.Context, pk string, docId string) (*InboundPeekCosmosData, error) { + response := InboundPeekCosmosData{} + err := cosmosdbapi.GetDocumentOrNil( + s.db.connection, + s.db.cosmosConfig, + ctx, + pk, + docId, + &response) + + if response.Id == "" { + return nil, nil + } + + return &response, err +} + +func setInboundPeek(s *inboundPeeksStatements, ctx context.Context, inboundPeek InboundPeekCosmosData) (*InboundPeekCosmosData, error) { + var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(inboundPeek.Pk, inboundPeek.ETag) + var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + inboundPeek.Id, + &inboundPeek, + optionsReplace) + return &inboundPeek, ex +} + +func deleteInboundPeek(s *inboundPeeksStatements, ctx context.Context, dbData InboundPeekCosmosData) error { + var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk) + var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + dbData.Id, + options) + + if err != nil { + return err + } + return err +} + +func NewCosmosDBInboundPeeksTable(db *Database) (s *inboundPeeksStatements, err error) { s = &inboundPeeksStatements{ db: db, } - _, err = db.Exec(inboundPeeksSchema) - if err != nil { - return - } - - if s.insertInboundPeekStmt, err = db.Prepare(insertInboundPeekSQL); err != nil { - return - } - if s.selectInboundPeekStmt, err = db.Prepare(selectInboundPeekSQL); err != nil { - return - } - if s.selectInboundPeeksStmt, err = db.Prepare(selectInboundPeeksSQL); err != nil { - return - } - if s.renewInboundPeekStmt, err = db.Prepare(renewInboundPeekSQL); err != nil { - return - } - if s.deleteInboundPeeksStmt, err = db.Prepare(deleteInboundPeeksSQL); err != nil { - return - } - if s.deleteInboundPeekStmt, err = db.Prepare(deleteInboundPeekSQL); err != nil { - return - } + s.selectInboundPeeksStmt = selectInboundPeeksSQL + s.deleteInboundPeeksStmt = deleteInboundPeeksSQL + s.deleteInboundPeekStmt = deleteInboundPeekSQL + s.tableName = "inbound_peeks" return } func (s *inboundPeeksStatements) InsertInboundPeek( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64, ) (err error) { + + // "INSERT INTO federationsender_inbound_peeks (room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval) VALUES ($1, $2, $3, $4, $5, $6)" nowMilli := time.Now().UnixNano() / int64(time.Millisecond) - stmt := sqlutil.TxStmt(txn, s.insertInboundPeekStmt) - _, err = stmt.ExecContext(ctx, roomID, serverName, peekID, nowMilli, nowMilli, renewalInterval) + // stmt := sqlutil.TxStmt(txn, s.insertInboundPeekStmt) + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // UNIQUE (room_id, server_name, peek_id) + docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + data := InboundPeekCosmos{ + RoomID: roomID, + ServerName: string(serverName), + PeekID: peekID, + CreationTimestamp: nowMilli, + RenewedTimestamp: nowMilli, + RenewalInterval: renewalInterval, + } + + dbData := &InboundPeekCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + InboundPeek: data, + } + + // _, err = stmt.ExecContext(ctx, roomID, serverName, peekID, nowMilli, nowMilli, renewalInterval) + + var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk) + _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + &dbData, + options) + return } @@ -108,26 +216,58 @@ func (s *inboundPeeksStatements) RenewInboundPeek( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64, ) (err error) { nowMilli := time.Now().UnixNano() / int64(time.Millisecond) - _, err = sqlutil.TxStmt(txn, s.renewInboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID) + // "UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5" + + // _, err = sqlutil.TxStmt(txn, s.renewInboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID) + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // UNIQUE (room_id, server_name, peek_id) + docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + // _, err = sqlutil.TxStmt(txn, s.renewInboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID) + res, err := getInboundPeek(s, ctx, pk, cosmosDocId) + + if err != nil { + return + } + + if res == nil { + return + } + + res.InboundPeek.RenewedTimestamp = nowMilli + res.InboundPeek.RenewalInterval = renewalInterval + + _, err = setInboundPeek(s, ctx, *res) + return } func (s *inboundPeeksStatements) SelectInboundPeek( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, ) (*types.InboundPeek, error) { - row := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryRowContext(ctx, roomID) - inboundPeek := types.InboundPeek{} - err := row.Scan( - &inboundPeek.RoomID, - &inboundPeek.ServerName, - &inboundPeek.PeekID, - &inboundPeek.CreationTimestamp, - &inboundPeek.RenewedTimestamp, - &inboundPeek.RenewalInterval, - ) - if err == sql.ErrNoRows { + + // "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // UNIQUE (room_id, server_name, peek_id) + docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + // row := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryRowContext(ctx, roomID) + row, err := getInboundPeek(s, ctx, pk, cosmosDocId) + + if row == nil { return nil, nil } + inboundPeek := types.InboundPeek{} + inboundPeek.RoomID = row.InboundPeek.RoomID + inboundPeek.ServerName = gomatrixserverlib.ServerName(row.InboundPeek.ServerName) + inboundPeek.PeekID = row.InboundPeek.PeekID + inboundPeek.CreationTimestamp = row.InboundPeek.CreationTimestamp + inboundPeek.RenewedTimestamp = row.InboundPeek.RenewedTimestamp + inboundPeek.RenewalInterval = row.InboundPeek.RenewalInterval if err != nil { return nil, err } @@ -137,40 +277,87 @@ func (s *inboundPeeksStatements) SelectInboundPeek( func (s *inboundPeeksStatements) SelectInboundPeeks( ctx context.Context, txn *sql.Tx, roomID string, ) (inboundPeeks []types.InboundPeek, err error) { - rows, err := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryContext(ctx, roomID) + // "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": roomID, + } + + // rows, err := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryContext(ctx, roomID) + rows, err := queryInboundPeek(s, ctx, s.selectInboundPeeksStmt, params) + if err != nil { return } - defer internal.CloseAndLogIfError(ctx, rows, "SelectInboundPeeks: rows.close() failed") - for rows.Next() { + for _, item := range rows { inboundPeek := types.InboundPeek{} - if err = rows.Scan( - &inboundPeek.RoomID, - &inboundPeek.ServerName, - &inboundPeek.PeekID, - &inboundPeek.CreationTimestamp, - &inboundPeek.RenewedTimestamp, - &inboundPeek.RenewalInterval, - ); err != nil { - return - } + inboundPeek.RoomID = item.InboundPeek.RoomID + inboundPeek.ServerName = gomatrixserverlib.ServerName(item.InboundPeek.ServerName) + inboundPeek.PeekID = item.InboundPeek.PeekID + inboundPeek.CreationTimestamp = item.InboundPeek.CreationTimestamp + inboundPeek.RenewedTimestamp = item.InboundPeek.RenewedTimestamp + inboundPeek.RenewalInterval = item.InboundPeek.RenewalInterval inboundPeeks = append(inboundPeeks, inboundPeek) } - return inboundPeeks, rows.Err() + return inboundPeeks, nil } func (s *inboundPeeksStatements) DeleteInboundPeek( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, ) (err error) { - _, err = sqlutil.TxStmt(txn, s.deleteInboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID) + // "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": roomID, + "@x3": serverName, + } + + // _, err = sqlutil.TxStmt(txn, s.deleteInboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID) + rows, err := queryInboundPeek(s, ctx, s.deleteInboundPeekStmt, params) + + if err != nil { + return + } + + for _, item := range rows { + err = deleteInboundPeek(s, ctx, item) + if err != nil { + return + } + } + return } func (s *inboundPeeksStatements) DeleteInboundPeeks( ctx context.Context, txn *sql.Tx, roomID string, ) (err error) { - _, err = sqlutil.TxStmt(txn, s.deleteInboundPeeksStmt).ExecContext(ctx, roomID) + // "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": roomID, + } + + // _, err = sqlutil.TxStmt(txn, s.deleteInboundPeeksStmt).ExecContext(ctx, roomID) + rows, err := queryInboundPeek(s, ctx, s.deleteInboundPeekStmt, params) + + if err != nil { + return + } + + for _, item := range rows { + err = deleteInboundPeek(s, ctx, item) + if err != nil { + return + } + } return } diff --git a/federationsender/storage/cosmosdb/joined_hosts_table.go b/federationsender/storage/cosmosdb/joined_hosts_table.go index b903d1b7b..10315f30e 100644 --- a/federationsender/storage/cosmosdb/joined_hosts_table.go +++ b/federationsender/storage/cosmosdb/joined_hosts_table.go @@ -18,87 +18,155 @@ package cosmosdb import ( "context" "database/sql" - "strings" + "fmt" + "time" "github.com/matrix-org/dendrite/federationsender/types" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/internal/cosmosdbapi" "github.com/matrix-org/gomatrixserverlib" ) -const joinedHostsSchema = ` --- The joined_hosts table stores a list of m.room.member event ids in the --- current state for each room where the membership is "join". --- There will be an entry for every user that is joined to the room. -CREATE TABLE IF NOT EXISTS federationsender_joined_hosts ( - -- The string ID of the room. - room_id TEXT NOT NULL, - -- The event ID of the m.room.member join event. - event_id TEXT NOT NULL, - -- The domain part of the user ID the m.room.member event is for. - server_name TEXT NOT NULL -); +// const joinedHostsSchema = ` +// -- The joined_hosts table stores a list of m.room.member event ids in the +// -- current state for each room where the membership is "join". +// -- There will be an entry for every user that is joined to the room. +// CREATE TABLE IF NOT EXISTS federationsender_joined_hosts ( +// -- The string ID of the room. +// room_id TEXT NOT NULL, +// -- The event ID of the m.room.member join event. +// event_id TEXT NOT NULL, +// -- The domain part of the user ID the m.room.member event is for. +// server_name TEXT NOT NULL +// ); -CREATE UNIQUE INDEX IF NOT EXISTS federatonsender_joined_hosts_event_id_idx - ON federationsender_joined_hosts (event_id); +// CREATE UNIQUE INDEX IF NOT EXISTS federatonsender_joined_hosts_event_id_idx +// ON federationsender_joined_hosts (event_id); -CREATE INDEX IF NOT EXISTS federatonsender_joined_hosts_room_id_idx - ON federationsender_joined_hosts (room_id) -` +// CREATE INDEX IF NOT EXISTS federatonsender_joined_hosts_room_id_idx +// ON federationsender_joined_hosts (room_id) +// ` -const insertJoinedHostsSQL = "" + - "INSERT OR IGNORE INTO federationsender_joined_hosts (room_id, event_id, server_name)" + - " VALUES ($1, $2, $3)" - -const deleteJoinedHostsSQL = "" + - "DELETE FROM federationsender_joined_hosts WHERE event_id = $1" - -const deleteJoinedHostsForRoomSQL = "" + - "DELETE FROM federationsender_joined_hosts WHERE room_id = $1" - -const selectJoinedHostsSQL = "" + - "SELECT event_id, server_name FROM federationsender_joined_hosts" + - " WHERE room_id = $1" - -const selectAllJoinedHostsSQL = "" + - "SELECT DISTINCT server_name FROM federationsender_joined_hosts" - -const selectJoinedHostsForRoomsSQL = "" + - "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)" - -type joinedHostsStatements struct { - db *sql.DB - insertJoinedHostsStmt *sql.Stmt - deleteJoinedHostsStmt *sql.Stmt - deleteJoinedHostsForRoomStmt *sql.Stmt - selectJoinedHostsStmt *sql.Stmt - selectAllJoinedHostsStmt *sql.Stmt - // selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic +type JoinedHostCosmos struct { + RoomID string `json:"room_id"` + EventID string `json:"event_id"` + ServerName string `json:"server_name"` } -func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { +type JoinedHostCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + JoinedHost JoinedHostCosmos `json:"mx_federationsender_joined_host"` +} + +// const insertJoinedHostsSQL = "" + +// "INSERT OR IGNORE INTO federationsender_joined_hosts (room_id, event_id, server_name)" + +// " VALUES ($1, $2, $3)" + +// "DELETE FROM federationsender_joined_hosts WHERE event_id = $1" +const deleteJoinedHostsSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_federationsender_joined_host.event_id = @x2 " + +// "DELETE FROM federationsender_joined_hosts WHERE room_id = $1" +const deleteJoinedHostsForRoomSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_federationsender_joined_host.room_id = @x2 " + +// "SELECT event_id, server_name FROM federationsender_joined_hosts" + +// " WHERE room_id = $1" +const selectJoinedHostsSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_federationsender_joined_host.room_id = @x2 " + +// "SELECT DISTINCT server_name FROM federationsender_joined_hosts" +const selectAllJoinedHostsSQL = "" + + "select distinct c.mx_federationsender_joined_host.server_name from c where c._cn = @x1 " + +// "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)" +const selectJoinedHostsForRoomsSQL = "" + + "select distinct c.mx_federationsender_joined_host.server_name from c where c._cn = @x1 " + + "and ARRAY_CONTAINS(@x2, c.mx_federationsender_joined_host.room_id) " + +type joinedHostsStatements struct { + db *Database + // insertJoinedHostsStmt *sql.Stmt + deleteJoinedHostsStmt string + deleteJoinedHostsForRoomStmt string + selectJoinedHostsStmt string + selectAllJoinedHostsStmt string + // selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic + tableName string +} + +func queryJoinedHostDistinct(s *joinedHostsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]JoinedHostCosmos, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []JoinedHostCosmos + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, err + } + return response, nil +} + +func queryJoinedHost(s *joinedHostsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]JoinedHostCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []JoinedHostCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, err + } + return response, nil +} + +func deleteJoinedHost(s *joinedHostsStatements, ctx context.Context, dbData JoinedHostCosmosData) error { + var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk) + var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + dbData.Id, + options) + + if err != nil { + return err + } + return err +} + +func NewCosmosDBJoinedHostsTable(db *Database) (s *joinedHostsStatements, err error) { s = &joinedHostsStatements{ db: db, } - _, err = db.Exec(joinedHostsSchema) - if err != nil { - return - } - if s.insertJoinedHostsStmt, err = db.Prepare(insertJoinedHostsSQL); err != nil { - return - } - if s.deleteJoinedHostsStmt, err = db.Prepare(deleteJoinedHostsSQL); err != nil { - return - } - if s.deleteJoinedHostsForRoomStmt, err = s.db.Prepare(deleteJoinedHostsForRoomSQL); err != nil { - return - } - if s.selectJoinedHostsStmt, err = db.Prepare(selectJoinedHostsSQL); err != nil { - return - } - if s.selectAllJoinedHostsStmt, err = db.Prepare(selectAllJoinedHostsSQL); err != nil { - return - } + s.deleteJoinedHostsStmt = deleteJoinedHostsSQL + s.deleteJoinedHostsForRoomStmt = deleteJoinedHostsForRoomSQL + s.selectJoinedHostsStmt = selectJoinedHostsSQL + s.selectAllJoinedHostsStmt = selectAllJoinedHostsSQL + s.tableName = "joined_hosts" return } @@ -108,8 +176,43 @@ func (s *joinedHostsStatements) InsertJoinedHosts( roomID, eventID string, serverName gomatrixserverlib.ServerName, ) error { - stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt) - _, err := stmt.ExecContext(ctx, roomID, eventID, serverName) + + // "INSERT OR IGNORE INTO federationsender_joined_hosts (room_id, event_id, server_name)" + + // " VALUES ($1, $2, $3)" + + // stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt) + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // CREATE UNIQUE INDEX IF NOT EXISTS federatonsender_joined_hosts_event_id_idx + // ON federationsender_joined_hosts (event_id); + docId := fmt.Sprintf("%s", eventID) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + data := JoinedHostCosmos{ + EventID: eventID, + RoomID: roomID, + ServerName: string(serverName), + } + + dbData := &JoinedHostCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + JoinedHost: data, + } + + // _, err := stmt.ExecContext(ctx, roomID, eventID, serverName) + + var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk) + _, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + &dbData, + options) + return err } @@ -117,9 +220,21 @@ func (s *joinedHostsStatements) DeleteJoinedHosts( ctx context.Context, txn *sql.Tx, eventIDs []string, ) error { for _, eventID := range eventIDs { - stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt) - if _, err := stmt.ExecContext(ctx, eventID); err != nil { - return err + // "DELETE FROM federationsender_joined_hosts WHERE event_id = $1" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": eventID, + } + // stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt) + + rows, err := queryJoinedHost(s, ctx, s.deleteJoinedHostsStmt, params) + + for _, item := range rows { + if err = deleteJoinedHost(s, ctx, item); err != nil { + return err + } } } return nil @@ -128,92 +243,123 @@ func (s *joinedHostsStatements) DeleteJoinedHosts( func (s *joinedHostsStatements) DeleteJoinedHostsForRoom( ctx context.Context, txn *sql.Tx, roomID string, ) error { - stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsForRoomStmt) - _, err := stmt.ExecContext(ctx, roomID) + // "DELETE FROM federationsender_joined_hosts WHERE room_id = $1" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": roomID, + } + + // stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsForRoomStmt) + rows, err := queryJoinedHost(s, ctx, s.deleteJoinedHostsStmt, params) + + // _, err := stmt.ExecContext(ctx, roomID) + for _, item := range rows { + if err = deleteJoinedHost(s, ctx, item); err != nil { + return err + } + } return err } func (s *joinedHostsStatements) SelectJoinedHostsWithTx( ctx context.Context, txn *sql.Tx, roomID string, ) ([]types.JoinedHost, error) { - stmt := sqlutil.TxStmt(txn, s.selectJoinedHostsStmt) - return joinedHostsFromStmt(ctx, stmt, roomID) + // "SELECT event_id, server_name FROM federationsender_joined_hosts" + + // " WHERE room_id = $1" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": roomID, + } + + // stmt := sqlutil.TxStmt(txn, s.selectJoinedHostsStmt) + rows, err := queryJoinedHost(s, ctx, s.deleteJoinedHostsStmt, params) + + if err != nil { + return nil, err + } + + return rowsToJoinedHosts(&rows), nil } func (s *joinedHostsStatements) SelectJoinedHosts( ctx context.Context, roomID string, ) ([]types.JoinedHost, error) { - return joinedHostsFromStmt(ctx, s.selectJoinedHostsStmt, roomID) + return s.SelectJoinedHostsWithTx(ctx, nil, roomID) } func (s *joinedHostsStatements) SelectAllJoinedHosts( ctx context.Context, ) ([]gomatrixserverlib.ServerName, error) { - rows, err := s.selectAllJoinedHostsStmt.QueryContext(ctx) + // "SELECT DISTINCT server_name FROM federationsender_joined_hosts" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + } + + // rows, err := s.selectAllJoinedHostsStmt.QueryContext(ctx) + rows, err := queryJoinedHostDistinct(s, ctx, s.selectAllJoinedHostsStmt, params) if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "selectAllJoinedHosts: rows.close() failed") var result []gomatrixserverlib.ServerName - for rows.Next() { + for _, item := range rows { var serverName string - if err = rows.Scan(&serverName); err != nil { - return nil, err - } + serverName = item.ServerName result = append(result, gomatrixserverlib.ServerName(serverName)) } - return result, rows.Err() + return result, err } func (s *joinedHostsStatements) SelectJoinedHostsForRooms( ctx context.Context, roomIDs []string, ) ([]gomatrixserverlib.ServerName, error) { - iRoomIDs := make([]interface{}, len(roomIDs)) - for i := range roomIDs { - iRoomIDs[i] = roomIDs[i] + // iRoomIDs := make([]interface{}, len(roomIDs)) + // for i := range roomIDs { + // iRoomIDs[i] = roomIDs[i] + // } + + // "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)" + + // sql := strings.Replace(selectJoinedHostsForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1) + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": roomIDs, } - sql := strings.Replace(selectJoinedHostsForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1) - rows, err := s.db.QueryContext(ctx, sql, iRoomIDs...) + // rows, err := s.db.QueryContext(ctx, sql, iRoomIDs...) + rows, err := queryJoinedHostDistinct(s, ctx, s.selectAllJoinedHostsStmt, params) if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedHostsForRoomsStmt: rows.close() failed") var result []gomatrixserverlib.ServerName - for rows.Next() { + for _, item := range rows { var serverName string - if err = rows.Scan(&serverName); err != nil { - return nil, err - } + serverName = item.ServerName result = append(result, gomatrixserverlib.ServerName(serverName)) } - return result, rows.Err() -} - -func joinedHostsFromStmt( - ctx context.Context, stmt *sql.Stmt, roomID string, -) ([]types.JoinedHost, error) { - rows, err := stmt.QueryContext(ctx, roomID) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "joinedHostsFromStmt: rows.close() failed") - - var result []types.JoinedHost - for rows.Next() { - var eventID, serverName string - if err = rows.Scan(&eventID, &serverName); err != nil { - return nil, err - } - result = append(result, types.JoinedHost{ - MemberEventID: eventID, - ServerName: gomatrixserverlib.ServerName(serverName), - }) - } - return result, nil } + +func rowsToJoinedHosts(rows *[]JoinedHostCosmosData) []types.JoinedHost { + var result []types.JoinedHost + if rows == nil { + return result + } + for _, item := range *rows { + result = append(result, types.JoinedHost{ + MemberEventID: item.JoinedHost.EventID, + ServerName: gomatrixserverlib.ServerName(item.JoinedHost.ServerName), + }) + } + return result +} diff --git a/federationsender/storage/cosmosdb/outbound_peeks_table.go b/federationsender/storage/cosmosdb/outbound_peeks_table.go index 0da9344d2..61a5d4c29 100644 --- a/federationsender/storage/cosmosdb/outbound_peeks_table.go +++ b/federationsender/storage/cosmosdb/outbound_peeks_table.go @@ -17,160 +17,350 @@ package cosmosdb import ( "context" "database/sql" + "fmt" "time" + "github.com/matrix-org/dendrite/internal/cosmosdbapi" + "github.com/matrix-org/dendrite/federationsender/types" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" ) -const outboundPeeksSchema = ` -CREATE TABLE IF NOT EXISTS federationsender_outbound_peeks ( - room_id TEXT NOT NULL, - server_name TEXT NOT NULL, - peek_id TEXT NOT NULL, - creation_ts INTEGER NOT NULL, - renewed_ts INTEGER NOT NULL, - renewal_interval INTEGER NOT NULL, - UNIQUE (room_id, server_name, peek_id) -); -` +// const outboundPeeksSchema = ` +// CREATE TABLE IF NOT EXISTS federationsender_outbound_peeks ( +// room_id TEXT NOT NULL, +// server_name TEXT NOT NULL, +// peek_id TEXT NOT NULL, +// creation_ts INTEGER NOT NULL, +// renewed_ts INTEGER NOT NULL, +// renewal_interval INTEGER NOT NULL, +// UNIQUE (room_id, server_name, peek_id) +// ); +// ` -const insertOutboundPeekSQL = "" + - "INSERT INTO federationsender_outbound_peeks (room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval) VALUES ($1, $2, $3, $4, $5, $6)" - -const selectOutboundPeekSQL = "" + - "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" - -const selectOutboundPeeksSQL = "" + - "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1" - -const renewOutboundPeekSQL = "" + - "UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5" - -const deleteOutboundPeekSQL = "" + - "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2" - -const deleteOutboundPeeksSQL = "" + - "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1" - -type outboundPeeksStatements struct { - db *sql.DB - insertOutboundPeekStmt *sql.Stmt - selectOutboundPeekStmt *sql.Stmt - selectOutboundPeeksStmt *sql.Stmt - renewOutboundPeekStmt *sql.Stmt - deleteOutboundPeekStmt *sql.Stmt - deleteOutboundPeeksStmt *sql.Stmt +type OutboundPeekCosmos struct { + RoomID string `json:"room_id"` + ServerName string `json:"server_name"` + PeekID string `json:"peek_id"` + CreationTimestamp int64 `json:"creation_ts"` + RenewedTimestamp int64 `json:"renewed_ts"` + RenewalInterval int64 `json:"renewal_interval"` } -func NewSQLiteOutboundPeeksTable(db *sql.DB) (s *outboundPeeksStatements, err error) { +type OutboundPeekCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + OutboundPeek OutboundPeekCosmos `json:"mx_federationsender_outbound_peek"` +} + +// const insertOutboundPeekSQL = "" + +// "INSERT INTO federationsender_outbound_peeks (room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval) VALUES ($1, $2, $3, $4, $5, $6)" + +// "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1" +const selectOutboundPeeksSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_federationsender_outbound_peek.room_id = @x2" + +// const renewOutboundPeekSQL = "" + +// "UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5" + +// "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2" +const deleteOutboundPeekSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_federationsender_outbound_peek.room_id = @x2" + + "and c.mx_federationsender_outbound_peek.server_name = @x3" + +// "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1" +const deleteOutboundPeeksSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_federationsender_outbound_peek.room_id = @x2" + +type outboundPeeksStatements struct { + db *Database + // insertOutboundPeekStmt *sql.Stmt + // selectOutboundPeekStmt *sql.Stmt + selectOutboundPeeksStmt string + // renewOutboundPeekStmt *sql.Stmt + deleteOutboundPeekStmt string + deleteOutboundPeeksStmt string + tableName string +} + +func queryOutboundPeek(s *outboundPeeksStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OutboundPeekCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []OutboundPeekCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, err + } + return response, nil +} + +func getOutboundPeek(s *outboundPeeksStatements, ctx context.Context, pk string, docId string) (*OutboundPeekCosmosData, error) { + response := OutboundPeekCosmosData{} + err := cosmosdbapi.GetDocumentOrNil( + s.db.connection, + s.db.cosmosConfig, + ctx, + pk, + docId, + &response) + + if response.Id == "" { + return nil, nil + } + + return &response, err +} + +func setOutboundPeek(s *outboundPeeksStatements, ctx context.Context, outboundPeek OutboundPeekCosmosData) (*OutboundPeekCosmosData, error) { + var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(outboundPeek.Pk, outboundPeek.ETag) + var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + outboundPeek.Id, + &outboundPeek, + optionsReplace) + return &outboundPeek, ex +} + +func deleteOutboundPeek(s *outboundPeeksStatements, ctx context.Context, dbData OutboundPeekCosmosData) error { + var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk) + var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + dbData.Id, + options) + + if err != nil { + return err + } + return err +} + +func NewCosmosDBOutboundPeeksTable(db *Database) (s *outboundPeeksStatements, err error) { s = &outboundPeeksStatements{ db: db, } - _, err = db.Exec(outboundPeeksSchema) - if err != nil { - return - } - - if s.insertOutboundPeekStmt, err = db.Prepare(insertOutboundPeekSQL); err != nil { - return - } - if s.selectOutboundPeekStmt, err = db.Prepare(selectOutboundPeekSQL); err != nil { - return - } - if s.selectOutboundPeeksStmt, err = db.Prepare(selectOutboundPeeksSQL); err != nil { - return - } - if s.renewOutboundPeekStmt, err = db.Prepare(renewOutboundPeekSQL); err != nil { - return - } - if s.deleteOutboundPeeksStmt, err = db.Prepare(deleteOutboundPeeksSQL); err != nil { - return - } - if s.deleteOutboundPeekStmt, err = db.Prepare(deleteOutboundPeekSQL); err != nil { - return - } + s.selectOutboundPeeksStmt = selectOutboundPeeksSQL + s.deleteOutboundPeeksStmt = deleteOutboundPeeksSQL + s.deleteOutboundPeekStmt = deleteOutboundPeekSQL + s.tableName = "outbound_peeks" return } func (s *outboundPeeksStatements) InsertOutboundPeek( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64, ) (err error) { + // "INSERT INTO federationsender_outbound_peeks (room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval) VALUES ($1, $2, $3, $4, $5, $6)" + + // stmt := sqlutil.TxStmt(txn, s.insertOutboundPeekStmt) nowMilli := time.Now().UnixNano() / int64(time.Millisecond) - stmt := sqlutil.TxStmt(txn, s.insertOutboundPeekStmt) - _, err = stmt.ExecContext(ctx, roomID, serverName, peekID, nowMilli, nowMilli, renewalInterval) + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // UNIQUE (room_id, server_name, peek_id) + docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + data := OutboundPeekCosmos{ + RoomID: roomID, + ServerName: string(serverName), + PeekID: peekID, + CreationTimestamp: nowMilli, + RenewedTimestamp: nowMilli, + RenewalInterval: renewalInterval, + } + + dbData := &OutboundPeekCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + OutboundPeek: data, + } + + // _, err = stmt.ExecContext(ctx, roomID, serverName, peekID, nowMilli, nowMilli, renewalInterval) + + var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk) + _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + &dbData, + options) + return } func (s *outboundPeeksStatements) RenewOutboundPeek( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64, ) (err error) { + // "UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5" + nowMilli := time.Now().UnixNano() / int64(time.Millisecond) - _, err = sqlutil.TxStmt(txn, s.renewOutboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID) + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // UNIQUE (room_id, server_name, peek_id) + docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + // _, err = sqlutil.TxStmt(txn, s.renewOutboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID) + res, err := getOutboundPeek(s, ctx, pk, cosmosDocId) + + if err != nil { + return + } + + if res == nil { + return + } + + res.OutboundPeek.RenewedTimestamp = nowMilli + res.OutboundPeek.RenewalInterval = renewalInterval + + _, err = setOutboundPeek(s, ctx, *res) return } func (s *outboundPeeksStatements) SelectOutboundPeek( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, ) (*types.OutboundPeek, error) { - row := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryRowContext(ctx, roomID) - outboundPeek := types.OutboundPeek{} - err := row.Scan( - &outboundPeek.RoomID, - &outboundPeek.ServerName, - &outboundPeek.PeekID, - &outboundPeek.CreationTimestamp, - &outboundPeek.RenewedTimestamp, - &outboundPeek.RenewalInterval, - ) - if err == sql.ErrNoRows { - return nil, nil - } + + // "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // UNIQUE (room_id, server_name, peek_id) + docId := fmt.Sprintf("%s_%s_%s", roomID, serverName, peekID) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + // row := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryRowContext(ctx, roomID) + row, err := getOutboundPeek(s, ctx, pk, cosmosDocId) + if err != nil { return nil, err } + + if row == nil { + return nil, nil + } + outboundPeek := types.OutboundPeek{} + outboundPeek.RoomID = row.OutboundPeek.RoomID + outboundPeek.ServerName = gomatrixserverlib.ServerName(row.OutboundPeek.ServerName) + outboundPeek.PeekID = row.OutboundPeek.PeekID + outboundPeek.CreationTimestamp = row.OutboundPeek.CreationTimestamp + outboundPeek.RenewedTimestamp = row.OutboundPeek.RenewedTimestamp + outboundPeek.RenewalInterval = row.OutboundPeek.RenewalInterval return &outboundPeek, nil } func (s *outboundPeeksStatements) SelectOutboundPeeks( ctx context.Context, txn *sql.Tx, roomID string, ) (outboundPeeks []types.OutboundPeek, err error) { - rows, err := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryContext(ctx, roomID) + + // "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1" + if err != nil { return } - defer internal.CloseAndLogIfError(ctx, rows, "SelectOutboundPeeks: rows.close() failed") + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": roomID, + } - for rows.Next() { + // rows, err := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryContext(ctx, roomID) + rows, err := queryOutboundPeek(s, ctx, s.selectOutboundPeeksStmt, params) + + if err != nil { + return + } + + for _, item := range rows { outboundPeek := types.OutboundPeek{} - if err = rows.Scan( - &outboundPeek.RoomID, - &outboundPeek.ServerName, - &outboundPeek.PeekID, - &outboundPeek.CreationTimestamp, - &outboundPeek.RenewedTimestamp, - &outboundPeek.RenewalInterval, - ); err != nil { - return - } + outboundPeek.RoomID = item.OutboundPeek.RoomID + outboundPeek.ServerName = gomatrixserverlib.ServerName(item.OutboundPeek.ServerName) + outboundPeek.PeekID = item.OutboundPeek.PeekID + outboundPeek.CreationTimestamp = item.OutboundPeek.CreationTimestamp + outboundPeek.RenewedTimestamp = item.OutboundPeek.RenewedTimestamp + outboundPeek.RenewalInterval = item.OutboundPeek.RenewalInterval outboundPeeks = append(outboundPeeks, outboundPeek) } - return outboundPeeks, rows.Err() + return outboundPeeks, nil } func (s *outboundPeeksStatements) DeleteOutboundPeek( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, ) (err error) { - _, err = sqlutil.TxStmt(txn, s.deleteOutboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID) + + // "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": roomID, + "@x3": serverName, + } + + // _, err = sqlutil.TxStmt(txn, s.deleteOutboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID) + rows, err := queryOutboundPeek(s, ctx, s.deleteOutboundPeekStmt, params) + + if err != nil { + return + } + + for _, item := range rows { + err = deleteOutboundPeek(s, ctx, item) + if err != nil { + return + } + } + return } func (s *outboundPeeksStatements) DeleteOutboundPeeks( ctx context.Context, txn *sql.Tx, roomID string, ) (err error) { - _, err = sqlutil.TxStmt(txn, s.deleteOutboundPeeksStmt).ExecContext(ctx, roomID) + + // "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": roomID, + } + + // _, err = sqlutil.TxStmt(txn, s.deleteOutboundPeeksStmt).ExecContext(ctx, roomID) + rows, err := queryOutboundPeek(s, ctx, s.deleteOutboundPeeksStmt, params) + + if err != nil { + return + } + + for _, item := range rows { + err = deleteOutboundPeek(s, ctx, item) + if err != nil { + return + } + } + return } diff --git a/federationsender/storage/cosmosdb/queue_edus_table.go b/federationsender/storage/cosmosdb/queue_edus_table.go index 530e0c088..79c99f897 100644 --- a/federationsender/storage/cosmosdb/queue_edus_table.go +++ b/federationsender/storage/cosmosdb/queue_edus_table.go @@ -18,82 +18,176 @@ import ( "context" "database/sql" "fmt" - "strings" + "time" + + "github.com/matrix-org/dendrite/internal/cosmosdbapi" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" ) -const queueEDUsSchema = ` -CREATE TABLE IF NOT EXISTS federationsender_queue_edus ( - -- The type of the event (informational). - edu_type TEXT NOT NULL, - -- The domain part of the user ID the EDU event is for. - server_name TEXT NOT NULL, - -- The JSON NID from the federationsender_queue_edus_json table. - json_nid BIGINT NOT NULL -); +// const queueEDUsSchema = ` +// CREATE TABLE IF NOT EXISTS federationsender_queue_edus ( +// -- The type of the event (informational). +// edu_type TEXT NOT NULL, +// -- The domain part of the user ID the EDU event is for. +// server_name TEXT NOT NULL, +// -- The JSON NID from the federationsender_queue_edus_json table. +// json_nid BIGINT NOT NULL +// ); -CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx - ON federationsender_queue_edus (json_nid, server_name); -` +// CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx +// ON federationsender_queue_edus (json_nid, server_name); +// ` -const insertQueueEDUSQL = "" + - "INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid)" + - " VALUES ($1, $2, $3)" - -const deleteQueueEDUsSQL = "" + - "DELETE FROM federationsender_queue_edus WHERE server_name = $1 AND json_nid IN ($2)" - -const selectQueueEDUSQL = "" + - "SELECT json_nid FROM federationsender_queue_edus" + - " WHERE server_name = $1" + - " LIMIT $2" - -const selectQueueEDUReferenceJSONCountSQL = "" + - "SELECT COUNT(*) FROM federationsender_queue_edus" + - " WHERE json_nid = $1" - -const selectQueueEDUCountSQL = "" + - "SELECT COUNT(*) FROM federationsender_queue_edus" + - " WHERE server_name = $1" - -const selectQueueServerNamesSQL = "" + - "SELECT DISTINCT server_name FROM federationsender_queue_edus" - -type queueEDUsStatements struct { - db *sql.DB - insertQueueEDUStmt *sql.Stmt - selectQueueEDUStmt *sql.Stmt - selectQueueEDUReferenceJSONCountStmt *sql.Stmt - selectQueueEDUCountStmt *sql.Stmt - selectQueueEDUServerNamesStmt *sql.Stmt +type QueueEDUCosmos struct { + EDUType string `json:"edu_type"` + ServerName string `json:"server_name"` + JSONNID int64 `json:"json_nid"` } -func NewSQLiteQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) { +type QueueEDUCosmosNumber struct { + Number int64 `json:"number"` +} + +type QueueEDUCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + QueueEDU QueueEDUCosmos `json:"mx_federationsender_queue_edu"` +} + +// const insertQueueEDUSQL = "" + +// "INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid)" + +// " VALUES ($1, $2, $3)" + +// "DELETE FROM federationsender_queue_edus WHERE server_name = $1 AND json_nid IN ($2)" +const deleteQueueEDUsSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_federationsender_queue_edu.server_name = @x2" + + "and ARRAY_CONTAINS(@x3, c.mx_federationsender_queue_edu.json_nid) " + +// "SELECT json_nid FROM federationsender_queue_edus" + +// " WHERE server_name = $1" + +// " LIMIT $2" +const selectQueueEDUSQL = "" + + "select top @x3 * from c where c._cn = @x1 " + + "and c.mx_federationsender_queue_edu.server_name = @x2" + +// "SELECT COUNT(*) FROM federationsender_queue_edus" + +// " WHERE json_nid = $1" +const selectQueueEDUReferenceJSONCountSQL = "" + + "select count(c._ts) as number from c where c._cn = @x1 " + + "and c.mx_federationsender_queue_edu.json_nid = @x2" + +// "SELECT COUNT(*) FROM federationsender_queue_edus" + +// " WHERE server_name = $1" +const selectQueueEDUCountSQL = "" + + "select count(c._ts) as number from c where c._cn = @x1 " + + "and c.mx_federationsender_queue_edu.server_name = @x2" + +// "SELECT DISTINCT server_name FROM federationsender_queue_edus" +const selectQueueServerNamesSQL = "" + + "select distinct c.mx_federationsender_queue_edu.server_name from c where c._cn = @x1 " + +type queueEDUsStatements struct { + db *Database + // insertQueueEDUStmt *sql.Stmt + selectQueueEDUStmt string + selectQueueEDUReferenceJSONCountStmt string + selectQueueEDUCountStmt string + selectQueueEDUServerNamesStmt string + tableName string +} + +func queryQueueEDUC(s *queueEDUsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]QueueEDUCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []QueueEDUCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, err + } + return response, nil +} + +func queryQueueEDUCDistinct(s *queueEDUsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]QueueEDUCosmos, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []QueueEDUCosmos + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, err + } + return response, nil +} + +func queryQueueEDUCNumber(s *queueEDUsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]QueueEDUCosmosNumber, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []QueueEDUCosmosNumber + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, err + } + return response, nil +} + +func deleteQueueEDUC(s *queueEDUsStatements, ctx context.Context, dbData QueueEDUCosmosData) error { + var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk) + var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + dbData.Id, + options) + + if err != nil { + return err + } + return err +} + +func NewCosmosDBQueueEDUsTable(db *Database) (s *queueEDUsStatements, err error) { s = &queueEDUsStatements{ db: db, } - _, err = db.Exec(queueEDUsSchema) - if err != nil { - return - } - if s.insertQueueEDUStmt, err = db.Prepare(insertQueueEDUSQL); err != nil { - return - } - if s.selectQueueEDUStmt, err = db.Prepare(selectQueueEDUSQL); err != nil { - return - } - if s.selectQueueEDUReferenceJSONCountStmt, err = db.Prepare(selectQueueEDUReferenceJSONCountSQL); err != nil { - return - } - if s.selectQueueEDUCountStmt, err = db.Prepare(selectQueueEDUCountSQL); err != nil { - return - } - if s.selectQueueEDUServerNamesStmt, err = db.Prepare(selectQueueServerNamesSQL); err != nil { - return - } + s.selectQueueEDUStmt = selectQueueEDUSQL + s.selectQueueEDUReferenceJSONCountStmt = selectQueueEDUReferenceJSONCountSQL + s.selectQueueEDUCountStmt = selectQueueEDUCountSQL + s.selectQueueEDUServerNamesStmt = selectQueueServerNamesSQL + s.tableName = "queue_edus" return } @@ -104,13 +198,47 @@ func (s *queueEDUsStatements) InsertQueueEDU( serverName gomatrixserverlib.ServerName, nid int64, ) error { - stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt) - _, err := stmt.ExecContext( + + // "INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid)" + + + // stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt) + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx + // ON federationsender_queue_edus (json_nid, server_name); + docId := fmt.Sprintf("%d_%s", nid, eduType) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + data := QueueEDUCosmos{ + EDUType: eduType, + JSONNID: nid, + ServerName: string(serverName), + } + + dbData := &QueueEDUCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + QueueEDU: data, + } + + // _, err := stmt.ExecContext( + // ctx, + // eduType, // the EDU type + // serverName, // destination server name + // nid, // JSON blob NID + // ) + + var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk) + _, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument( ctx, - eduType, // the EDU type - serverName, // destination server name - nid, // JSON blob NID - ) + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + &dbData, + options) + return err } @@ -119,20 +247,33 @@ func (s *queueEDUsStatements) DeleteQueueEDUs( serverName gomatrixserverlib.ServerName, jsonNIDs []int64, ) error { - deleteSQL := strings.Replace(deleteQueueEDUsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1) - deleteStmt, err := txn.Prepare(deleteSQL) + + // "DELETE FROM federationsender_queue_edus WHERE server_name = $1 AND json_nid IN ($2)" + + // deleteSQL := strings.Replace(deleteQueueEDUsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1) + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": serverName, + "@x3": jsonNIDs, + } + + // stmt := sqlutil.TxStmt(txn, deleteStmt) + // _, err = stmt.ExecContext(ctx, params...) + rows, err := queryQueueEDUC(s, ctx, deleteQueueEDUsSQL, params) + if err != nil { - return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err) + return err } - params := make([]interface{}, len(jsonNIDs)+1) - params[0] = serverName - for k, v := range jsonNIDs { - params[k+1] = v + for _, item := range rows { + err = deleteQueueEDUC(s, ctx, item) + if err != nil { + return err + } } - stmt := sqlutil.TxStmt(txn, deleteStmt) - _, err = stmt.ExecContext(ctx, params...) return err } @@ -141,18 +282,28 @@ func (s *queueEDUsStatements) SelectQueueEDUs( serverName gomatrixserverlib.ServerName, limit int, ) ([]int64, error) { - stmt := sqlutil.TxStmt(txn, s.selectQueueEDUStmt) - rows, err := stmt.QueryContext(ctx, serverName, limit) + + // "SELECT json_nid FROM federationsender_queue_edus" + + // " WHERE server_name = $1" + + // " LIMIT $2" + + // stmt := sqlutil.TxStmt(txn, s.selectQueueEDUStmt) + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": serverName, + "@x3": limit, + } + + // rows, err := stmt.QueryContext(ctx, serverName, limit) + rows, err := queryQueueEDUC(s, ctx, deleteQueueEDUsSQL, params) if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") var result []int64 - for rows.Next() { + for _, item := range rows { var nid int64 - if err = rows.Scan(&nid); err != nil { - return nil, err - } + nid = item.QueueEDU.JSONNID result = append(result, nid) } return result, nil @@ -162,11 +313,23 @@ func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount( ctx context.Context, txn *sql.Tx, jsonNID int64, ) (int64, error) { var count int64 - stmt := sqlutil.TxStmt(txn, s.selectQueueEDUReferenceJSONCountStmt) - err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count) - if err == sql.ErrNoRows { + + // "SELECT COUNT(*) FROM federationsender_queue_edus" + + // " WHERE json_nid = $1" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": jsonNID, + } + + // stmt := sqlutil.TxStmt(txn, s.selectQueueEDUReferenceJSONCountStmt) + rows, err := queryQueueEDUCNumber(s, ctx, s.selectQueueEDUReferenceJSONCountStmt, params) + if len(rows) == 0 { return -1, nil } + // err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count) + count = rows[0].Number return count, err } @@ -174,34 +337,52 @@ func (s *queueEDUsStatements) SelectQueueEDUCount( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ) (int64, error) { var count int64 - stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt) - err := stmt.QueryRowContext(ctx, serverName).Scan(&count) - if err == sql.ErrNoRows { + + // "SELECT COUNT(*) FROM federationsender_queue_edus" + + // " WHERE server_name = $1" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": serverName, + } + + // stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt) + rows, err := queryQueueEDUCNumber(s, ctx, s.selectQueueEDUCountStmt, params) + if len(rows) == 0 { // It's acceptable for there to be no rows referencing a given // JSON NID but it's not an error condition. Just return as if // there's a zero count. return 0, nil } + // err := stmt.QueryRowContext(ctx, serverName).Scan(&count) + count = rows[0].Number return count, err } func (s *queueEDUsStatements) SelectQueueEDUServerNames( ctx context.Context, txn *sql.Tx, ) ([]gomatrixserverlib.ServerName, error) { - stmt := sqlutil.TxStmt(txn, s.selectQueueEDUServerNamesStmt) - rows, err := stmt.QueryContext(ctx) + + // "SELECT DISTINCT server_name FROM federationsender_queue_edus" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + } + + // stmt := sqlutil.TxStmt(txn, s.selectQueueEDUServerNamesStmt) + // rows, err := stmt.QueryContext(ctx) + rows, err := queryQueueEDUCDistinct(s, ctx, s.selectQueueEDUServerNamesStmt, params) if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") var result []gomatrixserverlib.ServerName - for rows.Next() { + for _, item := range rows { var serverName gomatrixserverlib.ServerName - if err = rows.Scan(&serverName); err != nil { - return nil, err - } + serverName = gomatrixserverlib.ServerName(item.ServerName) result = append(result, serverName) } - return result, rows.Err() + return result, nil } diff --git a/federationsender/storage/cosmosdb/queue_json_table.go b/federationsender/storage/cosmosdb/queue_json_table.go index 74cee2b17..8441e8e84 100644 --- a/federationsender/storage/cosmosdb/queue_json_table.go +++ b/federationsender/storage/cosmosdb/queue_json_table.go @@ -19,97 +19,205 @@ import ( "context" "database/sql" "fmt" - "strings" + "time" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/internal/cosmosdbapi" ) -const queueJSONSchema = ` --- The queue_retry_json table contains event contents that --- we failed to send. -CREATE TABLE IF NOT EXISTS federationsender_queue_json ( - -- The JSON NID. This allows the federationsender_queue_retry table to - -- cross-reference to find the JSON blob. - json_nid INTEGER PRIMARY KEY AUTOINCREMENT, - -- The JSON body. Text so that we preserve UTF-8. - json_body TEXT NOT NULL -); -` +// const queueJSONSchema = ` +// -- The queue_retry_json table contains event contents that +// -- we failed to send. +// CREATE TABLE IF NOT EXISTS federationsender_queue_json ( +// -- The JSON NID. This allows the federationsender_queue_retry table to +// -- cross-reference to find the JSON blob. +// json_nid INTEGER PRIMARY KEY AUTOINCREMENT, +// -- The JSON body. Text so that we preserve UTF-8. +// json_body TEXT NOT NULL +// ); +// ` -const insertJSONSQL = "" + - "INSERT INTO federationsender_queue_json (json_body)" + - " VALUES ($1)" - -const deleteJSONSQL = "" + - "DELETE FROM federationsender_queue_json WHERE json_nid IN ($1)" - -const selectJSONSQL = "" + - "SELECT json_nid, json_body FROM federationsender_queue_json" + - " WHERE json_nid IN ($1)" - -type queueJSONStatements struct { - db *sql.DB - insertJSONStmt *sql.Stmt - //deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic - //selectJSONStmt *sql.Stmt - prepared at runtime due to variadic +type QueueJSONCosmos struct { + JSONNID int64 `json:"json_nid"` + JSONBody []byte `json:"json_body"` } -func NewSQLiteQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) { +type QueueJSONCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + QueueJSON QueueJSONCosmos `json:"mx_federationsender_queue_json"` +} + +// const insertJSONSQL = "" + +// "INSERT INTO federationsender_queue_json (json_body)" + +// " VALUES ($1)" + +// "DELETE FROM federationsender_queue_json WHERE json_nid IN ($1)" +const deleteJSONSQL = "" + + "select * from c where c._cn = @x1 " + + "and ARRAY_CONTAINS(@x2, c.mx_federationsender_queue_json.json_nid) " + +// "SELECT json_nid, json_body FROM federationsender_queue_json" + +// " WHERE json_nid IN ($1)" +const selectJSONSQL = "" + + "select * from c where c._cn = @x1 " + + "and ARRAY_CONTAINS(@x2, c.mx_federationsender_queue_json.json_nid) " + +type queueJSONStatements struct { + db *Database + // insertJSONStmt *sql.Stmt + //deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic + //selectJSONStmt *sql.Stmt - prepared at runtime due to variadic + tableName string +} + +func queryQueueJSON(s *queueJSONStatements, ctx context.Context, qry string, params map[string]interface{}) ([]QueueJSONCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []QueueJSONCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, err + } + return response, nil +} + +func deleteQueueJSON(s *queueJSONStatements, ctx context.Context, dbData QueueJSONCosmosData) error { + var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk) + var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + dbData.Id, + options) + + if err != nil { + return err + } + return err +} + +func NewCosmosDBQueueJSONTable(db *Database) (s *queueJSONStatements, err error) { s = &queueJSONStatements{ db: db, } - _, err = db.Exec(queueJSONSchema) - if err != nil { - return - } - if s.insertJSONStmt, err = db.Prepare(insertJSONSQL); err != nil { - return - } + s.tableName = "queue_jsons" return } func (s *queueJSONStatements) InsertQueueJSON( ctx context.Context, txn *sql.Tx, json string, ) (lastid int64, err error) { - stmt := sqlutil.TxStmt(txn, s.insertJSONStmt) - res, err := stmt.ExecContext(ctx, json) + + // "INSERT INTO federationsender_queue_json (json_body)" + + // " VALUES ($1)" + + // json_nid INTEGER PRIMARY KEY AUTOINCREMENT, + idSeq, err := GetNextQueueJSONNID(s, ctx) + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // json_nid INTEGER PRIMARY KEY AUTOINCREMENT, + docId := fmt.Sprintf("%d", idSeq) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + //Convert to byte + jsonData := []byte(json) + + data := QueueJSONCosmos{ + JSONNID: idSeq, + JSONBody: jsonData, + } + + dbData := &QueueJSONCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + QueueJSON: data, + } + + // stmt := sqlutil.TxStmt(txn, s.insertJSONStmt) + // res, err := stmt.ExecContext(ctx, json) + + var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk) + _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + &dbData, + options) + if err != nil { return 0, fmt.Errorf("stmt.QueryContext: %w", err) } - lastid, err = res.LastInsertId() - if err != nil { - return 0, fmt.Errorf("res.LastInsertId: %w", err) - } + lastid = idSeq return } func (s *queueJSONStatements) DeleteQueueJSON( ctx context.Context, txn *sql.Tx, nids []int64, ) error { - deleteSQL := strings.Replace(deleteJSONSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) - deleteStmt, err := txn.Prepare(deleteSQL) + + // "DELETE FROM federationsender_queue_json WHERE json_nid IN ($1)" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": nids, + } + + // deleteSQL := strings.Replace(deleteJSONSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) + // deleteStmt, err := txn.Prepare(deleteSQL) + // stmt := sqlutil.TxStmt(txn, deleteStmt) + rows, err := queryQueueJSON(s, ctx, deleteJSONSQL, params) + if err != nil { - return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err) + return err } - iNIDs := make([]interface{}, len(nids)) - for k, v := range nids { - iNIDs[k] = v - } + // iNIDs := make([]interface{}, len(nids)) + // for k, v := range nids { + // iNIDs[k] = v + // } - stmt := sqlutil.TxStmt(txn, deleteStmt) - _, err = stmt.ExecContext(ctx, iNIDs...) + for _, item := range rows { + err = deleteQueueJSON(s, ctx, item) + } return err } func (s *queueJSONStatements) SelectQueueJSON( ctx context.Context, txn *sql.Tx, jsonNIDs []int64, ) (map[int64][]byte, error) { - selectSQL := strings.Replace(selectJSONSQL, "($1)", sqlutil.QueryVariadic(len(jsonNIDs)), 1) - selectStmt, err := txn.Prepare(selectSQL) + + // "SELECT json_nid, json_body FROM federationsender_queue_json" + + // " WHERE json_nid IN ($1)" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": jsonNIDs, + } + + // selectSQL := strings.Replace(selectJSONSQL, "($1)", sqlutil.QueryVariadic(len(jsonNIDs)), 1) + // selectStmt, err := txn.Prepare(selectSQL) + rows, err := queryQueueJSON(s, ctx, selectJSONSQL, params) + if err != nil { - return nil, fmt.Errorf("s.selectQueueJSON s.db.Prepare: %w", err) + return nil, fmt.Errorf("s.selectQueueJSON stmt.QueryContext: %w", err) } iNIDs := make([]interface{}, len(jsonNIDs)) @@ -118,18 +226,11 @@ func (s *queueJSONStatements) SelectQueueJSON( } blobs := map[int64][]byte{} - stmt := sqlutil.TxStmt(txn, selectStmt) - rows, err := stmt.QueryContext(ctx, iNIDs...) - if err != nil { - return nil, fmt.Errorf("s.selectQueueJSON stmt.QueryContext: %w", err) - } - defer internal.CloseAndLogIfError(ctx, rows, "selectJSON: rows.close() failed") - for rows.Next() { + for _, item := range rows { var nid int64 var blob []byte - if err = rows.Scan(&nid, &blob); err != nil { - return nil, fmt.Errorf("s.selectQueueJSON rows.Scan: %w", err) - } + nid = item.QueueJSON.JSONNID + blob = item.QueueJSON.JSONBody blobs[nid] = blob } return blobs, err diff --git a/federationsender/storage/cosmosdb/queue_json_table_json_nid_seq.go b/federationsender/storage/cosmosdb/queue_json_table_json_nid_seq.go new file mode 100644 index 000000000..3edef960f --- /dev/null +++ b/federationsender/storage/cosmosdb/queue_json_table_json_nid_seq.go @@ -0,0 +1,12 @@ +package cosmosdb + +import ( + "context" + + "github.com/matrix-org/dendrite/internal/cosmosdbutil" +) + +func GetNextQueueJSONNID(s *queueJSONStatements, ctx context.Context) (int64, error) { + const docId = "json_nid_seq" + return cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 1) +} diff --git a/federationsender/storage/cosmosdb/queue_pdus_table.go b/federationsender/storage/cosmosdb/queue_pdus_table.go index 8ca0a1fde..e7ddb6a32 100644 --- a/federationsender/storage/cosmosdb/queue_pdus_table.go +++ b/federationsender/storage/cosmosdb/queue_pdus_table.go @@ -19,96 +19,188 @@ import ( "context" "database/sql" "fmt" - "strings" + "time" + + "github.com/matrix-org/dendrite/internal/cosmosdbapi" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" ) -const queuePDUsSchema = ` -CREATE TABLE IF NOT EXISTS federationsender_queue_pdus ( - -- The transaction ID that was generated before persisting the event. - transaction_id TEXT NOT NULL, - -- The domain part of the user ID the m.room.member event is for. - server_name TEXT NOT NULL, - -- The JSON NID from the federationsender_queue_pdus_json table. - json_nid BIGINT NOT NULL -); +// const queuePDUsSchema = ` +// CREATE TABLE IF NOT EXISTS federationsender_queue_pdus ( +// -- The transaction ID that was generated before persisting the event. +// transaction_id TEXT NOT NULL, +// -- The domain part of the user ID the m.room.member event is for. +// server_name TEXT NOT NULL, +// -- The JSON NID from the federationsender_queue_pdus_json table. +// json_nid BIGINT NOT NULL +// ); -CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_pdus_pdus_json_nid_idx - ON federationsender_queue_pdus (json_nid, server_name); -` +// CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_pdus_pdus_json_nid_idx +// ON federationsender_queue_pdus (json_nid, server_name); +// ` -const insertQueuePDUSQL = "" + - "INSERT INTO federationsender_queue_pdus (transaction_id, server_name, json_nid)" + - " VALUES ($1, $2, $3)" - -const deleteQueuePDUsSQL = "" + - "DELETE FROM federationsender_queue_pdus WHERE server_name = $1 AND json_nid IN ($2)" - -const selectQueueNextTransactionIDSQL = "" + - "SELECT transaction_id FROM federationsender_queue_pdus" + - " WHERE server_name = $1" + - " ORDER BY transaction_id ASC" + - " LIMIT 1" - -const selectQueuePDUsSQL = "" + - "SELECT json_nid FROM federationsender_queue_pdus" + - " WHERE server_name = $1" + - " LIMIT $2" - -const selectQueuePDUsReferenceJSONCountSQL = "" + - "SELECT COUNT(*) FROM federationsender_queue_pdus" + - " WHERE json_nid = $1" - -const selectQueuePDUsCountSQL = "" + - "SELECT COUNT(*) FROM federationsender_queue_pdus" + - " WHERE server_name = $1" - -const selectQueuePDUsServerNamesSQL = "" + - "SELECT DISTINCT server_name FROM federationsender_queue_pdus" - -type queuePDUsStatements struct { - db *sql.DB - insertQueuePDUStmt *sql.Stmt - selectQueueNextTransactionIDStmt *sql.Stmt - selectQueuePDUsStmt *sql.Stmt - selectQueueReferenceJSONCountStmt *sql.Stmt - selectQueuePDUsCountStmt *sql.Stmt - selectQueueServerNamesStmt *sql.Stmt - // deleteQueuePDUsStmt *sql.Stmt - prepared at runtime due to variadic +type QueuePDUCosmos struct { + TransactionID string `json:"transaction_id"` + ServerName string `json:"server_name"` + JSONNID int64 `json:"json_nid"` } -func NewSQLiteQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) { +type QueuePDUCosmosNumber struct { + Number int64 `json:"number"` +} + +type QueuePDUCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + QueuePDU QueuePDUCosmos `json:"mx_federationsender_queue_pdu"` +} + +// const insertQueuePDUSQL = "" + +// "INSERT INTO federationsender_queue_pdus (transaction_id, server_name, json_nid)" + +// " VALUES ($1, $2, $3)" + +// "DELETE FROM federationsender_queue_pdus WHERE server_name = $1 AND json_nid IN ($2)" +const deleteQueuePDUsSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_federationsender_queue_pdu.server_name = @x2 " + + "and ARRAY_CONTAINS(@x3, c.mx_federationsender_queue_pdu.json_nid) " + +// "SELECT transaction_id FROM federationsender_queue_pdus" + +// " WHERE server_name = $1" + +// " ORDER BY transaction_id ASC" + +// " LIMIT 1" +const selectQueueNextTransactionIDSQL = "" + + "select top 1 * from c where c._cn = @x1 " + + "and c.mx_federationsender_queue_pdu.server_name = @x2 " + + "order by c.mx_federationsender_queue_pdu.transaction_id asc " + +// "SELECT json_nid FROM federationsender_queue_pdus" + +// " WHERE server_name = $1" + +// " LIMIT $2" +const selectQueuePDUsSQL = "" + + "select top @x3 * from c where c._cn = @x1 " + + "and c.mx_federationsender_queue_pdu.server_name = @x2 " + +// "SELECT COUNT(*) FROM federationsender_queue_pdus" + +// " WHERE json_nid = $1" +const selectQueuePDUsReferenceJSONCountSQL = "" + + "select count(c._ts) as number from c where c._cn = @x1 " + + "and c.mx_federationsender_queue_pdu.json_nid = @x2 " + +// "SELECT COUNT(*) FROM federationsender_queue_pdus" + +// " WHERE server_name = $1" +const selectQueuePDUsCountSQL = "" + + "select count(c._ts) as number from c where c._cn = @x1 " + + "and c.mx_federationsender_queue_pdu.server_name = @x2 " + +// "SELECT DISTINCT server_name FROM federationsender_queue_pdus" +const selectQueuePDUsServerNamesSQL = "" + + "select distinct c.mx_federationsender_queue_pdu.server_name from c where c._cn = @x1 " + +type queuePDUsStatements struct { + db *Database + // insertQueuePDUStmt *sql.Stmt + selectQueueNextTransactionIDStmt string + selectQueuePDUsStmt string + selectQueueReferenceJSONCountStmt string + selectQueuePDUsCountStmt string + selectQueueServerNamesStmt string + // deleteQueuePDUsStmt *sql.Stmt - prepared at runtime due to variadic + tableName string +} + +func queryQueuePDU(s *queuePDUsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]QueuePDUCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []QueuePDUCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, err + } + return response, nil +} + +func queryQueuePDUDistinct(s *queuePDUsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]QueuePDUCosmos, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []QueuePDUCosmos + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, err + } + return response, nil +} + +func queryQueuePDUNumber(s *queuePDUsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]QueuePDUCosmosNumber, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []QueuePDUCosmosNumber + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, err + } + return response, nil +} + +func deleteQueuePDU(s *queuePDUsStatements, ctx context.Context, dbData QueuePDUCosmosData) error { + var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk) + var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + dbData.Id, + options) + + if err != nil { + return err + } + return err +} + +func NewCosmosDBQueuePDUsTable(db *Database) (s *queuePDUsStatements, err error) { s = &queuePDUsStatements{ db: db, } - _, err = db.Exec(queuePDUsSchema) - if err != nil { - return - } - if s.insertQueuePDUStmt, err = db.Prepare(insertQueuePDUSQL); err != nil { - return - } - //if s.deleteQueuePDUsStmt, err = db.Prepare(deleteQueuePDUsSQL); err != nil { - // return - //} - if s.selectQueueNextTransactionIDStmt, err = db.Prepare(selectQueueNextTransactionIDSQL); err != nil { - return - } - if s.selectQueuePDUsStmt, err = db.Prepare(selectQueuePDUsSQL); err != nil { - return - } - if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueuePDUsReferenceJSONCountSQL); err != nil { - return - } - if s.selectQueuePDUsCountStmt, err = db.Prepare(selectQueuePDUsCountSQL); err != nil { - return - } - if s.selectQueueServerNamesStmt, err = db.Prepare(selectQueuePDUsServerNamesSQL); err != nil { - return - } + s.selectQueueNextTransactionIDStmt = selectQueueNextTransactionIDSQL + s.selectQueuePDUsStmt = selectQueuePDUsSQL + s.selectQueueReferenceJSONCountStmt = selectQueuePDUsReferenceJSONCountSQL + s.selectQueuePDUsCountStmt = selectQueuePDUsCountSQL + s.selectQueueServerNamesStmt = selectQueuePDUsServerNamesSQL + s.tableName = "queue_pdus" return } @@ -119,13 +211,47 @@ func (s *queuePDUsStatements) InsertQueuePDU( serverName gomatrixserverlib.ServerName, nid int64, ) error { - stmt := sqlutil.TxStmt(txn, s.insertQueuePDUStmt) - _, err := stmt.ExecContext( + + // "INSERT INTO federationsender_queue_pdus (transaction_id, server_name, json_nid)" + + // " VALUES ($1, $2, $3)" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_pdus_pdus_json_nid_idx + // ON federationsender_queue_pdus (json_nid, server_name); + docId := fmt.Sprintf("%d_%s", nid, serverName) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + data := QueuePDUCosmos{ + JSONNID: nid, + ServerName: string(serverName), + TransactionID: string(transactionID), + } + + dbData := &QueuePDUCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + QueuePDU: data, + } + + // stmt := sqlutil.TxStmt(txn, s.insertQueuePDUStmt) + // _, err := stmt.ExecContext( + // ctx, + // transactionID, // the transaction ID that we initially attempted + // serverName, // destination server name + // nid, // JSON blob NID + // ) + + var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk) + _, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument( ctx, - transactionID, // the transaction ID that we initially attempted - serverName, // destination server name - nid, // JSON blob NID - ) + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + &dbData, + options) + return err } @@ -134,20 +260,31 @@ func (s *queuePDUsStatements) DeleteQueuePDUs( serverName gomatrixserverlib.ServerName, jsonNIDs []int64, ) error { - deleteSQL := strings.Replace(deleteQueuePDUsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1) - deleteStmt, err := txn.Prepare(deleteSQL) + + // "DELETE FROM federationsender_queue_pdus WHERE server_name = $1 AND json_nid IN ($2)" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": serverName, + "@x3": jsonNIDs, + } + + // deleteSQL := strings.Replace(deleteQueuePDUsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1) + // deleteStmt, err := txn.Prepare(deleteSQL) + rows, err := queryQueuePDU(s, ctx, deleteQueuePDUsSQL, params) + if err != nil { - return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err) + return err } - params := make([]interface{}, len(jsonNIDs)+1) - params[0] = serverName - for k, v := range jsonNIDs { - params[k+1] = v + for _, item := range rows { + // stmt := sqlutil.TxStmt(txn, deleteStmt) + err = deleteQueuePDU(s, ctx, item) + if err != nil { + return err + } } - - stmt := sqlutil.TxStmt(txn, deleteStmt) - _, err = stmt.ExecContext(ctx, params...) return err } @@ -155,11 +292,30 @@ func (s *queuePDUsStatements) SelectQueuePDUNextTransactionID( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ) (gomatrixserverlib.TransactionID, error) { var transactionID gomatrixserverlib.TransactionID - stmt := sqlutil.TxStmt(txn, s.selectQueueNextTransactionIDStmt) - err := stmt.QueryRowContext(ctx, serverName).Scan(&transactionID) - if err == sql.ErrNoRows { + + // "SELECT transaction_id FROM federationsender_queue_pdus" + + // " WHERE server_name = $1" + + // " ORDER BY transaction_id ASC" + + // " LIMIT 1" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": serverName, + } + + // stmt := sqlutil.TxStmt(txn, s.selectQueueNextTransactionIDStmt) + rows, err := queryQueuePDU(s, ctx, s.selectQueueNextTransactionIDStmt, params) + + if err != nil { + return "", err + } + + if len(rows) == 0 { return "", nil } + // err := stmt.QueryRowContext(ctx, serverName).Scan(&transactionID) + transactionID = gomatrixserverlib.TransactionID(rows[0].QueuePDU.TransactionID) return transactionID, err } @@ -167,11 +323,28 @@ func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount( ctx context.Context, txn *sql.Tx, jsonNID int64, ) (int64, error) { var count int64 - stmt := sqlutil.TxStmt(txn, s.selectQueueReferenceJSONCountStmt) - err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count) - if err == sql.ErrNoRows { + + // "SELECT COUNT(*) FROM federationsender_queue_pdus" + + // " WHERE json_nid = $1" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": jsonNID, + } + + // stmt := sqlutil.TxStmt(txn, s.selectQueueReferenceJSONCountStmt) + rows, err := queryQueuePDUNumber(s, ctx, s.selectQueueReferenceJSONCountStmt, params) + + if err != nil { + return -1, err + } + + if len(rows) == 0 { return -1, nil } + // err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count) + count = rows[0].Number return count, err } @@ -179,14 +352,31 @@ func (s *queuePDUsStatements) SelectQueuePDUCount( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ) (int64, error) { var count int64 - stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt) - err := stmt.QueryRowContext(ctx, serverName).Scan(&count) - if err == sql.ErrNoRows { + + // "SELECT COUNT(*) FROM federationsender_queue_pdus" + + // " WHERE server_name = $1" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": serverName, + } + + // stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt) + rows, err := queryQueuePDUNumber(s, ctx, s.selectQueuePDUsCountStmt, params) + + if err != nil { + return 0, err + } + + if len(rows) == 0 { // It's acceptable for there to be no rows referencing a given // JSON NID but it's not an error condition. Just return as if // there's a zero count. return 0, nil } + // err := stmt.QueryRowContext(ctx, serverName).Scan(&count) + count = rows[0].Number return count, err } @@ -195,41 +385,58 @@ func (s *queuePDUsStatements) SelectQueuePDUs( serverName gomatrixserverlib.ServerName, limit int, ) ([]int64, error) { - stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsStmt) - rows, err := stmt.QueryContext(ctx, serverName, limit) + + // "SELECT json_nid FROM federationsender_queue_pdus" + + // " WHERE server_name = $1" + + // " LIMIT $2" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": serverName, + "@x3": limit, + } + + // stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsStmt) + // rows, err := stmt.QueryContext(ctx, serverName, limit) + rows, err := queryQueuePDU(s, ctx, s.selectQueuePDUsStmt, params) + if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") var result []int64 - for rows.Next() { + for _, item := range rows { var nid int64 - if err = rows.Scan(&nid); err != nil { - return nil, err - } + nid = item.QueuePDU.JSONNID result = append(result, nid) } - return result, rows.Err() + return result, nil } func (s *queuePDUsStatements) SelectQueuePDUServerNames( ctx context.Context, txn *sql.Tx, ) ([]gomatrixserverlib.ServerName, error) { - stmt := sqlutil.TxStmt(txn, s.selectQueueServerNamesStmt) - rows, err := stmt.QueryContext(ctx) + + // "SELECT DISTINCT server_name FROM federationsender_queue_pdus" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + } + + // stmt := sqlutil.TxStmt(txn, s.selectQueueServerNamesStmt) + // rows, err := stmt.QueryContext(ctx) + rows, err := queryQueuePDUDistinct(s, ctx, s.selectQueueServerNamesStmt, params) if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") var result []gomatrixserverlib.ServerName - for rows.Next() { + for _, item := range rows { var serverName gomatrixserverlib.ServerName - if err = rows.Scan(&serverName); err != nil { - return nil, err - } + serverName = gomatrixserverlib.ServerName(item.ServerName) result = append(result, serverName) } - return result, rows.Err() + return result, nil } diff --git a/federationsender/storage/cosmosdb/storage.go b/federationsender/storage/cosmosdb/storage.go index da429046b..6e9192d99 100644 --- a/federationsender/storage/cosmosdb/storage.go +++ b/federationsender/storage/cosmosdb/storage.go @@ -16,68 +16,74 @@ package cosmosdb import ( - "database/sql" + "github.com/matrix-org/dendrite/internal/cosmosdbapi" + "github.com/matrix-org/dendrite/internal/cosmosdbutil" _ "github.com/mattn/go-sqlite3" "github.com/matrix-org/dendrite/federationsender/storage/shared" - "github.com/matrix-org/dendrite/federationsender/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/internal/caching" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" ) // Database stores information needed by the federation sender type Database struct { shared.Database - sqlutil.PartitionOffsetStatements - db *sql.DB - writer sqlutil.Writer + cosmosdbutil.PartitionOffsetStatements + database cosmosdbutil.Database + writer cosmosdbutil.Writer + connection cosmosdbapi.CosmosConnection + databaseName string + cosmosConfig cosmosdbapi.CosmosConfig } // NewDatabase opens a new database func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationSenderCache) (*Database, error) { + conn := cosmosdbutil.GetCosmosConnection(&dbProperties.ConnectionString) + configCosmos := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString) var d Database + d.connection = conn + d.cosmosConfig = configCosmos + d.databaseName = "federationsender" + d.database = cosmosdbutil.Database{ + Connection: conn, + CosmosConfig: configCosmos, + DatabaseName: d.databaseName, + } + var err error - if d.db, err = sqlutil.Open(dbProperties); err != nil { - return nil, err - } - d.writer = sqlutil.NewExclusiveWriter() - joinedHosts, err := NewSQLiteJoinedHostsTable(d.db) + + d.writer = cosmosdbutil.NewExclusiveWriterFake() + joinedHosts, err := NewCosmosDBJoinedHostsTable(&d) if err != nil { return nil, err } - queuePDUs, err := NewSQLiteQueuePDUsTable(d.db) + queuePDUs, err := NewCosmosDBQueuePDUsTable(&d) if err != nil { return nil, err } - queueEDUs, err := NewSQLiteQueueEDUsTable(d.db) + queueEDUs, err := NewCosmosDBQueueEDUsTable(&d) if err != nil { return nil, err } - queueJSON, err := NewSQLiteQueueJSONTable(d.db) + queueJSON, err := NewCosmosDBQueueJSONTable(&d) if err != nil { return nil, err } - blacklist, err := NewSQLiteBlacklistTable(d.db) + blacklist, err := NewCosmosDBBlacklistTable(&d) if err != nil { return nil, err } - outboundPeeks, err := NewSQLiteOutboundPeeksTable(d.db) + outboundPeeks, err := NewCosmosDBOutboundPeeksTable(&d) if err != nil { return nil, err } - inboundPeeks, err := NewSQLiteInboundPeeksTable(d.db) + inboundPeeks, err := NewCosmosDBInboundPeeksTable(&d) if err != nil { return nil, err } - m := sqlutil.NewMigrations() - deltas.LoadRemoveRoomsTable(m) - if err = m.RunDeltas(d.db, dbProperties); err != nil { - return nil, err - } d.Database = shared.Database{ - DB: d.db, + DB: nil, Cache: cache, Writer: d.writer, FederationSenderJoinedHosts: joinedHosts, @@ -88,7 +94,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationS FederationSenderOutboundPeeks: outboundPeeks, FederationSenderInboundPeeks: inboundPeeks, } - if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "federationsender"); err != nil { + if err = d.PartitionOffsetStatements.Prepare(&d.database, d.writer, "federationsender"); err != nil { return nil, err } return &d, nil