From fd7f25479b0c25dc1777d60b7cfc30942f213e12 Mon Sep 17 00:00:00 2001 From: alexfca <75228224+alexfca@users.noreply.github.com> Date: Fri, 10 Sep 2021 16:04:17 +1000 Subject: [PATCH] Upgrade Dendrite 0.5.0 support for CosmosDB (#15) * - Add CosmosDB back - Add missing methods to blacklist_table.go - Add missing methods to device_keys_table.go - Add missing methods to events_table.go - Add missing methods to membership_table.go - Update state_block_table.go (due to reafctor SQL) - Update state_snapshot_table.go (due to reafctor SQL) - Add new key_backup_table.go - Add new key_backup_version_table.go - Code compiles but has runtime errors * Message sending + receiving working Rooms and DMs working - Add CrossSigningKeys table - Add CrossSigningSigs table - Refactor DeviceKeys yable - Fix OneTimeKeys - Update the KeyServer storage.go to use a PartitionStorer instead of a specific SQL PartitionOffsetStatements - Fix small issues from the previous commit - Implement DeleteSendToDeviceMessages Co-authored-by: alexf@example.com --- .../storage/cosmosdb/blacklist_table.go | 59 ++- go.mod | 1 + go.sum | 4 + .../cosmosdb/cross_signing_keys_table.go | 179 ++++++++ .../cosmosdb/cross_signing_sigs_table.go | 241 ++++++++++ .../storage/cosmosdb/device_keys_table.go | 51 ++- .../storage/cosmosdb/one_time_keys_table.go | 8 +- keyserver/storage/cosmosdb/storage.go | 42 +- keyserver/storage/postgres/storage.go | 4 +- keyserver/storage/shared/storage.go | 3 +- keyserver/storage/sqlite3/storage.go | 4 +- roomserver/storage/cosmosdb/events_table.go | 86 ++++ .../storage/cosmosdb/membership_table.go | 116 +++++ roomserver/storage/cosmosdb/state_blob_seq.go | 13 + .../storage/cosmosdb/state_block_table.go | 382 +++++++--------- .../storage/cosmosdb/state_snapshot_table.go | 51 ++- .../storage/cosmosdb/send_to_device_table.go | 57 ++- .../accounts/cosmosdb/key_backup_table.go | 414 ++++++++++++++++++ .../cosmosdb/key_backup_version_table.go | 377 ++++++++++++++++ .../key_backup_version_table_id_seq.go | 12 + userapi/storage/accounts/cosmosdb/storage.go | 156 +++++++ 21 files changed, 1985 insertions(+), 275 deletions(-) create mode 100644 keyserver/storage/cosmosdb/cross_signing_keys_table.go create mode 100644 keyserver/storage/cosmosdb/cross_signing_sigs_table.go create mode 100644 roomserver/storage/cosmosdb/state_blob_seq.go create mode 100644 userapi/storage/accounts/cosmosdb/key_backup_table.go create mode 100644 userapi/storage/accounts/cosmosdb/key_backup_version_table.go create mode 100644 userapi/storage/accounts/cosmosdb/key_backup_version_table_id_seq.go diff --git a/federationsender/storage/cosmosdb/blacklist_table.go b/federationsender/storage/cosmosdb/blacklist_table.go index 532b53b8a..d5b08756c 100644 --- a/federationsender/storage/cosmosdb/blacklist_table.go +++ b/federationsender/storage/cosmosdb/blacklist_table.go @@ -57,12 +57,17 @@ type BlacklistCosmosData struct { // const deleteBlacklistSQL = "" + // "DELETE FROM federationsender_blacklist WHERE server_name = $1" +// "DELETE FROM federationsender_blacklist" +const deleteAllBlacklistSQL = "" + + "select * from c where c._cn = @x1 " + type blacklistStatements struct { db *Database // insertBlacklistStmt *sql.Stmt // selectBlacklistStmt *sql.Stmt // deleteBlacklistStmt *sql.Stmt - tableName string + deleteAllBlacklistStmt string + tableName string } func getBlacklist(s *blacklistStatements, ctx context.Context, pk string, docId string) (*BlacklistCosmosData, error) { @@ -82,6 +87,27 @@ func getBlacklist(s *blacklistStatements, ctx context.Context, pk string, docId return &response, err } +func queryBlacklist(s *blacklistStatements, ctx context.Context, qry string, params map[string]interface{}) ([]BlacklistCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + var response []BlacklistCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, err + } + return response, nil +} + func deleteBlacklist(s *blacklistStatements, ctx context.Context, dbData BlacklistCosmosData) error { var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk) var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( @@ -101,6 +127,7 @@ func NewCosmosDBBlacklistTable(db *Database) (s *blacklistStatements, err error) s = &blacklistStatements{ db: db, } + s.deleteAllBlacklistStmt = deleteAllBlacklistSQL s.tableName = "blacklists" return } @@ -189,8 +216,36 @@ func (s *blacklistStatements) DeleteBlacklist( pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) // _, err := stmt.ExecContext(ctx, serverName) res, err := getBlacklist(s, ctx, pk, cosmosDocId) - if(res != nil) { + if res != nil { _ = deleteBlacklist(s, ctx, *res) } return err } + +func (s *blacklistStatements) DeleteAllBlacklist( + ctx context.Context, txn *sql.Tx, +) error { + // "DELETE FROM federationsender_blacklist" + + // stmt := sqlutil.TxStmt(txn, s.deleteAllBlacklistStmt) + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + } + + // rows, err := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryContext(ctx, roomID) + rows, err := queryBlacklist(s, ctx, s.deleteAllBlacklistStmt, params) + + if err != nil { + return err + } + // _, err := stmt.ExecContext(ctx) + for _, item := range rows { + // stmt := sqlutil.TxStmt(txn, deleteStmt) + err = deleteBlacklist(s, ctx, item) + if err != nil { + return err + } + } + return err +} diff --git a/go.mod b/go.mod index 7f883bea0..f451a61a1 100644 --- a/go.mod +++ b/go.mod @@ -51,6 +51,7 @@ require ( github.com/tidwall/sjson v1.1.7 github.com/uber/jaeger-client-go v2.29.1+incompatible github.com/uber/jaeger-lib v2.4.1+incompatible + github.com/vippsas/go-cosmosdb v0.0.0-20200428065936-29dab535353d // indirect github.com/yggdrasil-network/yggdrasil-go v0.4.1-0.20210715083903-52309d094c00 go.uber.org/atomic v1.9.0 golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 diff --git a/go.sum b/go.sum index 65cf14657..9bf0b99c4 100644 --- a/go.sum +++ b/go.sum @@ -100,6 +100,7 @@ github.com/VividCortex/ewma v1.2.0/go.mod h1:nz4BbCtbLyFDeC9SUHbtcT5644juEuWfUAU github.com/aead/siphash v1.0.1/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII= github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= github.com/albertorestifo/dijkstra v0.0.0-20160910063646-aba76f725f72/go.mod h1:o+JdB7VetTHjLhU0N57x18B9voDBQe0paApdEAEoEfw= +github.com/alecthomas/repr v0.0.0-20181024024818-d37bc2a10ba1/go.mod h1:xTS7Pm1pD1mvyM075QCDSRqH6qRLXylzS24ZTpRiSzQ= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= @@ -432,6 +433,7 @@ github.com/godbus/dbus v0.0.0-20180201030542-885f9cc04c9c/go.mod h1:/YcGZj5zSblf github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e/go.mod h1:bBOAhwG1umN6/6ZUMtDFBMQR8jRg9O75tm9K00oMsK4= github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/gofrs/uuid v3.1.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gogo/googleapis v1.2.0/go.mod h1:Njal3psf3qN6dwBtQfUmBZh2ybovJ0tlu3o/AC7HYjU= github.com/gogo/googleapis v1.4.0/go.mod h1:5YRNX2z1oM5gXdAkurHa942MDgEJyk02w4OecKY87+c= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= @@ -1405,6 +1407,8 @@ github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPU github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio= github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= +github.com/vippsas/go-cosmosdb v0.0.0-20200428065936-29dab535353d h1:MZRYOouO0snrQyBAf4Wljc3qqaispjzMOhFRQgWfKMo= +github.com/vippsas/go-cosmosdb v0.0.0-20200428065936-29dab535353d/go.mod h1:ldPlejlc7ZyiP0QQWGwL9CoZLvEjhD9yzpz0ct7+sXo= github.com/vishvananda/netlink v0.0.0-20181108222139-023a6dafdcdf/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk= github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE= github.com/vishvananda/netlink v1.1.1-0.20201029203352-d40f9887b852/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho= diff --git a/keyserver/storage/cosmosdb/cross_signing_keys_table.go b/keyserver/storage/cosmosdb/cross_signing_keys_table.go new file mode 100644 index 000000000..fd32e7019 --- /dev/null +++ b/keyserver/storage/cosmosdb/cross_signing_keys_table.go @@ -0,0 +1,179 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/matrix-org/dendrite/internal/cosmosdbapi" + "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/dendrite/keyserver/types" + "github.com/matrix-org/gomatrixserverlib" +) + +// var crossSigningKeysSchema = ` +// CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys ( +// user_id TEXT NOT NULL, +// key_type INTEGER NOT NULL, +// key_data TEXT NOT NULL, +// PRIMARY KEY (user_id, key_type) +// ); +// ` + +type CrossSigningKeysCosmos struct { + UserID string `json:"user_id"` + KeyType int64 `json:"key_type"` + KeyData []byte `json:"key_data"` +} + +type CrossSigningKeysCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Tn string `json:"_sid"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + CrossSigningKeys CrossSigningKeysCosmos `json:"mx_keyserver_cross_signing_keys"` +} + +// "SELECT key_type, key_data FROM keyserver_cross_signing_keys" + +// " WHERE user_id = $1" +const selectCrossSigningKeysForUserSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_keyserver_cross_signing_keys.user_id = @x2 " + +// const upsertCrossSigningKeysForUserSQL = "" + +// "INSERT OR REPLACE INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" + +// " VALUES($1, $2, $3)" + +type crossSigningKeysStatements struct { + db *Database + selectCrossSigningKeysForUserStmt string + // upsertCrossSigningKeysForUserStmt *sql.Stmt + tableName string +} + +func queryCrossSigningKeys(s *crossSigningKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]CrossSigningKeysCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + var response []CrossSigningKeysCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, err + } + return response, nil +} + +func NewSqliteCrossSigningKeysTable(db *Database) (tables.CrossSigningKeys, error) { + s := &crossSigningKeysStatements{ + db: db, + } + s.selectCrossSigningKeysForUserStmt = selectCrossSigningKeysForUserSQL + // s.upsertCrossSigningKeysForUserStmt = upsertCrossSigningKeysForUserSQL + s.tableName = "cross_signing_keys" + return s, nil +} + +func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( + ctx context.Context, txn *sql.Tx, userID string, +) (r types.CrossSigningKeyMap, err error) { + // "SELECT key_type, key_data FROM keyserver_cross_signing_keys" + + // " WHERE user_id = $1" + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": userID, + } + rows, err := queryCrossSigningKeys(s, ctx, s.selectCrossSigningKeysForUserStmt, params) + // rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID) + if err != nil { + return nil, err + } + // defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningKeysForUserStmt: rows.close() failed") + r = types.CrossSigningKeyMap{} + // for rows.Next() { + for _, item := range rows { + var keyTypeInt int16 + var keyData gomatrixserverlib.Base64Bytes + // if err := rows.Scan(&keyTypeInt, &keyData); err != nil { + // return nil, err + // } + keyData = item.CrossSigningKeys.KeyData + keyTypeInt = int16(item.CrossSigningKeys.KeyType) + keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt] + if !ok { + return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt) + } + r[keyType] = keyData + } + return +} + +func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser( + ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes, +) error { + // "INSERT OR REPLACE INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" + + // " VALUES($1, $2, $3)" + keyTypeInt, ok := types.KeyTypePurposeToInt[keyType] + if !ok { + return fmt.Errorf("unknown key purpose %q", keyType) + } + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + // PRIMARY KEY (user_id, key_type) + docId := fmt.Sprintf("%s_%s", userID, keyType) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) + + data := CrossSigningKeysCosmos{ + UserID: userID, + KeyType: int64(keyTypeInt), + KeyData: keyData, + } + + dbData := CrossSigningKeysCosmosData{ + Id: cosmosDocId, + Tn: s.db.cosmosConfig.TenantName, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + CrossSigningKeys: data, + } + + // if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyTypeInt, keyData); err != nil { + // return fmt.Errorf("s.upsertCrossSigningKeysForUserStmt: %w", err) + // } + var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk) + var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + dbData, + options) + + return err +} diff --git a/keyserver/storage/cosmosdb/cross_signing_sigs_table.go b/keyserver/storage/cosmosdb/cross_signing_sigs_table.go new file mode 100644 index 000000000..2ef8a0518 --- /dev/null +++ b/keyserver/storage/cosmosdb/cross_signing_sigs_table.go @@ -0,0 +1,241 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/matrix-org/dendrite/internal/cosmosdbapi" + "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/dendrite/keyserver/types" + "github.com/matrix-org/gomatrixserverlib" +) + +// var crossSigningSigsSchema = ` +// CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs ( +// origin_user_id TEXT NOT NULL, +// origin_key_id TEXT NOT NULL, +// target_user_id TEXT NOT NULL, +// target_key_id TEXT NOT NULL, +// signature TEXT NOT NULL, +// PRIMARY KEY (origin_user_id, target_user_id, target_key_id) +// ); +// ` + +type CrossSigningSigsCosmos struct { + OriginUserId string `json:"origin_user_id"` + OriginKeyId string `json:"origin_key_id"` + TargetUserId string `json:"target_user_id"` + TargetKeyId string `json:"target_key_id"` + Signature []byte `json:"signature"` +} + +type CrossSigningSigsCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Tn string `json:"_sid"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + CrossSigningSigs CrossSigningSigsCosmos `json:"mx_keyserver_cross_signing_sigs"` +} + +// "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" + +// " WHERE target_user_id = $1 AND target_key_id = $2" +const selectCrossSigningSigsForTargetSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_keyserver_cross_signing_sigs.target_user_id = @x2 " + + "and c.mx_keyserver_cross_signing_sigs.target_key_id = @x3 " + +// const upsertCrossSigningSigsForTargetSQL = "" + +// "INSERT OR REPLACE INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" + +// " VALUES($1, $2, $3, $4, $5)" + +// "DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2" +const deleteCrossSigningSigsForTargetSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_keyserver_cross_signing_sigs.target_user_id = @x2 " + + "and c.mx_keyserver_cross_signing_sigs.target_key_id = @x3 " + +type crossSigningSigsStatements struct { + db *Database + selectCrossSigningSigsForTargetStmt string + // upsertCrossSigningSigsForTargetStmt *sql.Stmt + deleteCrossSigningSigsForTargetStmt string + tableName string +} + +func queryCrossSigningSigs(s *crossSigningSigsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]CrossSigningSigsCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + var response []CrossSigningSigsCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, err + } + return response, nil +} + +func deleteCrossSigningSigs(s *crossSigningSigsStatements, ctx context.Context, dbData CrossSigningSigsCosmosData) error { + var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk) + var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + dbData.Id, + options) + + if err != nil { + return err + } + return err +} + +func NewSqliteCrossSigningSigsTable(db *Database) (tables.CrossSigningSigs, error) { + s := &crossSigningSigsStatements{ + db: db, + } + // _, err := db.Exec(crossSigningSigsSchema) + // if err != nil { + // return nil, err + // } + s.selectCrossSigningSigsForTargetStmt = selectCrossSigningSigsForTargetSQL + // s.upsertCrossSigningSigsForTargetStmt = upsertCrossSigningSigsForTargetSQL + s.deleteCrossSigningSigsForTargetStmt = deleteCrossSigningSigsForTargetSQL + s.tableName = "cross_signing_sigs" + return s, nil +} + +func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( + ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID, +) (r types.CrossSigningSigMap, err error) { + // "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" + + // " WHERE target_user_id = $1 AND target_key_id = $2" + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": targetUserID, + "@x3": targetKeyID, + } + rows, err := queryCrossSigningSigs(s, ctx, s.selectCrossSigningSigsForTargetStmt, params) + // rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, targetUserID, targetKeyID) + if err != nil { + return nil, err + } + // defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningSigsForTargetStmt: rows.close() failed") + r = types.CrossSigningSigMap{} + // for rows.Next() { + for _, item := range rows { + var userID string + var keyID gomatrixserverlib.KeyID + var signature gomatrixserverlib.Base64Bytes + // if err := rows.Scan(&userID, &keyID, &signature); err != nil { + // return nil, err + // } + userID = item.CrossSigningSigs.OriginUserId + keyID = gomatrixserverlib.KeyID(item.CrossSigningSigs.OriginKeyId) + signature = gomatrixserverlib.Base64Bytes(item.CrossSigningSigs.Signature) + if _, ok := r[userID]; !ok { + r[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } + r[userID][keyID] = signature + } + return +} + +func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget( + ctx context.Context, txn *sql.Tx, + originUserID string, originKeyID gomatrixserverlib.KeyID, + targetUserID string, targetKeyID gomatrixserverlib.KeyID, + signature gomatrixserverlib.Base64Bytes, +) error { + // "INSERT OR REPLACE INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" + + // " VALUES($1, $2, $3, $4, $5)" + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + // PRIMARY KEY (origin_user_id, target_user_id, target_key_id) + docId := fmt.Sprintf("%s_%s_%s", originUserID, targetUserID, targetKeyID) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) + + data := CrossSigningSigsCosmos{ + TargetUserId: targetUserID, + TargetKeyId: string(targetKeyID), + OriginUserId: originUserID, + OriginKeyId: string(originKeyID), + Signature: signature, + } + + dbData := CrossSigningSigsCosmosData{ + Id: cosmosDocId, + Tn: s.db.cosmosConfig.TenantName, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + CrossSigningSigs: data, + } + + // if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { + // return fmt.Errorf("s.upsertCrossSigningSigsForTargetStmt: %w", err) + // } + var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk) + var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + dbData, + options) + return err +} + +func (s *crossSigningSigsStatements) DeleteCrossSigningSigsForTarget( + ctx context.Context, txn *sql.Tx, + targetUserID string, targetKeyID gomatrixserverlib.KeyID, +) error { + // "DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2" + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": targetUserID, + "@x3": targetKeyID, + } + rows, err := queryCrossSigningSigs(s, ctx, s.selectCrossSigningSigsForTargetStmt, params) + // if _, err := sqlutil.TxStmt(txn, s.deleteCrossSigningSigsForTargetStmt).ExecContext(ctx, targetUserID, targetKeyID); err != nil { + // return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err) + // } + if err != nil { + return err + } + + for _, item := range rows { + errItem := deleteCrossSigningSigs(s, ctx, item) + if errItem != nil { + return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err) + } + } + return nil +} diff --git a/keyserver/storage/cosmosdb/device_keys_table.go b/keyserver/storage/cosmosdb/device_keys_table.go index 903d83b8d..f96bd77f4 100644 --- a/keyserver/storage/cosmosdb/device_keys_table.go +++ b/keyserver/storage/cosmosdb/device_keys_table.go @@ -94,7 +94,13 @@ const selectAllDeviceKeysSQL = "" + "select * from c where c._cn = @x1 " + "and c.mx_keyserver_device_key.user_id = @x2 " -// const deleteAllDeviceKeysSQL = "" + +// "DELETE FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" +const deleteDeviceKeysSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_keyserver_device_key.user_id = @x2 " + + "and c.mx_keyserver_device_key.device_id = @x3 " + + // const deleteAllDeviceKeysSQL = "" + // "DELETE FROM keyserver_device_keys WHERE user_id=$1" func queryDeviceKey(s *deviceKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]DeviceKeyCosmosData, error) { @@ -192,6 +198,7 @@ type deviceKeysStatements struct { // selectDeviceKeysStmt *sql.Stmt selectBatchDeviceKeysStmt string selectMaxStreamForUserStmt string + deleteDeviceKeysStmt string // deleteAllDeviceKeysStmt *sql.Stmt tableName string } @@ -202,6 +209,7 @@ func NewCosmosDBDeviceKeysTable(db *Database) (tables.DeviceKeys, error) { } s.selectBatchDeviceKeysStmt = selectBatchDeviceKeysSQL s.selectMaxStreamForUserStmt = selectMaxStreamForUserSQL + s.deleteDeviceKeysStmt = deleteDeviceKeysSQL s.tableName = "device_keys" return s, nil } @@ -221,6 +229,30 @@ func deleteDeviceKeyCore(s *deviceKeysStatements, ctx context.Context, dbData De return err } +func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error { + // "DELETE FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" + // _, err := sqlutil.TxStmt(txn, s.deleteDeviceKeysStmt).ExecContext(ctx, userID, deviceID) + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": userID, + "@x3": deviceID, + } + response, err := queryDeviceKey(s, ctx, selectAllDeviceKeysSQL, params) + + if err != nil { + return err + } + + for _, item := range response { + errItem := deleteDeviceKeyCore(s, ctx, item) + if errItem != nil { + return errItem + } + } + return nil +} + func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error { // "DELETE FROM keyserver_device_keys WHERE user_id=$1" @@ -268,20 +300,25 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID var result []api.DeviceMessage for _, item := range response { - var dk api.DeviceMessage - dk.UserID = userID + dk := api.DeviceMessage{ + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{}, + } + dk.Type = api.TypeDeviceKeyUpdate + dk.UserID = item.DeviceKey.UserID // var keyJSON string var streamID int // var displayName sql.NullString // if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil { // return nil, err // } - streamID = item.DeviceKey.StreamID - + dk.DeviceID = item.DeviceKey.DeviceID dk.KeyJSON = item.DeviceKey.KeyJSON + streamID = item.DeviceKey.StreamID + displayName := item.DeviceKey.DisplayName dk.StreamID = streamID - if len(item.DeviceKey.DisplayName) > 0 { - dk.DisplayName = item.DeviceKey.DisplayName + if len(displayName) > 0 { + dk.DisplayName = displayName } // include the key if we want all keys (no device) or it was asked if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { diff --git a/keyserver/storage/cosmosdb/one_time_keys_table.go b/keyserver/storage/cosmosdb/one_time_keys_table.go index 78ad5c347..b2a4ccb21 100644 --- a/keyserver/storage/cosmosdb/one_time_keys_table.go +++ b/keyserver/storage/cosmosdb/one_time_keys_table.go @@ -341,7 +341,7 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string, ) (map[string]json.RawMessage, error) { var keyID string - var keyJSON string + // var keyJSON string // "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1" @@ -360,14 +360,16 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( } return nil, err } + keyID = response[0].OneTimeKey.KeyID + keyJSONBytes := response[0].OneTimeKey.KeyJSON err = deleteOneTimeKeyCore(s, ctx, response[0]) if err != nil { return nil, err } - if keyJSON == "" { + if keyID == "" { return nil, nil } return map[string]json.RawMessage{ - algorithm + ":" + keyID: json.RawMessage(keyJSON), + algorithm + ":" + keyID: keyJSONBytes, }, err } diff --git a/keyserver/storage/cosmosdb/storage.go b/keyserver/storage/cosmosdb/storage.go index 004d8d7d6..78c0d8632 100644 --- a/keyserver/storage/cosmosdb/storage.go +++ b/keyserver/storage/cosmosdb/storage.go @@ -24,7 +24,7 @@ import ( // A Database is used to store room events and stream offsets. type Database struct { - shared.Database + database cosmosdbutil.Database connection cosmosdbapi.CosmosConnection databaseName string cosmosConfig cosmosdbapi.CosmosConfig @@ -33,38 +33,62 @@ type Database struct { func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) { conn := cosmosdbutil.GetCosmosConnection(&dbProperties.ConnectionString) - config := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString) - d := &Database{ + configCosmos := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString) + result := &Database{ databaseName: "keyserver", connection: conn, - cosmosConfig: config, + cosmosConfig: configCosmos, + } + + result.database = cosmosdbutil.Database{ + Connection: conn, + CosmosConfig: configCosmos, + DatabaseName: result.databaseName, } // db, err := sqlutil.Open(dbProperties) // if err != nil { // return nil, err // } - otk, err := NewCosmosDBOneTimeKeysTable(d) + otk, err := NewCosmosDBOneTimeKeysTable(result) if err != nil { return nil, err } - dk, err := NewCosmosDBDeviceKeysTable(d) + dk, err := NewCosmosDBDeviceKeysTable(result) if err != nil { return nil, err } - kc, err := NewCosmosDBKeyChangesTable(d) + kc, err := NewCosmosDBKeyChangesTable(result) if err != nil { return nil, err } - sdl, err := NewCosmosDBStaleDeviceListsTable(d) + sdl, err := NewCosmosDBStaleDeviceListsTable(result) if err != nil { return nil, err } + csk, err := NewSqliteCrossSigningKeysTable(result) + if err != nil { + return nil, err + } + css, err := NewSqliteCrossSigningSigsTable(result) + if err != nil { + return nil, err + } + + writer := cosmosdbutil.NewExclusiveWriterFake() + storer := cosmosdbutil.PartitionOffsetStatements{} + if err = storer.Prepare(&result.database, writer, "keyserver"); err != nil { + return nil, err + } + return &shared.Database{ - Writer: cosmosdbutil.NewExclusiveWriterFake(), + Writer: writer, OneTimeKeysTable: otk, DeviceKeysTable: dk, KeyChangesTable: kc, StaleDeviceListsTable: sdl, + CrossSigningKeysTable: csk, + CrossSigningSigsTable: css, + PartitionStorer: &storer, }, nil } diff --git a/keyserver/storage/postgres/storage.go b/keyserver/storage/postgres/storage.go index 52f3a7f6b..68a10431a 100644 --- a/keyserver/storage/postgres/storage.go +++ b/keyserver/storage/postgres/storage.go @@ -61,8 +61,10 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) CrossSigningKeysTable: csk, CrossSigningSigsTable: css, } - if err = d.PartitionOffsetStatements.Prepare(db, d.Writer, "keyserver"); err != nil { + storer := sqlutil.PartitionOffsetStatements{} + if err = storer.Prepare(db, d.Writer, "keyserver"); err != nil { return nil, err } + d.PartitionStorer = &storer return d, nil } diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index 5bd8be368..94caf8f70 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -20,6 +20,7 @@ import ( "encoding/json" "fmt" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" @@ -28,6 +29,7 @@ import ( ) type Database struct { + internal.PartitionStorer DB *sql.DB Writer sqlutil.Writer OneTimeKeysTable tables.OneTimeKeys @@ -36,7 +38,6 @@ type Database struct { StaleDeviceListsTable tables.StaleDeviceLists CrossSigningKeysTable tables.CrossSigningKeys CrossSigningSigsTable tables.CrossSigningSigs - sqlutil.PartitionOffsetStatements } func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { diff --git a/keyserver/storage/sqlite3/storage.go b/keyserver/storage/sqlite3/storage.go index ee1746cd6..a5eb694f6 100644 --- a/keyserver/storage/sqlite3/storage.go +++ b/keyserver/storage/sqlite3/storage.go @@ -59,8 +59,10 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) CrossSigningKeysTable: csk, CrossSigningSigsTable: css, } - if err = d.PartitionOffsetStatements.Prepare(db, d.Writer, "keyserver"); err != nil { + storer := sqlutil.PartitionOffsetStatements{} + if err = storer.Prepare(db, d.Writer, "keyserver"); err != nil { return nil, err } + d.PartitionStorer = &storer return d, nil } diff --git a/roomserver/storage/cosmosdb/events_table.go b/roomserver/storage/cosmosdb/events_table.go index a5837030f..669d943ed 100644 --- a/roomserver/storage/cosmosdb/events_table.go +++ b/roomserver/storage/cosmosdb/events_table.go @@ -19,6 +19,7 @@ import ( "context" "database/sql" "fmt" + "sort" "time" "github.com/matrix-org/dendrite/internal/cosmosdbutil" @@ -98,6 +99,14 @@ const bulkSelectStateEventByIDSQL = "" + // ", c.mx_roomserver_event.event_state_key_nid " + "asc" +// "SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" + +// " WHERE event_nid IN ($1)" +// // Rest of query is built by BulkSelectStateEventByNID +const bulkSelectStateEventByNIDSQL = "" + + "select * from c where c._cn = @x1 " + + "and ARRAY_CONTAINS(@x2, c.mx_roomserver_event.event_nid) " + // Rest of query is built by BulkSelectStateEventByNID + // "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" + // " WHERE event_id IN ($1)" const bulkSelectStateAtEventByIDSQL = "" + @@ -491,6 +500,83 @@ func (s *eventStatements) BulkSelectStateEventByID( return results, err } +// bulkSelectStateEventByID lookups a list of state events by event ID. +// If any of the requested events are missing from the database it returns a types.MissingEventError +func (s *eventStatements) BulkSelectStateEventByNID( + ctx context.Context, eventNIDs []types.EventNID, + stateKeyTuples []types.StateKeyTuple, +) ([]types.StateEntry, error) { + // "SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" + + // " WHERE event_nid IN ($1)" + // // Rest of query is built by BulkSelectStateEventByNID + tuples := stateKeyTupleSorter(stateKeyTuples) + sort.Sort(tuples) + eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() + // params := make([]interface{}, 0, len(eventNIDs)+len(eventTypeNIDArray)+len(eventStateKeyNIDArray)) + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": eventNIDs, + } + // selectOrig := strings.Replace(bulkSelectStateEventByNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1) + selectOrig := bulkSelectStateEventByNIDSQL + // for _, v := range eventNIDs { + // params = append(params, v) + // } + if len(eventTypeNIDArray) > 0 { + // selectOrig += " AND event_type_nid IN " + sqlutil.QueryVariadicOffset(len(eventTypeNIDArray), len(params)) + selectOrig += " and ARRAY_CONTAINS(@x3, c.mx_roomserver_event.event_type_nid) " + // for _, v := range eventTypeNIDArray { + // params = append(params, v) + // } + params["@x3"] = eventTypeNIDArray + } + if len(eventStateKeyNIDArray) > 0 { + // selectOrig += " AND event_state_key_nid IN " + sqlutil.QueryVariadicOffset(len(eventStateKeyNIDArray), len(params)) + selectOrig += " and ARRAY_CONTAINS(@x4, c.mx_roomserver_event.event_state_key_nid) " + // for _, v := range eventStateKeyNIDArray { + // params = append(params, v) + // } + params["@x4"] = eventStateKeyNIDArray + } + // selectOrig += " ORDER BY event_type_nid, event_state_key_nid ASC" + //No Composite Index so just order by the 1st one + selectOrig += " order by c.mx_roomserver_event.event_type_nid asc " + // selectStmt, err := s.db.Prepare(selectOrig) + // if err != nil { + // return nil, fmt.Errorf("s.db.Prepare: %w", err) + // } + // rows, err := selectStmt.QueryContext(ctx, params...) + rows, err := queryEvent(s, ctx, selectOrig, params) + + if err != nil { + return nil, fmt.Errorf("selectStmt.QueryContext: %w", err) + } + // defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateEventByID: rows.close() failed") + // We know that we will only get as many results as event IDs + // because of the unique constraint on event IDs. + // So we can allocate an array of the correct size now. + // We might get fewer results than IDs so we adjust the length of the slice before returning it. + results := make([]types.StateEntry, len(eventNIDs)) + i := 0 + // for ; rows.Next(); i++ { + for _, item := range rows { + result := &results[i] + result.EventTypeNID = types.EventTypeNID(item.Event.EventTypeNID) + result.EventStateKeyNID = types.EventStateKeyNID(item.Event.EventStateKeyNID) + result.EventNID = types.EventNID(item.Event.EventNID) + // if err = rows.Scan( + // &result.EventTypeNID, + // &result.EventStateKeyNID, + // &result.EventNID, + // ); err != nil { + // return nil, err + // } + i++ + } + return results[:i], err +} + // bulkSelectStateAtEventByID lookups the state at a list of events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError. diff --git a/roomserver/storage/cosmosdb/membership_table.go b/roomserver/storage/cosmosdb/membership_table.go index 24005fa50..04a964ad5 100644 --- a/roomserver/storage/cosmosdb/membership_table.go +++ b/roomserver/storage/cosmosdb/membership_table.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/internal/cosmosdbapi" "github.com/matrix-org/dendrite/internal/cosmosdbutil" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" @@ -160,6 +161,32 @@ var selectKnownUsersSQLDistinctRoom = "" + "and c.mx_roomserver_membership.membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " " + "and contains(c.mx_roomserver_membership.event_state_key, @x3) " +// selectLocalServerInRoomSQL is an optimised case for checking if we, the local server, +// are in the room by using the target_local column of the membership table. Normally when +// we want to know if a server is in a room, we have to unmarshal the entire room state which +// is expensive. The presence of a single row from this query suggests we're still in the +// room, no rows returned suggests we aren't. +// "SELECT room_nid FROM roomserver_membership WHERE target_local = 1 AND membership_nid = $1 AND room_nid = $2 LIMIT 1" +const selectLocalServerInRoomSQL = "" + + "select top 1 * from c where c._cn = @x1 " + + " and c.mx_roomserver_membership.target_local = 1" + + " and c.mx_roomserver_membership.membership_nid = @x2" + + " and c.mx_roomserver_membership.room_nid = @x3" + +// selectServerMembersInRoomSQL is an optimised case for checking for server members in a room. +// The JOIN is significantly leaner than the previous case of looking up event NIDs and reading the +// membership events from the database, as the JOIN query amounts to little more than two index +// scans which are very fast. The presence of a single row from this query suggests the server is +// in the room, no rows returned suggests they aren't. +// "SELECT room_nid FROM roomserver_membership" + +// " JOIN roomserver_event_state_keys ON roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" + +// " WHERE membership_nid = $1 AND room_nid = $2 AND event_state_key LIKE '%:' || $3 LIMIT 1" +const selectServerInRoomSQL = "" + + "select top 1 * from c where c._cn = @x1 " + + " and c.mx_roomserver_membership.membership_nid = @x2" + + " and c.mx_roomserver_membership.room_nid = @x3" + + " and contains(c.mx_roomserver_membership.target_nid, @x4) " + type membershipStatements struct { db *Database // insertMembershipStmt *sql.Stmt @@ -172,6 +199,8 @@ type membershipStatements struct { selectRoomsWithMembershipStmt string // updateMembershipStmt *sql.Stmt // selectKnownUsersStmt string + selectLocalServerInRoomStmt string + selectServerInRoomStmt string // updateMembershipForgetRoomStmt *sql.Stmt tableName string } @@ -242,6 +271,8 @@ func NewCosmosDBMembershipTable(db *Database) (tables.Membership, error) { // {&s.updateMembershipStmt, updateMembershipSQL}, s.selectRoomsWithMembershipStmt = selectRoomsWithMembershipSQL // {&s.selectKnownUsersStmt, selectKnownUsersSQL}, + s.selectLocalServerInRoomStmt = selectLocalServerInRoomSQL + s.selectServerInRoomStmt = selectServerInRoomSQL // {&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom}, // }.Prepare(db) @@ -495,6 +526,91 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, return result, nil } +func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) { + // "SELECT room_nid FROM roomserver_membership WHERE target_local = 1 AND membership_nid = $1 AND room_nid = $2 LIMIT 1" + + var nid types.RoomNID + // err := s.selectLocalServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid) + // + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": tables.MembershipStateJoin, + "@x3": roomNID, + } + response, err := queryMembership(s, ctx, s.selectLocalServerInRoomStmt, params) + if len(response) == 0 { + if err == cosmosdbutil.ErrNoRows { + return false, nil + } + return false, err + } + nid = types.RoomNID(response[0].Membership.RoomNID) + + found := nid > 0 + return found, nil +} + +func (s *membershipStatements) SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) { + var nid types.RoomNID + // "SELECT room_nid FROM roomserver_membership" + + // " JOIN roomserver_event_state_keys ON roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" + + // " WHERE membership_nid = $1 AND room_nid = $2 AND event_state_key LIKE '%:' || $3 LIMIT 1" + + //First get the JOIN table + // SELECT event_state_key_nid FROM roomserver_event_state_keys + // WHERE event_state_key LIKE '%:' || $3 LIMIT 1 + selectEventStateKeyNIDSQL := "" + + "select * from c where c._cn = @x1 " + + "and (endswith(c.mx_roomserver_event_state_keys.event_state_key, \":\") or c.mx_roomserver_event_state_keys.event_state_key = @x2) " + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": serverName, + } + + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + var eventStateKeys []EventStateKeysCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(selectEventStateKeyNIDSQL, params) // + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &eventStateKeys, + optionsQry) + + eventStateKeyNids := []int64{} + for _, item := range eventStateKeys { + eventStateKeyNids = append(eventStateKeyNids, item.EventStateKeys.EventStateKeyNID) + } + + //Now do the JOIN + // "SELECT room_nid FROM roomserver_membership" + + // " JOIN roomserver_event_state_keys ON roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" + + // " WHERE membership_nid = $1 AND room_nid = $2 AND event_state_key LIKE '%:' || $3 LIMIT 1" + + // err := s.selectServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid) + params = map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": tables.MembershipStateJoin, + "@x3": roomNID, + "@x4": eventStateKeyNids, + } + response, err := queryMembership(s, ctx, s.selectServerInRoomStmt, params) + if len(response) == 0 { + if err == cosmosdbutil.ErrNoRows { + return false, nil + } + return false, err + } + nid = types.RoomNID(response[0].Membership.RoomNID) + return roomNID == nid, nil +} + func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) { // " SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + diff --git a/roomserver/storage/cosmosdb/state_blob_seq.go b/roomserver/storage/cosmosdb/state_blob_seq.go new file mode 100644 index 000000000..5c363b2d8 --- /dev/null +++ b/roomserver/storage/cosmosdb/state_blob_seq.go @@ -0,0 +1,13 @@ +package cosmosdb + +import ( + "context" + + "github.com/matrix-org/dendrite/internal/cosmosdbutil" +) + +func GetNextStateBlockNID(s *stateBlockStatements, ctx context.Context) (int64, error) { + const docId = "stateblocknid_seq" + //1 insert start at 2 + return cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 2) +} diff --git a/roomserver/storage/cosmosdb/state_block_table.go b/roomserver/storage/cosmosdb/state_block_table.go index b70e93e7e..eb4e1e57d 100644 --- a/roomserver/storage/cosmosdb/state_block_table.go +++ b/roomserver/storage/cosmosdb/state_block_table.go @@ -18,12 +18,13 @@ package cosmosdb import ( "context" "database/sql" + "encoding/hex" "fmt" "sort" "time" "github.com/matrix-org/dendrite/internal/cosmosdbapi" - + "github.com/matrix-org/dendrite/internal/cosmosdbutil" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/util" @@ -31,19 +32,21 @@ import ( // const stateDataSchema = ` // CREATE TABLE IF NOT EXISTS roomserver_state_block ( -// state_block_nid INTEGER NOT NULL, -// event_type_nid INTEGER NOT NULL, -// event_state_key_nid INTEGER NOT NULL, -// event_nid INTEGER NOT NULL, -// UNIQUE (state_block_nid, event_type_nid, event_state_key_nid) +// -- The state snapshot NID that identifies this snapshot. +// state_block_nid INTEGER PRIMARY KEY AUTOINCREMENT, +// -- The hash of the state block, which is used to enforce uniqueness. The hash is +// -- generated in Dendrite and passed through to the database, as a btree index over +// -- this column is cheap and fits within the maximum index size. +// state_block_hash BLOB UNIQUE, +// -- The event NIDs contained within the state block, encoded as JSON. +// event_nids TEXT NOT NULL DEFAULT '[]' // ); // ` type StateBlockCosmos struct { - StateBlockNID int64 `json:"state_block_nid"` - EventTypeNID int64 `json:"event_type_nid"` - EventStateKeyNID int64 `json:"event_state_key_nid"` - EventNID int64 `json:"event_nid"` + StateBlockNID int64 `json:"state_block_nid"` + StateBlockHash []byte `json:"state_block_hash"` + EventNIDs []int64 `json:"event_nids"` } type StateBlockCosmosMaxNID struct { @@ -60,63 +63,29 @@ type StateBlockCosmosData struct { StateBlock StateBlockCosmos `json:"mx_roomserver_state_block"` } -// const insertStateDataSQL = "" + -// "INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" + -// " VALUES ($1, $2, $3, $4)" +// Insert a new state block. If we conflict on the hash column then +// we must perform an update so that the RETURNING statement returns the +// ID of the row that we conflicted with, so that we can then refer to +// the original block. +// const insertStateDataSQL = ` +// INSERT INTO roomserver_state_block (state_block_hash, event_nids) +// VALUES ($1, $2) +// ON CONFLICT (state_block_hash) DO UPDATE SET event_nids=$2 +// RETURNING state_block_nid +// ` -// SELECT IFNULL(MAX(state_block_nid), 0) + 1 FROM roomserver_state_block -const selectNextStateBlockNIDSQL = "" + - "select sub.maxinner != null ? sub.maxinner + 1 : 1 as maxstateblocknid " + - "from " + - "(select MAX(c.mx_roomserver_state_block.state_block_nid) maxinner from c where c._sid = @x1 and c._cn = @x2) as sub" - -// Bulk state lookup by numeric state block ID. -// Sort by the state_block_nid, event_type_nid, event_state_key_nid -// This means that all the entries for a given state_block_nid will appear -// together in the list and those entries will sorted by event_type_nid -// and event_state_key_nid. This property makes it easier to merge two -// state data blocks together. -// "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + -// " FROM roomserver_state_block WHERE state_block_nid IN ($1)" + -// " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" +// "SELECT state_block_nid, event_nids" + +// " FROM roomserver_state_block WHERE state_block_nid IN ($1) ORDER BY state_block_nid ASC" const bulkSelectStateBlockEntriesSQL = "" + "select * from c where c._cn = @x1 " + "and ARRAY_CONTAINS(@x2, c.mx_roomserver_state_block.state_block_nid) " + - "order by c.mx_roomserver_state_block.state_block_nid " + - // Cant do multi field order by - The order by query does not have a corresponding composite index that it can be served from - // ", c.mx_roomserver_state_block.event_type_nid " + - // ", c.mx_roomserver_state_block.event_state_key_nid " + - " asc" - -// Bulk state lookup by numeric state block ID. -// Filters the rows in each block to the requested types and state keys. -// We would like to restrict to particular type state key pairs but we are -// restricted by the query language to pull the cross product of a list -// of types and a list state_keys. So we have to filter the result in the -// application to restrict it to the list of event types and state keys we -// actually wanted. -// "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + -// " FROM roomserver_state_block WHERE state_block_nid IN ($1)" + -// " AND event_type_nid IN ($2) AND event_state_key_nid IN ($3)" + -// " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" -const bulkSelectFilteredStateBlockEntriesSQL = "" + - "select * from c where c._cn = @x1 " + - "and ARRAY_CONTAINS(@x2, c.mx_roomserver_state_block.state_block_nid) " + - "and ARRAY_CONTAINS(@x3, c.mx_roomserver_state_block.event_type_nid) " + - "and ARRAY_CONTAINS(@x4, c.mx_roomserver_state_block.event_state_key_nid) " + - "order by c.mx_roomserver_state_block.state_block_nid " + - // Cant do multi field order by - The order by query does not have a corresponding composite index that it can be served from - // ", c.mx_roomserver_state_block.event_type_nid " + - // ", c.mx_roomserver_state_block.event_state_key_nid " + - "asc" + "order by c.mx_roomserver_state_block.state_block_nid " type stateBlockStatements struct { db *Database - // insertStateDataStmt *sql.Stmt - selectNextStateBlockNIDStmt string - bulkSelectStateBlockEntriesStmt string - bulkSelectFilteredStateBlockEntriesStmt string - tableName string + // insertStateDataStmt *sql.Stmt + bulkSelectStateBlockEntriesStmt string + tableName string } func queryStateBlock(s *stateBlockStatements, ctx context.Context, qry string, params map[string]interface{}) ([]StateBlockCosmosData, error) { @@ -140,39 +109,107 @@ func queryStateBlock(s *stateBlockStatements, ctx context.Context, qry string, p return response, nil } +func getStateBlock(s *stateBlockStatements, ctx context.Context, pk string, docId string) (*StateBlockCosmosData, error) { + response := StateBlockCosmosData{} + err := cosmosdbapi.GetDocumentOrNil( + s.db.connection, + s.db.cosmosConfig, + ctx, + pk, + docId, + &response) + + if response.Id == "" { + return nil, cosmosdbutil.ErrNoRows + } + + return &response, err +} + +func setStateBlock(s *stateBlockStatements, ctx context.Context, item StateBlockCosmosData) (*StateBlockCosmosData, error) { + var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(item.Pk, item.ETag) + var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + item.Id, + &item, + optionsReplace) + return &item, ex +} + func NewCosmosDBStateBlockTable(db *Database) (tables.StateBlock, error) { s := &stateBlockStatements{ db: db, } - // return s, shared.StatementList{ - // {&s.insertStateDataStmt, insertStateDataSQL}, - s.selectNextStateBlockNIDStmt = selectNextStateBlockNIDSQL + // s.insertStateDataStmt = insertStateDataSQL s.bulkSelectStateBlockEntriesStmt = bulkSelectStateBlockEntriesSQL - s.bulkSelectFilteredStateBlockEntriesStmt = bulkSelectFilteredStateBlockEntriesSQL - // }.Prepare(db) s.tableName = "state_block" return s, nil } -func inertStateBlockCore(s *stateBlockStatements, ctx context.Context, stateBlockNID types.StateBlockNID, entry types.StateEntry) error { +func (s *stateBlockStatements) BulkInsertStateData( + ctx context.Context, + txn *sql.Tx, + entries types.StateEntries, +) (id types.StateBlockNID, err error) { + // INSERT INTO roomserver_state_block (state_block_hash, event_nids) + // VALUES ($1, $2) + // ON CONFLICT (state_block_hash) DO UPDATE SET event_nids=$2 + // RETURNING state_block_nid - // "INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" + - // " VALUES ($1, $2, $3, $4)" - data := StateBlockCosmos{ - EventNID: int64(entry.EventNID), - EventStateKeyNID: int64(entry.EventStateKeyNID), - EventTypeNID: int64(entry.EventTypeNID), - StateBlockNID: int64(stateBlockNID), + entries = entries[:util.SortAndUnique(entries)] + nids := types.EventNIDs{} // zero slice to not store 'null' in the DB + ids := []int64{} + for _, e := range entries { + nids = append(nids, e.EventNID) + ids = append(ids, int64(e.EventNID)) } + // js, err := json.Marshal(nids) + // if err != nil { + // return 0, fmt.Errorf("json.Marshal: %w", err) + // } + // err = s.insertStateDataStmt.QueryRowContext( + // ctx, nids.Hash(), js, + // ).Scan(&id) var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - // UNIQUE (state_block_nid, event_type_nid, event_state_key_nid) - docId := fmt.Sprintf("%d_%d_%d", data.StateBlockNID, data.EventTypeNID, data.EventStateKeyNID) + // state_block_hash BLOB UNIQUE, + docId := hex.EncodeToString(nids.Hash()) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + //See if it exists + existing, err := getStateBlock(s, ctx, pk, cosmosDocId) + if err != nil { + if err != cosmosdbutil.ErrNoRows { + return 0, err + } + } + if existing != nil { + //if exists, just update and dont create a new seq + existing.StateBlock.EventNIDs = ids + _, err = setStateBlock(s, ctx, *existing) + if err != nil { + return 0, err + } + return types.StateBlockNID(existing.StateBlock.StateBlockNID), nil + } + + //Doesnt exist,create a new one + // state_block_nid INTEGER PRIMARY KEY AUTOINCREMENT, + seq, err := GetNextStateBlockNID(s, ctx) + id = types.StateBlockNID(seq) + + data := StateBlockCosmos{ + StateBlockNID: seq, + StateBlockHash: nids.Hash(), + EventNIDs: ids, + } + var dbData = StateBlockCosmosData{ Id: cosmosDocId, Tn: s.db.cosmosConfig.TenantName, @@ -182,187 +219,72 @@ func inertStateBlockCore(s *stateBlockStatements, ctx context.Context, stateBloc StateBlock: data, } - var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk) - _, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument( + var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk) + _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( ctx, s.db.cosmosConfig.DatabaseName, s.db.cosmosConfig.ContainerName, &dbData, options) - return err - -} - -func getNextStateBlockNID(s *stateBlockStatements, ctx context.Context) (int64, error) { - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var stateBlockNext []StateBlockCosmosMaxNID - params := map[string]interface{}{ - "@x1": s.db.cosmosConfig.TenantName, - "@x2": dbCollectionName, - } - - var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions() - var query = cosmosdbapi.GetQuery(s.selectNextStateBlockNIDStmt, params) - var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( - ctx, - s.db.cosmosConfig.DatabaseName, - s.db.cosmosConfig.ContainerName, - query, - &stateBlockNext, - optionsQry) - - if err != nil { - return 0, err - } - - return stateBlockNext[0].Max, nil -} - -func (s *stateBlockStatements) BulkInsertStateData( - ctx context.Context, txn *sql.Tx, - entries []types.StateEntry, -) (types.StateBlockNID, error) { - if len(entries) == 0 { - return 0, nil - } - - nextID, errNextID := getNextStateBlockNID(s, ctx) - if errNextID != nil { - return 0, errNextID - } - - stateBlockNID := types.StateBlockNID(nextID) - - for _, entry := range entries { - err := inertStateBlockCore(s, ctx, stateBlockNID, entry) - if err != nil { - return 0, err - } - } - return stateBlockNID, nil + return } func (s *stateBlockStatements) BulkSelectStateBlockEntries( - ctx context.Context, stateBlockNIDs []types.StateBlockNID, -) ([]types.StateEntryList, error) { + ctx context.Context, stateBlockNIDs types.StateBlockNIDs, +) ([][]types.EventNID, error) { + // "SELECT state_block_nid, event_nids" + + // " FROM roomserver_state_block WHERE state_block_nid IN ($1) ORDER BY state_block_nid ASC" - // "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + - // " FROM roomserver_state_block WHERE state_block_nid IN ($1)" + - // " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" + intfs := make([]interface{}, len(stateBlockNIDs)) + for i := range stateBlockNIDs { + intfs[i] = int64(stateBlockNIDs[i]) + } + // selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(intfs)), 1) + // selectStmt, err := s.db.Prepare(selectOrig) + // if err != nil { + // return nil, err + // } var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var response []StateBlockCosmosData params := map[string]interface{}{ "@x1": dbCollectionName, "@x2": stateBlockNIDs, } - response, err := queryStateBlock(s, ctx, s.bulkSelectStateBlockEntriesStmt, params) + // rows, err := selectStmt.QueryContext(ctx, intfs...) + rows, err := queryStateBlock(s, ctx, s.bulkSelectStateBlockEntriesStmt, params) if err != nil { return nil, err } + // defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed") - results := make([]types.StateEntryList, len(stateBlockNIDs)) - // current is a pointer to the StateEntryList to append the state entries to. - var current *types.StateEntryList + results := make([][]types.EventNID, len(stateBlockNIDs)) i := 0 - for _, item := range response { - entry := types.StateEntry{} - entry.EventTypeNID = types.EventTypeNID(item.StateBlock.EventTypeNID) - entry.EventStateKeyNID = types.EventStateKeyNID(item.StateBlock.EventStateKeyNID) - entry.EventNID = types.EventNID(item.StateBlock.EventNID) - - if current == nil || types.StateBlockNID(item.StateBlock.StateBlockNID) != current.StateBlockNID { - // The state entry row is for a different state data block to the current one. - // So we start appending to the next entry in the list. - current = &results[i] - current.StateBlockNID = types.StateBlockNID(item.StateBlock.StateBlockNID) - i++ + // for ; rows.Next(); i++ { + for _, item := range rows { + // var stateBlockNID types.StateBlockNID + // var result json.RawMessage + // if err = rows.Scan(&stateBlockNID, &result); err != nil { + // return nil, err + // } + r := []types.EventNID{} + // if err = json.Unmarshal(result, &r); err != nil { + // return nil, fmt.Errorf("json.Unmarshal: %w", err) + // } + for _, eventNID := range item.StateBlock.EventNIDs { + r = append(r, types.EventNID(eventNID)) } - current.StateEntries = append(current.StateEntries, entry) + results[i] = r + i++ } + // if err = rows.Err(); err != nil { + // return nil, err + // } if i != len(stateBlockNIDs) { - return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(stateBlockNIDs)) + return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", len(results), len(stateBlockNIDs)) } - return results, nil -} - -func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries( - ctx context.Context, - stateBlockNIDs []types.StateBlockNID, - stateKeyTuples []types.StateKeyTuple, -) ([]types.StateEntryList, error) { - tuples := stateKeyTupleSorter(stateKeyTuples) - // Sort the tuples so that we can run binary search against them as we filter the rows returned by the db. - sort.Sort(tuples) - - eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() - // sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(stateBlockNIDs)), 1) - // sqlStatement = strings.Replace(sqlStatement, "($2)", sqlutil.QueryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1) - // sqlStatement = strings.Replace(sqlStatement, "($3)", sqlutil.QueryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1) - - // "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + - // " FROM roomserver_state_block WHERE state_block_nid IN ($1)" + - // " AND event_type_nid IN ($2) AND event_state_key_nid IN ($3)" + - // " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" - - var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) - var response []StateBlockCosmosData - params := map[string]interface{}{ - "@x1": dbCollectionName, - "@x2": stateBlockNIDs, - "@x3": eventTypeNIDArray, - "@x4": eventStateKeyNIDArray, - } - - response, err := queryStateBlock(s, ctx, s.bulkSelectFilteredStateBlockEntriesStmt, params) - - if err != nil { - return nil, err - } - - var results []types.StateEntryList - var current types.StateEntryList - for _, item := range response { - var ( - stateBlockNID int64 - eventTypeNID int64 - eventStateKeyNID int64 - eventNID int64 - entry types.StateEntry - ) - stateBlockNID = item.StateBlock.StateBlockNID - eventTypeNID = item.StateBlock.EventTypeNID - eventStateKeyNID = item.StateBlock.EventStateKeyNID - eventNID = item.StateBlock.EventNID - entry.EventTypeNID = types.EventTypeNID(eventTypeNID) - entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) - entry.EventNID = types.EventNID(eventNID) - - // We can use binary search here because we sorted the tuples earlier - if !tuples.contains(entry.StateKeyTuple) { - // The select will return the cross product of types and state keys. - // So we need to check if type of the entry is in the list. - continue - } - - if types.StateBlockNID(stateBlockNID) != current.StateBlockNID { - // The state entry row is for a different state data block to the current one. - // So we append the current entry to the results and start adding to a new one. - // The first time through the loop current will be empty. - if current.StateEntries != nil { - results = append(results, current) - } - current = types.StateEntryList{StateBlockNID: types.StateBlockNID(stateBlockNID)} - } - current.StateEntries = append(current.StateEntries, entry) - } - // Add the last entry to the list if it is not empty. - if current.StateEntries != nil { - results = append(results, current) - } - return results, nil + return results, err } type stateKeyTupleSorter []types.StateKeyTuple diff --git a/roomserver/storage/cosmosdb/state_snapshot_table.go b/roomserver/storage/cosmosdb/state_snapshot_table.go index 61ee0e8ba..a1f7819eb 100644 --- a/roomserver/storage/cosmosdb/state_snapshot_table.go +++ b/roomserver/storage/cosmosdb/state_snapshot_table.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/internal/cosmosdbapi" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/util" ) // const stateSnapshotSchema = ` @@ -34,10 +35,24 @@ import ( // ); // ` +// CREATE TABLE IF NOT EXISTS roomserver_state_snapshots ( +// -- The state snapshot NID that identifies this snapshot. +// state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT, +// -- The hash of the state snapshot, which is used to enforce uniqueness. The hash is +// -- generated in Dendrite and passed through to the database, as a btree index over +// -- this column is cheap and fits within the maximum index size. +// state_snapshot_hash BLOB UNIQUE, +// -- The room NID that the snapshot belongs to. +// room_nid INTEGER NOT NULL, +// -- The state blocks contained within this snapshot, encoded as JSON. +// state_block_nids TEXT NOT NULL DEFAULT '[]' +// ); + type StateSnapshotCosmos struct { - StateSnapshotNID int64 `json:"state_snapshot_nid"` - RoomNID int64 `json:"room_nid"` - StateBlockNIDs []int64 `json:"state_block_nids"` + StateSnapshotNID int64 `json:"state_snapshot_nid"` + StateSnapshotHash []byte `json:"state_snapshot_hash"` + RoomNID int64 `json:"room_nid"` + StateBlockNIDs []int64 `json:"state_block_nids"` } type StateSnapshotCosmosData struct { @@ -51,8 +66,10 @@ type StateSnapshotCosmosData struct { } // const insertStateSQL = ` -// INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids) -// VALUES ($1, $2);` +// INSERT INTO roomserver_state_snapshots (state_snapshot_hash, room_nid, state_block_nids) +// VALUES ($1, $2, $3) +// ON CONFLICT (state_snapshot_hash) DO UPDATE SET room_nid=$2 +// RETURNING state_snapshot_nid // Bulk state data NID lookup. // Sorting by state_snapshot_nid means we can use binary search over the result @@ -101,20 +118,32 @@ func NewCosmosDBStateSnapshotTable(db *Database) (tables.StateSnapshot, error) { } func (s *stateSnapshotStatements) InsertState( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs types.StateBlockNIDs, ) (stateNID types.StateSnapshotNID, err error) { - // INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids) - // VALUES ($1, $2);` + // INSERT INTO roomserver_state_snapshots (state_snapshot_hash, room_nid, state_block_nids) + // VALUES ($1, $2, $3) + // ON CONFLICT (state_snapshot_hash) DO UPDATE SET room_nid=$2 + // RETURNING state_snapshot_nid stateSnapshotNIDSeq, seqErr := GetNextStateSnapshotNID(s, ctx) if seqErr != nil { return 0, seqErr } + if stateBlockNIDs == nil { + stateBlockNIDs = []types.StateBlockNID{} // zero slice to not store 'null' in the DB + } + stateBlockNIDs = stateBlockNIDs[:util.SortAndUnique(stateBlockNIDs)] + // stateBlockNIDsJSON, err := json.Marshal(stateBlockNIDs) + // if err != nil { + // return + // } + data := StateSnapshotCosmos{ - RoomNID: int64(roomNID), - StateBlockNIDs: mapFromStateBlockNIDArray(stateBlockNIDs), - StateSnapshotNID: int64(stateSnapshotNIDSeq), + RoomNID: int64(roomNID), + StateSnapshotHash: stateBlockNIDs.Hash(), + StateBlockNIDs: mapFromStateBlockNIDArray(stateBlockNIDs), + StateSnapshotNID: int64(stateSnapshotNIDSeq), } var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) diff --git a/syncapi/storage/cosmosdb/send_to_device_table.go b/syncapi/storage/cosmosdb/send_to_device_table.go index 933673e39..57fa5914f 100644 --- a/syncapi/storage/cosmosdb/send_to_device_table.go +++ b/syncapi/storage/cosmosdb/send_to_device_table.go @@ -22,7 +22,6 @@ import ( "time" "github.com/matrix-org/dendrite/internal/cosmosdbapi" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/sirupsen/logrus" @@ -80,10 +79,13 @@ const selectSendToDeviceMessagesSQL = "" + "and c.mx_syncapi_send_to_device.id <= @x5 " + "order by c.mx_syncapi_send_to_device.id desc " -const deleteSendToDeviceMessagesSQL = ` - DELETE FROM syncapi_send_to_device - WHERE user_id = $1 AND device_id = $2 AND id < $3 -` +// DELETE FROM syncapi_send_to_device +// WHERE user_id = $1 AND device_id = $2 AND id < $3 +const deleteSendToDeviceMessagesSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_syncapi_send_to_device.user_id = @x2 " + + "and c.mx_syncapi_send_to_device.device_id = @x3 " + + "and c.mx_syncapi_send_to_device.id < @x4 " // "SELECT MAX(id) FROM syncapi_send_to_device" const selectMaxSendToDeviceIDSQL = "" + @@ -93,7 +95,7 @@ type sendToDeviceStatements struct { db *SyncServerDatasource // insertSendToDeviceMessageStmt *sql.Stmt selectSendToDeviceMessagesStmt string - deleteSendToDeviceMessagesStmt *sql.Stmt + deleteSendToDeviceMessagesStmt string selectMaxSendToDeviceIDStmt string tableName string } @@ -140,6 +142,21 @@ func querySendToDeviceNumber(s *sendToDeviceStatements, ctx context.Context, qry return response, nil } +func deleteSendToDevice(s *sendToDeviceStatements, ctx context.Context, dbData SendToDeviceCosmosData) error { + var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk) + var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + dbData.Id, + options) + + if err != nil { + return err + } + return err +} + func NewCosmosDBSendToDeviceTable(db *SyncServerDatasource) (tables.SendToDevice, error) { s := &sendToDeviceStatements{ db: db, @@ -148,9 +165,7 @@ func NewCosmosDBSendToDeviceTable(db *SyncServerDatasource) (tables.SendToDevice // return nil, err // } s.selectSendToDeviceMessagesStmt = selectSendToDeviceMessagesSQL - // if s.deleteSendToDeviceMessagesStmt, err = db.Prepare(deleteSendToDeviceMessagesSQL); err != nil { - // return nil, err - // } + s.deleteSendToDeviceMessagesStmt = deleteSendToDeviceMessagesSQL s.selectMaxSendToDeviceIDStmt = selectMaxSendToDeviceIDSQL s.tableName = "send_to_device" return s, nil @@ -260,7 +275,29 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages( func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( ctx context.Context, txn *sql.Tx, userID, deviceID string, pos types.StreamPosition, ) (err error) { - _, err = sqlutil.TxStmt(txn, s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, userID, deviceID, pos) + // DELETE FROM syncapi_send_to_device + // WHERE user_id = $1 AND device_id = $2 AND id < $3 + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": userID, + "@x3": deviceID, + "@x4": pos, + } + + // _, err = sqlutil.TxStmt(txn, s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, userID, deviceID, pos) + rows, err := querySendToDevice(s, ctx, s.deleteSendToDeviceMessagesStmt, params) + if err != nil { + return err + } + + for _, item := range rows { + err = deleteSendToDevice(s, ctx, item) + if err != nil { + return err + } + } return } diff --git a/userapi/storage/accounts/cosmosdb/key_backup_table.go b/userapi/storage/accounts/cosmosdb/key_backup_table.go new file mode 100644 index 000000000..b88f1835f --- /dev/null +++ b/userapi/storage/accounts/cosmosdb/key_backup_table.go @@ -0,0 +1,414 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/matrix-org/dendrite/internal/cosmosdbapi" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" +) + +// const keyBackupTableSchema = ` +// CREATE TABLE IF NOT EXISTS account_e2e_room_keys ( +// user_id TEXT NOT NULL, +// room_id TEXT NOT NULL, +// session_id TEXT NOT NULL, + +// version TEXT NOT NULL, +// first_message_index INTEGER NOT NULL, +// forwarded_count INTEGER NOT NULL, +// is_verified BOOLEAN NOT NULL, +// session_data TEXT NOT NULL +// ); +// CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version); +// CREATE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version); +// ` + +type KeyBackupCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Tn string `json:"_sid"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + KeyBackup KeyBackupCosmos `json:"mx_userapi_account_e2e_room_keys"` +} + +type KeyBackupCosmos struct { + UserId string `json:"user_id"` + RoomId string `json:"room_id"` + SessionId string `json:"session_id"` + Version string `json:"vesion"` + FirstMessageIndex int `json:"first_message_index"` + ForwardedCount int `json:"forwarded_count"` + IsVerified bool `json:"is_verified"` + SessionData []byte `json:"session_data"` +} + +type KeyBackupCosmosNumber struct { + Number int64 `json:"number"` +} + +// const insertBackupKeySQL = "" + +// "INSERT INTO account_e2e_room_keys(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " + +// "VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" + +// const updateBackupKeySQL = "" + +// "UPDATE account_e2e_room_keys SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " + +// "WHERE user_id=$5 AND room_id=$6 AND session_id=$7 AND version=$8" + +// "SELECT COUNT(*) FROM account_e2e_room_keys WHERE user_id = $1 AND version = $2" +const countKeysSQL = "" + + "select count(c._ts) as number from c where c._cn = @x1 " + + "and c.mx_userapi_account_e2e_room_keys.user_id = @x2 " + + "and c.mx_userapi_account_e2e_room_keys.version = @x3 " + +// "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + +// "WHERE user_id = $1 AND version = $2" +const selectKeysSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_userapi_account_e2e_room_keys.user_id = @x2 " + + "and c.mx_userapi_account_e2e_room_keys.version = @x3 " + +// "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + +// "WHERE user_id = $1 AND version = $2 AND room_id = $3" +const selectKeysByRoomIDSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_userapi_account_e2e_room_keys.user_id = @x2 " + + "and c.mx_userapi_account_e2e_room_keys.version = @x3 " + + "and c.mx_userapi_account_e2e_room_keys.room_id = @x4 " + +// "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + +// "WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4" +const selectKeysByRoomIDAndSessionIDSQL = "" + + "select * from c where c._cn = @x1 " + + "and c.mx_userapi_account_e2e_room_keys.user_id = @x2 " + + "and c.mx_userapi_account_e2e_room_keys.version = @x3 " + + "and c.mx_userapi_account_e2e_room_keys.room_id = @x4 " + + "and c.mx_userapi_account_e2e_room_keys.session_id = @x5 " + +type keyBackupStatements struct { + db *Database + // insertBackupKeyStmt *sql.Stmt + // updateBackupKeyStmt *sql.Stmt + countKeysStmt string + selectKeysStmt string + selectKeysByRoomIDStmt string + selectKeysByRoomIDAndSessionIDStmt string + tableName string + serverName gomatrixserverlib.ServerName +} + +func queryKeyBackup(s *keyBackupStatements, ctx context.Context, qry string, params map[string]interface{}) ([]KeyBackupCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + var response []KeyBackupCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, err + } + return response, nil +} + +func queryKeyBackupNumber(s *keyBackupStatements, ctx context.Context, qry string, params map[string]interface{}) ([]KeyBackupCosmosNumber, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + var response []KeyBackupCosmosNumber + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, err + } + return response, nil +} + +func getKeyBackup(s *keyBackupStatements, ctx context.Context, pk string, docId string) (*KeyBackupCosmosData, error) { + response := KeyBackupCosmosData{} + err := cosmosdbapi.GetDocumentOrNil( + s.db.connection, + s.db.cosmosConfig, + ctx, + pk, + docId, + &response) + + if response.Id == "" { + return nil, nil + } + + return &response, err +} + +func setKeyBackup(s *keyBackupStatements, ctx context.Context, keyBackup KeyBackupCosmosData) (*KeyBackupCosmosData, error) { + var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(keyBackup.Pk, keyBackup.ETag) + var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + keyBackup.Id, + &keyBackup, + optionsReplace) + return &keyBackup, ex +} + +func (s *keyBackupStatements) prepare(db *Database, server gomatrixserverlib.ServerName) (err error) { + s.db = db + // s.insertBackupKeyStmt = insertBackupKeySQL + // s.updateBackupKeyStmt = updateBackupKeySQL + s.countKeysStmt = countKeysSQL + s.selectKeysStmt = selectKeysSQL + s.selectKeysByRoomIDStmt = selectKeysByRoomIDSQL + s.selectKeysByRoomIDAndSessionIDStmt = selectKeysByRoomIDAndSessionIDSQL + s.tableName = "account_e2e_room_keys" + s.serverName = server + return +} + +func (s keyBackupStatements) countKeys( + ctx context.Context, userID, version string, +) (count int64, err error) { + // "SELECT COUNT(*) FROM account_e2e_room_keys WHERE user_id = $1 AND version = $2" + // err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count) + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": userID, + "@x3": version, + } + rows, err := queryKeyBackupNumber(&s, ctx, s.countKeysStmt, params) + + if err != nil { + return -1, err + } + + if len(rows) == 0 { + return -1, nil + } + // err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count) + count = rows[0].Number + return +} + +func (s *keyBackupStatements) insertBackupKey( + ctx context.Context, userID, version string, key api.InternalKeyBackupSession, +) (err error) { + // "INSERT INTO account_e2e_room_keys(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " + + // "VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" + // _, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext( + // ctx, userID, key.RoomID, key.SessionID, version, key.FirstMessageIndex, key.ForwardedCount, key.IsVerified, string(key.SessionData), + // ) + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version); + docId := fmt.Sprintf("%s_%s_%s_%s", userID, key.RoomID, key.SessionID, version) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + + data := KeyBackupCosmos{ + UserId: userID, + RoomId: key.RoomID, + SessionId: key.SessionID, + Version: version, + FirstMessageIndex: key.FirstMessageIndex, + ForwardedCount: key.ForwardedCount, + IsVerified: key.IsVerified, + SessionData: key.SessionData, + } + + dbData := &KeyBackupCosmosData{ + Id: cosmosDocId, + Tn: s.db.cosmosConfig.TenantName, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + KeyBackup: data, + } + + var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk) + _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + &dbData, + options) + + return +} + +func (s *keyBackupStatements) updateBackupKey( + ctx context.Context, userID, version string, key api.InternalKeyBackupSession, +) (err error) { + // "UPDATE account_e2e_room_keys SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " + + // "WHERE user_id=$5 AND room_id=$6 AND session_id=$7 AND version=$8" + // _, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext( + // ctx, key.FirstMessageIndex, key.ForwardedCount, key.IsVerified, string(key.SessionData), userID, key.RoomID, key.SessionID, version, + // ) + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version); + docId := fmt.Sprintf("%s_%s_%s_%s", userID, key.RoomID, key.SessionID, version) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + + res, err := getKeyBackup(s, ctx, pk, cosmosDocId) + + if err != nil { + return + } + + if res == nil { + return + } + + // ctx, key.FirstMessageIndex, key.ForwardedCount, key.IsVerified, string(key.SessionData), userID, key.RoomID, key.SessionID, version, + res.KeyBackup.FirstMessageIndex = key.FirstMessageIndex + res.KeyBackup.ForwardedCount = key.ForwardedCount + res.KeyBackup.IsVerified = key.IsVerified + res.KeyBackup.SessionData = key.SessionData + + _, err = setKeyBackup(s, ctx, *res) + + return +} + +func (s *keyBackupStatements) selectKeys( + ctx context.Context, userID, version string, +) (map[string]map[string]api.KeyBackupSession, error) { + // "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + // "WHERE user_id = $1 AND version = $2" + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": userID, + "@x3": version, + } + rows, err := queryKeyBackup(s, ctx, s.selectKeysStmt, params) + + if err != nil { + return nil, err + } + + if len(rows) == 0 { + return nil, nil + } + + // rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version) + return unpackKeys(ctx, &rows) +} + +func (s *keyBackupStatements) selectKeysByRoomID( + ctx context.Context, userID, version, roomID string, +) (map[string]map[string]api.KeyBackupSession, error) { + // "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + // "WHERE user_id = $1 AND version = $2 AND room_id = $3" + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": userID, + "@x3": version, + "@x4": roomID, + } + rows, err := queryKeyBackup(s, ctx, s.selectKeysByRoomIDStmt, params) + + if err != nil { + return nil, err + } + + if len(rows) == 0 { + return nil, nil + } + // rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID) + if err != nil { + return nil, err + } + return unpackKeys(ctx, &rows) +} + +func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID( + ctx context.Context, userID, version, roomID, sessionID string, +) (map[string]map[string]api.KeyBackupSession, error) { + // "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + // "WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4" + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": userID, + "@x3": version, + "@x4": roomID, + "@x5": sessionID, + } + rows, err := queryKeyBackup(s, ctx, s.selectKeysByRoomIDAndSessionIDStmt, params) + + if err != nil { + return nil, err + } + + if len(rows) == 0 { + return nil, nil + } + // rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID) + if err != nil { + return nil, err + } + return unpackKeys(ctx, &rows) +} + +func unpackKeys(ctx context.Context, rows *[]KeyBackupCosmosData) (map[string]map[string]api.KeyBackupSession, error) { + result := make(map[string]map[string]api.KeyBackupSession) + for _, item := range *rows { + var key api.InternalKeyBackupSession + // room_id, session_id, first_message_index, forwarded_count, is_verified, session_data + var sessionDataStr string + // if err := rows.Scan(&key.RoomID, &key.SessionID, &key.FirstMessageIndex, &key.ForwardedCount, &key.IsVerified, &sessionDataStr); err != nil { + // return nil, err + // } + key.RoomID = item.KeyBackup.RoomId + key.SessionID = item.KeyBackup.SessionId + key.FirstMessageIndex = item.KeyBackup.FirstMessageIndex + key.ForwardedCount = item.KeyBackup.ForwardedCount + key.SessionData = json.RawMessage(sessionDataStr) + roomData := result[key.RoomID] + if roomData == nil { + roomData = make(map[string]api.KeyBackupSession) + } + roomData[key.SessionID] = key.KeyBackupSession + result[key.RoomID] = roomData + } + return result, nil +} diff --git a/userapi/storage/accounts/cosmosdb/key_backup_version_table.go b/userapi/storage/accounts/cosmosdb/key_backup_version_table.go new file mode 100644 index 000000000..fe24bbe23 --- /dev/null +++ b/userapi/storage/accounts/cosmosdb/key_backup_version_table.go @@ -0,0 +1,377 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cosmosdb + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "time" + + "github.com/matrix-org/dendrite/internal/cosmosdbapi" + "github.com/matrix-org/dendrite/internal/cosmosdbutil" + "github.com/matrix-org/gomatrixserverlib" +) + +// const keyBackupVersionTableSchema = ` +// -- the metadata for each generation of encrypted e2e session backups +// CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions ( +// user_id TEXT NOT NULL, +// -- this means no 2 users will ever have the same version of e2e session backups which strictly +// -- isn't necessary, but this is easy to do rather than SELECT MAX(version)+1. +// version INTEGER PRIMARY KEY AUTOINCREMENT, +// algorithm TEXT NOT NULL, +// auth_data TEXT NOT NULL, +// etag TEXT NOT NULL, +// deleted INTEGER DEFAULT 0 NOT NULL +// ); + +// CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version); +// ` + +type KeyBackupVersionCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Tn string `json:"_sid"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + KeyBackupVersion KeyBackupVersionCosmos `json:"mx_userapi_account_e2e_room_keys_versions"` +} + +type KeyBackupVersionCosmos struct { + UserId string `json:"user_id"` + Version int64 `json:"vesion"` + Algorithm string `json:"algorithm"` + AuthData []byte `json:"auth_data"` + Etag string `json:"etag"` + Deleted int `json:"deleted"` +} + +type KeyBackupVersionCosmosNumber struct { + Number int64 `json:"number"` +} + +// const insertKeyBackupSQL = "" + +// "INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version" + +// const updateKeyBackupAuthDataSQL = "" + +// "UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3" + +// const updateKeyBackupETagSQL = "" + +// "UPDATE account_e2e_room_keys_versions SET etag = $1 WHERE user_id = $2 AND version = $3" + +// const deleteKeyBackupSQL = "" + +// "UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2" + +// const selectKeyBackupSQL = "" + +// "SELECT algorithm, auth_data, etag, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2" + +// "SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1" +const selectLatestVersionSQL = "" + + "select max(c.mx_userapi_account_e2e_room_keys_versions.version) as number from c where c._sid = @x1 and c._cn = @x2 " + + "and c.mx_userapi_account_e2e_room_keys_versions.user_id = @x3 " + +type keyBackupVersionStatements struct { + db *Database + // insertKeyBackupStmt *sql.Stmt + // updateKeyBackupAuthDataStmt *sql.Stmt + // deleteKeyBackupStmt *sql.Stmt + // selectKeyBackupStmt *sql.Stmt + selectLatestVersionStmt string + // updateKeyBackupETagStmt *sql.Stmt + tableName string + serverName gomatrixserverlib.ServerName +} + +func queryKeyBackupVersionNumber(s *keyBackupVersionStatements, ctx context.Context, qry string, params map[string]interface{}) ([]KeyBackupVersionCosmosNumber, error) { + var response []KeyBackupVersionCosmosNumber + + var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions() + var query = cosmosdbapi.GetQuery(qry, params) + var _, _ = cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + //WHen there is no data these GroupBy queries return errors + // if err != nil { + // return nil, err + // } + + if len(response) == 0 { + return nil, cosmosdbutil.ErrNoRows + } + + return response, nil +} + +func getKeyBackupVersion(s *keyBackupVersionStatements, ctx context.Context, pk string, docId string) (*KeyBackupVersionCosmosData, error) { + response := KeyBackupVersionCosmosData{} + err := cosmosdbapi.GetDocumentOrNil( + s.db.connection, + s.db.cosmosConfig, + ctx, + pk, + docId, + &response) + + if response.Id == "" { + return nil, nil + } + + return &response, err +} + +func setKeyBackupVersion(s *keyBackupVersionStatements, ctx context.Context, keyBackup KeyBackupVersionCosmosData) (*KeyBackupVersionCosmosData, error) { + var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(keyBackup.Pk, keyBackup.ETag) + var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + keyBackup.Id, + &keyBackup, + optionsReplace) + return &keyBackup, ex +} + +func (s *keyBackupVersionStatements) prepare(db *Database, server gomatrixserverlib.ServerName) (err error) { + s.db = db + // s.insertKeyBackupStmt = insertKeyBackupSQL + // s.updateKeyBackupAuthDataStmt = updateKeyBackupAuthDataSQL + // s.deleteKeyBackupStmt = deleteKeyBackupSQL + // s.selectKeyBackupStmt = selectKeyBackupSQL + s.selectLatestVersionStmt = selectLatestVersionSQL + // s.updateKeyBackupETagStmt = updateKeyBackupETagSQL + s.tableName = "account_e2e_room_keys_versions" + s.serverName = server + return +} + +func (s *keyBackupVersionStatements) insertKeyBackup( + ctx context.Context, userID, algorithm string, authData json.RawMessage, etag string, +) (version string, err error) { + // "INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version" + var versionInt int64 + // -- this means no 2 users will ever have the same version of e2e session backups which strictly + // -- isn't necessary, but this is easy to do rather than SELECT MAX(version)+1. + // version INTEGER PRIMARY KEY AUTOINCREMENT, + versionInt, seqErr := GetNextKeyBackupVersionID(s, ctx) + if seqErr != nil { + return "", seqErr + } + // err = txn.Stmt(s.insertKeyBackupStmt).QueryRowContext(ctx, userID, algorithm, string(authData), etag).Scan(&versionInt) + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version); + docId := fmt.Sprintf("%s_%d", userID, versionInt) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) + + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + + data := KeyBackupVersionCosmos{ + UserId: userID, + Version: versionInt, + Algorithm: algorithm, + AuthData: authData, + Etag: etag, + Deleted: 0, + } + + dbData := &KeyBackupVersionCosmosData{ + Id: cosmosDocId, + Tn: s.db.cosmosConfig.TenantName, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + KeyBackupVersion: data, + } + + var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk) + _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + &dbData, + options) + + return strconv.FormatInt(versionInt, 10), err +} + +func (s *keyBackupVersionStatements) updateKeyBackupAuthData( + ctx context.Context, userID, version string, authData json.RawMessage, +) error { + // "UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3" + versionInt, err := strconv.ParseInt(version, 10, 64) + if err != nil { + return fmt.Errorf("invalid version") + } + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version); + docId := fmt.Sprintf("%s_%d", userID, versionInt) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + + res, err := getKeyBackupVersion(s, ctx, pk, cosmosDocId) + + if err != nil { + return err + } + + if res == nil { + return err + } + + // _, err = txn.Stmt(s.updateKeyBackupAuthDataStmt).ExecContext(ctx, string(authData), userID, versionInt) + res.KeyBackupVersion.AuthData = authData + + _, err = setKeyBackupVersion(s, ctx, *res) + + return err +} + +func (s *keyBackupVersionStatements) updateKeyBackupETag( + ctx context.Context, userID, version, etag string, +) error { + // "UPDATE account_e2e_room_keys_versions SET etag = $1 WHERE user_id = $2 AND version = $3" + versionInt, err := strconv.ParseInt(version, 10, 64) + if err != nil { + return fmt.Errorf("invalid version") + } + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version); + docId := fmt.Sprintf("%s_%d", userID, versionInt) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + + res, err := getKeyBackupVersion(s, ctx, pk, cosmosDocId) + + if err != nil { + return err + } + + if res == nil { + return err + } + + // _, err = txn.Stmt(s.updateKeyBackupETagStmt).ExecContext(ctx, etag, userID, versionInt) + res.KeyBackupVersion.Etag = etag + + _, err = setKeyBackupVersion(s, ctx, *res) + + return err +} + +func (s *keyBackupVersionStatements) deleteKeyBackup( + ctx context.Context, userID, version string, +) (bool, error) { + // "UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2" + versionInt, err := strconv.ParseInt(version, 10, 64) + if err != nil { + return false, fmt.Errorf("invalid version") + } + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version); + docId := fmt.Sprintf("%s_%d", userID, versionInt) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + + res, err := getKeyBackupVersion(s, ctx, pk, cosmosDocId) + + if err != nil { + return false, err + } + + if res == nil { + return false, err + } + + // result, err := txn.Stmt(s.deleteKeyBackupStmt).ExecContext(ctx, userID, versionInt) + res.KeyBackupVersion.Deleted = 1 + + _, err = setKeyBackupVersion(s, ctx, *res) + + if err != nil { + return false, err + } + return true, nil +} + +func (s *keyBackupVersionStatements) selectKeyBackup( + ctx context.Context, userID, version string, +) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { + // "SELECT algorithm, auth_data, etag, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2" + var versionInt int64 + if version == "" { + // var v *int64 // allows nulls + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": s.db.cosmosConfig.TenantName, + "@x2": dbCollectionName, + "@x3": userID, + } + + // err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) + response, err1 := queryKeyBackupVersionNumber(s, ctx, s.selectLatestVersionStmt, params) + + if err1 != nil { + if err == cosmosdbutil.ErrNoRows { + err = nil + } + } + // if err = txn.Stmt(s.selectLatestVersionStmt).QueryRowContext(ctx, userID).Scan(&v); err != nil { + // return + // } + if response == nil || len(response) == 0 { + err = cosmosdbutil.ErrNoRows + versionInt = 0 + return + } + versionInt = response[0].Number + } else { + if versionInt, err = strconv.ParseInt(version, 10, 64); err != nil { + return + } + } + versionResult = strconv.FormatInt(versionInt, 10) + if err != nil { + return + } + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version); + docId := fmt.Sprintf("%s_%d", userID, versionInt) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName) + + res, err := getKeyBackupVersion(s, ctx, pk, cosmosDocId) + + if err != nil { + return + } + + if res == nil { + return + } + + // var deletedInt int + // var authDataStr string + // err = txn.Stmt(s.selectKeyBackupStmt).QueryRowContext(ctx, userID, versionInt).Scan(&algorithm, &authDataStr, &etag, &deletedInt) + deleted = res.KeyBackupVersion.Deleted == 1 + authData = res.KeyBackupVersion.AuthData + return +} diff --git a/userapi/storage/accounts/cosmosdb/key_backup_version_table_id_seq.go b/userapi/storage/accounts/cosmosdb/key_backup_version_table_id_seq.go new file mode 100644 index 000000000..5f9060739 --- /dev/null +++ b/userapi/storage/accounts/cosmosdb/key_backup_version_table_id_seq.go @@ -0,0 +1,12 @@ +package cosmosdb + +import ( + "context" + + "github.com/matrix-org/dendrite/internal/cosmosdbutil" +) + +func GetNextKeyBackupVersionID(s *keyBackupVersionStatements, ctx context.Context) (int64, error) { + const docId = "id_seq" + return cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 1) +} diff --git a/userapi/storage/accounts/cosmosdb/storage.go b/userapi/storage/accounts/cosmosdb/storage.go index c3124fd98..9a0534b85 100644 --- a/userapi/storage/accounts/cosmosdb/storage.go +++ b/userapi/storage/accounts/cosmosdb/storage.go @@ -18,6 +18,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "strconv" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -45,6 +46,8 @@ type Database struct { accountDatas accountDataStatements threepids threepidStatements openIDTokens tokenStatements + keyBackupVersions keyBackupVersionStatements + keyBackups keyBackupStatements serverName gomatrixserverlib.ServerName bcryptCost int openIDTokenLifetimeMS int64 @@ -105,6 +108,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.openIDTokens.prepare(d, serverName); err != nil { return nil, err } + if err = d.keyBackupVersions.prepare(d, serverName); err != nil { + return nil, err + } + if err = d.keyBackups.prepare(d, serverName); err != nil { + return nil, err + } return d, nil } @@ -419,3 +428,150 @@ func (d *Database) GetOpenIDTokenAttributes( ) (*api.OpenIDTokenAttributes, error) { return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token) } + +func (d *Database) CreateKeyBackup( + ctx context.Context, userID, algorithm string, authData json.RawMessage, +) (version string, err error) { + // err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + version, err = d.keyBackupVersions.insertKeyBackup(ctx, userID, algorithm, authData, "") + return version, err + // }) + // return +} + +func (d *Database) UpdateKeyBackupAuthData( + ctx context.Context, userID, version string, authData json.RawMessage, +) (err error) { + // err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.keyBackupVersions.updateKeyBackupAuthData(ctx, userID, version, authData) + // }) + // return +} + +func (d *Database) DeleteKeyBackup( + ctx context.Context, userID, version string, +) (exists bool, err error) { + // err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + exists, err = d.keyBackupVersions.deleteKeyBackup(ctx, userID, version) + return + // }) + // return +} + +func (d *Database) GetKeyBackup( + ctx context.Context, userID, version string, +) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { + // err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + versionResult, algorithm, authData, etag, deleted, err = d.keyBackupVersions.selectKeyBackup(ctx, userID, version) + return + // }) + // return +} + +func (d *Database) GetBackupKeys( + ctx context.Context, version, userID, filterRoomID, filterSessionID string, +) (result map[string]map[string]api.KeyBackupSession, err error) { + // err = d.writer.Do(d, nil, func(txn *sql.Tx) error { + if filterSessionID != "" { + result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, userID, version, filterRoomID, filterSessionID) + return + } + if filterRoomID != "" { + result, err = d.keyBackups.selectKeysByRoomID(ctx, userID, version, filterRoomID) + return + } + result, err = d.keyBackups.selectKeys(ctx, userID, version) + return + // }) + // return +} + +func (d *Database) CountBackupKeys( + ctx context.Context, version, userID string, +) (count int64, err error) { + // err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + count, err = d.keyBackups.countKeys(ctx, userID, version) + if err != nil { + return + } + return + // }) + // return +} + +// nolint:nakedret +func (d *Database) UpsertBackupKeys( + ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession, +) (count int64, etag string, err error) { + // wrap the following logic in a txn to ensure we atomically upload keys + // err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + _, _, _, oldETag, deleted, err := d.keyBackupVersions.selectKeyBackup(ctx, userID, version) + if err != nil { + return + } + if deleted { + err = fmt.Errorf("backup was deleted") + return + } + // pull out all keys for this (user_id, version) + existingKeys, err := d.keyBackups.selectKeys(ctx, userID, version) + if err != nil { + return + } + + changed := false + // loop over all the new keys (which should be smaller than the set of backed up keys) + for _, newKey := range uploads { + // if we have a matching (room_id, session_id), we may need to update the key if it meets some rules, check them. + existingRoom := existingKeys[newKey.RoomID] + if existingRoom != nil { + existingSession, ok := existingRoom[newKey.SessionID] + if ok { + if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) { + err = d.keyBackups.updateBackupKey(ctx, userID, version, newKey) + changed = true + if err != nil { + err = fmt.Errorf("d.keyBackups.updateBackupKey: %w", err) + return + } + } + // if we shouldn't replace the key we do nothing with it + continue + } + } + // if we're here, either the room or session are new, either way, we insert + err = d.keyBackups.insertBackupKey(ctx, userID, version, newKey) + changed = true + if err != nil { + err = fmt.Errorf("d.keyBackups.insertBackupKey: %w", err) + return + } + } + + count, err = d.keyBackups.countKeys(ctx, userID, version) + if err != nil { + return + } + if changed { + // update the etag + var newETag string + if oldETag == "" { + newETag = "1" + } else { + oldETagInt, err1 := strconv.ParseInt(oldETag, 10, 64) + if err1 != nil { + err = fmt.Errorf("failed to parse old etag: %s", err1) + return + } + newETag = strconv.FormatInt(oldETagInt+1, 10) + } + etag = newETag + err = d.keyBackupVersions.updateKeyBackupETag(ctx, userID, version, newETag) + } else { + etag = oldETag + } + + return + // }) + // return +}