diff --git a/appservice/storage/cosmosdb/storage.go b/appservice/storage/cosmosdb/storage.go index bd5cdae93..0fb421a31 100644 --- a/appservice/storage/cosmosdb/storage.go +++ b/appservice/storage/cosmosdb/storage.go @@ -23,7 +23,7 @@ import ( "github.com/matrix-org/dendrite/internal/cosmosdbutil" // Import SQLite database driver - "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" _ "github.com/mattn/go-sqlite3" @@ -31,7 +31,8 @@ import ( // Database stores events intended to be later sent to application services type Database struct { - sqlutil.PartitionOffsetStatements + database cosmosdbutil.Database + cosmosdbutil.PartitionOffsetStatements events eventsStatements txnID txnStatements writer cosmosdbutil.Writer @@ -44,14 +45,23 @@ type Database struct { // NewDatabase opens a new database func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) { conn := cosmosdbutil.GetCosmosConnection(&dbProperties.ConnectionString) - config := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString) + configCosmos := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString) result := &Database{ databaseName: "appservice", connection: conn, - cosmosConfig: config, + cosmosConfig: configCosmos, } + result.database = cosmosdbutil.Database{ + Connection: conn, + CosmosConfig: configCosmos, + DatabaseName: result.databaseName, + } + var err error result.writer = cosmosdbutil.NewExclusiveWriterFake() + if err = result.PartitionOffsetStatements.Prepare(&result.database, result.writer, "appservice"); err != nil { + return nil, err + } if err = result.prepare(); err != nil { return nil, err } diff --git a/dendrite-config-cosmosdb.yaml b/dendrite-config-cosmosdb.yaml index 03d4983f7..c25f2ee9d 100644 --- a/dendrite-config-cosmosdb.yaml +++ b/dendrite-config-cosmosdb.yaml @@ -6,7 +6,7 @@ # # At a minimum, to get started, you will need to update the settings in the # "global" section for your deployment, and you will need to check that the -# database "connection_string" line in each component section is correct. +# database "connection_string" line in each component section is correct. # # Each component with a "database" section can accept the following formats # for "connection_string": @@ -23,13 +23,13 @@ # small number of users and likely will perform worse still with a higher volume # of users. # -# The "max_open_conns" and "max_idle_conns" settings configure the maximum +# The "max_open_conns" and "max_idle_conns" settings configure the maximum # number of open/idle database connections. The value 0 will use the database # engine default, and a negative value will use unlimited connections. The # "conn_max_lifetime" option controls the maximum length of time a database # connection can be idle in seconds - a negative value is unlimited. -# The version of the configuration file. +# The version of the configuration file. version: 1 # Global Matrix configuration. This configuration applies to all components. @@ -154,13 +154,13 @@ client_api: # Whether to require reCAPTCHA for registration. enable_registration_captcha: false - # Settings for ReCAPTCHA. + # Settings for ReCAPTCHA. recaptcha_public_key: "" recaptcha_private_key: "" recaptcha_bypass_secret: "" recaptcha_siteverify_api: "" - # TURN server information that this homeserver should send to clients. + # TURN server information that this homeserver should send to clients. turn: turn_user_lifetime: "" turn_uris: [] @@ -169,7 +169,7 @@ client_api: turn_password: "" # Settings for rate-limited endpoints. Rate limiting will kick in after the - # threshold number of "slots" have been taken by requests from a specific + # threshold number of "slots" have been taken by requests from a specific # host. Each "slot" will be released after the cooloff time in milliseconds. rate_limiting: enabled: true @@ -331,7 +331,7 @@ sync_api: external_api: listen: http://[::]:8073 database: - connection_string: file:syncapi.db + connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=test.criticalarc.com;" max_open_conns: 10 max_idle_conns: 2 conn_max_lifetime: -1 @@ -363,9 +363,9 @@ user_api: max_open_conns: 10 max_idle_conns: 2 conn_max_lifetime: -1 - # The length of time that a token issued for a relying party from + # The length of time that a token issued for a relying party from # /_matrix/client/r0/user/{userId}/openid/request_token endpoint - # is considered to be valid in milliseconds. + # is considered to be valid in milliseconds. # The default lifetime is 3600000ms (60 minutes). # openid_token_lifetime_ms: 3600000 diff --git a/internal/cosmosdbapi/document.go b/internal/cosmosdbapi/document.go index 54f6499e2..6c6edafed 100644 --- a/internal/cosmosdbapi/document.go +++ b/internal/cosmosdbapi/document.go @@ -3,10 +3,23 @@ package cosmosdbapi import ( "context" "fmt" + "strings" ) +func removeSpecialChars(docId string) string { + // The following characters are restricted and cannot be used in the Id property: '/', '\', '?', '#' + invalidChars := [4]string{"/", "\\", "?", "#"} + replaceChar := "," + result := docId + for _, invalidChar := range invalidChars { + result = strings.ReplaceAll(result, invalidChar, replaceChar) + } + return result +} + func GetDocumentId(tenantName string, collectionName string, id string) string { - return fmt.Sprintf("%s,%s,%s", collectionName, tenantName, id) + safeId := removeSpecialChars(id) + return fmt.Sprintf("%s,%s,%s", collectionName, tenantName, safeId) } func GetPartitionKey(tenantName string, collectionName string) string { diff --git a/internal/cosmosdbutil/partition_offset_table.go b/internal/cosmosdbutil/partition_offset_table.go new file mode 100644 index 000000000..32ec22b32 --- /dev/null +++ b/internal/cosmosdbutil/partition_offset_table.go @@ -0,0 +1,227 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdbutil + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/matrix-org/dendrite/internal/sqlutil" + + "github.com/matrix-org/dendrite/internal/cosmosdbapi" + "github.com/matrix-org/gomatrixserverlib" +) + +// // A PartitionOffset is the offset into a partition of the input log. +// type PartitionOffset struct { +// // The ID of the partition. +// Partition int32 +// // The offset into the partition. +// Offset int64 +// } + +// const partitionOffsetsSchema = ` +// -- The offsets that the server has processed up to. +// CREATE TABLE IF NOT EXISTS ${prefix}_partition_offsets ( +// -- The name of the topic. +// topic TEXT NOT NULL, +// -- The 32-bit partition ID +// partition INTEGER NOT NULL, +// -- The 64-bit offset. +// partition_offset BIGINT NOT NULL, +// UNIQUE (topic, partition) +// ); +// ` + +type PartitionOffsetCosmos struct { + Topic string `json:"topic"` + Partition int32 `json:"partition"` + PartitionOffset int64 `json:"partition_offset"` +} + +type PartitionOffsetCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + PartitionOffset PartitionOffsetCosmos `json:"mx_partition_offset"` +} + +// "SELECT partition, partition_offset FROM ${prefix}_partition_offsets WHERE topic = $1" +const selectPartitionOffsetsSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_partition_offset.topic = @x2 " + +// const upsertPartitionOffsetsSQL = "" + +// "INSERT INTO ${prefix}_partition_offsets (topic, partition, partition_offset) VALUES ($1, $2, $3)" + +// " ON CONFLICT (topic, partition)" + +// " DO UPDATE SET partition_offset = $3" + +type Database struct { + Connection cosmosdbapi.CosmosConnection + DatabaseName string + CosmosConfig cosmosdbapi.CosmosConfig + ServerName gomatrixserverlib.ServerName +} + +// PartitionOffsetStatements represents a set of statements that can be run on a partition_offsets table. +type PartitionOffsetStatements struct { + db *Database + writer Writer + selectPartitionOffsetsStmt string + // upsertPartitionOffsetStmt *sql.Stmt + prefix string + tableName string +} + +func queryPartitionOffset(s *PartitionOffsetStatements, ctx context.Context, qry string, params map[string]interface{}) ([]PartitionOffsetCosmosData, error) { + var dbCollectionName = getCollectionName(*s) + var pk = cosmosdbapi.GetPartitionKey(s.db.CosmosConfig.ContainerName, dbCollectionName) + var response []PartitionOffsetCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.Connection).QueryDocuments( + ctx, + s.db.CosmosConfig.DatabaseName, + s.db.CosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, err + } + return response, nil +} + +// Prepare converts the raw SQL statements into prepared statements. +// Takes a prefix to prepend to the table name used to store the partition offsets. +// This allows multiple components to share the same database schema. +func (s *PartitionOffsetStatements) Prepare(db *Database, writer Writer, prefix string) (err error) { + s.db = db + s.writer = writer + s.selectPartitionOffsetsStmt = selectPartitionOffsetsSQL + s.prefix = prefix + s.tableName = "partition_offsets" + return +} + +// PartitionOffsets implements PartitionStorer +func (s *PartitionOffsetStatements) PartitionOffsets( + ctx context.Context, topic string, +) ([]sqlutil.PartitionOffset, error) { + return s.selectPartitionOffsets(ctx, topic) +} + +// SetPartitionOffset implements PartitionStorer +func (s *PartitionOffsetStatements) SetPartitionOffset( + ctx context.Context, topic string, partition int32, offset int64, +) error { + return s.upsertPartitionOffset(ctx, topic, partition, offset) +} + +// selectPartitionOffsets returns all the partition offsets for the given topic. +func (s *PartitionOffsetStatements) selectPartitionOffsets( + ctx context.Context, topic string, +) (results []sqlutil.PartitionOffset, err error) { + + // "SELECT partition, partition_offset FROM ${prefix}_partition_offsets WHERE topic = $1" + + var dbCollectionName = getCollectionName(*s) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": topic, + } + + rows, err := queryPartitionOffset(s, ctx, s.selectPartitionOffsetsStmt, params) + // rows, err := s.selectPartitionOffsetsStmt.QueryContext(ctx, topic) + if err != nil { + return nil, err + } + for _, item := range rows { + var offset sqlutil.PartitionOffset + // if err = rows.Scan(&offset.Partition, &offset.Offset); err != nil { + // return nil, err + // } + offset.Partition = item.PartitionOffset.Partition + offset.Offset = item.PartitionOffset.PartitionOffset + results = append(results, offset) + } + return results, nil +} + +// checkNamedErr calls fn and overwrite err if it was nil and fn returned non-nil +func checkNamedErr(fn func() error, err *error) { + if e := fn(); e != nil && *err == nil { + *err = e + } +} + +// UpsertPartitionOffset updates or inserts the partition offset for the given topic. +func (s *PartitionOffsetStatements) upsertPartitionOffset( + ctx context.Context, topic string, partition int32, offset int64, +) error { + return s.writer.Do(nil, nil, func(txn *sql.Tx) error { + + // "INSERT INTO ${prefix}_partition_offsets (topic, partition, partition_offset) VALUES ($1, $2, $3)" + + // " ON CONFLICT (topic, partition)" + + // " DO UPDATE SET partition_offset = $3" + + // stmt := TxStmt(txn, s.upsertPartitionOffsetStmt) + + dbCollectionName := getCollectionName(*s) + // UNIQUE (topic, partition) + docId := fmt.Sprintf("%s_%d", topic, partition) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.CosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.CosmosConfig.ContainerName, dbCollectionName) + + data := PartitionOffsetCosmos{ + Partition: partition, + PartitionOffset: offset, + Topic: topic, + } + + dbData := &PartitionOffsetCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + // nowMilli := time.Now().UnixNano() / int64(time.Millisecond) + Timestamp: time.Now().Unix(), + PartitionOffset: data, + } + + // _, err := stmt.ExecContext(ctx, topic, partition, offset) + + var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk) + _, _, err := cosmosdbapi.GetClient(s.db.Connection).CreateDocument( + ctx, + s.db.CosmosConfig.DatabaseName, + s.db.CosmosConfig.ContainerName, + &dbData, + options) + + return err + }) +} + +func getCollectionName(s PartitionOffsetStatements) string { + // Include the Prefix + tableName := fmt.Sprintf("%s_%s", s.prefix, s.tableName) + return cosmosdbapi.GetCollectionName(s.db.DatabaseName, tableName) +} diff --git a/keyserver/storage/cosmosdb/one_time_keys_table.go b/keyserver/storage/cosmosdb/one_time_keys_table.go index 6eb33d470..b6a89d194 100644 --- a/keyserver/storage/cosmosdb/one_time_keys_table.go +++ b/keyserver/storage/cosmosdb/one_time_keys_table.go @@ -53,9 +53,9 @@ type OneTimeKeyCosmos struct { KeyJSON []byte `json:"key_json"` } -type OneTimeKeyAlgoCountCosmosData struct { +type OneTimeKeyAlgoNumberCosmosData struct { Algorithm string `json:"algorithm"` - Count int `json:"count"` + Number int `json:"number"` } type OneTimeKeyCosmosData struct { @@ -81,7 +81,7 @@ const selectKeysSQL = "" + // "SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm" const selectKeysCountSQL = "" + - "select c.mx_keyserver_one_time_key.algorithm as algorithm, count(c.mx_keyserver_one_time_key.key_id) as count " + + "select c.mx_keyserver_one_time_key.algorithm, count(c.mx_keyserver_one_time_key.key_id) as number " + "from c where c._cn = @x1 " + "and c.mx_keyserver_one_time_key.user_id = @x2 " + "and c.mx_keyserver_one_time_key.device_id = @x3 " + @@ -110,7 +110,9 @@ type oneTimeKeysStatements struct { func queryOneTimeKey(s *oneTimeKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OneTimeKeyCosmosData, error) { var response []OneTimeKeyCosmosData - var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions() + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) var query = cosmosdbapi.GetQuery(qry, params) var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( ctx, @@ -127,18 +129,20 @@ func queryOneTimeKey(s *oneTimeKeysStatements, ctx context.Context, qry string, return response, nil } -func queryOneTimeKeyAlgoCount(s *oneTimeKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OneTimeKeyAlgoCountCosmosData, error) { - var response []OneTimeKeyAlgoCountCosmosData - var test interface{} +func queryOneTimeKeyAlgoCount(s *oneTimeKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OneTimeKeyAlgoNumberCosmosData, error) { + var response []OneTimeKeyAlgoNumberCosmosData - var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions() + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + // var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions() var query = cosmosdbapi.GetQuery(qry, params) var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( ctx, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, query, - &test, + &response, optionsQry) // When there are no Rows we seem to get the generic Bad Req JSON error @@ -252,7 +256,7 @@ func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, de var algorithm string var count int algorithm = item.Algorithm - count = item.Count + count = item.Number counts.KeyCount[algorithm] = count } return counts, nil @@ -324,7 +328,7 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys( var algorithm string var count int algorithm = item.Algorithm - count = item.Count + count = item.Number counts.KeyCount[algorithm] = count } diff --git a/setup/kafka/kafka.go b/setup/kafka/kafka.go index 936115a37..4e9d3ce62 100644 --- a/setup/kafka/kafka.go +++ b/setup/kafka/kafka.go @@ -49,6 +49,7 @@ func setupNaffka(cfg *config.Kafka) (sarama.Consumer, sarama.SyncProducer) { if cfg.Database.ConnectionString.IsCosmosDB() { //TODO: What do we do for Nafka // cfg.Database.ConnectionString = cosmosdbutil.GetConnectionString(&cfg.Database.ConnectionString) + cfg.Database.ConnectionString = "file:naffka.db" } naffkaDB, err := naffkaStorage.NewDatabase(string(cfg.Database.ConnectionString)) diff --git a/syncapi/storage/cosmosdb/account_data_table.go b/syncapi/storage/cosmosdb/account_data_table.go index 308d3bd1f..1dc10dcfc 100644 --- a/syncapi/storage/cosmosdb/account_data_table.go +++ b/syncapi/storage/cosmosdb/account_data_table.go @@ -18,63 +18,127 @@ package cosmosdb import ( "context" "database/sql" + "fmt" + "time" + + "github.com/matrix-org/dendrite/internal/cosmosdbutil" + + "github.com/matrix-org/dendrite/internal/cosmosdbapi" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" ) -const accountDataSchema = ` -CREATE TABLE IF NOT EXISTS syncapi_account_data_type ( - id INTEGER PRIMARY KEY, - user_id TEXT NOT NULL, - room_id TEXT NOT NULL, - type TEXT NOT NULL, - UNIQUE (user_id, room_id, type) -); -` +// const accountDataSchema = ` +// CREATE TABLE IF NOT EXISTS syncapi_account_data_type ( +// id INTEGER PRIMARY KEY, +// user_id TEXT NOT NULL, +// room_id TEXT NOT NULL, +// type TEXT NOT NULL, +// UNIQUE (user_id, room_id, type) +// ); +// ` -const insertAccountDataSQL = "" + - "INSERT INTO syncapi_account_data_type (id, user_id, room_id, type) VALUES ($1, $2, $3, $4)" + - " ON CONFLICT (user_id, room_id, type) DO UPDATE" + - " SET id = $5" - -const selectAccountDataInRangeSQL = "" + - "SELECT room_id, type FROM syncapi_account_data_type" + - " WHERE user_id = $1 AND id > $2 AND id <= $3" + - " ORDER BY id ASC" - -const selectMaxAccountDataIDSQL = "" + - "SELECT MAX(id) FROM syncapi_account_data_type" - -type accountDataStatements struct { - db *sql.DB - streamIDStatements *streamIDStatements - insertAccountDataStmt *sql.Stmt - selectMaxAccountDataIDStmt *sql.Stmt - selectAccountDataInRangeStmt *sql.Stmt +type AccountDataTypeCosmos struct { + ID int64 `json:"id"` + UserID string `json:"user_id"` + RoomID string `json:"room_id"` + DataType string `json:"type"` } -func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) { +type AccountDataTypeNumberCosmosData struct { + Number int64 `json:"number"` +} + +type AccountDataTypeCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + AccountDataType AccountDataTypeCosmos `json:"mx_syncapi_account_data_type"` +} + +// const insertAccountDataSQL = "" + +// "INSERT INTO syncapi_account_data_type (id, user_id, room_id, type) VALUES ($1, $2, $3, $4)" + +// " ON CONFLICT (user_id, room_id, type) DO UPDATE" + +// " SET id = $5" + +// "SELECT room_id, type FROM syncapi_account_data_type" + +// " WHERE user_id = $1 AND id > $2 AND id <= $3" + +// " ORDER BY id ASC" +const selectAccountDataInRangeSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_syncapi_account_data_type.user_id = @x2 " + + "and c.mx_syncapi_account_data_type.id > @x3 " + + "and c.mx_syncapi_account_data_type.id < @x4 " + + "order by c.mx_syncapi_account_data_type.id " + +// "SELECT MAX(id) FROM syncapi_account_data_type" +const selectMaxAccountDataIDSQL = "" + + "select max(c.mx_syncapi_account_data_type.id) as number from c where c._cn = @x1 " + +type accountDataStatements struct { + db *SyncServerDatasource + streamIDStatements *streamIDStatements + insertAccountDataStmt *sql.Stmt + selectMaxAccountDataIDStmt string + selectAccountDataInRangeStmt string + tableName string +} + +func queryAccountDataType(s *accountDataStatements, ctx context.Context, qry string, params map[string]interface{}) ([]AccountDataTypeCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []AccountDataTypeCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, err + } + return response, nil +} + +func queryAccountDataTypeNumber(s *accountDataStatements, ctx context.Context, qry string, params map[string]interface{}) ([]AccountDataTypeNumberCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []AccountDataTypeNumberCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, cosmosdbutil.ErrNoRows + } + return response, nil +} + +func NewCosmosDBAccountDataTable(db *SyncServerDatasource, streamID *streamIDStatements) (tables.AccountData, error) { s := &accountDataStatements{ db: db, streamIDStatements: streamID, } - _, err := db.Exec(accountDataSchema) - if err != nil { - return nil, err - } - if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil { - return nil, err - } - if s.selectMaxAccountDataIDStmt, err = db.Prepare(selectMaxAccountDataIDSQL); err != nil { - return nil, err - } - if s.selectAccountDataInRangeStmt, err = db.Prepare(selectAccountDataInRangeSQL); err != nil { - return nil, err - } + + s.selectMaxAccountDataIDStmt = selectMaxAccountDataIDSQL + s.selectAccountDataInRangeStmt = selectAccountDataInRangeSQL + s.tableName = "account_data_types" return s, nil } @@ -82,11 +146,46 @@ func (s *accountDataStatements) InsertAccountData( ctx context.Context, txn *sql.Tx, userID, roomID, dataType string, ) (pos types.StreamPosition, err error) { + + // "INSERT INTO syncapi_account_data_type (id, user_id, room_id, type) VALUES ($1, $2, $3, $4)" + + // " ON CONFLICT (user_id, room_id, type) DO UPDATE" + + // " SET id = $5" + pos, err = s.streamIDStatements.nextAccountDataID(ctx, txn) if err != nil { return } - _, err = sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType, pos) + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // UNIQUE (user_id, room_id, type) + docId := fmt.Sprintf("%s_%s_%s", userID, roomID, dataType) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + data := AccountDataTypeCosmos{ + ID: int64(pos), + UserID: userID, + RoomID: roomID, + DataType: dataType, + } + + dbData := &AccountDataTypeCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + AccountDataType: data, + } + + // _, err = sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType, pos) + var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk) + _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + &dbData, + options) + return } @@ -98,21 +197,32 @@ func (s *accountDataStatements) SelectAccountDataInRange( ) (data map[string][]string, err error) { data = make(map[string][]string) - rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High()) + // "SELECT room_id, type FROM syncapi_account_data_type" + + // " WHERE user_id = $1 AND id > $2 AND id <= $3" + + // " ORDER BY id ASC" + + // rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High()) + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": userID, + "@x3": r.Low(), + "@x4": r.High(), + } + + rows, err := queryAccountDataType(s, ctx, s.selectAccountDataInRangeStmt, params) + if err != nil { return } - defer internal.CloseAndLogIfError(ctx, rows, "selectAccountDataInRange: rows.close() failed") var entries int - for rows.Next() { + for _, item := range rows { var dataType string var roomID string - - if err = rows.Scan(&roomID, &dataType); err != nil { - return - } + roomID = item.AccountDataType.RoomID + dataType = item.AccountDataType.DataType // check if we should add this by looking at the filter. // It would be nice if we could do this in SQL-land, but the mix of variadic @@ -147,8 +257,22 @@ func (s *accountDataStatements) SelectAccountDataInRange( func (s *accountDataStatements) SelectMaxAccountDataID( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { + + // "SELECT MAX(id) FROM syncapi_account_data_type" + var nullableID sql.NullInt64 - err = sqlutil.TxStmt(txn, s.selectMaxAccountDataIDStmt).QueryRowContext(ctx).Scan(&nullableID) + // err = sqlutil.TxStmt(txn, s.selectMaxAccountDataIDStmt).QueryRowContext(ctx).Scan(&nullableID) + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + } + + rows, err := queryAccountDataTypeNumber(s, ctx, s.selectMaxAccountDataIDStmt, params) + + if err != cosmosdbutil.ErrNoRows && len(rows) == 1 { + nullableID.Int64 = rows[0].Number + } + if nullableID.Valid { id = nullableID.Int64 } diff --git a/syncapi/storage/cosmosdb/backwards_extremities_table.go b/syncapi/storage/cosmosdb/backwards_extremities_table.go index 5bc7d723b..df5ad830d 100644 --- a/syncapi/storage/cosmosdb/backwards_extremities_table.go +++ b/syncapi/storage/cosmosdb/backwards_extremities_table.go @@ -17,109 +17,238 @@ package cosmosdb import ( "context" "database/sql" + "fmt" + "time" + + "github.com/matrix-org/dendrite/internal/cosmosdbapi" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" ) -const backwardExtremitiesSchema = ` --- Stores output room events received from the roomserver. -CREATE TABLE IF NOT EXISTS syncapi_backward_extremities ( - -- The 'room_id' key for the event. - room_id TEXT NOT NULL, - -- The event ID for the last known event. This is the backwards extremity. - event_id TEXT NOT NULL, - -- The prev_events for the last known event. This is used to update extremities. - prev_event_id TEXT NOT NULL, - PRIMARY KEY(room_id, event_id, prev_event_id) -); -` +// const backwardExtremitiesSchema = ` +// -- Stores output room events received from the roomserver. +// CREATE TABLE IF NOT EXISTS syncapi_backward_extremities ( +// -- The 'room_id' key for the event. +// room_id TEXT NOT NULL, +// -- The event ID for the last known event. This is the backwards extremity. +// event_id TEXT NOT NULL, +// -- The prev_events for the last known event. This is used to update extremities. +// prev_event_id TEXT NOT NULL, +// PRIMARY KEY(room_id, event_id, prev_event_id) +// ); +// ` -const insertBackwardExtremitySQL = "" + - "INSERT INTO syncapi_backward_extremities (room_id, event_id, prev_event_id)" + - " VALUES ($1, $2, $3)" + - " ON CONFLICT (room_id, event_id, prev_event_id) DO NOTHING" - -const selectBackwardExtremitiesForRoomSQL = "" + - "SELECT event_id, prev_event_id FROM syncapi_backward_extremities WHERE room_id = $1" - -const deleteBackwardExtremitySQL = "" + - "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2" - -const deleteBackwardExtremitiesForRoomSQL = "" + - "DELETE FROM syncapi_backward_extremities WHERE room_id = $1" - -type backwardExtremitiesStatements struct { - db *sql.DB - insertBackwardExtremityStmt *sql.Stmt - selectBackwardExtremitiesForRoomStmt *sql.Stmt - deleteBackwardExtremityStmt *sql.Stmt - deleteBackwardExtremitiesForRoomStmt *sql.Stmt +type BackwardExtremityCosmos struct { + RoomID string `json:"room_id"` + EventID string `json:"event_id"` + PrevEventID string `json:"prev_event_id"` } -func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) { - s := &backwardExtremitiesStatements{ - db: db, - } - _, err := db.Exec(backwardExtremitiesSchema) +type BackwardExtremityCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + BackwardExtremity BackwardExtremityCosmos `json:"mx_syncapi_backward_extremity"` +} + +// const insertBackwardExtremitySQL = "" + +// "INSERT INTO syncapi_backward_extremities (room_id, event_id, prev_event_id)" + +// " VALUES ($1, $2, $3)" + +// " ON CONFLICT (room_id, event_id, prev_event_id) DO NOTHING" + +// "SELECT event_id, prev_event_id FROM syncapi_backward_extremities WHERE room_id = $1" +const selectBackwardExtremitiesForRoomSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_syncapi_account_data_type.room_id = @x2 " + +// "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2" +const deleteBackwardExtremitySQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_syncapi_account_data_type.room_id = @x2 " + + "and c.mx_syncapi_account_data_type.prev_event_id = @x3" + +// "DELETE FROM syncapi_backward_extremities WHERE room_id = $1" +const deleteBackwardExtremitiesForRoomSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_syncapi_account_data_type.room_id = @x2 " + +type backwardExtremitiesStatements struct { + db *SyncServerDatasource + // insertBackwardExtremityStmt *sql.Stmt + selectBackwardExtremitiesForRoomStmt string + deleteBackwardExtremityStmt string + deleteBackwardExtremitiesForRoomStmt string + tableName string +} + +func queryBackwardExtremity(s *backwardExtremitiesStatements, ctx context.Context, qry string, params map[string]interface{}) ([]BackwardExtremityCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []BackwardExtremityCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + if err != nil { return nil, err } - if s.insertBackwardExtremityStmt, err = db.Prepare(insertBackwardExtremitySQL); err != nil { - return nil, err + return response, nil +} + +func deleteBackwardExtremity(s *backwardExtremitiesStatements, ctx context.Context, dbData BackwardExtremityCosmosData) error { + var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk) + var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + dbData.Id, + options) + + if err != nil { + return err } - if s.selectBackwardExtremitiesForRoomStmt, err = db.Prepare(selectBackwardExtremitiesForRoomSQL); err != nil { - return nil, err - } - if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil { - return nil, err - } - if s.deleteBackwardExtremitiesForRoomStmt, err = db.Prepare(deleteBackwardExtremitiesForRoomSQL); err != nil { - return nil, err + return err +} + +func NewCosmosDBBackwardsExtremitiesTable(db *SyncServerDatasource) (tables.BackwardsExtremities, error) { + s := &backwardExtremitiesStatements{ + db: db, } + s.selectBackwardExtremitiesForRoomStmt = selectBackwardExtremitiesForRoomSQL + s.deleteBackwardExtremityStmt = deleteBackwardExtremitySQL + s.deleteBackwardExtremitiesForRoomStmt = deleteBackwardExtremitiesForRoomSQL + s.tableName = "backward_extremities" return s, nil } func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string, ) (err error) { - _, err = sqlutil.TxStmt(txn, s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID) - return err + + // "INSERT INTO syncapi_backward_extremities (room_id, event_id, prev_event_id)" + + // " VALUES ($1, $2, $3)" + + // " ON CONFLICT (room_id, event_id, prev_event_id) DO NOTHING" + + // _, err = sqlutil.TxStmt(txn, s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID) + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // PRIMARY KEY(room_id, event_id, prev_event_id) + docId := fmt.Sprintf("%s_%s_%s", roomID, eventID, prevEventID) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + data := BackwardExtremityCosmos{ + EventID: eventID, + PrevEventID: prevEventID, + RoomID: roomID, + } + + dbData := &BackwardExtremityCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + BackwardExtremity: data, + } + + var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk) + _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + &dbData, + options) + + return } func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom( ctx context.Context, roomID string, ) (bwExtrems map[string][]string, err error) { - rows, err := s.selectBackwardExtremitiesForRoomStmt.QueryContext(ctx, roomID) + + // "SELECT event_id, prev_event_id FROM syncapi_backward_extremities WHERE room_id = $1" + + // rows, err := s.selectBackwardExtremitiesForRoomStmt.QueryContext(ctx, roomID) + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": roomID, + } + + rows, err := queryBackwardExtremity(s, ctx, s.selectBackwardExtremitiesForRoomStmt, params) + if err != nil { return } - defer internal.CloseAndLogIfError(ctx, rows, "selectBackwardExtremitiesForRoom: rows.close() failed") bwExtrems = make(map[string][]string) - for rows.Next() { + for _, item := range rows { var eID string var prevEventID string - if err = rows.Scan(&eID, &prevEventID); err != nil { - return - } + eID = item.BackwardExtremity.EventID + prevEventID = item.BackwardExtremity.PrevEventID bwExtrems[eID] = append(bwExtrems[eID], prevEventID) } - return bwExtrems, rows.Err() + return bwExtrems, err } func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( ctx context.Context, txn *sql.Tx, roomID, knownEventID string, ) (err error) { - _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) - return err + + // "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2" + + // _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": roomID, + "@x3": knownEventID, + } + + rows, err := queryBackwardExtremity(s, ctx, s.deleteBackwardExtremityStmt, params) + if err != nil { + return + } + + for _, item := range rows { + err = deleteBackwardExtremity(s, ctx, item) + } + return } func (s *backwardExtremitiesStatements) DeleteBackwardExtremitiesForRoom( ctx context.Context, txn *sql.Tx, roomID string, ) (err error) { - _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremitiesForRoomStmt).ExecContext(ctx, roomID) - return err + + // "DELETE FROM syncapi_backward_extremities WHERE room_id = $1" + + // _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremitiesForRoomStmt).ExecContext(ctx, roomID) + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": roomID, + } + + rows, err := queryBackwardExtremity(s, ctx, s.deleteBackwardExtremitiesForRoomStmt, params) + if err != nil { + return + } + + for _, item := range rows { + err = deleteBackwardExtremity(s, ctx, item) + } + return } diff --git a/syncapi/storage/cosmosdb/current_room_state_table.go b/syncapi/storage/cosmosdb/current_room_state_table.go index a3a5a4a4a..2cba5a078 100644 --- a/syncapi/storage/cosmosdb/current_room_state_table.go +++ b/syncapi/storage/cosmosdb/current_room_state_table.go @@ -20,108 +20,208 @@ import ( "database/sql" "encoding/json" "fmt" - "strings" + "time" + + "github.com/matrix-org/dendrite/internal/cosmosdbutil" + + "github.com/matrix-org/dendrite/internal/cosmosdbapi" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" ) -const currentRoomStateSchema = ` --- Stores the current room state for every room. -CREATE TABLE IF NOT EXISTS syncapi_current_room_state ( - room_id TEXT NOT NULL, - event_id TEXT NOT NULL, - type TEXT NOT NULL, - sender TEXT NOT NULL, - contains_url BOOL NOT NULL DEFAULT false, - state_key TEXT NOT NULL, - headered_event_json TEXT NOT NULL, - membership TEXT, - added_at BIGINT, - UNIQUE (room_id, type, state_key) -); --- for event deletion -CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_current_room_state(event_id, room_id, type, sender, contains_url); --- for querying membership states of users --- CREATE INDEX IF NOT EXISTS syncapi_membership_idx ON syncapi_current_room_state(type, state_key, membership) WHERE membership IS NOT NULL AND membership != 'leave'; --- for querying state by event IDs -CREATE UNIQUE INDEX IF NOT EXISTS syncapi_current_room_state_eventid_idx ON syncapi_current_room_state(event_id); -` +// const currentRoomStateSchema = ` +// -- Stores the current room state for every room. +// CREATE TABLE IF NOT EXISTS syncapi_current_room_state ( +// room_id TEXT NOT NULL, +// event_id TEXT NOT NULL, +// type TEXT NOT NULL, +// sender TEXT NOT NULL, +// contains_url BOOL NOT NULL DEFAULT false, +// state_key TEXT NOT NULL, +// headered_event_json TEXT NOT NULL, +// membership TEXT, +// added_at BIGINT, +// UNIQUE (room_id, type, state_key) +// ); +// -- for event deletion +// CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_current_room_state(event_id, room_id, type, sender, contains_url); +// -- for querying membership states of users +// -- CREATE INDEX IF NOT EXISTS syncapi_membership_idx ON syncapi_current_room_state(type, state_key, membership) WHERE membership IS NOT NULL AND membership != 'leave'; +// -- for querying state by event IDs +// CREATE UNIQUE INDEX IF NOT EXISTS syncapi_current_room_state_eventid_idx ON syncapi_current_room_state(event_id); +// ` -const upsertRoomStateSQL = "" + - "INSERT INTO syncapi_current_room_state (room_id, event_id, type, sender, contains_url, state_key, headered_event_json, membership, added_at)" + - " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" + - " ON CONFLICT (room_id, type, state_key)" + - " DO UPDATE SET event_id = $2, sender=$4, contains_url=$5, headered_event_json = $7, membership = $8, added_at = $9" +type CurrentRoomStateCosmos struct { + RoomID string `json:"room_id"` + EventID string `json:"event_id"` + Type string `json:"type"` + Sender string `json:"sender"` + ContainsUrl bool `json:"contains_url"` + StateKey string `json:"state_key"` + HeaderedEventJSON []byte `json:"headered_event_json"` + Membership string `json:"membership"` + AddedAt int64 `json:"added_at"` +} +type CurrentRoomStateCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + CurrentRoomState CurrentRoomStateCosmos `json:"mx_syncapi_current_room_state"` +} + +// const upsertRoomStateSQL = "" + +// "INSERT INTO syncapi_current_room_state (room_id, event_id, type, sender, contains_url, state_key, headered_event_json, membership, added_at)" + +// " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" + +// " ON CONFLICT (room_id, type, state_key)" + +// " DO UPDATE SET event_id = $2, sender=$4, contains_url=$5, headered_event_json = $7, membership = $8, added_at = $9" + +// "DELETE FROM syncapi_current_room_state WHERE event_id = $1" const deleteRoomStateByEventIDSQL = "" + - "DELETE FROM syncapi_current_room_state WHERE event_id = $1" + "select * from c where c._cn = @x1 " + + "and c.mx_syncapi_current_room_state.event_id = @x2 " +// TODO: Check the SQL is correct here +// "DELETE FROM syncapi_current_room_state WHERE event_id = $1" const DeleteRoomStateForRoomSQL = "" + - "DELETE FROM syncapi_current_room_state WHERE event_id = $1" + "select * from c where c._cn = @x1 " + + "and c.mx_syncapi_current_room_state.room_id = @x2 " +// "SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2" const selectRoomIDsWithMembershipSQL = "" + - "SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2" + "select distinct c.mx_syncapi_current_room_state.room_id from c where c._cn = @x1 " + + "and c.mx_syncapi_current_room_state.type = \"m.room.member\" " + + "and c.mx_syncapi_current_room_state.state_key = @x2 " + + "and c.mx_syncapi_current_room_state.membership = @x3 " +// "SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1" +// // WHEN, ORDER BY and LIMIT will be added by prepareWithFilter const selectCurrentStateSQL = "" + - "SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1" - // WHEN, ORDER BY and LIMIT will be added by prepareWithFilter + "select top @x3 * from c where c._cn = @x1 " + + "and c.mx_syncapi_current_room_state.room_id = @x2 " + // // WHEN, ORDER BY (and LIMIT) will be added by prepareWithFilter +// "SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'" const selectJoinedUsersSQL = "" + - "SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'" + "select * from c where c._cn = @x1 " + + "and c.mx_syncapi_current_room_state.type = \"m.room.member\" " + + "and c.mx_syncapi_current_room_state.membership = \"join\" " -const selectStateEventSQL = "" + - "SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3" +// const selectStateEventSQL = "" + +// "SELECT headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3" +// "SELECT event_id, added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" + +// " FROM syncapi_current_room_state WHERE event_id IN ($1)" const selectEventsWithEventIDsSQL = "" + // TODO: The session_id and transaction_id blanks are here because otherwise // the rowsToStreamEvents expects there to be exactly six columns. We need to // figure out if these really need to be in the DB, and if so, we need a // better permanent fix for this. - neilalexander, 2 Jan 2020 - "SELECT event_id, added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" + - " FROM syncapi_current_room_state WHERE event_id IN ($1)" + "select * from c where c._cn = @x1 " + + "and ARRAY_CONTAINS(@x2, c.mx_syncapi_current_room_state.event_id) " type currentRoomStateStatements struct { - db *sql.DB - streamIDStatements *streamIDStatements - upsertRoomStateStmt *sql.Stmt - deleteRoomStateByEventIDStmt *sql.Stmt - DeleteRoomStateForRoomStmt *sql.Stmt - selectRoomIDsWithMembershipStmt *sql.Stmt - selectJoinedUsersStmt *sql.Stmt - selectStateEventStmt *sql.Stmt + db *SyncServerDatasource + streamIDStatements *streamIDStatements + // upsertRoomStateStmt *sql.Stmt + deleteRoomStateByEventIDStmt string + DeleteRoomStateForRoomStmt string + selectRoomIDsWithMembershipStmt string + selectJoinedUsersStmt string + // selectStateEventStmt *sql.Stmt + tableName string + jsonPropertyName string } -func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) { +func queryCurrentRoomState(s *currentRoomStateStatements, ctx context.Context, qry string, params map[string]interface{}) ([]CurrentRoomStateCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []CurrentRoomStateCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, err + } + return response, nil +} + +func queryCurrentRoomStateDistinct(s *currentRoomStateStatements, ctx context.Context, qry string, params map[string]interface{}) ([]CurrentRoomStateCosmos, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []CurrentRoomStateCosmos + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, err + } + return response, nil +} + +func getEvent(s *currentRoomStateStatements, ctx context.Context, pk string, docId string) (*CurrentRoomStateCosmosData, error) { + response := CurrentRoomStateCosmosData{} + err := cosmosdbapi.GetDocumentOrNil( + s.db.connection, + s.db.cosmosConfig, + ctx, + pk, + docId, + &response) + + if response.Id == "" { + return nil, cosmosdbutil.ErrNoRows + } + + return &response, err +} + +func deleteCurrentRoomState(s *currentRoomStateStatements, ctx context.Context, dbData CurrentRoomStateCosmosData) error { + var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk) + var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + dbData.Id, + options) + + if err != nil { + return err + } + return err +} + +func NewCosmosDBCurrentRoomStateTable(db *SyncServerDatasource, streamID *streamIDStatements) (tables.CurrentRoomState, error) { s := ¤tRoomStateStatements{ db: db, streamIDStatements: streamID, } - _, err := db.Exec(currentRoomStateSchema) - if err != nil { - return nil, err - } - if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil { - return nil, err - } - if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil { - return nil, err - } - if s.DeleteRoomStateForRoomStmt, err = db.Prepare(DeleteRoomStateForRoomSQL); err != nil { - return nil, err - } - if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil { - return nil, err - } - if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil { - return nil, err - } - if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil { - return nil, err - } + s.deleteRoomStateByEventIDStmt = deleteRoomStateByEventIDSQL + s.DeleteRoomStateForRoomStmt = DeleteRoomStateForRoomSQL + s.selectRoomIDsWithMembershipStmt = selectRoomIDsWithMembershipSQL + s.selectJoinedUsersStmt = selectJoinedUsersSQL + s.tableName = "current_room_states" + s.jsonPropertyName = "mx_syncapi_current_room_state" return s, nil } @@ -129,19 +229,27 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (t func (s *currentRoomStateStatements) SelectJoinedUsers( ctx context.Context, ) (map[string][]string, error) { - rows, err := s.selectJoinedUsersStmt.QueryContext(ctx) + + // "SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'" + + // rows, err := s.selectJoinedUsersStmt.QueryContext(ctx) + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + } + + rows, err := queryCurrentRoomState(s, ctx, s.selectJoinedUsersStmt, params) + if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsers: rows.close() failed") result := make(map[string][]string) - for rows.Next() { + for _, item := range rows { var roomID string var userID string - if err := rows.Scan(&roomID, &userID); err != nil { - return nil, err - } + roomID = item.CurrentRoomState.RoomID + userID = item.CurrentRoomState.StateKey //StateKey and Not UserID - See the SQL above users := result[roomID] users = append(users, userID) result[roomID] = users @@ -156,19 +264,28 @@ func (s *currentRoomStateStatements) SelectRoomIDsWithMembership( userID string, membership string, // nolint: unparam ) ([]string, error) { - stmt := sqlutil.TxStmt(txn, s.selectRoomIDsWithMembershipStmt) - rows, err := stmt.QueryContext(ctx, userID, membership) + + // "SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2" + + // stmt := sqlutil.TxStmt(txn, s.selectRoomIDsWithMembershipStmt) + // rows, err := stmt.QueryContext(ctx, userID, membership) + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": userID, + "@x3": membership, + } + + rows, err := queryCurrentRoomStateDistinct(s, ctx, s.selectRoomIDsWithMembershipStmt, params) + if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsWithMembership: rows.close() failed") var result []string - for rows.Next() { + for _, item := range rows { var roomID string - if err := rows.Scan(&roomID); err != nil { - return nil, err - } + roomID = item.RoomID result = append(result, roomID) } return result, nil @@ -180,41 +297,74 @@ func (s *currentRoomStateStatements) SelectCurrentState( stateFilter *gomatrixserverlib.StateFilter, excludeEventIDs []string, ) ([]*gomatrixserverlib.HeaderedEvent, error) { - stmt, params, err := prepareWithFilters( - s.db, txn, selectCurrentStateSQL, - []interface{}{ - roomID, - }, + + // "SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1" + // // WHEN, ORDER BY and LIMIT will be added by prepareWithFilter + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": roomID, + "@x3": stateFilter.Limit, + } + + stmt, params := prepareWithFilters( + s.jsonPropertyName, selectCurrentStateSQL, params, stateFilter.Senders, stateFilter.NotSenders, stateFilter.Types, stateFilter.NotTypes, excludeEventIDs, stateFilter.Limit, FilterOrderNone, ) - if err != nil { - return nil, fmt.Errorf("s.prepareWithFilters: %w", err) - } + rows, err := queryCurrentRoomState(s, ctx, stmt, params) - rows, err := stmt.QueryContext(ctx, params...) if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "selectCurrentState: rows.close() failed") - return rowsToEvents(rows) + return rowsToEvents(&rows) } func (s *currentRoomStateStatements) DeleteRoomStateByEventID( ctx context.Context, txn *sql.Tx, eventID string, ) error { - stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt) - _, err := stmt.ExecContext(ctx, eventID) + + // "DELETE FROM syncapi_current_room_state WHERE event_id = $1" + // stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt) + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": eventID, + } + + rows, err := queryCurrentRoomState(s, ctx, s.deleteRoomStateByEventIDStmt, params) + + for _, item := range rows { + err = deleteCurrentRoomState(s, ctx, item) + } + return err } func (s *currentRoomStateStatements) DeleteRoomStateForRoom( ctx context.Context, txn *sql.Tx, roomID string, ) error { - stmt := sqlutil.TxStmt(txn, s.DeleteRoomStateForRoomStmt) - _, err := stmt.ExecContext(ctx, roomID) + + // TODO: Check the SQL is correct here + // "DELETE FROM syncapi_current_room_state WHERE event_id = $1" + + // stmt := sqlutil.TxStmt(txn, s.DeleteRoomStateForRoomStmt) + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": roomID, + } + + rows, err := queryCurrentRoomState(s, ctx, s.DeleteRoomStateForRoomStmt, params) + + for _, item := range rows { + err = deleteCurrentRoomState(s, ctx, item) + } + return err } @@ -235,20 +385,73 @@ func (s *currentRoomStateStatements) UpsertRoomState( return err } + // "INSERT INTO syncapi_current_room_state (room_id, event_id, type, sender, contains_url, state_key, headered_event_json, membership, added_at)" + + // " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" + + // " ON CONFLICT (room_id, type, state_key)" + + // " DO UPDATE SET event_id = $2, sender=$4, contains_url=$5, headered_event_json = $7, membership = $8, added_at = $9" + + // TODO: Not sure how we can enfore these extra unique indexes + // CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_current_room_state(event_id, room_id, type, sender, contains_url); + // -- for querying membership states of users + // -- CREATE INDEX IF NOT EXISTS syncapi_membership_idx ON syncapi_current_room_state(type, state_key, membership) WHERE membership IS NOT NULL AND membership != 'leave'; + // -- for querying state by event IDs + // CREATE UNIQUE INDEX IF NOT EXISTS syncapi_current_room_state_eventid_idx ON syncapi_current_room_state(event_id); + // upsert state event - stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt) - _, err = stmt.ExecContext( + // stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt) + // _, err = stmt.ExecContext( + // ctx, + // event.RoomID(), + // event.EventID(), + // event.Type(), + // event.Sender(), + // containsURL, + // *event.StateKey(), + // headeredJSON, + // membership, + // addedAt, + // ) + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // " ON CONFLICT (room_id, type, state_key)" + + docId := fmt.Sprintf("%s_%s_%s", event.RoomID(), event.Type(), *event.StateKey()) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + membershipData := "" + if membership != nil { + membershipData = *membership + } + + data := CurrentRoomStateCosmos{ + RoomID: event.RoomID(), + EventID: event.EventID(), + Type: event.Type(), + Sender: event.Sender(), + ContainsUrl: containsURL, + StateKey: *event.StateKey(), + HeaderedEventJSON: headeredJSON, + Membership: membershipData, + AddedAt: int64(addedAt), + } + + dbData := &CurrentRoomStateCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + CurrentRoomState: data, + } + + // _, err = sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType, pos) + var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk) + _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( ctx, - event.RoomID(), - event.EventID(), - event.Type(), - event.Sender(), - containsURL, - *event.StateKey(), - headeredJSON, - membership, - addedAt, - ) + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + &dbData, + options) + return err } @@ -262,22 +465,33 @@ func minOfInts(a, b int) int { func (s *currentRoomStateStatements) SelectEventsWithEventIDs( ctx context.Context, txn *sql.Tx, eventIDs []string, ) ([]types.StreamEvent, error) { - iEventIDs := make([]interface{}, len(eventIDs)) - for k, v := range eventIDs { - iEventIDs[k] = v - } + // iEventIDs := make([]interface{}, len(eventIDs)) + // for k, v := range eventIDs { + // iEventIDs[k] = v + // } res := make([]types.StreamEvent, 0, len(eventIDs)) var start int for start < len(eventIDs) { n := minOfInts(len(eventIDs)-start, 999) - query := strings.Replace(selectEventsWithEventIDsSQL, "($1)", sqlutil.QueryVariadic(n), 1) - rows, err := txn.QueryContext(ctx, query, iEventIDs[start:start+n]...) + // "SELECT event_id, added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" + + // " FROM syncapi_current_room_state WHERE event_id IN ($1)" + + // query := strings.Replace(selectEventsWithEventIDsSQL, "@x2", sql.QueryVariadic(n), 1) + + // rows, err := txn.QueryContext(ctx, query, iEventIDs[start:start+n]...) + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": eventIDs, + } + + rows, err := queryCurrentRoomState(s, ctx, s.DeleteRoomStateForRoomStmt, params) + if err != nil { return nil, err } start = start + n - events, err := rowsToStreamEvents(rows) - internal.CloseAndLogIfError(ctx, rows, "selectEventsWithEventIDs: rows.close() failed") + events, err := rowsToStreamEventsFromCurrentRoomState(&rows) if err != nil { return nil, err } @@ -286,14 +500,58 @@ func (s *currentRoomStateStatements) SelectEventsWithEventIDs( return res, nil } -func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) { - result := []*gomatrixserverlib.HeaderedEvent{} - for rows.Next() { - var eventID string - var eventBytes []byte - if err := rows.Scan(&eventID, &eventBytes); err != nil { +// Copied from output_room_events_table +func rowsToStreamEventsFromCurrentRoomState(rows *[]CurrentRoomStateCosmosData) ([]types.StreamEvent, error) { + var result []types.StreamEvent + for _, item := range *rows { + var ( + eventID string + streamPos types.StreamPosition + eventBytes []byte + excludeFromSync bool + // Not required for this call, see output_room_events_table + // sessionID *int64 + // txnID *string + // transactionID *api.TransactionID + ) + // if err := rows.Scan(&eventID, &streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID); err != nil { + // return nil, err + // } + // Taken from the SQL above + eventID = item.CurrentRoomState.EventID + streamPos = types.StreamPosition(item.CurrentRoomState.AddedAt) + + // TODO: Handle redacted events + var ev gomatrixserverlib.HeaderedEvent + if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil { return nil, err } + + // Always null for this use-case + // if sessionID != nil && txnID != nil { + // transactionID = &api.TransactionID{ + // SessionID: *sessionID, + // TransactionID: *txnID, + // } + // } + + result = append(result, types.StreamEvent{ + HeaderedEvent: &ev, + StreamPosition: streamPos, + TransactionID: nil, + ExcludeFromSync: excludeFromSync, + }) + } + return result, nil +} + +func rowsToEvents(rows *[]CurrentRoomStateCosmosData) ([]*gomatrixserverlib.HeaderedEvent, error) { + result := []*gomatrixserverlib.HeaderedEvent{} + for _, item := range *rows { + var eventID string + var eventBytes []byte + eventID = item.CurrentRoomState.EventID + eventBytes = item.CurrentRoomState.HeaderedEventJSON // TODO: Handle redacted events var ev gomatrixserverlib.HeaderedEvent if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil { @@ -307,15 +565,25 @@ func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) { func (s *currentRoomStateStatements) SelectStateEvent( ctx context.Context, roomID, evType, stateKey string, ) (*gomatrixserverlib.HeaderedEvent, error) { - stmt := s.selectStateEventStmt + + // stmt := s.selectStateEventStmt var res []byte - err := stmt.QueryRowContext(ctx, roomID, evType, stateKey).Scan(&res) - if err == sql.ErrNoRows { + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + // " ON CONFLICT (room_id, type, state_key)" + + docId := fmt.Sprintf("%s_%s_%s", roomID, evType, stateKey) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + var response, err = getEvent(s, ctx, pk, cosmosDocId) + + // err := stmt.QueryRowContext(ctx, roomID, evType, stateKey).Scan(&res) + if err == cosmosdbutil.ErrNoRows { return nil, nil } if err != nil { return nil, err } + res = response.CurrentRoomState.HeaderedEventJSON var ev gomatrixserverlib.HeaderedEvent if err = json.Unmarshal(res, &ev); err != nil { return nil, err diff --git a/syncapi/storage/cosmosdb/deltas/20201211125500_sequences.go b/syncapi/storage/cosmosdb/deltas/20201211125500_sequences.go new file mode 100644 index 000000000..8e7ebff86 --- /dev/null +++ b/syncapi/storage/cosmosdb/deltas/20201211125500_sequences.go @@ -0,0 +1,59 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/pressly/goose" +) + +func LoadFromGoose() { + goose.AddMigration(UpFixSequences, DownFixSequences) + goose.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn) +} + +func LoadFixSequences(m *sqlutil.Migrations) { + m.AddMigration(UpFixSequences, DownFixSequences) +} + +func UpFixSequences(tx *sql.Tx) error { + _, err := tx.Exec(` + -- We need to delete all of the existing receipts because the indexes + -- will be wrong, and we'll get primary key violations if we try to + -- reuse existing stream IDs from a different sequence. + DELETE FROM syncapi_receipts; + UPDATE syncapi_stream_id SET stream_id=1 WHERE stream_name="receipt"; + `) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownFixSequences(tx *sql.Tx) error { + _, err := tx.Exec(` + -- We need to delete all of the existing receipts because the indexes + -- will be wrong, and we'll get primary key violations if we try to + -- reuse existing stream IDs from a different sequence. + DELETE FROM syncapi_receipts; + `) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/syncapi/storage/cosmosdb/deltas/20210112130000_sendtodevice_sentcolumn.go b/syncapi/storage/cosmosdb/deltas/20210112130000_sendtodevice_sentcolumn.go new file mode 100644 index 000000000..e0c514102 --- /dev/null +++ b/syncapi/storage/cosmosdb/deltas/20210112130000_sendtodevice_sentcolumn.go @@ -0,0 +1,67 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +func LoadRemoveSendToDeviceSentColumn(m *sqlutil.Migrations) { + m.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn) +} + +func UpRemoveSendToDeviceSentColumn(tx *sql.Tx) error { + _, err := tx.Exec(` + CREATE TEMPORARY TABLE syncapi_send_to_device_backup(id, user_id, device_id, content); + INSERT INTO syncapi_send_to_device_backup SELECT id, user_id, device_id, content FROM syncapi_send_to_device; + DROP TABLE syncapi_send_to_device; + CREATE TABLE syncapi_send_to_device( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + content TEXT NOT NULL + ); + INSERT INTO syncapi_send_to_device SELECT id, user_id, device_id, content FROM syncapi_send_to_device_backup; + DROP TABLE syncapi_send_to_device_backup; + `) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownRemoveSendToDeviceSentColumn(tx *sql.Tx) error { + _, err := tx.Exec(` + CREATE TEMPORARY TABLE syncapi_send_to_device_backup(id, user_id, device_id, content); + INSERT INTO syncapi_send_to_device_backup SELECT id, user_id, device_id, content FROM syncapi_send_to_device; + DROP TABLE syncapi_send_to_device; + CREATE TABLE syncapi_send_to_device( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + content TEXT NOT NULL, + sent_by_token TEXT + ); + INSERT INTO syncapi_send_to_device SELECT id, user_id, device_id, content FROM syncapi_send_to_device_backup; + DROP TABLE syncapi_send_to_device_backup; + `) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} diff --git a/syncapi/storage/cosmosdb/filter_table.go b/syncapi/storage/cosmosdb/filter_table.go index 9447ddd82..f73d03587 100644 --- a/syncapi/storage/cosmosdb/filter_table.go +++ b/syncapi/storage/cosmosdb/filter_table.go @@ -16,80 +16,147 @@ package cosmosdb import ( "context" - "database/sql" "encoding/json" "fmt" + "time" + "github.com/matrix-org/dendrite/internal/cosmosdbapi" + "github.com/matrix-org/dendrite/internal/cosmosdbutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/gomatrixserverlib" ) -const filterSchema = ` --- Stores data about filters -CREATE TABLE IF NOT EXISTS syncapi_filter ( - -- The filter - filter TEXT NOT NULL, - -- The ID - id INTEGER PRIMARY KEY AUTOINCREMENT, - -- The localpart of the Matrix user ID associated to this filter - localpart TEXT NOT NULL, +// const filterSchema = ` +// -- Stores data about filters +// CREATE TABLE IF NOT EXISTS syncapi_filter ( +// -- The filter +// filter TEXT NOT NULL, +// -- The ID +// id INTEGER PRIMARY KEY AUTOINCREMENT, +// -- The localpart of the Matrix user ID associated to this filter +// localpart TEXT NOT NULL, - UNIQUE (id, localpart) -); +// UNIQUE (id, localpart) +// ); -CREATE INDEX IF NOT EXISTS syncapi_filter_localpart ON syncapi_filter(localpart); -` +// CREATE INDEX IF NOT EXISTS syncapi_filter_localpart ON syncapi_filter(localpart); +// ` -const selectFilterSQL = "" + - "SELECT filter FROM syncapi_filter WHERE localpart = $1 AND id = $2" - -const selectFilterIDByContentSQL = "" + - "SELECT id FROM syncapi_filter WHERE localpart = $1 AND filter = $2" - -const insertFilterSQL = "" + - "INSERT INTO syncapi_filter (filter, localpart) VALUES ($1, $2)" - -type filterStatements struct { - db *sql.DB - selectFilterStmt *sql.Stmt - selectFilterIDByContentStmt *sql.Stmt - insertFilterStmt *sql.Stmt +type FilterCosmos struct { + ID int64 `json:"id"` + Filter []byte `json:"filter"` + Localpart string `json:"localpart"` } -func NewSqliteFilterTable(db *sql.DB) (tables.Filter, error) { - _, err := db.Exec(filterSchema) +type FilterCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + Filter FilterCosmos `json:"mx_syncapi_filter"` +} + +// const selectFilterSQL = "" + +// "SELECT filter FROM syncapi_filter WHERE localpart = $1 AND id = $2" + +// "SELECT id FROM syncapi_filter WHERE localpart = $1 AND filter = $2" +const selectFilterIDByContentSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_syncapi_filter.localpart = @x2 " + + "and c.mx_syncapi_filter.filter = @x3 " + +// const insertFilterSQL = "" + +// "INSERT INTO syncapi_filter (filter, localpart) VALUES ($1, $2)" + +type filterStatements struct { + db *SyncServerDatasource + // selectFilterStmt *sql.Stmt + selectFilterIDByContentStmt string + // insertFilterStmt *sql.Stmt + tableName string +} + +func queryFilter(s *filterStatements, ctx context.Context, qry string, params map[string]interface{}) ([]FilterCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []FilterCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + if err != nil { return nil, err } + + if len(response) == 0 { + return nil, cosmosdbutil.ErrNoRows + } + + return response, nil +} + +func getFilter(s *filterStatements, ctx context.Context, pk string, docId string) (*FilterCosmosData, error) { + response := FilterCosmosData{} + err := cosmosdbapi.GetDocumentOrNil( + s.db.connection, + s.db.cosmosConfig, + ctx, + pk, + docId, + &response) + + if response.Id == "" { + return nil, nil + } + + return &response, err +} + +func NewCosmosDBFilterTable(db *SyncServerDatasource) (tables.Filter, error) { s := &filterStatements{ db: db, } - if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil { - return nil, err - } - if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil { - return nil, err - } - if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil { - return nil, err - } + s.selectFilterIDByContentStmt = selectFilterIDByContentSQL + s.tableName = "filters" return s, nil } func (s *filterStatements) SelectFilter( ctx context.Context, localpart string, filterID string, ) (*gomatrixserverlib.Filter, error) { + + // "SELECT filter FROM syncapi_filter WHERE localpart = $1 AND id = $2" + // Retrieve filter from database (stored as canonical JSON) var filterData []byte - err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData) + // err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData) + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // UNIQUE (id, localpart) + docId := fmt.Sprintf("%s_%s", localpart, filterID) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response, err = getFilter(s, ctx, pk, cosmosDocId) + if err != nil { return nil, err } // Unmarshal JSON into Filter struct filter := gomatrixserverlib.DefaultFilter() - if err = json.Unmarshal(filterData, &filter); err != nil { - return nil, err + if response != nil { + filterData = response.Filter.Filter + if err = json.Unmarshal(filterData, &filter); err != nil { + return nil, err + } } return &filter, nil } @@ -97,6 +164,9 @@ func (s *filterStatements) SelectFilter( func (s *filterStatements) InsertFilter( ctx context.Context, filter *gomatrixserverlib.Filter, localpart string, ) (filterID string, err error) { + + // "INSERT INTO syncapi_filter (filter, localpart) VALUES ($1, $2)" + var existingFilterID string // Serialise json @@ -116,25 +186,73 @@ func (s *filterStatements) InsertFilter( // This can result in a race condition when two clients try to insert the // same filter and localpart at the same time, however this is not a // problem as both calls will result in the same filterID - err = s.selectFilterIDByContentStmt.QueryRowContext(ctx, - localpart, filterJSON).Scan(&existingFilterID) - if err != nil && err != sql.ErrNoRows { + // err = s.selectFilterIDByContentStmt.QueryRowContext(ctx, + // localpart, filterJSON).Scan(&existingFilterID) + + // TODO: See if we can avoid the search by Content []byte + // "SELECT id FROM syncapi_filter WHERE localpart = $1 AND filter = $2" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": localpart, + "@x3": filterJSON, + } + + response, err := queryFilter(s, ctx, s.selectFilterIDByContentStmt, params) + + if err != nil && err != cosmosdbutil.ErrNoRows { return "", err } + + if response != nil { + existingFilterID = fmt.Sprintf("%d", response[0].Filter.ID) + } // If it does, return the existing ID if existingFilterID != "" { return existingFilterID, nil } // Otherwise insert the filter and return the new ID - res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart) - if err != nil { - return "", err - } - rowid, err := res.LastInsertId() + // res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart) + + // id INTEGER PRIMARY KEY AUTOINCREMENT, + seqID, seqErr := GetNextFilterID(s, ctx) + if seqErr != nil { + return "", seqErr + } + + data := FilterCosmos{ + ID: seqID, + Localpart: localpart, + Filter: filterJSON, + } + + // UNIQUE (id, localpart) + docId := fmt.Sprintf("%s_%d", localpart, seqID) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + var dbData = FilterCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + Filter: data, + } + + var optionsCreate = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk) + _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + dbData, + optionsCreate) + if err != nil { return "", err } + rowid := seqID filterID = fmt.Sprintf("%d", rowid) return } diff --git a/syncapi/storage/cosmosdb/filter_table_id_seq.go b/syncapi/storage/cosmosdb/filter_table_id_seq.go new file mode 100644 index 000000000..b3d674cc8 --- /dev/null +++ b/syncapi/storage/cosmosdb/filter_table_id_seq.go @@ -0,0 +1,12 @@ +package cosmosdb + +import ( + "context" + + "github.com/matrix-org/dendrite/internal/cosmosdbutil" +) + +func GetNextFilterID(s *filterStatements, ctx context.Context) (int64, error) { + const docId = "id_seq" + return cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 1) +} diff --git a/syncapi/storage/cosmosdb/filtering.go b/syncapi/storage/cosmosdb/filtering.go index 62d2434f6..6d5acda4b 100644 --- a/syncapi/storage/cosmosdb/filtering.go +++ b/syncapi/storage/cosmosdb/filtering.go @@ -1,10 +1,7 @@ package cosmosdb import ( - "database/sql" "fmt" - - "github.com/matrix-org/dendrite/internal/sqlutil" ) type FilterOrder int @@ -15,6 +12,10 @@ const ( FilterOrderDesc ) +func getParamName(offset int) string { + return fmt.Sprintf("@x%d", offset) +} + // prepareWithFilters returns a prepared statement with the // relevant filters included. It also includes an []interface{} // list of all the relevant parameters to pass straight to @@ -24,59 +25,54 @@ const ( // and it's easier just to have the caller extract the relevant // parts. func prepareWithFilters( - db *sql.DB, txn *sql.Tx, query string, params []interface{}, + collectionName string, query string, params map[string]interface{}, senders, notsenders, types, nottypes []string, excludeEventIDs []string, limit int, order FilterOrder, -) (*sql.Stmt, []interface{}, error) { +) (sql string, paramsResult map[string]interface{}) { offset := len(params) - if count := len(senders); count > 0 { - query += " AND sender IN " + sqlutil.QueryVariadicOffset(count, offset) - for _, v := range senders { - params, offset = append(params, v), offset+1 - } + sql = query + paramsResult = params + // "and (@x4 = null OR ARRAY_CONTAINS(@x4, c.mx_syncapi_current_room_state.sender)) " + + if len(senders) > 0 { + offset++ + paramName := getParamName(offset) + sql += fmt.Sprintf("and ARRAY_CONTAINS(%s, c.%s.sender) ", paramName, collectionName) + paramsResult[paramName] = senders } - if count := len(notsenders); count > 0 { - query += " AND sender NOT IN " + sqlutil.QueryVariadicOffset(count, offset) - for _, v := range notsenders { - params, offset = append(params, v), offset+1 - } + // "and (@x5 = null OR NOT ARRAY_CONTAINS(@x5, c.mx_syncapi_current_room_state.sender)) " + + if len(notsenders) > 0 { + offset++ + paramName := getParamName(offset) + sql += fmt.Sprintf("and NOT ARRAY_CONTAINS(%s, c.%s.sender) ", paramName, collectionName) + paramsResult[getParamName(offset)] = notsenders } - if count := len(types); count > 0 { - query += " AND type IN " + sqlutil.QueryVariadicOffset(count, offset) - for _, v := range types { - params, offset = append(params, v), offset+1 - } + // "and (@x6 = null OR ARRAY_CONTAINS(@x6, c.mx_syncapi_current_room_state.type)) " + + if len(types) > 0 { + offset++ + paramName := getParamName(offset) + sql += fmt.Sprintf("and ARRAY_CONTAINS(%s, c.%s.type) ", paramName, collectionName) + paramsResult[paramName] = types } - if count := len(nottypes); count > 0 { - query += " AND type NOT IN " + sqlutil.QueryVariadicOffset(count, offset) - for _, v := range nottypes { - params, offset = append(params, v), offset+1 - } + // "and (@x7 = null OR NOT ARRAY_CONTAINS(@x7, c.mx_syncapi_current_room_state.type)) " + + if len(nottypes) > 0 { + offset++ + paramName := getParamName(offset) + sql += fmt.Sprintf("and NOT ARRAY_CONTAINS(%s, c.%s.type) ", paramName, collectionName) + paramsResult[getParamName(offset)] = nottypes } - if count := len(excludeEventIDs); count > 0 { - query += " AND event_id NOT IN " + sqlutil.QueryVariadicOffset(count, offset) - for _, v := range excludeEventIDs { - params, offset = append(params, v), offset+1 - } + // "and (NOT ARRAY_CONTAINS(@x9, c.mx_syncapi_current_room_state.event_id)) " + if len(excludeEventIDs) > 0 { + offset++ + paramName := getParamName(offset) + sql += fmt.Sprintf("and NOT ARRAY_CONTAINS(%s, c.%s.event_id) ", paramName, collectionName) + paramsResult[getParamName(offset)] = excludeEventIDs } switch order { case FilterOrderAsc: - query += " ORDER BY id ASC" + sql += fmt.Sprintf("order by c.%s.event_id asc ", collectionName) case FilterOrderDesc: - query += " ORDER BY id DESC" + sql += fmt.Sprintf("order by c.%s.event_id desc ", collectionName) } - query += fmt.Sprintf(" LIMIT $%d", offset+1) - params = append(params, limit) - - var stmt *sql.Stmt - var err error - if txn != nil { - stmt, err = txn.Prepare(query) - } else { - stmt, err = db.Prepare(query) - } - if err != nil { - return nil, nil, fmt.Errorf("s.db.Prepare: %w", err) - } - return stmt, params, nil + // query += fmt.Sprintf(" LIMIT $%d", offset+1) + return } diff --git a/syncapi/storage/cosmosdb/invites_table.go b/syncapi/storage/cosmosdb/invites_table.go index ea5d0bd85..6b173d590 100644 --- a/syncapi/storage/cosmosdb/invites_table.go +++ b/syncapi/storage/cosmosdb/invites_table.go @@ -19,80 +19,179 @@ import ( "context" "database/sql" "encoding/json" + "fmt" + "time" + + "github.com/matrix-org/dendrite/internal/cosmosdbapi" + "github.com/matrix-org/dendrite/internal/cosmosdbutil" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" ) -const inviteEventsSchema = ` -CREATE TABLE IF NOT EXISTS syncapi_invite_events ( - id INTEGER PRIMARY KEY, - event_id TEXT NOT NULL, - room_id TEXT NOT NULL, - target_user_id TEXT NOT NULL, - headered_event_json TEXT NOT NULL, - deleted BOOL NOT NULL -); +// const inviteEventsSchema = ` +// CREATE TABLE IF NOT EXISTS syncapi_invite_events ( +// id INTEGER PRIMARY KEY, +// event_id TEXT NOT NULL, +// room_id TEXT NOT NULL, +// target_user_id TEXT NOT NULL, +// headered_event_json TEXT NOT NULL, +// deleted BOOL NOT NULL +// ); -CREATE INDEX IF NOT EXISTS syncapi_invites_target_user_id_idx ON syncapi_invite_events (target_user_id, id); -CREATE INDEX IF NOT EXISTS syncapi_invites_event_id_idx ON syncapi_invite_events (event_id); -` +// CREATE INDEX IF NOT EXISTS syncapi_invites_target_user_id_idx ON syncapi_invite_events (target_user_id, id); +// CREATE INDEX IF NOT EXISTS syncapi_invites_event_id_idx ON syncapi_invite_events (event_id); +// ` -const insertInviteEventSQL = "" + - "INSERT INTO syncapi_invite_events" + - " (id, room_id, event_id, target_user_id, headered_event_json, deleted)" + - " VALUES ($1, $2, $3, $4, $5, false)" - -const deleteInviteEventSQL = "" + - "UPDATE syncapi_invite_events SET deleted=true, id=$1 WHERE event_id = $2" - -const selectInviteEventsInRangeSQL = "" + - "SELECT room_id, headered_event_json, deleted FROM syncapi_invite_events" + - " WHERE target_user_id = $1 AND id > $2 AND id <= $3" + - " ORDER BY id DESC" - -const selectMaxInviteIDSQL = "" + - "SELECT MAX(id) FROM syncapi_invite_events" - -type inviteEventsStatements struct { - db *sql.DB - streamIDStatements *streamIDStatements - insertInviteEventStmt *sql.Stmt - selectInviteEventsInRangeStmt *sql.Stmt - deleteInviteEventStmt *sql.Stmt - selectMaxInviteIDStmt *sql.Stmt +type InviteEventCosmos struct { + ID int64 `json:"id"` + EventID string `json:"event_id"` + RoomID string `json:"room_id"` + TargetUserID string `json:"target_user_id"` + HeaderedEventJSON []byte `json:"headered_event_json"` + Deleted bool `json:"deleted"` } -func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) { +type InviteEventCosmosMaxNumber struct { + Max int64 `json:"number"` +} + +type InviteEventCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + InviteEvent InviteEventCosmos `json:"mx_syncapi_invite_event"` +} + +// const insertInviteEventSQL = "" + +// "INSERT INTO syncapi_invite_events" + +// " (id, room_id, event_id, target_user_id, headered_event_json, deleted)" + +// " VALUES ($1, $2, $3, $4, $5, false)" + +// "UPDATE syncapi_invite_events SET deleted=true, id=$1 WHERE event_id = $2" +const deleteInviteEventSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_syncapi_invite_event.event_id = @x2 " + +// "SELECT room_id, headered_event_json, deleted FROM syncapi_invite_events" + +// " WHERE target_user_id = $1 AND id > $2 AND id <= $3" + +// " ORDER BY id DESC" +const selectInviteEventsInRangeSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_syncapi_invite_event.target_user_id = @x2 " + + "and c.mx_syncapi_invite_event.id > @x3 " + + "and c.mx_syncapi_invite_event.id <= @x4 " + + "order by c.mx_syncapi_invite_event.id desc " + +// "SELECT MAX(id) FROM syncapi_invite_events" +const selectMaxInviteIDSQL = "" + + "select max(c.mx_syncapi_invite_event.id) from c where c._cn = @x1 " + +type inviteEventsStatements struct { + db *SyncServerDatasource + streamIDStatements *streamIDStatements + // insertInviteEventStmt *sql.Stmt + selectInviteEventsInRangeStmt string + deleteInviteEventStmt string + selectMaxInviteIDStmt string + tableName string +} + +func queryInviteEvent(s *inviteEventsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]InviteEventCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []InviteEventCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, err + } + return response, nil +} + +func queryInviteEventMaxNumber(s *inviteEventsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]InviteEventCosmosMaxNumber, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []InviteEventCosmosMaxNumber + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, nil + } + + return response, nil +} + +func getInviteEvent(s *inviteEventsStatements, ctx context.Context, pk string, docId string) (*InviteEventCosmosData, error) { + response := InviteEventCosmosData{} + err := cosmosdbapi.GetDocumentOrNil( + s.db.connection, + s.db.cosmosConfig, + ctx, + pk, + docId, + &response) + + if response.Id == "" { + return nil, cosmosdbutil.ErrNoRows + } + + return &response, err +} + +func setInviteEvent(s *inviteEventsStatements, ctx context.Context, invite InviteEventCosmosData) (*InviteEventCosmosData, error) { + var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(invite.Pk, invite.ETag) + var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + invite.Id, + &invite, + optionsReplace) + return &invite, ex +} + +func NewCosmosDBInvitesTable(db *SyncServerDatasource, streamID *streamIDStatements) (tables.Invites, error) { s := &inviteEventsStatements{ db: db, streamIDStatements: streamID, } - _, err := db.Exec(inviteEventsSchema) - if err != nil { - return nil, err - } - if s.insertInviteEventStmt, err = db.Prepare(insertInviteEventSQL); err != nil { - return nil, err - } - if s.selectInviteEventsInRangeStmt, err = db.Prepare(selectInviteEventsInRangeSQL); err != nil { - return nil, err - } - if s.deleteInviteEventStmt, err = db.Prepare(deleteInviteEventSQL); err != nil { - return nil, err - } - if s.selectMaxInviteIDStmt, err = db.Prepare(selectMaxInviteIDSQL); err != nil { - return nil, err - } + s.selectInviteEventsInRangeStmt = selectInviteEventsInRangeSQL + s.deleteInviteEventStmt = deleteInviteEventSQL + s.selectMaxInviteIDStmt = selectMaxInviteIDSQL + s.tableName = "invite_events" return s, nil } func (s *inviteEventsStatements) InsertInviteEvent( ctx context.Context, txn *sql.Tx, inviteEvent *gomatrixserverlib.HeaderedEvent, ) (streamPos types.StreamPosition, err error) { + + // "INSERT INTO syncapi_invite_events" + + // " (id, room_id, event_id, target_user_id, headered_event_json, deleted)" + + // " VALUES ($1, $2, $3, $4, $5, false)" + streamPos, err = s.streamIDStatements.nextInviteID(ctx, txn) if err != nil { return @@ -104,15 +203,45 @@ func (s *inviteEventsStatements) InsertInviteEvent( return } - stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt) - _, err = stmt.ExecContext( + // stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt) + // _, err = stmt.ExecContext( + // ctx, + // streamPos, + // inviteEvent.RoomID(), + // inviteEvent.EventID(), + // *inviteEvent.StateKey(), + // headeredJSON, + // ) + data := InviteEventCosmos{ + ID: int64(streamPos), + RoomID: inviteEvent.RoomID(), + EventID: inviteEvent.EventID(), + TargetUserID: *inviteEvent.StateKey(), + HeaderedEventJSON: headeredJSON, + } + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + // id INTEGER PRIMARY KEY, + docId := fmt.Sprintf("%d", streamPos) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + + var dbData = InviteEventCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + InviteEvent: data, + } + + var optionsCreate = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk) + _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( ctx, - streamPos, - inviteEvent.RoomID(), - inviteEvent.EventID(), - *inviteEvent.StateKey(), - headeredJSON, - ) + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + dbData, + optionsCreate) + return } @@ -123,8 +252,23 @@ func (s *inviteEventsStatements) DeleteInviteEvent( if err != nil { return streamPos, err } - stmt := sqlutil.TxStmt(txn, s.deleteInviteEventStmt) - _, err = stmt.ExecContext(ctx, streamPos, inviteEventID) + + // "UPDATE syncapi_invite_events SET deleted=true, id=$1 WHERE event_id = $2" + + // stmt := sqlutil.TxStmt(txn, s.deleteInviteEventStmt) + // _, err = stmt.ExecContext(ctx, streamPos, inviteEventID) + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": inviteEventID, + } + response, err := queryInviteEvent(s, ctx, s.deleteInviteEventStmt, params) + + for _, item := range response { + item.InviteEvent.Deleted = true + item.InviteEvent.ID = int64(streamPos) + setInviteEvent(s, ctx, item) + } return streamPos, err } @@ -133,23 +277,39 @@ func (s *inviteEventsStatements) DeleteInviteEvent( func (s *inviteEventsStatements) SelectInviteEventsInRange( ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range, ) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, error) { - stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt) - rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High()) + + // "SELECT room_id, headered_event_json, deleted FROM syncapi_invite_events" + + // " WHERE target_user_id = $1 AND id > $2 AND id <= $3" + + // " ORDER BY id DESC" + + // stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt) + // rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High()) + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": targetUserID, + "@x3": r.Low(), + "@x4": r.High(), + } + rows, err := queryInviteEvent(s, ctx, s.selectInviteEventsInRangeStmt, params) + if err != nil { return nil, nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "selectInviteEventsInRange: rows.close() failed") result := map[string]*gomatrixserverlib.HeaderedEvent{} retired := map[string]*gomatrixserverlib.HeaderedEvent{} - for rows.Next() { + for _, item := range rows { var ( roomID string eventJSON []byte deleted bool ) - if err = rows.Scan(&roomID, &eventJSON, &deleted); err != nil { - return nil, nil, err - } + roomID = item.InviteEvent.RoomID + eventJSON = item.InviteEvent.HeaderedEventJSON + deleted = item.InviteEvent.Deleted + // if err = rows.Scan(&roomID, &eventJSON, &deleted); err != nil { + // return nil, nil, err + // } // if we have seen this room before, it has a higher stream position and hence takes priority // because the query is ORDER BY id DESC so drop them @@ -176,8 +336,21 @@ func (s *inviteEventsStatements) SelectMaxInviteID( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { var nullableID sql.NullInt64 - stmt := sqlutil.TxStmt(txn, s.selectMaxInviteIDStmt) - err = stmt.QueryRowContext(ctx).Scan(&nullableID) + + // "SELECT MAX(id) FROM syncapi_invite_events" + + // stmt := sqlutil.TxStmt(txn, s.selectMaxInviteIDStmt) + // err = stmt.QueryRowContext(ctx).Scan(&nullableID) + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + } + response, err := queryInviteEventMaxNumber(s, ctx, s.selectMaxInviteIDStmt, params) + + if response != nil { + nullableID.Int64 = response[0].Max + } + if nullableID.Valid { id = nullableID.Int64 } diff --git a/syncapi/storage/cosmosdb/memberships_table.go b/syncapi/storage/cosmosdb/memberships_table.go index 9b660509b..104c3c365 100644 --- a/syncapi/storage/cosmosdb/memberships_table.go +++ b/syncapi/storage/cosmosdb/memberships_table.go @@ -18,9 +18,10 @@ import ( "context" "database/sql" "fmt" - "strings" + "time" + + "github.com/matrix-org/dendrite/internal/cosmosdbapi" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -32,53 +33,92 @@ import ( // a room, either by choice or otherwise. This is important for // building history visibility. -const membershipsSchema = ` -CREATE TABLE IF NOT EXISTS syncapi_memberships ( - -- The 'room_id' key for the state event. - room_id TEXT NOT NULL, - -- The state event ID - user_id TEXT NOT NULL, - -- The status of the membership - membership TEXT NOT NULL, - -- The event ID that last changed the membership - event_id TEXT NOT NULL, - -- The stream position of the change - stream_pos BIGINT NOT NULL, - -- The topological position of the change in the room - topological_pos BIGINT NOT NULL, - -- Unique index - UNIQUE (room_id, user_id, membership) -); -` +// const membershipsSchema = ` +// CREATE TABLE IF NOT EXISTS syncapi_memberships ( +// -- The 'room_id' key for the state event. +// room_id TEXT NOT NULL, +// -- The state event ID +// user_id TEXT NOT NULL, +// -- The status of the membership +// membership TEXT NOT NULL, +// -- The event ID that last changed the membership +// event_id TEXT NOT NULL, +// -- The stream position of the change +// stream_pos BIGINT NOT NULL, +// -- The topological position of the change in the room +// topological_pos BIGINT NOT NULL, +// -- Unique index +// UNIQUE (room_id, user_id, membership) +// ); +// ` -const upsertMembershipSQL = "" + - "INSERT INTO syncapi_memberships (room_id, user_id, membership, event_id, stream_pos, topological_pos)" + - " VALUES ($1, $2, $3, $4, $5, $6)" + - " ON CONFLICT (room_id, user_id, membership)" + - " DO UPDATE SET event_id = $4, stream_pos = $5, topological_pos = $6" - -const selectMembershipSQL = "" + - "SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" + - " WHERE room_id = $1 AND user_id = $2 AND membership IN ($3)" + - " ORDER BY stream_pos DESC" + - " LIMIT 1" - -type membershipsStatements struct { - db *sql.DB - upsertMembershipStmt *sql.Stmt +type MembershipCosmos struct { + RoomID string `json:"room_id"` + UserID string `json:"user_id"` + Membership string `json:"membership"` + EventID string `json:"event_id"` + StreamPos int64 `json:"stream_pos"` + TopologicalPos int64 `json:"topological_pos"` } -func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { - s := &membershipsStatements{ - db: db, - } - _, err := db.Exec(membershipsSchema) +type MembershipCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + Membership MembershipCosmos `json:"mx_syncapi_membership"` +} + +// const upsertMembershipSQL = "" + +// "INSERT INTO syncapi_memberships (room_id, user_id, membership, event_id, stream_pos, topological_pos)" + +// " VALUES ($1, $2, $3, $4, $5, $6)" + +// " ON CONFLICT (room_id, user_id, membership)" + +// " DO UPDATE SET event_id = $4, stream_pos = $5, topological_pos = $6" + +// "SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" + +// " WHERE room_id = $1 AND user_id = $2 AND membership IN ($3)" + +// " ORDER BY stream_pos DESC" + +// " LIMIT 1" +const selectMembershipSQL = "" + + "select top 1 * from c where c._cn = @x1 " + + "and c.mx_syncapi_membership.room_id = @x2 " + + "and c.mx_syncapi_membership.user_id = @x3 " + + "and ARRAY_CONTAINS(@x4, c.mx_syncapi_membership.membership) " + + "order by c.mx_syncapi_membership.stream_pos desc " + +type membershipsStatements struct { + db *SyncServerDatasource + // upsertMembershipStmt *sql.Stmt + tableName string +} + +func queryMembership(s *membershipsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]MembershipCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []MembershipCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + if err != nil { return nil, err } - if s.upsertMembershipStmt, err = db.Prepare(upsertMembershipSQL); err != nil { - return nil, err + return response, nil +} + +func NewCosmosDBMembershipsTable(db *SyncServerDatasource) (tables.Memberships, error) { + s := &membershipsStatements{ + db: db, } + s.tableName = "memberships" return s, nil } @@ -90,30 +130,86 @@ func (s *membershipsStatements) UpsertMembership( if err != nil { return fmt.Errorf("event.Membership: %w", err) } - _, err = sqlutil.TxStmt(txn, s.upsertMembershipStmt).ExecContext( + + // "INSERT INTO syncapi_memberships (room_id, user_id, membership, event_id, stream_pos, topological_pos)" + + // " VALUES ($1, $2, $3, $4, $5, $6)" + + // " ON CONFLICT (room_id, user_id, membership)" + + // " DO UPDATE SET event_id = $4, stream_pos = $5, topological_pos = $6" + + // _, err = sqlutil.TxStmt(txn, s.upsertMembershipStmt).ExecContext( + // ctx, + // event.RoomID(), + // *event.StateKey(), + // membership, + // event.EventID(), + // streamPos, + // topologicalPos, + // ) + + data := MembershipCosmos{ + RoomID: event.RoomID(), + UserID: *event.StateKey(), + Membership: membership, + EventID: event.EventID(), + StreamPos: int64(streamPos), + TopologicalPos: int64(topologicalPos), + } + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + // UNIQUE (room_id, user_id, membership) + docId := fmt.Sprintf("%s_%s_%s", event.RoomID(), *event.StateKey(), membership) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + + var dbData = MembershipCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + Membership: data, + } + + var optionsCreate = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk) + _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( ctx, - event.RoomID(), - *event.StateKey(), - membership, - event.EventID(), - streamPos, - topologicalPos, - ) + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + dbData, + optionsCreate) + return err } func (s *membershipsStatements) SelectMembership( ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string, ) (eventID string, streamPos, topologyPos types.StreamPosition, err error) { - params := []interface{}{roomID, userID} - for _, membership := range memberships { - params = append(params, membership) + // params := []interface{}{roomID, userID} + // for _, membership := range memberships { + // params = append(params, membership) + // } + + // "SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" + + // " WHERE room_id = $1 AND user_id = $2 AND membership IN ($3)" + + // " ORDER BY stream_pos DESC" + + // " LIMIT 1" + + // err = sqlutil.TxStmt(txn, stmt).QueryRowContext(ctx, params...).Scan(&eventID, &streamPos, &topologyPos) + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": roomID, + "@x3": userID, + "@x4": memberships, } - orig := strings.Replace(selectMembershipSQL, "($3)", sqlutil.QueryVariadicOffset(len(memberships), 2), 1) - stmt, err := s.db.Prepare(orig) - if err != nil { + // orig := strings.Replace(selectMembershipSQL, "@x4", cosmosdbutil.QueryVariadicOffset(len(memberships), 2), 1) + rows, err := queryMembership(s, ctx, selectMembershipSQL, params) + + if err != nil || len(rows) == 0 { return "", 0, 0, err } - err = sqlutil.TxStmt(txn, stmt).QueryRowContext(ctx, params...).Scan(&eventID, &streamPos, &topologyPos) + // err = sqlutil.TxStmt(txn, stmt).QueryRowContext(ctx, params...).Scan(&eventID, &streamPos, &topologyPos) + eventID = rows[0].Membership.EventID + streamPos = types.StreamPosition(rows[0].Membership.StreamPos) + topologyPos = types.StreamPosition(rows[0].Membership.TopologicalPos) return } diff --git a/syncapi/storage/cosmosdb/output_room_events_table.go b/syncapi/storage/cosmosdb/output_room_events_table.go index 7ce485d8f..2bd294bba 100644 --- a/syncapi/storage/cosmosdb/output_room_events_table.go +++ b/syncapi/storage/cosmosdb/output_room_events_table.go @@ -21,109 +21,222 @@ import ( "encoding/json" "fmt" "sort" + "time" + + "github.com/matrix-org/dendrite/internal/cosmosdbutil" + + "github.com/matrix-org/dendrite/internal/cosmosdbapi" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" ) -const outputRoomEventsSchema = ` --- Stores output room events received from the roomserver. -CREATE TABLE IF NOT EXISTS syncapi_output_room_events ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - event_id TEXT NOT NULL UNIQUE, - room_id TEXT NOT NULL, - headered_event_json TEXT NOT NULL, - type TEXT NOT NULL, - sender TEXT NOT NULL, - contains_url BOOL NOT NULL, - add_state_ids TEXT, -- JSON encoded string array - remove_state_ids TEXT, -- JSON encoded string array - session_id BIGINT, - transaction_id TEXT, - exclude_from_sync BOOL NOT NULL DEFAULT FALSE -); -` +// const outputRoomEventsSchema = ` +// -- Stores output room events received from the roomserver. +// CREATE TABLE IF NOT EXISTS syncapi_output_room_events ( +// id INTEGER PRIMARY KEY AUTOINCREMENT, +// event_id TEXT NOT NULL UNIQUE, +// room_id TEXT NOT NULL, +// headered_event_json TEXT NOT NULL, +// type TEXT NOT NULL, +// sender TEXT NOT NULL, +// contains_url BOOL NOT NULL, +// add_state_ids TEXT, -- JSON encoded string array +// remove_state_ids TEXT, -- JSON encoded string array +// session_id BIGINT, +// transaction_id TEXT, +// exclude_from_sync BOOL NOT NULL DEFAULT FALSE +// ); +// ` -const insertEventSQL = "" + - "INSERT INTO syncapi_output_room_events (" + - "id, room_id, event_id, headered_event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id, exclude_from_sync" + - ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) " + - "ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $13)" - -const selectEventsSQL = "" + - "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = $1" - -const selectRecentEventsSQL = "" + - "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + - " WHERE room_id = $1 AND id > $2 AND id <= $3" - // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters - -const selectRecentEventsForSyncSQL = "" + - "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + - " WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" - // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters - -const selectEarlyEventsSQL = "" + - "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + - " WHERE room_id = $1 AND id > $2 AND id <= $3" - // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters - -const selectMaxEventIDSQL = "" + - "SELECT MAX(id) FROM syncapi_output_room_events" - -const updateEventJSONSQL = "" + - "UPDATE syncapi_output_room_events SET headered_event_json=$1 WHERE event_id=$2" - -const selectStateInRangeSQL = "" + - "SELECT id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids" + - " FROM syncapi_output_room_events" + - " WHERE (id > $1 AND id <= $2)" + - " AND ((add_state_ids IS NOT NULL AND add_state_ids != '') OR (remove_state_ids IS NOT NULL AND remove_state_ids != ''))" - // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters - -const deleteEventsForRoomSQL = "" + - "DELETE FROM syncapi_output_room_events WHERE room_id = $1" - -type outputRoomEventsStatements struct { - db *sql.DB - streamIDStatements *streamIDStatements - insertEventStmt *sql.Stmt - selectEventsStmt *sql.Stmt - selectMaxEventIDStmt *sql.Stmt - updateEventJSONStmt *sql.Stmt - deleteEventsForRoomStmt *sql.Stmt +type OutputRoomEventCosmos struct { + ID int64 `json:"id"` + EventID string `json:"event_id"` + RoomID string `json:"room_id"` + HeaderedEventJSON []byte `json:"headered_event_json"` + Type string `json:"type"` + Sender string `json:"sender"` + ContainsUrl bool `json:"contains_url"` + AddStateIDs string `json:"add_state_ids"` + RemoveStateIDs string `json:"remove_state_ids"` + SessionID int64 `json:"session_id"` + TransactionID string `json:"transaction_id"` + ExcludeFromSync bool `json:"exclude_from_sync"` } -func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) { +type OutputRoomEventCosmosMaxNumber struct { + Max int64 `json:"number"` +} + +type OutputRoomEventCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + OutputRoomEvent OutputRoomEventCosmos `json:"mx_syncapi_output_room_event"` +} + +// const insertEventSQL = "" + +// "INSERT INTO syncapi_output_room_events (" + +// "id, room_id, event_id, headered_event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id, exclude_from_sync" + +// ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) " + +// "ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $13)" + +// "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = $1" +const selectEventsSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_syncapi_output_room_event.event_id = @x2 " + +// "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + +// " WHERE room_id = $1 AND id > $2 AND id <= $3" +// // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters +const selectRecentEventsSQL = "" + + "select top @x5 * from c where c._cn = @x1 " + + "and c.mx_syncapi_output_room_event.room_id = @x2 " + + "and c.mx_syncapi_output_room_event.id > @x3 " + + "and c.mx_syncapi_output_room_event.id <= @x4 " + +// "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + +// " WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" +// // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters +const selectRecentEventsForSyncSQL = "" + + "select top @x5 * from c where c._cn = @x1 " + + "and c.mx_syncapi_output_room_event.room_id = @x2 " + + "and c.mx_syncapi_output_room_event.id > @x3 " + + "and c.mx_syncapi_output_room_event.id <= @x4 " + + "and c.mx_syncapi_output_room_event.exclude_from_sync = false " + +// "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + +// " WHERE room_id = $1 AND id > $2 AND id <= $3" +// // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters +const selectEarlyEventsSQL = "" + + "select top @x5 * from c where c._cn = @x1 " + + "and c.mx_syncapi_output_room_event.room_id = @x2 " + + "and c.mx_syncapi_output_room_event.id > @x3 " + + "and c.mx_syncapi_output_room_event.id <= @x4 " + +// "SELECT MAX(id) FROM syncapi_output_room_events" +const selectMaxEventIDSQL = "" + + "select max(c.mx_syncapi_output_room_event.id) as number from c where c._cn = @x1 " + +// "UPDATE syncapi_output_room_events SET headered_event_json=$1 WHERE event_id=$2" +const updateEventJSONSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_syncapi_output_room_event.event_id = @x2 " + +// "SELECT id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids" + +// " FROM syncapi_output_room_events" + +// " WHERE (id > $1 AND id <= $2)" + +// " AND ((add_state_ids IS NOT NULL AND add_state_ids != '') OR (remove_state_ids IS NOT NULL AND remove_state_ids != ''))" +// // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters +const selectStateInRangeSQL = "" + + "select top @x4 * from c where c._cn = @x1 " + + "and c.mx_syncapi_output_room_event.id > @x2 " + + "and c.mx_syncapi_output_room_event.id <= @x3 " + + "and (c.mx_syncapi_output_room_event.add_state_ids != null or c.mx_syncapi_output_room_event.remove_state_ids != null) " + + // "DELETE FROM syncapi_output_room_events WHERE room_id = $1" +const deleteEventsForRoomSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_syncapi_output_room_event.room_id = @x2 " + +type outputRoomEventsStatements struct { + db *SyncServerDatasource + streamIDStatements *streamIDStatements + // insertEventStmt *sql.Stmt + selectEventsStmt string + selectMaxEventIDStmt string + updateEventJSONStmt string + deleteEventsForRoomStmt string + tableName string + jsonPropertyName string +} + +func queryOutputRoomEvent(s *outputRoomEventsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OutputRoomEventCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []OutputRoomEventCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, err + } + return response, nil +} + +func queryOutputRoomEventNumber(s *outputRoomEventsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OutputRoomEventCosmosMaxNumber, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []OutputRoomEventCosmosMaxNumber + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, cosmosdbutil.ErrNoRows + } + return response, nil +} + +func setOutputRoomEvent(s *outputRoomEventsStatements, ctx context.Context, outputRoomEvent OutputRoomEventCosmosData) (*OutputRoomEventCosmosData, error) { + var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(outputRoomEvent.Pk, outputRoomEvent.ETag) + var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + outputRoomEvent.Id, + &outputRoomEvent, + optionsReplace) + return &outputRoomEvent, ex +} + +func deleteOutputRoomEvent(s *outputRoomEventsStatements, ctx context.Context, dbData OutputRoomEventCosmosData) error { + var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk) + var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + dbData.Id, + options) + + if err != nil { + return err + } + return err +} + +func NewCosmosDBEventsTable(db *SyncServerDatasource, streamID *streamIDStatements) (tables.Events, error) { s := &outputRoomEventsStatements{ db: db, streamIDStatements: streamID, } - _, err := db.Exec(outputRoomEventsSchema) - if err != nil { - return nil, err - } - if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil { - return nil, err - } - if s.selectEventsStmt, err = db.Prepare(selectEventsSQL); err != nil { - return nil, err - } - if s.selectMaxEventIDStmt, err = db.Prepare(selectMaxEventIDSQL); err != nil { - return nil, err - } - if s.updateEventJSONStmt, err = db.Prepare(updateEventJSONSQL); err != nil { - return nil, err - } - if s.deleteEventsForRoomStmt, err = db.Prepare(deleteEventsForRoomSQL); err != nil { - return nil, err - } + s.selectEventsStmt = selectEventsSQL + s.selectMaxEventIDStmt = selectMaxEventIDSQL + s.updateEventJSONStmt = updateEventJSONSQL + s.deleteEventsForRoomStmt = deleteEventsForRoomSQL + s.tableName = "output_room_events" + s.jsonPropertyName = "mx_syncapi_output_room_event" return s, nil } @@ -132,7 +245,27 @@ func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event if err != nil { return err } - _, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID()) + + // "UPDATE syncapi_output_room_events SET headered_event_json=$1 WHERE event_id=$2" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": event.EventID(), + } + + // _, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID()) + rows, err := queryOutputRoomEvent(s, ctx, s.deleteEventsForRoomStmt, params) + if err != nil { + return err + } + + for _, item := range rows { + item.OutputRoomEvent.HeaderedEventJSON = headeredJSON + _, err = setOutputRoomEvent(s, ctx, item) + } + + return err return err } @@ -143,24 +276,31 @@ func (s *outputRoomEventsStatements) SelectStateInRange( ctx context.Context, txn *sql.Tx, r types.Range, stateFilter *gomatrixserverlib.StateFilter, ) (map[string]map[string]bool, map[string]types.StreamEvent, error) { - stmt, params, err := prepareWithFilters( - s.db, txn, selectStateInRangeSQL, - []interface{}{ - r.Low(), r.High(), - }, + // "SELECT id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids" + + // " FROM syncapi_output_room_events" + + // " WHERE (id > $1 AND id <= $2)" + + // " AND ((add_state_ids IS NOT NULL AND add_state_ids != '') OR (remove_state_ids IS NOT NULL AND remove_state_ids != ''))" + // // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": r.Low(), + "@x3": r.High(), + "@x4": stateFilter.Limit, + } + query, params := prepareWithFilters( + s.jsonPropertyName, selectStateInRangeSQL, params, stateFilter.Senders, stateFilter.NotSenders, stateFilter.Types, stateFilter.NotTypes, nil, stateFilter.Limit, FilterOrderAsc, ) - if err != nil { - return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err) - } - rows, err := stmt.QueryContext(ctx, params...) + // rows, err := stmt.QueryContext(ctx, params...) + rows, err := queryOutputRoomEvent(s, ctx, query, params) if err != nil { return nil, nil, err } - defer rows.Close() // nolint: errcheck // Fetch all the state change events for all rooms between the two positions then loop each event and: // - Keep a cache of the event by ID (99% of state change events are for the event itself) // - For each room ID, build up an array of event IDs which represents cumulative adds/removes @@ -171,7 +311,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange( // RoomID => A set (map[string]bool) of state event IDs which are between the two positions stateNeeded := make(map[string]map[string]bool) - for rows.Next() { + for _, item := range rows { var ( streamPos types.StreamPosition eventBytes []byte @@ -179,10 +319,15 @@ func (s *outputRoomEventsStatements) SelectStateInRange( addIDsJSON string delIDsJSON string ) - if err := rows.Scan(&streamPos, &eventBytes, &excludeFromSync, &addIDsJSON, &delIDsJSON); err != nil { - return nil, nil, err - } - + // SELECT id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids + // if err := rows.Scan(&streamPos, &eventBytes, &excludeFromSync, &addIDsJSON, &delIDsJSON); err != nil { + // return nil, nil, err + // } + streamPos = types.StreamPosition(item.OutputRoomEvent.ID) + eventBytes = item.OutputRoomEvent.HeaderedEventJSON + excludeFromSync = item.OutputRoomEvent.ExcludeFromSync + addIDsJSON = item.OutputRoomEvent.AddStateIDs + delIDsJSON = item.OutputRoomEvent.RemoveStateIDs addIDs, delIDs, err := unmarshalStateIDs(addIDsJSON, delIDsJSON) if err != nil { return nil, nil, err @@ -233,8 +378,20 @@ func (s *outputRoomEventsStatements) SelectMaxEventID( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { var nullableID sql.NullInt64 - stmt := sqlutil.TxStmt(txn, s.selectMaxEventIDStmt) - err = stmt.QueryRowContext(ctx).Scan(&nullableID) + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + } + // stmt := sqlutil.TxStmt(txn, s.selectMaxEventIDStmt) + + rows, err := queryOutputRoomEventNumber(s, ctx, s.selectMaxEventIDStmt, params) + // err = stmt.QueryRowContext(ctx).Scan(&nullableID) + + if rows != nil { + nullableID.Int64 = rows[0].Max + } + if nullableID.Valid { id = nullableID.Int64 } @@ -248,6 +405,7 @@ func (s *outputRoomEventsStatements) InsertEvent( event *gomatrixserverlib.HeaderedEvent, addState, removeState []string, transactionID *api.TransactionID, excludeFromSync bool, ) (types.StreamPosition, error) { + var txnID *string var sessionID *int64 if transactionID != nil { @@ -283,27 +441,74 @@ func (s *outputRoomEventsStatements) InsertEvent( return 0, fmt.Errorf("json.Marshal(removeState): %w", err) } + // id INTEGER PRIMARY KEY AUTOINCREMENT, streamPos, err := s.streamIDStatements.nextPDUID(ctx, txn) if err != nil { return 0, err } - insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) - _, err = insertStmt.ExecContext( + // "INSERT INTO syncapi_output_room_events (" + + // "id, room_id, event_id, headered_event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id, exclude_from_sync" + + // ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) " + + // "ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $13)" + + // insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) + // _, err = insertStmt.ExecContext( + // ctx, + // streamPos, + // event.RoomID(), + // event.EventID(), + // headeredJSON, + // event.Type(), + // event.Sender(), + // containsURL, + // string(addStateJSON), + // string(removeStateJSON), + // sessionID, + // txnID, + // excludeFromSync, + // excludeFromSync, + // ) + + data := OutputRoomEventCosmos{ + ID: int64(streamPos), + RoomID: event.RoomID(), + EventID: event.EventID(), + HeaderedEventJSON: headeredJSON, + Type: event.Type(), + Sender: event.Sender(), + ContainsUrl: containsURL, + AddStateIDs: string(addStateJSON), + RemoveStateIDs: string(removeStateJSON), + ExcludeFromSync: excludeFromSync, + } + + if transactionID != nil { + data.SessionID = *sessionID + data.TransactionID = *txnID + } + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + // id INTEGER PRIMARY KEY, + docId := fmt.Sprintf("%d", streamPos) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + + var dbData = OutputRoomEventCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + OutputRoomEvent: data, + } + + var optionsCreate = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk) + _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( ctx, - streamPos, - event.RoomID(), - event.EventID(), - headeredJSON, - event.Type(), - event.Sender(), - containsURL, - string(addStateJSON), - string(removeStateJSON), - sessionID, - txnID, - excludeFromSync, - excludeFromSync, - ) + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + dbData, + optionsCreate) + return streamPos, err } @@ -314,30 +519,39 @@ func (s *outputRoomEventsStatements) SelectRecentEvents( ) ([]types.StreamEvent, bool, error) { var query string if onlySyncEvents { + // "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + + // " WHERE room_id = $1 AND id > $2 AND id <= $3" + // // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters query = selectRecentEventsForSyncSQL } else { + // "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + + // " WHERE room_id = $1 AND id > $2 AND id <= $3" + query = selectRecentEventsSQL } - stmt, params, err := prepareWithFilters( - s.db, txn, query, - []interface{}{ - roomID, r.Low(), r.High(), - }, + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": roomID, + "@x3": r.Low(), + "@x4": r.High(), + "@x5": eventFilter.Limit + 1, + } + + query, params = prepareWithFilters( + s.jsonPropertyName, query, params, eventFilter.Senders, eventFilter.NotSenders, eventFilter.Types, eventFilter.NotTypes, nil, eventFilter.Limit+1, FilterOrderDesc, ) - if err != nil { - return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err) - } - rows, err := stmt.QueryContext(ctx, params...) + // rows, err := stmt.QueryContext(ctx, params...) + rows, err := queryOutputRoomEvent(s, ctx, query, params) + if err != nil { return nil, false, err } - defer internal.CloseAndLogIfError(ctx, rows, "selectRecentEvents: rows.close() failed") - events, err := rowsToStreamEvents(rows) + events, err := rowsToStreamEvents(&rows) if err != nil { return nil, false, err } @@ -367,24 +581,31 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents( ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, ) ([]types.StreamEvent, error) { - stmt, params, err := prepareWithFilters( - s.db, txn, selectEarlyEventsSQL, - []interface{}{ - roomID, r.Low(), r.High(), - }, + // "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + + // " WHERE room_id = $1 AND id > $2 AND id <= $3" + // // WHEN, ORDER BY (and not LIMIT) are appended by prepareWithFilters + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": roomID, + "@x3": r.Low(), + "@x4": r.High(), + "@x5": eventFilter.Limit, + } + stmt, params := prepareWithFilters( + s.jsonPropertyName, selectEarlyEventsSQL, params, eventFilter.Senders, eventFilter.NotSenders, eventFilter.Types, eventFilter.NotTypes, nil, eventFilter.Limit, FilterOrderAsc, ) - if err != nil { - return nil, fmt.Errorf("s.prepareWithFilters: %w", err) - } - rows, err := stmt.QueryContext(ctx, params...) + + // rows, err := stmt.QueryContext(ctx, params...) + rows, err := queryOutputRoomEvent(s, ctx, stmt, params) if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "selectEarlyEvents: rows.close() failed") - events, err := rowsToStreamEvents(rows) + events, err := rowsToStreamEvents(&rows) if err != nil { return nil, err } @@ -402,17 +623,27 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents( func (s *outputRoomEventsStatements) SelectEvents( ctx context.Context, txn *sql.Tx, eventIDs []string, ) ([]types.StreamEvent, error) { + // "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = $1" + var returnEvents []types.StreamEvent - stmt := sqlutil.TxStmt(txn, s.selectEventsStmt) + + // stmt := sqlutil.TxStmt(txn, s.selectEventsStmt) + for _, eventID := range eventIDs { - rows, err := stmt.QueryContext(ctx, eventID) + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": eventID, + } + + // rows, err := stmt.QueryContext(ctx, eventID) + rows, err := queryOutputRoomEvent(s, ctx, s.selectEventsStmt, params) if err != nil { return nil, err } - if streamEvents, err := rowsToStreamEvents(rows); err == nil { + if streamEvents, err := rowsToStreamEvents(&rows); err == nil { returnEvents = append(returnEvents, streamEvents...) } - internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed") } return returnEvents, nil } @@ -420,13 +651,30 @@ func (s *outputRoomEventsStatements) SelectEvents( func (s *outputRoomEventsStatements) DeleteEventsForRoom( ctx context.Context, txn *sql.Tx, roomID string, ) (err error) { - _, err = sqlutil.TxStmt(txn, s.deleteEventsForRoomStmt).ExecContext(ctx, roomID) + // "DELETE FROM syncapi_output_room_events WHERE room_id = $1" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": roomID, + } + + // _, err = sqlutil.TxStmt(txn, s.deleteEventsForRoomStmt).ExecContext(ctx, roomID) + rows, err := queryOutputRoomEvent(s, ctx, s.deleteEventsForRoomStmt, params) + if err != nil { + return err + } + + for _, item := range rows { + err = deleteOutputRoomEvent(s, ctx, item) + } + return err } -func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) { +func rowsToStreamEvents(rows *[]OutputRoomEventCosmosData) ([]types.StreamEvent, error) { var result []types.StreamEvent - for rows.Next() { + for _, item := range *rows { var ( eventID string streamPos types.StreamPosition @@ -436,9 +684,17 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) { txnID *string transactionID *api.TransactionID ) - if err := rows.Scan(&eventID, &streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID); err != nil { - return nil, err - } + // SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id + // if err := rows.Scan(&eventID, &streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID); err != nil { + // return nil, err + // } + eventID = item.OutputRoomEvent.EventID + streamPos = types.StreamPosition(item.OutputRoomEvent.ID) + eventBytes = item.OutputRoomEvent.HeaderedEventJSON + sessionID = &item.OutputRoomEvent.SessionID + excludeFromSync = item.OutputRoomEvent.ExcludeFromSync + txnID = &item.OutputRoomEvent.TransactionID + // TODO: Handle redacted events var ev gomatrixserverlib.HeaderedEvent if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil { diff --git a/syncapi/storage/cosmosdb/output_room_events_topology_table.go b/syncapi/storage/cosmosdb/output_room_events_topology_table.go index 1a52b76b8..571de39f1 100644 --- a/syncapi/storage/cosmosdb/output_room_events_topology_table.go +++ b/syncapi/storage/cosmosdb/output_room_events_topology_table.go @@ -17,93 +17,167 @@ package cosmosdb import ( "context" "database/sql" + "fmt" + "time" - "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/internal/cosmosdbapi" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" ) -const outputRoomEventsTopologySchema = ` --- Stores output room events received from the roomserver. -CREATE TABLE IF NOT EXISTS syncapi_output_room_events_topology ( - event_id TEXT PRIMARY KEY, - topological_position BIGINT NOT NULL, - stream_position BIGINT NOT NULL, - room_id TEXT NOT NULL, +// const outputRoomEventsTopologySchema = ` +// -- Stores output room events received from the roomserver. +// CREATE TABLE IF NOT EXISTS syncapi_output_room_events_topology ( +// event_id TEXT PRIMARY KEY, +// topological_position BIGINT NOT NULL, +// stream_position BIGINT NOT NULL, +// room_id TEXT NOT NULL, - UNIQUE(topological_position, room_id, stream_position) -); --- The topological order will be used in events selection and ordering --- CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_topological_position_idx ON syncapi_output_room_events_topology(topological_position, stream_position, room_id); -` +// UNIQUE(topological_position, room_id, stream_position) +// ); +// -- The topological order will be used in events selection and ordering +// -- CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_topological_position_idx ON syncapi_output_room_events_topology(topological_position, stream_position, room_id); +// ` -const insertEventInTopologySQL = "" + - "INSERT INTO syncapi_output_room_events_topology (event_id, topological_position, room_id, stream_position)" + - " VALUES ($1, $2, $3, $4)" + - " ON CONFLICT DO NOTHING" - -const selectEventIDsInRangeASCSQL = "" + - "SELECT event_id FROM syncapi_output_room_events_topology" + - " WHERE room_id = $1 AND (" + - "(topological_position > $2 AND topological_position < $3) OR" + - "(topological_position = $4 AND stream_position <= $5)" + - ") ORDER BY topological_position ASC, stream_position ASC LIMIT $6" - -const selectEventIDsInRangeDESCSQL = "" + - "SELECT event_id FROM syncapi_output_room_events_topology" + - " WHERE room_id = $1 AND (" + - "(topological_position > $2 AND topological_position < $3) OR" + - "(topological_position = $4 AND stream_position <= $5)" + - ") ORDER BY topological_position DESC, stream_position DESC LIMIT $6" - -const selectPositionInTopologySQL = "" + - "SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" + - " WHERE event_id = $1" - -const selectMaxPositionInTopologySQL = "" + - "SELECT MAX(topological_position), stream_position FROM syncapi_output_room_events_topology" + - " WHERE room_id = $1 ORDER BY stream_position DESC" - -const deleteTopologyForRoomSQL = "" + - "DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1" - -type outputRoomEventsTopologyStatements struct { - db *sql.DB - insertEventInTopologyStmt *sql.Stmt - selectEventIDsInRangeASCStmt *sql.Stmt - selectEventIDsInRangeDESCStmt *sql.Stmt - selectPositionInTopologyStmt *sql.Stmt - selectMaxPositionInTopologyStmt *sql.Stmt - deleteTopologyForRoomStmt *sql.Stmt +type OutputRoomEventTopologyCosmos struct { + EventID string `json:"event_id"` + TopologicalPosition int64 `json:"topological_position"` + StreamPosition int64 `json:"stream_position"` + RoomID string `json:"room_id"` } -func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { - s := &outputRoomEventsTopologyStatements{ - db: db, - } - _, err := db.Exec(outputRoomEventsTopologySchema) +type OutputRoomEventTopologyCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + OutputRoomEventTopology OutputRoomEventTopologyCosmos `json:"mx_syncapi_output_room_event_topology"` +} + +// const insertEventInTopologySQL = "" + +// "INSERT INTO syncapi_output_room_events_topology (event_id, topological_position, room_id, stream_position)" + +// " VALUES ($1, $2, $3, $4)" + +// " ON CONFLICT DO NOTHING" + +// "SELECT event_id FROM syncapi_output_room_events_topology" + +// " WHERE room_id = $1 AND (" + +// "(topological_position > $2 AND topological_position < $3) OR" + +// "(topological_position = $4 AND stream_position <= $5)" + +// ") ORDER BY topological_position ASC, stream_position ASC LIMIT $6" +const selectEventIDsInRangeASCSQL = "" + + "select top @x7 * from c where c._cn = @x1 " + + "and c.mx_syncapi_output_room_event_topology.room_id = @x2 " + + "and ( " + + "(c.mx_syncapi_output_room_event_topology.topological_position > @x3 and c.mx_syncapi_output_room_event_topology.topological_position < @x4) " + + "OR " + + "(c.mx_syncapi_output_room_event_topology.topological_position = @x5 and c.mx_syncapi_output_room_event_topology.stream_position < @x6) " + + ") " + + "order by c.mx_syncapi_output_room_event_topology.topological_position asc " + // ", c.mx_syncapi_output_room_event_topology.stream_position asc " + +// "SELECT event_id FROM syncapi_output_room_events_topology" + +// " WHERE room_id = $1 AND (" + +// "(topological_position > $2 AND topological_position < $3) OR" + +// "(topological_position = $4 AND stream_position <= $5)" + +// ") ORDER BY topological_position DESC, stream_position DESC LIMIT $6" +const selectEventIDsInRangeDESCSQL = "" + + "select top @x7 * from c where c._cn = @x1 " + + "and c.mx_syncapi_output_room_event_topology.room_id = @x2 " + + "and ( " + + "(c.mx_syncapi_output_room_event_topology.topological_position > @x3 and c.mx_syncapi_output_room_event_topology.topological_position < @x4) " + + "OR " + + "(c.mx_syncapi_output_room_event_topology.topological_position = @x5 and c.mx_syncapi_output_room_event_topology.stream_position < @x6) " + + ") " + + "order by c.mx_syncapi_output_room_event_topology.topological_position desc " + // ", c.mx_syncapi_output_room_event_topology.stream_position desc " + +// "SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" + +// " WHERE event_id = $1" +const selectPositionInTopologySQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_syncapi_output_room_event_topology.event_id = @x2 " + +// "SELECT MAX(topological_position), stream_position FROM syncapi_output_room_events_topology" + +// " WHERE room_id = $1 ORDER BY stream_position DESC" + +// "SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" + +// " WHERE topological_position=(" + +// "SELECT MAX(topological_position) FROM syncapi_output_room_events_topology WHERE room_id=$1" + +// ") ORDER BY stream_position DESC LIMIT 1" +const selectMaxPositionInTopologySQL = "" + + "select top 1 * from c where c._cn = @x1 " + + "and c.mx_syncapi_output_room_event_topology.topological_position = " + + "( " + + "select max(c.mx_syncapi_output_room_event_topology.topological_position) from c where c._cn = @x1 " + + "and c.mx_syncapi_output_room_event_topology.room_id = @x2" + + ") " + + "order by c.mx_syncapi_output_room_event_topology.stream_position desc " + +// "DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1" +const deleteTopologyForRoomSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_syncapi_output_room_event_topology.room_id = @x2 " + +type outputRoomEventsTopologyStatements struct { + db *SyncServerDatasource + // insertEventInTopologyStmt *sql.Stmt + selectEventIDsInRangeASCStmt string + selectEventIDsInRangeDESCStmt string + selectPositionInTopologyStmt string + selectMaxPositionInTopologyStmt string + deleteTopologyForRoomStmt string + tableName string +} + +func queryOutputRoomEventTopology(s *outputRoomEventsTopologyStatements, ctx context.Context, qry string, params map[string]interface{}) ([]OutputRoomEventTopologyCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []OutputRoomEventTopologyCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + if err != nil { return nil, err } - if s.insertEventInTopologyStmt, err = db.Prepare(insertEventInTopologySQL); err != nil { - return nil, err + return response, nil +} + +func deleteOutputRoomEventTopology(s *outputRoomEventsTopologyStatements, ctx context.Context, dbData OutputRoomEventTopologyCosmosData) error { + var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk) + var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + dbData.Id, + options) + + if err != nil { + return err } - if s.selectEventIDsInRangeASCStmt, err = db.Prepare(selectEventIDsInRangeASCSQL); err != nil { - return nil, err - } - if s.selectEventIDsInRangeDESCStmt, err = db.Prepare(selectEventIDsInRangeDESCSQL); err != nil { - return nil, err - } - if s.selectPositionInTopologyStmt, err = db.Prepare(selectPositionInTopologySQL); err != nil { - return nil, err - } - if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil { - return nil, err - } - if s.deleteTopologyForRoomStmt, err = db.Prepare(deleteTopologyForRoomSQL); err != nil { - return nil, err + return err +} + +func NewCosmosDBTopologyTable(db *SyncServerDatasource) (tables.Topology, error) { + s := &outputRoomEventsTopologyStatements{ + db: db, } + + s.selectEventIDsInRangeASCStmt = selectEventIDsInRangeASCSQL + s.selectEventIDsInRangeDESCStmt = selectEventIDsInRangeDESCSQL + s.selectPositionInTopologyStmt = selectPositionInTopologySQL + s.selectMaxPositionInTopologyStmt = selectMaxPositionInTopologySQL + s.deleteTopologyForRoomStmt = deleteTopologyForRoomSQL + s.tableName = "output_room_events_topology" return s, nil } @@ -112,9 +186,44 @@ func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { func (s *outputRoomEventsTopologyStatements) InsertEventInTopology( ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition, ) (types.StreamPosition, error) { - _, err := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt).ExecContext( - ctx, event.EventID(), event.Depth(), event.RoomID(), pos, - ) + + // "INSERT INTO syncapi_output_room_events_topology (event_id, topological_position, room_id, stream_position)" + + // " VALUES ($1, $2, $3, $4)" + + // " ON CONFLICT DO NOTHING" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // UNIQUE(topological_position, room_id, stream_position) + docId := fmt.Sprintf("%d_%s_%d", event.Depth(), event.RoomID(), pos) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + data := OutputRoomEventTopologyCosmos{ + EventID: event.EventID(), + TopologicalPosition: event.Depth(), + RoomID: event.RoomID(), + StreamPosition: int64(pos), + } + + dbData := &OutputRoomEventTopologyCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + OutputRoomEventTopology: data, + } + + // _, err := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt).ExecContext( + // ctx, event.EventID(), event.Depth(), event.RoomID(), pos, + // ) + + var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk) + _, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + &dbData, + options) + return types.StreamPosition(event.Depth()), err } @@ -125,15 +234,38 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange( ) (eventIDs []string, err error) { // Decide on the selection's order according to whether chronological order // is requested or not. - var stmt *sql.Stmt + var stmt string if chronologicalOrder { - stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeASCStmt) + // "SELECT event_id FROM syncapi_output_room_events_topology" + + // " WHERE room_id = $1 AND (" + + // "(topological_position > $2 AND topological_position < $3) OR" + + // "(topological_position = $4 AND stream_position <= $5)" + + // ") ORDER BY topological_position ASC, stream_position ASC LIMIT $6" + stmt = s.selectEventIDsInRangeASCStmt } else { - stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeDESCStmt) + // "SELECT event_id FROM syncapi_output_room_events_topology" + + // " WHERE room_id = $1 AND (" + + // "(topological_position > $2 AND topological_position < $3) OR" + + // "(topological_position = $4 AND stream_position <= $5)" + + // ") ORDER BY topological_position DESC, stream_position DESC LIMIT $6" + stmt = s.selectEventIDsInRangeDESCStmt } // Query the event IDs. - rows, err := stmt.QueryContext(ctx, roomID, minDepth, maxDepth, maxDepth, maxStreamPos, limit) + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": roomID, + "@x3": minDepth, + "@x4": maxDepth, + "@x5": maxDepth, + "@x6": maxStreamPos, + "@x7": limit, + } + + rows, err := queryOutputRoomEventTopology(s, ctx, stmt, params) + // rows, err := stmt.QueryContext(ctx, roomID, minDepth, maxDepth, maxDepth, maxStreamPos, limit) + if err == sql.ErrNoRows { // If no event matched the request, return an empty slice. return []string{}, nil @@ -143,10 +275,11 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange( // Return the IDs. var eventID string - for rows.Next() { - if err = rows.Scan(&eventID); err != nil { - return - } + for _, item := range rows { + // if err = rows.Scan(&eventID); err != nil { + // return + // } + eventID = item.OutputRoomEventTopology.EventID eventIDs = append(eventIDs, eventID) } @@ -158,22 +291,89 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange( func (s *outputRoomEventsTopologyStatements) SelectPositionInTopology( ctx context.Context, txn *sql.Tx, eventID string, ) (pos types.StreamPosition, spos types.StreamPosition, err error) { - stmt := sqlutil.TxStmt(txn, s.selectPositionInTopologyStmt) - err = stmt.QueryRowContext(ctx, eventID).Scan(&pos, &spos) + + // "SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" + + // " WHERE event_id = $1" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": eventID, + } + + rows, err := queryOutputRoomEventTopology(s, ctx, s.selectPositionInTopologyStmt, params) + // stmt := sqlutil.TxStmt(txn, s.selectPositionInTopologyStmt) + + if err != nil { + return + } + + if len(rows) == 0 { + return + } + + // err = stmt.QueryRowContext(ctx, eventID).Scan(&pos, &spos) + pos = types.StreamPosition(rows[0].OutputRoomEventTopology.TopologicalPosition) + spos = types.StreamPosition(rows[0].OutputRoomEventTopology.StreamPosition) + return } func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology( ctx context.Context, txn *sql.Tx, roomID string, ) (pos types.StreamPosition, spos types.StreamPosition, err error) { - stmt := sqlutil.TxStmt(txn, s.selectMaxPositionInTopologyStmt) - err = stmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos) + + // "SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" + + // " WHERE topological_position=(" + + // "SELECT MAX(topological_position) FROM syncapi_output_room_events_topology WHERE room_id=$1" + + // ") ORDER BY stream_position DESC LIMIT 1" + + // stmt := sqlutil.TxStmt(txn, s.selectMaxPositionInTopologyStmt) + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": roomID, + } + + rows, err := queryOutputRoomEventTopology(s, ctx, s.selectMaxPositionInTopologyStmt, params) + // err = stmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos) + + if err != nil { + return + } + + if len(rows) == 0 { + return + } + + pos = types.StreamPosition(rows[0].OutputRoomEventTopology.TopologicalPosition) + spos = types.StreamPosition(rows[0].OutputRoomEventTopology.StreamPosition) return } func (s *outputRoomEventsTopologyStatements) DeleteTopologyForRoom( ctx context.Context, txn *sql.Tx, roomID string, ) (err error) { - _, err = sqlutil.TxStmt(txn, s.deleteTopologyForRoomStmt).ExecContext(ctx, roomID) + + // "DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1" + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": roomID, + } + + rows, err := queryOutputRoomEventTopology(s, ctx, s.deleteTopologyForRoomStmt, params) + // _, err = sqlutil.TxStmt(txn, s.deleteTopologyForRoomStmt).ExecContext(ctx, roomID) + + if err != nil { + return + } + + for _, item := range rows { + err = deleteOutputRoomEventTopology(s, ctx, item) + if err != nil { + return + } + } return err } diff --git a/syncapi/storage/cosmosdb/peeks_table.go b/syncapi/storage/cosmosdb/peeks_table.go index e6d9b8a3c..95c661696 100644 --- a/syncapi/storage/cosmosdb/peeks_table.go +++ b/syncapi/storage/cosmosdb/peeks_table.go @@ -17,91 +17,175 @@ package cosmosdb import ( "context" "database/sql" + "fmt" "time" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/internal/cosmosdbutil" + + "github.com/matrix-org/dendrite/internal/cosmosdbapi" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" ) -const peeksSchema = ` -CREATE TABLE IF NOT EXISTS syncapi_peeks ( - id INTEGER, - room_id TEXT NOT NULL, - user_id TEXT NOT NULL, - device_id TEXT NOT NULL, - deleted BOOL NOT NULL DEFAULT false, - -- When the peek was created in UNIX epoch ms. - creation_ts INTEGER NOT NULL, - UNIQUE(room_id, user_id, device_id) -); +// const peeksSchema = ` +// CREATE TABLE IF NOT EXISTS syncapi_peeks ( +// id INTEGER, +// room_id TEXT NOT NULL, +// user_id TEXT NOT NULL, +// device_id TEXT NOT NULL, +// deleted BOOL NOT NULL DEFAULT false, +// -- When the peek was created in UNIX epoch ms. +// creation_ts INTEGER NOT NULL, +// UNIQUE(room_id, user_id, device_id) +// ); -CREATE INDEX IF NOT EXISTS syncapi_peeks_room_id_idx ON syncapi_peeks(room_id); -CREATE INDEX IF NOT EXISTS syncapi_peeks_user_id_device_id_idx ON syncapi_peeks(user_id, device_id); -` +// CREATE INDEX IF NOT EXISTS syncapi_peeks_room_id_idx ON syncapi_peeks(room_id); +// CREATE INDEX IF NOT EXISTS syncapi_peeks_user_id_device_id_idx ON syncapi_peeks(user_id, device_id); +// ` -const insertPeekSQL = "" + - "INSERT OR REPLACE INTO syncapi_peeks" + - " (id, room_id, user_id, device_id, creation_ts, deleted)" + - " VALUES ($1, $2, $3, $4, $5, false)" +type PeekCosmos struct { + ID int64 `json:"id"` + RoomID string `json:"room_id"` + UserID string `json:"user_id"` + DeviceID string `json:"device_id"` + Deleted bool `json:"deleted"` + // Use the CosmosDB.Timestamp for this one + // creation_ts int64 `json:"creation_ts"` +} +type PeekCosmosMaxNumber struct { + Max int64 `json:"number"` +} + +type PeekCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + Peek PeekCosmos `json:"mx_syncapi_peek"` +} + +// const insertPeekSQL = "" + +// "INSERT OR REPLACE INTO syncapi_peeks" + +// " (id, room_id, user_id, device_id, creation_ts, deleted)" + +// " VALUES ($1, $2, $3, $4, $5, false)" + +// "UPDATE syncapi_peeks SET deleted=true, id=$1 WHERE room_id = $2 AND user_id = $3 AND device_id = $4" const deletePeekSQL = "" + - "UPDATE syncapi_peeks SET deleted=true, id=$1 WHERE room_id = $2 AND user_id = $3 AND device_id = $4" + "select * from c where c._cn = @x1 " + + "and c.mx_syncapi_peek.room_id = @x2 " + + "and c.mx_syncapi_peek.user_id = @x3 " + + "and c.mx_syncapi_peek.device_id = @x4 " +// "UPDATE syncapi_peeks SET deleted=true, id=$1 WHERE room_id = $2 AND user_id = $3" const deletePeeksSQL = "" + - "UPDATE syncapi_peeks SET deleted=true, id=$1 WHERE room_id = $2 AND user_id = $3" + "select * from c where c._cn = @x1 " + + "and c.mx_syncapi_peek.room_id = @x2 " + + "and c.mx_syncapi_peek.user_id = @x3 " // we care about all the peeks which were created in this range, deleted in this range, // or were created before this range but haven't been deleted yet. // BEWARE: sqlite chokes on out of order substitution strings. + +// "SELECT id, room_id, deleted FROM syncapi_peeks WHERE user_id = $1 AND device_id = $2 AND ((id <= $3 AND NOT deleted=true) OR (id > $3 AND id <= $4))" const selectPeeksInRangeSQL = "" + - "SELECT id, room_id, deleted FROM syncapi_peeks WHERE user_id = $1 AND device_id = $2 AND ((id <= $3 AND NOT deleted=true) OR (id > $3 AND id <= $4))" + "select * from c where c._cn = @x1 " + + "and c.mx_syncapi_peek.user_id = @x2 " + + "and c.mx_syncapi_peek.device_id = @x3 " + + "and ( " + + "(c.mx_syncapi_peek.id <= @x4 and c.mx_syncapi_peek.deleted = false)" + + "or " + + "(c.mx_syncapi_peek.id > @x4 and c.mx_syncapi_peek.id <= @x5)" + + ") " +// "SELECT room_id, user_id, device_id FROM syncapi_peeks WHERE deleted=false" const selectPeekingDevicesSQL = "" + - "SELECT room_id, user_id, device_id FROM syncapi_peeks WHERE deleted=false" + "select * from c where c._cn = @x1 " + + "and c.mx_syncapi_peek.deleted = false " +// "SELECT MAX(id) FROM syncapi_peeks" const selectMaxPeekIDSQL = "" + - "SELECT MAX(id) FROM syncapi_peeks" + "select max(c.mx_syncapi_peek.id) from c where c._cn = @x1 " type peekStatements struct { - db *sql.DB - streamIDStatements *streamIDStatements - insertPeekStmt *sql.Stmt - deletePeekStmt *sql.Stmt - deletePeeksStmt *sql.Stmt - selectPeeksInRangeStmt *sql.Stmt - selectPeekingDevicesStmt *sql.Stmt - selectMaxPeekIDStmt *sql.Stmt + db *SyncServerDatasource + streamIDStatements *streamIDStatements + // insertPeekStmt *sql.Stmt + deletePeekStmt string + deletePeeksStmt string + selectPeeksInRangeStmt string + selectPeekingDevicesStmt string + selectMaxPeekIDStmt string + tableName string } -func NewSqlitePeeksTable(db *sql.DB, streamID *streamIDStatements) (tables.Peeks, error) { - _, err := db.Exec(peeksSchema) +func queryPeek(s *peekStatements, ctx context.Context, qry string, params map[string]interface{}) ([]PeekCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []PeekCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + if err != nil { return nil, err } + return response, nil +} + +func queryPeekMaxNumber(s *peekStatements, ctx context.Context, qry string, params map[string]interface{}) ([]PeekCosmosMaxNumber, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []PeekCosmosMaxNumber + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, nil + } + return response, nil +} + +func setPeek(s *peekStatements, ctx context.Context, peek PeekCosmosData) (*PeekCosmosData, error) { + var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(peek.Pk, peek.ETag) + var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + peek.Id, + &peek, + optionsReplace) + return &peek, ex +} + +func NewCosmosDBPeeksTable(db *SyncServerDatasource, streamID *streamIDStatements) (tables.Peeks, error) { s := &peekStatements{ db: db, streamIDStatements: streamID, } - if s.insertPeekStmt, err = db.Prepare(insertPeekSQL); err != nil { - return nil, err - } - if s.deletePeekStmt, err = db.Prepare(deletePeekSQL); err != nil { - return nil, err - } - if s.deletePeeksStmt, err = db.Prepare(deletePeeksSQL); err != nil { - return nil, err - } - if s.selectPeeksInRangeStmt, err = db.Prepare(selectPeeksInRangeSQL); err != nil { - return nil, err - } - if s.selectPeekingDevicesStmt, err = db.Prepare(selectPeekingDevicesSQL); err != nil { - return nil, err - } - if s.selectMaxPeekIDStmt, err = db.Prepare(selectMaxPeekIDSQL); err != nil { - return nil, err - } + + s.deletePeekStmt = deletePeekSQL + s.deletePeeksStmt = deletePeeksSQL + s.selectPeeksInRangeStmt = selectPeeksInRangeSQL + s.selectPeekingDevicesStmt = selectPeekingDevicesSQL + s.selectMaxPeekIDStmt = selectMaxPeekIDSQL + s.tableName = "peeks" return s, nil } @@ -112,39 +196,120 @@ func (s *peekStatements) InsertPeek( if err != nil { return } - nowMilli := time.Now().UnixNano() / int64(time.Millisecond) - _, err = sqlutil.TxStmt(txn, s.insertPeekStmt).ExecContext(ctx, streamPos, roomID, userID, deviceID, nowMilli) + + // "INSERT OR REPLACE INTO syncapi_peeks" + + // " (id, room_id, user_id, device_id, creation_ts, deleted)" + + // " VALUES ($1, $2, $3, $4, $5, false)" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // UNIQUE(room_id, user_id, device_id) + docId := fmt.Sprintf("%d_%s_%d", roomID, userID, deviceID) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + data := PeekCosmos{ + ID: int64(streamPos), + RoomID: roomID, + UserID: userID, + DeviceID: deviceID, + } + + dbData := &PeekCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + // nowMilli := time.Now().UnixNano() / int64(time.Millisecond) + Timestamp: time.Now().Unix(), + Peek: data, + } + + // _, err = sqlutil.TxStmt(txn, s.insertPeekStmt).ExecContext(ctx, streamPos, roomID, userID, deviceID, nowMilli) + + var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk) + _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + &dbData, + options) + return } func (s *peekStatements) DeletePeek( ctx context.Context, txn *sql.Tx, roomID, userID, deviceID string, ) (streamPos types.StreamPosition, err error) { + + // "UPDATE syncapi_peeks SET deleted=true, id=$1 WHERE room_id = $2 AND user_id = $3 AND device_id = $4" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": roomID, + "@x3": userID, + "@x4": deviceID, + } + + rows, err := queryPeek(s, ctx, s.deletePeekStmt, params) + // _, err = sqlutil.TxStmt(txn, s.deletePeekStmt).ExecContext(ctx, streamPos, roomID, userID, deviceID) + + numAffected := len(rows) + if numAffected == 0 { + return 0, cosmosdbutil.ErrNoRows + } + + // Only create a new ID if there are rows to mark as deleted. This is handled in an SQL TX for DBs streamPos, err = s.streamIDStatements.nextPDUID(ctx, txn) if err != nil { - return + return 0, err + } + + for _, item := range rows { + item.Peek.Deleted = true + item.Peek.ID = int64(streamPos) + _, err = setPeek(s, ctx, item) + if err != nil { + return + } } - _, err = sqlutil.TxStmt(txn, s.deletePeekStmt).ExecContext(ctx, streamPos, roomID, userID, deviceID) return } func (s *peekStatements) DeletePeeks( ctx context.Context, txn *sql.Tx, roomID, userID string, ) (types.StreamPosition, error) { + // "UPDATE syncapi_peeks SET deleted=true, id=$1 WHERE room_id = $2 AND user_id = $3" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": roomID, + "@x3": userID, + } + + rows, err := queryPeek(s, ctx, s.deletePeekStmt, params) + // result, err := sqlutil.TxStmt(txn, s.deletePeeksStmt).ExecContext(ctx, streamPos, roomID, userID) + if err != nil { + return 0, err + } + numAffected := len(rows) + if numAffected == 0 { + return 0, cosmosdbutil.ErrNoRows + } + + // Only create a new ID if there are rows to mark as deleted. This is handled in an SQL TX for DBs streamPos, err := s.streamIDStatements.nextPDUID(ctx, txn) if err != nil { return 0, err } - result, err := sqlutil.TxStmt(txn, s.deletePeeksStmt).ExecContext(ctx, streamPos, roomID, userID) - if err != nil { - return 0, err - } - numAffected, err := result.RowsAffected() - if err != nil { - return 0, err - } - if numAffected == 0 { - return 0, sql.ErrNoRows + + for _, item := range rows { + item.Peek.Deleted = true + item.Peek.ID = int64(streamPos) + _, err = setPeek(s, ctx, item) + if err != nil { + return 0, err + } } return streamPos, nil } @@ -152,40 +317,65 @@ func (s *peekStatements) DeletePeeks( func (s *peekStatements) SelectPeeksInRange( ctx context.Context, txn *sql.Tx, userID, deviceID string, r types.Range, ) (peeks []types.Peek, err error) { - rows, err := sqlutil.TxStmt(txn, s.selectPeeksInRangeStmt).QueryContext(ctx, userID, deviceID, r.Low(), r.High()) + // "SELECT id, room_id, deleted FROM syncapi_peeks WHERE user_id = $1 AND device_id = $2 AND ((id <= $3 AND NOT deleted=true) OR (id > $3 AND id <= $4))" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": userID, + "@x3": deviceID, + "@x4": r.Low(), + "@x5": r.High(), + } + + rows, err := queryPeek(s, ctx, s.selectPeeksInRangeStmt, params) + // rows, err := sqlutil.TxStmt(txn, s.selectPeeksInRangeStmt).QueryContext(ctx, userID, deviceID, r.Low(), r.High()) if err != nil { return } - defer internal.CloseAndLogIfError(ctx, rows, "SelectPeeksInRange: rows.close() failed") - for rows.Next() { + for _, item := range rows { peek := types.Peek{} var id types.StreamPosition - if err = rows.Scan(&id, &peek.RoomID, &peek.Deleted); err != nil { - return - } + // if err = rows.Scan(&id, &peek.RoomID, &peek.Deleted); err != nil { + // return + // } + id = types.StreamPosition(item.Peek.ID) + peek.RoomID = item.Peek.RoomID + peek.Deleted = item.Peek.Deleted peek.New = (id > r.Low() && id <= r.High()) && !peek.Deleted peeks = append(peeks, peek) } - return peeks, rows.Err() + return peeks, nil } func (s *peekStatements) SelectPeekingDevices( ctx context.Context, ) (peekingDevices map[string][]types.PeekingDevice, err error) { - rows, err := s.selectPeekingDevicesStmt.QueryContext(ctx) + + // "SELECT room_id, user_id, device_id FROM syncapi_peeks WHERE deleted=false" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + } + + rows, err := queryPeek(s, ctx, s.selectPeekingDevicesStmt, params) + // rows, err := s.selectPeekingDevicesStmt.QueryContext(ctx) if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "SelectPeekingDevices: rows.close() failed") result := make(map[string][]types.PeekingDevice) - for rows.Next() { + for _, item := range rows { var roomID, userID, deviceID string - if err := rows.Scan(&roomID, &userID, &deviceID); err != nil { - return nil, err - } + // if err := rows.Scan(&roomID, &userID, &deviceID); err != nil { + // return nil, err + // } + roomID = item.Peek.RoomID + userID = item.Peek.UserID + deviceID = item.Peek.DeviceID devices := result[roomID] devices = append(devices, types.PeekingDevice{UserID: userID, DeviceID: deviceID}) result[roomID] = devices @@ -196,9 +386,22 @@ func (s *peekStatements) SelectPeekingDevices( func (s *peekStatements) SelectMaxPeekID( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { + // "SELECT MAX(id) FROM syncapi_peeks" + + // stmt := sqlutil.TxStmt(txn, s.selectMaxPeekIDStmt) var nullableID sql.NullInt64 - stmt := sqlutil.TxStmt(txn, s.selectMaxPeekIDStmt) - err = stmt.QueryRowContext(ctx).Scan(&nullableID) + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + } + + rows, err := queryPeekMaxNumber(s, ctx, s.selectMaxPeekIDStmt, params) + // err = stmt.QueryRowContext(ctx).Scan(&nullableID) + + if rows != nil { + nullableID.Int64 = rows[0].Max + } + if nullableID.Valid { id = nullableID.Int64 } diff --git a/syncapi/storage/cosmosdb/receipt_table.go b/syncapi/storage/cosmosdb/receipt_table.go index de3983c5b..5fb63f2dc 100644 --- a/syncapi/storage/cosmosdb/receipt_table.go +++ b/syncapi/storage/cosmosdb/receipt_table.go @@ -18,72 +18,129 @@ import ( "context" "database/sql" "fmt" - "strings" + "time" "github.com/matrix-org/dendrite/eduserver/api" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/internal/cosmosdbapi" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" ) -const receiptsSchema = ` --- Stores data about receipts -CREATE TABLE IF NOT EXISTS syncapi_receipts ( - -- The ID - id BIGINT, - room_id TEXT NOT NULL, - receipt_type TEXT NOT NULL, - user_id TEXT NOT NULL, - event_id TEXT NOT NULL, - receipt_ts BIGINT NOT NULL, - CONSTRAINT syncapi_receipts_unique UNIQUE (room_id, receipt_type, user_id) -); -CREATE INDEX IF NOT EXISTS syncapi_receipts_room_id_idx ON syncapi_receipts(room_id); -` +// const receiptsSchema = ` +// -- Stores data about receipts +// CREATE TABLE IF NOT EXISTS syncapi_receipts ( +// -- The ID +// id BIGINT, +// room_id TEXT NOT NULL, +// receipt_type TEXT NOT NULL, +// user_id TEXT NOT NULL, +// event_id TEXT NOT NULL, +// receipt_ts BIGINT NOT NULL, +// CONSTRAINT syncapi_receipts_unique UNIQUE (room_id, receipt_type, user_id) +// ); +// CREATE INDEX IF NOT EXISTS syncapi_receipts_room_id_idx ON syncapi_receipts(room_id); +// ` -const upsertReceipt = "" + - "INSERT INTO syncapi_receipts" + - " (id, room_id, receipt_type, user_id, event_id, receipt_ts)" + - " VALUES ($1, $2, $3, $4, $5, $6)" + - " ON CONFLICT (room_id, receipt_type, user_id)" + - " DO UPDATE SET id = $7, event_id = $8, receipt_ts = $9" - -const selectRoomReceipts = "" + - "SELECT id, room_id, receipt_type, user_id, event_id, receipt_ts" + - " FROM syncapi_receipts" + - " WHERE id > $1 and room_id in ($2)" - -const selectMaxReceiptIDSQL = "" + - "SELECT MAX(id) FROM syncapi_receipts" - -type receiptStatements struct { - db *sql.DB - streamIDStatements *streamIDStatements - upsertReceipt *sql.Stmt - selectRoomReceipts *sql.Stmt - selectMaxReceiptID *sql.Stmt +type ReceiptCosmos struct { + ID int64 `json:"id"` + RoomID string `json:"room_id"` + ReceiptType string `json:"receipt_type"` + UserID string `json:"user_id"` + EventID string `json:"event_id"` + ReceiptTS int64 `json:"receipt_ts"` } -func NewSqliteReceiptsTable(db *sql.DB, streamID *streamIDStatements) (tables.Receipts, error) { - _, err := db.Exec(receiptsSchema) +type ReceiptCosmosMaxNumber struct { + Max int64 `json:"number"` +} + +type ReceiptCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + Receipt ReceiptCosmos `json:"mx_syncapi_receipt"` +} + +// const upsertReceipt = "" + +// "INSERT INTO syncapi_receipts" + +// " (id, room_id, receipt_type, user_id, event_id, receipt_ts)" + +// " VALUES ($1, $2, $3, $4, $5, $6)" + +// " ON CONFLICT (room_id, receipt_type, user_id)" + +// " DO UPDATE SET id = $7, event_id = $8, receipt_ts = $9" + +// "SELECT id, room_id, receipt_type, user_id, event_id, receipt_ts" + +// " FROM syncapi_receipts" + +// " WHERE id > $1 and room_id in ($2)" +const selectRoomReceipts = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_syncapi_receipt.id > @x2 " + + "and ARRAY_CONTAINS(@x3, c.mx_syncapi_receipt.room_id)" + +// "SELECT MAX(id) FROM syncapi_receipts" +const selectMaxReceiptIDSQL = "" + + "select max(c.mx_syncapi_receipt.id) as number from c where c._cn = @x1 " + +type receiptStatements struct { + db *SyncServerDatasource + streamIDStatements *streamIDStatements + // upsertReceipt *sql.Stmt + // selectRoomReceipts *sql.Stmt + selectMaxReceiptID string + tableName string +} + +func queryReceipt(s *receiptStatements, ctx context.Context, qry string, params map[string]interface{}) ([]ReceiptCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []ReceiptCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + if err != nil { return nil, err } + return response, nil +} + +func queryReceiptNumber(s *receiptStatements, ctx context.Context, qry string, params map[string]interface{}) ([]ReceiptCosmosMaxNumber, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []ReceiptCosmosMaxNumber + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, nil + } + return response, nil +} + +func NewCosmosDBReceiptsTable(db *SyncServerDatasource, streamID *streamIDStatements) (tables.Receipts, error) { r := &receiptStatements{ db: db, streamIDStatements: streamID, } - if r.upsertReceipt, err = db.Prepare(upsertReceipt); err != nil { - return nil, fmt.Errorf("unable to prepare upsertReceipt statement: %w", err) - } - if r.selectRoomReceipts, err = db.Prepare(selectRoomReceipts); err != nil { - return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err) - } - if r.selectMaxReceiptID, err = db.Prepare(selectMaxReceiptIDSQL); err != nil { - return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err) - } + r.selectMaxReceiptID = selectMaxReceiptIDSQL + r.tableName = "receipts" return r, nil } @@ -93,47 +150,115 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room if err != nil { return } - stmt := sqlutil.TxStmt(txn, r.upsertReceipt) - _, err = stmt.ExecContext(ctx, pos, roomId, receiptType, userId, eventId, timestamp, pos, eventId, timestamp) + + // "INSERT INTO syncapi_receipts" + + // " (id, room_id, receipt_type, user_id, event_id, receipt_ts)" + + // " VALUES ($1, $2, $3, $4, $5, $6)" + + // " ON CONFLICT (room_id, receipt_type, user_id)" + + // " DO UPDATE SET id = $7, event_id = $8, receipt_ts = $9" + + data := ReceiptCosmos{ + ID: int64(pos), + RoomID: roomId, + ReceiptType: receiptType, + UserID: userId, + EventID: eventId, + ReceiptTS: int64(timestamp), + } + + var dbCollectionName = cosmosdbapi.GetCollectionName(r.db.databaseName, r.tableName) + var pk = cosmosdbapi.GetPartitionKey(r.db.cosmosConfig.ContainerName, dbCollectionName) + // CONSTRAINT syncapi_receipts_unique UNIQUE (room_id, receipt_type, user_id) + docId := fmt.Sprintf("%s_%s_%s", roomId, receiptType, userId) + cosmosDocId := cosmosdbapi.GetDocumentId(r.db.cosmosConfig.ContainerName, dbCollectionName, docId) + + var dbData = ReceiptCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + Receipt: data, + } + + var optionsCreate = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk) + _, _, err = cosmosdbapi.GetClient(r.db.connection).CreateDocument( + ctx, + r.db.cosmosConfig.DatabaseName, + r.db.cosmosConfig.ContainerName, + dbData, + optionsCreate) + + // _, err = stmt.ExecContext(ctx, pos, roomId, receiptType, userId, eventId, timestamp, pos, eventId, timestamp) return } // SelectRoomReceiptsAfter select all receipts for a given room after a specific timestamp func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []api.OutputReceiptEvent, error) { - selectSQL := strings.Replace(selectRoomReceipts, "($2)", sqlutil.QueryVariadicOffset(len(roomIDs), 1), 1) + // "SELECT id, room_id, receipt_type, user_id, event_id, receipt_ts" + + // " FROM syncapi_receipts" + + // " WHERE id > $1 and room_id in ($2)" + + // selectSQL := strings.Replace(selectRoomReceipts, "($2)", sqlutil.QueryVariadicOffset(len(roomIDs), 1), 1) lastPos := streamPos - params := make([]interface{}, len(roomIDs)+1) - params[0] = streamPos - for k, v := range roomIDs { - params[k+1] = v + // params := make([]interface{}, len(roomIDs)+1) + // params[0] = streamPos + // for k, v := range roomIDs { + // params[k+1] = v + + var dbCollectionName = cosmosdbapi.GetCollectionName(r.db.databaseName, r.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": streamPos, + "@x3": roomIDs, } - rows, err := r.db.QueryContext(ctx, selectSQL, params...) + + rows, err := queryReceipt(r, ctx, selectRoomReceipts, params) + // rows, err := r.db.QueryContext(ctx, selectSQL, params...) if err != nil { return 0, nil, fmt.Errorf("unable to query room receipts: %w", err) } - defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomReceiptsAfter: rows.close() failed") + var res []api.OutputReceiptEvent - for rows.Next() { + for _, item := range rows { r := api.OutputReceiptEvent{} var id types.StreamPosition - err = rows.Scan(&id, &r.RoomID, &r.Type, &r.UserID, &r.EventID, &r.Timestamp) - if err != nil { - return 0, res, fmt.Errorf("unable to scan row to api.Receipts: %w", err) - } + // err = rows.Scan(&id, &r.RoomID, &r.Type, &r.UserID, &r.EventID, &r.Timestamp) + // if err != nil { + // return 0, res, fmt.Errorf("unable to scan row to api.Receipts: %w", err) + // } + id = types.StreamPosition(item.Receipt.ID) + r.RoomID = item.Receipt.RoomID + r.Type = item.Receipt.ReceiptType + r.UserID = item.Receipt.UserID + r.EventID = item.Receipt.EventID + r.Timestamp = gomatrixserverlib.Timestamp(item.Receipt.ReceiptTS) res = append(res, r) if id > lastPos { lastPos = id } } - return lastPos, res, rows.Err() + return lastPos, res, nil } func (s *receiptStatements) SelectMaxReceiptID( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { var nullableID sql.NullInt64 - stmt := sqlutil.TxStmt(txn, s.selectMaxReceiptID) - err = stmt.QueryRowContext(ctx).Scan(&nullableID) + + // "SELECT MAX(id) FROM syncapi_receipts" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + } + + rows, err := queryReceiptNumber(s, ctx, s.selectMaxReceiptID, params) + // stmt := sqlutil.TxStmt(txn, s.selectMaxReceiptID) + + if rows != nil { + nullableID.Int64 = rows[0].Max + } + // err = stmt.QueryRowContext(ctx).Scan(&nullableID) if nullableID.Valid { id = nullableID.Int64 } diff --git a/syncapi/storage/cosmosdb/send_to_device_table.go b/syncapi/storage/cosmosdb/send_to_device_table.go index 3a985c8d4..8cc3f24b8 100644 --- a/syncapi/storage/cosmosdb/send_to_device_table.go +++ b/syncapi/storage/cosmosdb/send_to_device_table.go @@ -18,108 +18,223 @@ import ( "context" "database/sql" "encoding/json" + "fmt" + "time" - "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/cosmosdbapi" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/sirupsen/logrus" ) -const sendToDeviceSchema = ` --- Stores send-to-device messages. -CREATE TABLE IF NOT EXISTS syncapi_send_to_device ( - -- The ID that uniquely identifies this message. - id INTEGER PRIMARY KEY AUTOINCREMENT, - -- The user ID to send the message to. - user_id TEXT NOT NULL, - -- The device ID to send the message to. - device_id TEXT NOT NULL, - -- The event content JSON. - content TEXT NOT NULL -); -` +// const sendToDeviceSchema = ` +// -- Stores send-to-device messages. +// CREATE TABLE IF NOT EXISTS syncapi_send_to_device ( +// -- The ID that uniquely identifies this message. +// id INTEGER PRIMARY KEY AUTOINCREMENT, +// -- The user ID to send the message to. +// user_id TEXT NOT NULL, +// -- The device ID to send the message to. +// device_id TEXT NOT NULL, +// -- The event content JSON. +// content TEXT NOT NULL +// ); +// ` -const insertSendToDeviceMessageSQL = ` - INSERT INTO syncapi_send_to_device (user_id, device_id, content) - VALUES ($1, $2, $3) -` +type SendToDeviceCosmos struct { + ID int64 `json:"id"` + UserID string `json:"user_id"` + DeviceID string `json:"device_id"` + Content string `json:"content"` +} -const selectSendToDeviceMessagesSQL = ` - SELECT id, user_id, device_id, content - FROM syncapi_send_to_device - WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4 - ORDER BY id DESC -` +type SendToDeviceCosmosMaxNumber struct { + Max int64 `json:"number"` +} + +type SendToDeviceCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + SendToDevice SendToDeviceCosmos `json:"mx_syncapi_send_to_device"` +} + +// const insertSendToDeviceMessageSQL = ` +// INSERT INTO syncapi_send_to_device (user_id, device_id, content) +// VALUES ($1, $2, $3) +// ` + +// SELECT id, user_id, device_id, content +// FROM syncapi_send_to_device +// WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4 +// ORDER BY id DESC +const selectSendToDeviceMessagesSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_syncapi_send_to_device.user_id = @x2 " + + "and c.mx_syncapi_send_to_device.device_id = @x3 " + + "and c.mx_syncapi_send_to_device.id > @x4 " + + "and c.mx_syncapi_send_to_device.id <= @x5 " + + "order by c.mx_syncapi_send_to_device.id desc " const deleteSendToDeviceMessagesSQL = ` DELETE FROM syncapi_send_to_device WHERE user_id = $1 AND device_id = $2 AND id < $3 ` +// "SELECT MAX(id) FROM syncapi_send_to_device" const selectMaxSendToDeviceIDSQL = "" + - "SELECT MAX(id) FROM syncapi_send_to_device" + "select max(c.mx_syncapi_send_to_device.id) as number from c where c._cn = @x1 " type sendToDeviceStatements struct { - db *sql.DB - insertSendToDeviceMessageStmt *sql.Stmt - selectSendToDeviceMessagesStmt *sql.Stmt + db *SyncServerDatasource + // insertSendToDeviceMessageStmt *sql.Stmt + selectSendToDeviceMessagesStmt string deleteSendToDeviceMessagesStmt *sql.Stmt - selectMaxSendToDeviceIDStmt *sql.Stmt + selectMaxSendToDeviceIDStmt string + tableName string } -func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { - s := &sendToDeviceStatements{ - db: db, - } - _, err := db.Exec(sendToDeviceSchema) +func querySendToDevice(s *sendToDeviceStatements, ctx context.Context, qry string, params map[string]interface{}) ([]SendToDeviceCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []SendToDeviceCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + if err != nil { return nil, err } - if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil { - return nil, err + return response, nil +} + +func querySendToDeviceNumber(s *sendToDeviceStatements, ctx context.Context, qry string, params map[string]interface{}) ([]SendToDeviceCosmosMaxNumber, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []SendToDeviceCosmosMaxNumber + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, nil } - if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil { - return nil, err - } - if s.deleteSendToDeviceMessagesStmt, err = db.Prepare(deleteSendToDeviceMessagesSQL); err != nil { - return nil, err - } - if s.selectMaxSendToDeviceIDStmt, err = db.Prepare(selectMaxSendToDeviceIDSQL); err != nil { - return nil, err + return response, nil +} + +func NewCosmosDBSendToDeviceTable(db *SyncServerDatasource) (tables.SendToDevice, error) { + s := &sendToDeviceStatements{ + db: db, } + // if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil { + // return nil, err + // } + s.selectSendToDeviceMessagesStmt = selectSendToDeviceMessagesSQL + // if s.deleteSendToDeviceMessagesStmt, err = db.Prepare(deleteSendToDeviceMessagesSQL); err != nil { + // return nil, err + // } + s.selectMaxSendToDeviceIDStmt = selectMaxSendToDeviceIDSQL + s.tableName = "send_to_device" return s, nil } func (s *sendToDeviceStatements) InsertSendToDeviceMessage( ctx context.Context, txn *sql.Tx, userID, deviceID, content string, ) (pos types.StreamPosition, err error) { - var result sql.Result - result, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content) - if p, err := result.LastInsertId(); err != nil { + + // id INTEGER PRIMARY KEY AUTOINCREMENT, + id, err := GetNextSendToDeviceID(s, ctx) + if err != nil { return 0, err - } else { - pos = types.StreamPosition(p) } + + pos = types.StreamPosition(id) + + // INSERT INTO syncapi_send_to_device (user_id, device_id, content) + // VALUES ($1, $2, $3) + + data := SendToDeviceCosmos{ + ID: int64(pos), + UserID: userID, + DeviceID: deviceID, + Content: content, + } + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + // NO CONSTRAINT + docId := fmt.Sprintf("%d", pos) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + + var dbData = SendToDeviceCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + SendToDevice: data, + } + + var optionsCreate = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk) + _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + dbData, + optionsCreate) + return } func (s *sendToDeviceStatements) SelectSendToDeviceMessages( ctx context.Context, txn *sql.Tx, userID, deviceID string, from, to types.StreamPosition, ) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) { - rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID, from, to) + // SELECT id, user_id, device_id, content + // FROM syncapi_send_to_device + // WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4 + // ORDER BY id DESC + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": userID, + "@x3": deviceID, + "@x4": from, + "@x5": to, + } + + rows, err := querySendToDevice(s, ctx, s.selectSendToDeviceMessagesStmt, params) if err != nil { return } - defer internal.CloseAndLogIfError(ctx, rows, "SelectSendToDeviceMessages: rows.close() failed") - for rows.Next() { + for _, item := range rows { var id types.StreamPosition var userID, deviceID, content string - if err = rows.Scan(&id, &userID, &deviceID, &content); err != nil { - logrus.WithError(err).Errorf("Failed to retrieve send-to-device message") - return - } + // if err = rows.Scan(&id, &userID, &deviceID, &content); err != nil { + // logrus.WithError(err).Errorf("Failed to retrieve send-to-device message") + // return + // } + id = types.StreamPosition(item.SendToDevice.ID) + userID = item.SendToDevice.UserID + deviceID = item.SendToDevice.DeviceID + content = item.SendToDevice.Content if id > lastPos { lastPos = id } @@ -128,8 +243,8 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages( UserID: userID, DeviceID: deviceID, } - if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil { - logrus.WithError(err).Errorf("Failed to unmarshal send-to-device message") + if jsonErr := json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil { + logrus.WithError(jsonErr).Errorf("Failed to unmarshal send-to-device message") continue } events = append(events, event) @@ -137,7 +252,7 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages( if lastPos == 0 { lastPos = to } - return lastPos, events, rows.Err() + return lastPos, events, err } func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( @@ -151,8 +266,21 @@ func (s *sendToDeviceStatements) SelectMaxSendToDeviceMessageID( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { var nullableID sql.NullInt64 - stmt := sqlutil.TxStmt(txn, s.selectMaxSendToDeviceIDStmt) - err = stmt.QueryRowContext(ctx).Scan(&nullableID) + // "SELECT MAX(id) FROM syncapi_send_to_device" + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + } + + rows, err := querySendToDeviceNumber(s, ctx, s.selectMaxSendToDeviceIDStmt, params) + // stmt := sqlutil.TxStmt(txn, s.selectMaxSendToDeviceIDStmt) + // err = stmt.QueryRowContext(ctx).Scan(&nullableID) + + if rows != nil { + nullableID.Int64 = rows[0].Max + } + if nullableID.Valid { id = nullableID.Int64 } diff --git a/syncapi/storage/cosmosdb/send_to_device_table_seq.go b/syncapi/storage/cosmosdb/send_to_device_table_seq.go new file mode 100644 index 000000000..05c4f89ef --- /dev/null +++ b/syncapi/storage/cosmosdb/send_to_device_table_seq.go @@ -0,0 +1,12 @@ +package cosmosdb + +import ( + "context" + + "github.com/matrix-org/dendrite/internal/cosmosdbutil" +) + +func GetNextSendToDeviceID(s *sendToDeviceStatements, ctx context.Context) (int64, error) { + const docId = "sendtodevice_seq" + return cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 1) +} diff --git a/syncapi/storage/cosmosdb/stream_id_table.go b/syncapi/storage/cosmosdb/stream_id_table.go index a599a9e65..9b0205ae2 100644 --- a/syncapi/storage/cosmosdb/stream_id_table.go +++ b/syncapi/storage/cosmosdb/stream_id_table.go @@ -4,91 +4,108 @@ import ( "context" "database/sql" - "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/internal/cosmosdbutil" + "github.com/matrix-org/dendrite/syncapi/types" ) -const streamIDTableSchema = ` --- Global stream ID counter, used by other tables. -CREATE TABLE IF NOT EXISTS syncapi_stream_id ( - stream_name TEXT NOT NULL PRIMARY KEY, - stream_id INT DEFAULT 0, +// const streamIDTableSchema = ` +// -- Global stream ID counter, used by other tables. +// CREATE TABLE IF NOT EXISTS syncapi_stream_id ( +// stream_name TEXT NOT NULL PRIMARY KEY, +// stream_id INT DEFAULT 0, - UNIQUE(stream_name) -); -INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("global", 0) - ON CONFLICT DO NOTHING; -INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("receipt", 0) - ON CONFLICT DO NOTHING; -INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("accountdata", 0) - ON CONFLICT DO NOTHING; -INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("invite", 0) - ON CONFLICT DO NOTHING; -` +// UNIQUE(stream_name) +// ); +// INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("global", 0) +// ON CONFLICT DO NOTHING; +// INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("receipt", 0) +// ON CONFLICT DO NOTHING; +// INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("accountdata", 0) +// ON CONFLICT DO NOTHING; +// INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("invite", 0) +// ON CONFLICT DO NOTHING; +// ` -const increaseStreamIDStmt = "" + - "UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1" +// const increaseStreamIDStmt = "" + +// "UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1" -const selectStreamIDStmt = "" + - "SELECT stream_id FROM syncapi_stream_id WHERE stream_name = $1" +// const selectStreamIDStmt = "" + +// "SELECT stream_id FROM syncapi_stream_id WHERE stream_name = $1" type streamIDStatements struct { - db *sql.DB - increaseStreamIDStmt *sql.Stmt - selectStreamIDStmt *sql.Stmt + db *SyncServerDatasource + // increaseStreamIDStmt *sql.Stmt + // selectStreamIDStmt *sql.Stmt + tableName string } -func (s *streamIDStatements) prepare(db *sql.DB) (err error) { +func (s *streamIDStatements) prepare(db *SyncServerDatasource) (err error) { s.db = db - _, err = db.Exec(streamIDTableSchema) - if err != nil { - return - } - if s.increaseStreamIDStmt, err = db.Prepare(increaseStreamIDStmt); err != nil { - return - } - if s.selectStreamIDStmt, err = db.Prepare(selectStreamIDStmt); err != nil { - return - } + s.tableName = "stream_id" return } func (s *streamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { - increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) - selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt) - if _, err = increaseStmt.ExecContext(ctx, "global"); err != nil { - return + const docId = "global_seq" + result, err := cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 1) + // increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) + // selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt) + // if _, err = increaseStmt.ExecContext(ctx, "global"); err != nil { + // return + // } + // err = selectStmt.QueryRowContext(ctx, "global").Scan(&pos) + if err != nil { + return -1, err } - err = selectStmt.QueryRowContext(ctx, "global").Scan(&pos) + pos = types.StreamPosition(result) return } func (s *streamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { - increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) - selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt) - if _, err = increaseStmt.ExecContext(ctx, "receipt"); err != nil { - return + const docId = "receipt_seq" + result, err := cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 1) + // increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) + // selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt) + // if _, err = increaseStmt.ExecContext(ctx, "receipt"); err != nil { + // return + // } + // err = selectStmt.QueryRowContext(ctx, "receipt").Scan(&pos) + if err != nil { + return -1, err } - err = selectStmt.QueryRowContext(ctx, "receipt").Scan(&pos) + pos = types.StreamPosition(result) return } func (s *streamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { - increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) - selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt) - if _, err = increaseStmt.ExecContext(ctx, "invite"); err != nil { - return + const docId = "invite_seq" + result, err := cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 1) + // increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) + // selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt) + // if _, err = increaseStmt.ExecContext(ctx, "invite"); err != nil { + // return + // } + // err = selectStmt.QueryRowContext(ctx, "invite").Scan(&pos) + if err != nil { + return -1, err } - err = selectStmt.QueryRowContext(ctx, "invite").Scan(&pos) + pos = types.StreamPosition(result) return } func (s *streamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { - increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) - selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt) - if _, err = increaseStmt.ExecContext(ctx, "accountdata"); err != nil { - return + const docId = "accountdata_seq" + result, err := cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 1) + // increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) + // selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt) + // if _, err = increaseStmt.ExecContext(ctx, "accountdata"); err != nil { + // return + // } + // err = selectStmt.QueryRowContext(ctx, "accountdata").Scan(&pos) + if err != nil { + return -1, err } - err = selectStmt.QueryRowContext(ctx, "accountdata").Scan(&pos) + pos = types.StreamPosition(result) return } diff --git a/syncapi/storage/cosmosdb/syncserver.go b/syncapi/storage/cosmosdb/syncserver.go index 7bf1a1387..1d30d9725 100644 --- a/syncapi/storage/cosmosdb/syncserver.go +++ b/syncapi/storage/cosmosdb/syncserver.go @@ -16,101 +16,104 @@ package cosmosdb import ( - "database/sql" + "github.com/matrix-org/dendrite/internal/cosmosdbapi" + "github.com/matrix-org/dendrite/internal/cosmosdbutil" // Import the sqlite3 package - _ "github.com/mattn/go-sqlite3" + // _ "github.com/mattn/go-sqlite3" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/storage/shared" - "github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas" ) // SyncServerDatasource represents a sync server datasource which manages // both the database for PDUs and caches for EDUs. type SyncServerDatasource struct { shared.Database - db *sql.DB - writer sqlutil.Writer - sqlutil.PartitionOffsetStatements - streamID streamIDStatements + // db *sql.DB + writer cosmosdbutil.Writer + database cosmosdbutil.Database + cosmosdbutil.PartitionOffsetStatements + streamID streamIDStatements + connection cosmosdbapi.CosmosConnection + databaseName string + cosmosConfig cosmosdbapi.CosmosConfig } // NewDatabase creates a new sync server database // nolint: gocyclo func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, error) { + conn := cosmosdbutil.GetCosmosConnection(&dbProperties.ConnectionString) + configCosmos := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString) var d SyncServerDatasource - var err error - if d.db, err = sqlutil.Open(dbProperties); err != nil { + d.writer = cosmosdbutil.NewExclusiveWriterFake() + if err := d.prepare(dbProperties); err != nil { return nil, err } - d.writer = sqlutil.NewExclusiveWriter() - if err = d.prepare(dbProperties); err != nil { - return nil, err + d.connection = conn + d.cosmosConfig = configCosmos + d.databaseName = "syncapi" + d.database = cosmosdbutil.Database{ + Connection: conn, + CosmosConfig: configCosmos, + DatabaseName: d.databaseName, } return &d, nil } func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (err error) { - if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "syncapi"); err != nil { + if err = d.PartitionOffsetStatements.Prepare(&d.database, d.writer, "syncapi"); err != nil { return err } - if err = d.streamID.prepare(d.db); err != nil { + if err = d.streamID.prepare(d); err != nil { return err } - accountData, err := NewSqliteAccountDataTable(d.db, &d.streamID) + accountData, err := NewCosmosDBAccountDataTable(d, &d.streamID) if err != nil { return err } - events, err := NewSqliteEventsTable(d.db, &d.streamID) + events, err := NewCosmosDBEventsTable(d, &d.streamID) if err != nil { return err } - roomState, err := NewSqliteCurrentRoomStateTable(d.db, &d.streamID) + roomState, err := NewCosmosDBCurrentRoomStateTable(d, &d.streamID) if err != nil { return err } - invites, err := NewSqliteInvitesTable(d.db, &d.streamID) + invites, err := NewCosmosDBInvitesTable(d, &d.streamID) if err != nil { return err } - peeks, err := NewSqlitePeeksTable(d.db, &d.streamID) + peeks, err := NewCosmosDBPeeksTable(d, &d.streamID) if err != nil { return err } - topology, err := NewSqliteTopologyTable(d.db) + topology, err := NewCosmosDBTopologyTable(d) if err != nil { return err } - bwExtrem, err := NewSqliteBackwardsExtremitiesTable(d.db) + bwExtrem, err := NewCosmosDBBackwardsExtremitiesTable(d) if err != nil { return err } - sendToDevice, err := NewSqliteSendToDeviceTable(d.db) + sendToDevice, err := NewCosmosDBSendToDeviceTable(d) if err != nil { return err } - filter, err := NewSqliteFilterTable(d.db) + filter, err := NewCosmosDBFilterTable(d) if err != nil { return err } - receipts, err := NewSqliteReceiptsTable(d.db, &d.streamID) + receipts, err := NewCosmosDBReceiptsTable(d, &d.streamID) if err != nil { return err } - memberships, err := NewSqliteMembershipsTable(d.db) + memberships, err := NewCosmosDBMembershipsTable(d) if err != nil { return err } - m := sqlutil.NewMigrations() - deltas.LoadFixSequences(m) - deltas.LoadRemoveSendToDeviceSentColumn(m) - if err = m.RunDeltas(d.db, dbProperties); err != nil { - return err - } d.Database = shared.Database{ - DB: d.db, + DB: nil, Writer: d.writer, Invites: invites, Peeks: peeks, diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index b8271877b..bf1d1f300 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -674,12 +674,17 @@ func (d *Database) GetStateDeltas( // * Check if user is still CURRENTLY invited to the room. If so, add room to 'invited' block. // * Check if the user is CURRENTLY (TODO) left/banned. If so, add room to 'archived' block. // - Get all CURRENTLY joined rooms, and add them to 'joined' block. - txn, err := d.readOnlySnapshot(ctx) - if err != nil { - return nil, nil, fmt.Errorf("d.readOnlySnapshot: %w", err) + + // HACK: CosmosDB - Allow for DB nil + var txn *sql.Tx + succeeded := true + if d.DB != nil { + txn, err := d.readOnlySnapshot(ctx) + if err != nil { + return nil, nil, fmt.Errorf("d.readOnlySnapshot: %w", err) + } + defer sqlutil.EndTransactionWithCheck(txn, &succeeded, &err) } - var succeeded bool - defer sqlutil.EndTransactionWithCheck(txn, &succeeded, &err) var deltas []types.StateDelta diff --git a/userapi/storage/accounts/cosmosdb/storage.go b/userapi/storage/accounts/cosmosdb/storage.go index 0f344945e..c3124fd98 100644 --- a/userapi/storage/accounts/cosmosdb/storage.go +++ b/userapi/storage/accounts/cosmosdb/storage.go @@ -37,8 +37,9 @@ import ( // Database represents an account database type Database struct { - sqlutil.PartitionOffsetStatements - writer sqlutil.Writer + database cosmosdbutil.Database + cosmosdbutil.PartitionOffsetStatements + writer cosmosdbutil.Writer accounts accountsStatements profiles profilesStatements accountDatas accountDataStatements @@ -56,18 +57,23 @@ type Database struct { // NewDatabase creates a new accounts and profiles database func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) { conn := cosmosdbutil.GetCosmosConnection(&dbProperties.ConnectionString) - config := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString) + configCosmos := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString) d := &Database{ serverName: serverName, databaseName: "userapi", connection: conn, - cosmosConfig: config, + cosmosConfig: configCosmos, // db: db, writer: sqlutil.NewExclusiveWriter(), // bcryptCost: bcryptCost, // openIDTokenLifetimeMS: openIDTokenLifetimeMS, } + d.database = cosmosdbutil.Database{ + Connection: conn, + CosmosConfig: configCosmos, + DatabaseName: d.databaseName, + } // Create tables before executing migrations so we don't fail if the table is missing, // and THEN prepare statements so we don't fail due to referencing new columns @@ -80,10 +86,9 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver // return nil, err // } - // partitions := sqlutil.PartitionOffsetStatements{} - // if err = partitions.Prepare(db, d.writer, "account"); err != nil { - // return nil, err - // } + if err := d.PartitionOffsetStatements.Prepare(&d.database, d.writer, "account"); err != nil { + return nil, err + } var err error if err = d.accounts.prepare(d, serverName); err != nil { return nil, err diff --git a/userapi/storage/devices/cosmosdb/devices_table.go b/userapi/storage/devices/cosmosdb/devices_table.go index ae1062140..1581064ae 100644 --- a/userapi/storage/devices/cosmosdb/devices_table.go +++ b/userapi/storage/devices/cosmosdb/devices_table.go @@ -160,8 +160,8 @@ func getDevice(s *devicesStatements, ctx context.Context, pk string, docId strin return &response, err } -func setDevice(s *devicesStatements, ctx context.Context, pk string, device DeviceCosmosData) (*DeviceCosmosData, error) { - var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(pk, device.ETag) +func setDevice(s *devicesStatements, ctx context.Context, device DeviceCosmosData) (*DeviceCosmosData, error) { + var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(device.Pk, device.ETag) var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( ctx, s.db.cosmosConfig.DatabaseName, @@ -345,7 +345,7 @@ func (s *devicesStatements) updateDeviceName( response.Device.DisplayName = *displayName - var _, exReplace = setDevice(s, ctx, pk, *response) + var _, exReplace = setDevice(s, ctx, *response) if exReplace != nil { return exReplace } @@ -460,8 +460,9 @@ func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, localpart, } response.Device.LastSeenTS = lastSeenTs + response.Device.LastSeenIP = ipAddr - var _, exReplace = setDevice(s, ctx, pk, *response) + var _, exReplace = setDevice(s, ctx, *response) if exReplace != nil { return exReplace }