mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-29 01:33:10 -06:00
Upgrade Dendrite 0.5.0 support for CosmosDB (#15)
* - Add CosmosDB back - Add missing methods to blacklist_table.go - Add missing methods to device_keys_table.go - Add missing methods to events_table.go - Add missing methods to membership_table.go - Update state_block_table.go (due to reafctor SQL) - Update state_snapshot_table.go (due to reafctor SQL) - Add new key_backup_table.go - Add new key_backup_version_table.go - Code compiles but has runtime errors * Message sending + receiving working Rooms and DMs working - Add CrossSigningKeys table - Add CrossSigningSigs table - Refactor DeviceKeys yable - Fix OneTimeKeys - Update the KeyServer storage.go to use a PartitionStorer instead of a specific SQL PartitionOffsetStatements - Fix small issues from the previous commit - Implement DeleteSendToDeviceMessages Co-authored-by: alexf@example.com <alexf@example.com>
This commit is contained in:
parent
84c8cb052b
commit
fd7f25479b
|
|
@ -57,12 +57,17 @@ type BlacklistCosmosData struct {
|
||||||
// const deleteBlacklistSQL = "" +
|
// const deleteBlacklistSQL = "" +
|
||||||
// "DELETE FROM federationsender_blacklist WHERE server_name = $1"
|
// "DELETE FROM federationsender_blacklist WHERE server_name = $1"
|
||||||
|
|
||||||
|
// "DELETE FROM federationsender_blacklist"
|
||||||
|
const deleteAllBlacklistSQL = "" +
|
||||||
|
"select * from c where c._cn = @x1 "
|
||||||
|
|
||||||
type blacklistStatements struct {
|
type blacklistStatements struct {
|
||||||
db *Database
|
db *Database
|
||||||
// insertBlacklistStmt *sql.Stmt
|
// insertBlacklistStmt *sql.Stmt
|
||||||
// selectBlacklistStmt *sql.Stmt
|
// selectBlacklistStmt *sql.Stmt
|
||||||
// deleteBlacklistStmt *sql.Stmt
|
// deleteBlacklistStmt *sql.Stmt
|
||||||
tableName string
|
deleteAllBlacklistStmt string
|
||||||
|
tableName string
|
||||||
}
|
}
|
||||||
|
|
||||||
func getBlacklist(s *blacklistStatements, ctx context.Context, pk string, docId string) (*BlacklistCosmosData, error) {
|
func getBlacklist(s *blacklistStatements, ctx context.Context, pk string, docId string) (*BlacklistCosmosData, error) {
|
||||||
|
|
@ -82,6 +87,27 @@ func getBlacklist(s *blacklistStatements, ctx context.Context, pk string, docId
|
||||||
return &response, err
|
return &response, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func queryBlacklist(s *blacklistStatements, ctx context.Context, qry string, params map[string]interface{}) ([]BlacklistCosmosData, error) {
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
|
||||||
|
var response []BlacklistCosmosData
|
||||||
|
|
||||||
|
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||||
|
var query = cosmosdbapi.GetQuery(qry, params)
|
||||||
|
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
query,
|
||||||
|
&response,
|
||||||
|
optionsQry)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
func deleteBlacklist(s *blacklistStatements, ctx context.Context, dbData BlacklistCosmosData) error {
|
func deleteBlacklist(s *blacklistStatements, ctx context.Context, dbData BlacklistCosmosData) error {
|
||||||
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
|
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
|
||||||
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
|
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
|
||||||
|
|
@ -101,6 +127,7 @@ func NewCosmosDBBlacklistTable(db *Database) (s *blacklistStatements, err error)
|
||||||
s = &blacklistStatements{
|
s = &blacklistStatements{
|
||||||
db: db,
|
db: db,
|
||||||
}
|
}
|
||||||
|
s.deleteAllBlacklistStmt = deleteAllBlacklistSQL
|
||||||
s.tableName = "blacklists"
|
s.tableName = "blacklists"
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -189,8 +216,36 @@ func (s *blacklistStatements) DeleteBlacklist(
|
||||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
|
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
|
||||||
// _, err := stmt.ExecContext(ctx, serverName)
|
// _, err := stmt.ExecContext(ctx, serverName)
|
||||||
res, err := getBlacklist(s, ctx, pk, cosmosDocId)
|
res, err := getBlacklist(s, ctx, pk, cosmosDocId)
|
||||||
if(res != nil) {
|
if res != nil {
|
||||||
_ = deleteBlacklist(s, ctx, *res)
|
_ = deleteBlacklist(s, ctx, *res)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *blacklistStatements) DeleteAllBlacklist(
|
||||||
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
) error {
|
||||||
|
// "DELETE FROM federationsender_blacklist"
|
||||||
|
|
||||||
|
// stmt := sqlutil.TxStmt(txn, s.deleteAllBlacklistStmt)
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
}
|
||||||
|
|
||||||
|
// rows, err := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryContext(ctx, roomID)
|
||||||
|
rows, err := queryBlacklist(s, ctx, s.deleteAllBlacklistStmt, params)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// _, err := stmt.ExecContext(ctx)
|
||||||
|
for _, item := range rows {
|
||||||
|
// stmt := sqlutil.TxStmt(txn, deleteStmt)
|
||||||
|
err = deleteBlacklist(s, ctx, item)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
|
||||||
1
go.mod
1
go.mod
|
|
@ -51,6 +51,7 @@ require (
|
||||||
github.com/tidwall/sjson v1.1.7
|
github.com/tidwall/sjson v1.1.7
|
||||||
github.com/uber/jaeger-client-go v2.29.1+incompatible
|
github.com/uber/jaeger-client-go v2.29.1+incompatible
|
||||||
github.com/uber/jaeger-lib v2.4.1+incompatible
|
github.com/uber/jaeger-lib v2.4.1+incompatible
|
||||||
|
github.com/vippsas/go-cosmosdb v0.0.0-20200428065936-29dab535353d // indirect
|
||||||
github.com/yggdrasil-network/yggdrasil-go v0.4.1-0.20210715083903-52309d094c00
|
github.com/yggdrasil-network/yggdrasil-go v0.4.1-0.20210715083903-52309d094c00
|
||||||
go.uber.org/atomic v1.9.0
|
go.uber.org/atomic v1.9.0
|
||||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97
|
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97
|
||||||
|
|
|
||||||
4
go.sum
4
go.sum
|
|
@ -100,6 +100,7 @@ github.com/VividCortex/ewma v1.2.0/go.mod h1:nz4BbCtbLyFDeC9SUHbtcT5644juEuWfUAU
|
||||||
github.com/aead/siphash v1.0.1/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII=
|
github.com/aead/siphash v1.0.1/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII=
|
||||||
github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY=
|
github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY=
|
||||||
github.com/albertorestifo/dijkstra v0.0.0-20160910063646-aba76f725f72/go.mod h1:o+JdB7VetTHjLhU0N57x18B9voDBQe0paApdEAEoEfw=
|
github.com/albertorestifo/dijkstra v0.0.0-20160910063646-aba76f725f72/go.mod h1:o+JdB7VetTHjLhU0N57x18B9voDBQe0paApdEAEoEfw=
|
||||||
|
github.com/alecthomas/repr v0.0.0-20181024024818-d37bc2a10ba1/go.mod h1:xTS7Pm1pD1mvyM075QCDSRqH6qRLXylzS24ZTpRiSzQ=
|
||||||
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
||||||
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
||||||
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
|
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
|
||||||
|
|
@ -432,6 +433,7 @@ github.com/godbus/dbus v0.0.0-20180201030542-885f9cc04c9c/go.mod h1:/YcGZj5zSblf
|
||||||
github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e/go.mod h1:bBOAhwG1umN6/6ZUMtDFBMQR8jRg9O75tm9K00oMsK4=
|
github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e/go.mod h1:bBOAhwG1umN6/6ZUMtDFBMQR8jRg9O75tm9K00oMsK4=
|
||||||
github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||||
|
github.com/gofrs/uuid v3.1.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
|
||||||
github.com/gogo/googleapis v1.2.0/go.mod h1:Njal3psf3qN6dwBtQfUmBZh2ybovJ0tlu3o/AC7HYjU=
|
github.com/gogo/googleapis v1.2.0/go.mod h1:Njal3psf3qN6dwBtQfUmBZh2ybovJ0tlu3o/AC7HYjU=
|
||||||
github.com/gogo/googleapis v1.4.0/go.mod h1:5YRNX2z1oM5gXdAkurHa942MDgEJyk02w4OecKY87+c=
|
github.com/gogo/googleapis v1.4.0/go.mod h1:5YRNX2z1oM5gXdAkurHa942MDgEJyk02w4OecKY87+c=
|
||||||
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
||||||
|
|
@ -1405,6 +1407,8 @@ github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPU
|
||||||
github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio=
|
github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio=
|
||||||
github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU=
|
github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU=
|
||||||
github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM=
|
github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM=
|
||||||
|
github.com/vippsas/go-cosmosdb v0.0.0-20200428065936-29dab535353d h1:MZRYOouO0snrQyBAf4Wljc3qqaispjzMOhFRQgWfKMo=
|
||||||
|
github.com/vippsas/go-cosmosdb v0.0.0-20200428065936-29dab535353d/go.mod h1:ldPlejlc7ZyiP0QQWGwL9CoZLvEjhD9yzpz0ct7+sXo=
|
||||||
github.com/vishvananda/netlink v0.0.0-20181108222139-023a6dafdcdf/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk=
|
github.com/vishvananda/netlink v0.0.0-20181108222139-023a6dafdcdf/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk=
|
||||||
github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE=
|
github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE=
|
||||||
github.com/vishvananda/netlink v1.1.1-0.20201029203352-d40f9887b852/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho=
|
github.com/vishvananda/netlink v1.1.1-0.20201029203352-d40f9887b852/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho=
|
||||||
|
|
|
||||||
179
keyserver/storage/cosmosdb/cross_signing_keys_table.go
Normal file
179
keyserver/storage/cosmosdb/cross_signing_keys_table.go
Normal file
|
|
@ -0,0 +1,179 @@
|
||||||
|
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package cosmosdb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||||
|
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||||
|
"github.com/matrix-org/dendrite/keyserver/types"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
// var crossSigningKeysSchema = `
|
||||||
|
// CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys (
|
||||||
|
// user_id TEXT NOT NULL,
|
||||||
|
// key_type INTEGER NOT NULL,
|
||||||
|
// key_data TEXT NOT NULL,
|
||||||
|
// PRIMARY KEY (user_id, key_type)
|
||||||
|
// );
|
||||||
|
// `
|
||||||
|
|
||||||
|
type CrossSigningKeysCosmos struct {
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
KeyType int64 `json:"key_type"`
|
||||||
|
KeyData []byte `json:"key_data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CrossSigningKeysCosmosData struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Pk string `json:"_pk"`
|
||||||
|
Tn string `json:"_sid"`
|
||||||
|
Cn string `json:"_cn"`
|
||||||
|
ETag string `json:"_etag"`
|
||||||
|
Timestamp int64 `json:"_ts"`
|
||||||
|
CrossSigningKeys CrossSigningKeysCosmos `json:"mx_keyserver_cross_signing_keys"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// "SELECT key_type, key_data FROM keyserver_cross_signing_keys" +
|
||||||
|
// " WHERE user_id = $1"
|
||||||
|
const selectCrossSigningKeysForUserSQL = "" +
|
||||||
|
"select * from c where c._cn = @x1 " +
|
||||||
|
"and c.mx_keyserver_cross_signing_keys.user_id = @x2 "
|
||||||
|
|
||||||
|
// const upsertCrossSigningKeysForUserSQL = "" +
|
||||||
|
// "INSERT OR REPLACE INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" +
|
||||||
|
// " VALUES($1, $2, $3)"
|
||||||
|
|
||||||
|
type crossSigningKeysStatements struct {
|
||||||
|
db *Database
|
||||||
|
selectCrossSigningKeysForUserStmt string
|
||||||
|
// upsertCrossSigningKeysForUserStmt *sql.Stmt
|
||||||
|
tableName string
|
||||||
|
}
|
||||||
|
|
||||||
|
func queryCrossSigningKeys(s *crossSigningKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]CrossSigningKeysCosmosData, error) {
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
|
||||||
|
var response []CrossSigningKeysCosmosData
|
||||||
|
|
||||||
|
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||||
|
var query = cosmosdbapi.GetQuery(qry, params)
|
||||||
|
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
query,
|
||||||
|
&response,
|
||||||
|
optionsQry)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSqliteCrossSigningKeysTable(db *Database) (tables.CrossSigningKeys, error) {
|
||||||
|
s := &crossSigningKeysStatements{
|
||||||
|
db: db,
|
||||||
|
}
|
||||||
|
s.selectCrossSigningKeysForUserStmt = selectCrossSigningKeysForUserSQL
|
||||||
|
// s.upsertCrossSigningKeysForUserStmt = upsertCrossSigningKeysForUserSQL
|
||||||
|
s.tableName = "cross_signing_keys"
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser(
|
||||||
|
ctx context.Context, txn *sql.Tx, userID string,
|
||||||
|
) (r types.CrossSigningKeyMap, err error) {
|
||||||
|
// "SELECT key_type, key_data FROM keyserver_cross_signing_keys" +
|
||||||
|
// " WHERE user_id = $1"
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": userID,
|
||||||
|
}
|
||||||
|
rows, err := queryCrossSigningKeys(s, ctx, s.selectCrossSigningKeysForUserStmt, params)
|
||||||
|
// rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningKeysForUserStmt: rows.close() failed")
|
||||||
|
r = types.CrossSigningKeyMap{}
|
||||||
|
// for rows.Next() {
|
||||||
|
for _, item := range rows {
|
||||||
|
var keyTypeInt int16
|
||||||
|
var keyData gomatrixserverlib.Base64Bytes
|
||||||
|
// if err := rows.Scan(&keyTypeInt, &keyData); err != nil {
|
||||||
|
// return nil, err
|
||||||
|
// }
|
||||||
|
keyData = item.CrossSigningKeys.KeyData
|
||||||
|
keyTypeInt = int16(item.CrossSigningKeys.KeyType)
|
||||||
|
keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt)
|
||||||
|
}
|
||||||
|
r[keyType] = keyData
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser(
|
||||||
|
ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes,
|
||||||
|
) error {
|
||||||
|
// "INSERT OR REPLACE INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" +
|
||||||
|
// " VALUES($1, $2, $3)"
|
||||||
|
keyTypeInt, ok := types.KeyTypePurposeToInt[keyType]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unknown key purpose %q", keyType)
|
||||||
|
}
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
|
||||||
|
// PRIMARY KEY (user_id, key_type)
|
||||||
|
docId := fmt.Sprintf("%s_%s", userID, keyType)
|
||||||
|
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
|
||||||
|
|
||||||
|
data := CrossSigningKeysCosmos{
|
||||||
|
UserID: userID,
|
||||||
|
KeyType: int64(keyTypeInt),
|
||||||
|
KeyData: keyData,
|
||||||
|
}
|
||||||
|
|
||||||
|
dbData := CrossSigningKeysCosmosData{
|
||||||
|
Id: cosmosDocId,
|
||||||
|
Tn: s.db.cosmosConfig.TenantName,
|
||||||
|
Cn: dbCollectionName,
|
||||||
|
Pk: pk,
|
||||||
|
Timestamp: time.Now().Unix(),
|
||||||
|
CrossSigningKeys: data,
|
||||||
|
}
|
||||||
|
|
||||||
|
// if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyTypeInt, keyData); err != nil {
|
||||||
|
// return fmt.Errorf("s.upsertCrossSigningKeysForUserStmt: %w", err)
|
||||||
|
// }
|
||||||
|
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
|
||||||
|
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
dbData,
|
||||||
|
options)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
241
keyserver/storage/cosmosdb/cross_signing_sigs_table.go
Normal file
241
keyserver/storage/cosmosdb/cross_signing_sigs_table.go
Normal file
|
|
@ -0,0 +1,241 @@
|
||||||
|
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package cosmosdb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||||
|
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||||
|
"github.com/matrix-org/dendrite/keyserver/types"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
// var crossSigningSigsSchema = `
|
||||||
|
// CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs (
|
||||||
|
// origin_user_id TEXT NOT NULL,
|
||||||
|
// origin_key_id TEXT NOT NULL,
|
||||||
|
// target_user_id TEXT NOT NULL,
|
||||||
|
// target_key_id TEXT NOT NULL,
|
||||||
|
// signature TEXT NOT NULL,
|
||||||
|
// PRIMARY KEY (origin_user_id, target_user_id, target_key_id)
|
||||||
|
// );
|
||||||
|
// `
|
||||||
|
|
||||||
|
type CrossSigningSigsCosmos struct {
|
||||||
|
OriginUserId string `json:"origin_user_id"`
|
||||||
|
OriginKeyId string `json:"origin_key_id"`
|
||||||
|
TargetUserId string `json:"target_user_id"`
|
||||||
|
TargetKeyId string `json:"target_key_id"`
|
||||||
|
Signature []byte `json:"signature"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CrossSigningSigsCosmosData struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Pk string `json:"_pk"`
|
||||||
|
Tn string `json:"_sid"`
|
||||||
|
Cn string `json:"_cn"`
|
||||||
|
ETag string `json:"_etag"`
|
||||||
|
Timestamp int64 `json:"_ts"`
|
||||||
|
CrossSigningSigs CrossSigningSigsCosmos `json:"mx_keyserver_cross_signing_sigs"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" +
|
||||||
|
// " WHERE target_user_id = $1 AND target_key_id = $2"
|
||||||
|
const selectCrossSigningSigsForTargetSQL = "" +
|
||||||
|
"select * from c where c._cn = @x1 " +
|
||||||
|
"and c.mx_keyserver_cross_signing_sigs.target_user_id = @x2 " +
|
||||||
|
"and c.mx_keyserver_cross_signing_sigs.target_key_id = @x3 "
|
||||||
|
|
||||||
|
// const upsertCrossSigningSigsForTargetSQL = "" +
|
||||||
|
// "INSERT OR REPLACE INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" +
|
||||||
|
// " VALUES($1, $2, $3, $4, $5)"
|
||||||
|
|
||||||
|
// "DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2"
|
||||||
|
const deleteCrossSigningSigsForTargetSQL = "" +
|
||||||
|
"select * from c where c._cn = @x1 " +
|
||||||
|
"and c.mx_keyserver_cross_signing_sigs.target_user_id = @x2 " +
|
||||||
|
"and c.mx_keyserver_cross_signing_sigs.target_key_id = @x3 "
|
||||||
|
|
||||||
|
type crossSigningSigsStatements struct {
|
||||||
|
db *Database
|
||||||
|
selectCrossSigningSigsForTargetStmt string
|
||||||
|
// upsertCrossSigningSigsForTargetStmt *sql.Stmt
|
||||||
|
deleteCrossSigningSigsForTargetStmt string
|
||||||
|
tableName string
|
||||||
|
}
|
||||||
|
|
||||||
|
func queryCrossSigningSigs(s *crossSigningSigsStatements, ctx context.Context, qry string, params map[string]interface{}) ([]CrossSigningSigsCosmosData, error) {
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
|
||||||
|
var response []CrossSigningSigsCosmosData
|
||||||
|
|
||||||
|
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||||
|
var query = cosmosdbapi.GetQuery(qry, params)
|
||||||
|
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
query,
|
||||||
|
&response,
|
||||||
|
optionsQry)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func deleteCrossSigningSigs(s *crossSigningSigsStatements, ctx context.Context, dbData CrossSigningSigsCosmosData) error {
|
||||||
|
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
|
||||||
|
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
dbData.Id,
|
||||||
|
options)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSqliteCrossSigningSigsTable(db *Database) (tables.CrossSigningSigs, error) {
|
||||||
|
s := &crossSigningSigsStatements{
|
||||||
|
db: db,
|
||||||
|
}
|
||||||
|
// _, err := db.Exec(crossSigningSigsSchema)
|
||||||
|
// if err != nil {
|
||||||
|
// return nil, err
|
||||||
|
// }
|
||||||
|
s.selectCrossSigningSigsForTargetStmt = selectCrossSigningSigsForTargetSQL
|
||||||
|
// s.upsertCrossSigningSigsForTargetStmt = upsertCrossSigningSigsForTargetSQL
|
||||||
|
s.deleteCrossSigningSigsForTargetStmt = deleteCrossSigningSigsForTargetSQL
|
||||||
|
s.tableName = "cross_signing_sigs"
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget(
|
||||||
|
ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID,
|
||||||
|
) (r types.CrossSigningSigMap, err error) {
|
||||||
|
// "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" +
|
||||||
|
// " WHERE target_user_id = $1 AND target_key_id = $2"
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": targetUserID,
|
||||||
|
"@x3": targetKeyID,
|
||||||
|
}
|
||||||
|
rows, err := queryCrossSigningSigs(s, ctx, s.selectCrossSigningSigsForTargetStmt, params)
|
||||||
|
// rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, targetUserID, targetKeyID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningSigsForTargetStmt: rows.close() failed")
|
||||||
|
r = types.CrossSigningSigMap{}
|
||||||
|
// for rows.Next() {
|
||||||
|
for _, item := range rows {
|
||||||
|
var userID string
|
||||||
|
var keyID gomatrixserverlib.KeyID
|
||||||
|
var signature gomatrixserverlib.Base64Bytes
|
||||||
|
// if err := rows.Scan(&userID, &keyID, &signature); err != nil {
|
||||||
|
// return nil, err
|
||||||
|
// }
|
||||||
|
userID = item.CrossSigningSigs.OriginUserId
|
||||||
|
keyID = gomatrixserverlib.KeyID(item.CrossSigningSigs.OriginKeyId)
|
||||||
|
signature = gomatrixserverlib.Base64Bytes(item.CrossSigningSigs.Signature)
|
||||||
|
if _, ok := r[userID]; !ok {
|
||||||
|
r[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
|
||||||
|
}
|
||||||
|
r[userID][keyID] = signature
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget(
|
||||||
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
originUserID string, originKeyID gomatrixserverlib.KeyID,
|
||||||
|
targetUserID string, targetKeyID gomatrixserverlib.KeyID,
|
||||||
|
signature gomatrixserverlib.Base64Bytes,
|
||||||
|
) error {
|
||||||
|
// "INSERT OR REPLACE INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" +
|
||||||
|
// " VALUES($1, $2, $3, $4, $5)"
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
|
||||||
|
// PRIMARY KEY (origin_user_id, target_user_id, target_key_id)
|
||||||
|
docId := fmt.Sprintf("%s_%s_%s", originUserID, targetUserID, targetKeyID)
|
||||||
|
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
|
||||||
|
|
||||||
|
data := CrossSigningSigsCosmos{
|
||||||
|
TargetUserId: targetUserID,
|
||||||
|
TargetKeyId: string(targetKeyID),
|
||||||
|
OriginUserId: originUserID,
|
||||||
|
OriginKeyId: string(originKeyID),
|
||||||
|
Signature: signature,
|
||||||
|
}
|
||||||
|
|
||||||
|
dbData := CrossSigningSigsCosmosData{
|
||||||
|
Id: cosmosDocId,
|
||||||
|
Tn: s.db.cosmosConfig.TenantName,
|
||||||
|
Cn: dbCollectionName,
|
||||||
|
Pk: pk,
|
||||||
|
Timestamp: time.Now().Unix(),
|
||||||
|
CrossSigningSigs: data,
|
||||||
|
}
|
||||||
|
|
||||||
|
// if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil {
|
||||||
|
// return fmt.Errorf("s.upsertCrossSigningSigsForTargetStmt: %w", err)
|
||||||
|
// }
|
||||||
|
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
|
||||||
|
var _, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
dbData,
|
||||||
|
options)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *crossSigningSigsStatements) DeleteCrossSigningSigsForTarget(
|
||||||
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
targetUserID string, targetKeyID gomatrixserverlib.KeyID,
|
||||||
|
) error {
|
||||||
|
// "DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2"
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": targetUserID,
|
||||||
|
"@x3": targetKeyID,
|
||||||
|
}
|
||||||
|
rows, err := queryCrossSigningSigs(s, ctx, s.selectCrossSigningSigsForTargetStmt, params)
|
||||||
|
// if _, err := sqlutil.TxStmt(txn, s.deleteCrossSigningSigsForTargetStmt).ExecContext(ctx, targetUserID, targetKeyID); err != nil {
|
||||||
|
// return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err)
|
||||||
|
// }
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, item := range rows {
|
||||||
|
errItem := deleteCrossSigningSigs(s, ctx, item)
|
||||||
|
if errItem != nil {
|
||||||
|
return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
@ -94,7 +94,13 @@ const selectAllDeviceKeysSQL = "" +
|
||||||
"select * from c where c._cn = @x1 " +
|
"select * from c where c._cn = @x1 " +
|
||||||
"and c.mx_keyserver_device_key.user_id = @x2 "
|
"and c.mx_keyserver_device_key.user_id = @x2 "
|
||||||
|
|
||||||
// const deleteAllDeviceKeysSQL = "" +
|
// "DELETE FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
|
||||||
|
const deleteDeviceKeysSQL = "" +
|
||||||
|
"select * from c where c._cn = @x1 " +
|
||||||
|
"and c.mx_keyserver_device_key.user_id = @x2 " +
|
||||||
|
"and c.mx_keyserver_device_key.device_id = @x3 "
|
||||||
|
|
||||||
|
// const deleteAllDeviceKeysSQL = "" +
|
||||||
// "DELETE FROM keyserver_device_keys WHERE user_id=$1"
|
// "DELETE FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
|
||||||
func queryDeviceKey(s *deviceKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]DeviceKeyCosmosData, error) {
|
func queryDeviceKey(s *deviceKeysStatements, ctx context.Context, qry string, params map[string]interface{}) ([]DeviceKeyCosmosData, error) {
|
||||||
|
|
@ -192,6 +198,7 @@ type deviceKeysStatements struct {
|
||||||
// selectDeviceKeysStmt *sql.Stmt
|
// selectDeviceKeysStmt *sql.Stmt
|
||||||
selectBatchDeviceKeysStmt string
|
selectBatchDeviceKeysStmt string
|
||||||
selectMaxStreamForUserStmt string
|
selectMaxStreamForUserStmt string
|
||||||
|
deleteDeviceKeysStmt string
|
||||||
// deleteAllDeviceKeysStmt *sql.Stmt
|
// deleteAllDeviceKeysStmt *sql.Stmt
|
||||||
tableName string
|
tableName string
|
||||||
}
|
}
|
||||||
|
|
@ -202,6 +209,7 @@ func NewCosmosDBDeviceKeysTable(db *Database) (tables.DeviceKeys, error) {
|
||||||
}
|
}
|
||||||
s.selectBatchDeviceKeysStmt = selectBatchDeviceKeysSQL
|
s.selectBatchDeviceKeysStmt = selectBatchDeviceKeysSQL
|
||||||
s.selectMaxStreamForUserStmt = selectMaxStreamForUserSQL
|
s.selectMaxStreamForUserStmt = selectMaxStreamForUserSQL
|
||||||
|
s.deleteDeviceKeysStmt = deleteDeviceKeysSQL
|
||||||
s.tableName = "device_keys"
|
s.tableName = "device_keys"
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
@ -221,6 +229,30 @@ func deleteDeviceKeyCore(s *deviceKeysStatements, ctx context.Context, dbData De
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
|
||||||
|
// "DELETE FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
|
||||||
|
// _, err := sqlutil.TxStmt(txn, s.deleteDeviceKeysStmt).ExecContext(ctx, userID, deviceID)
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": userID,
|
||||||
|
"@x3": deviceID,
|
||||||
|
}
|
||||||
|
response, err := queryDeviceKey(s, ctx, selectAllDeviceKeysSQL, params)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, item := range response {
|
||||||
|
errItem := deleteDeviceKeyCore(s, ctx, item)
|
||||||
|
if errItem != nil {
|
||||||
|
return errItem
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
|
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
|
||||||
|
|
||||||
// "DELETE FROM keyserver_device_keys WHERE user_id=$1"
|
// "DELETE FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
|
@ -268,20 +300,25 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
|
||||||
|
|
||||||
var result []api.DeviceMessage
|
var result []api.DeviceMessage
|
||||||
for _, item := range response {
|
for _, item := range response {
|
||||||
var dk api.DeviceMessage
|
dk := api.DeviceMessage{
|
||||||
dk.UserID = userID
|
Type: api.TypeDeviceKeyUpdate,
|
||||||
|
DeviceKeys: &api.DeviceKeys{},
|
||||||
|
}
|
||||||
|
dk.Type = api.TypeDeviceKeyUpdate
|
||||||
|
dk.UserID = item.DeviceKey.UserID
|
||||||
// var keyJSON string
|
// var keyJSON string
|
||||||
var streamID int
|
var streamID int
|
||||||
// var displayName sql.NullString
|
// var displayName sql.NullString
|
||||||
// if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
|
// if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
|
||||||
// return nil, err
|
// return nil, err
|
||||||
// }
|
// }
|
||||||
streamID = item.DeviceKey.StreamID
|
dk.DeviceID = item.DeviceKey.DeviceID
|
||||||
|
|
||||||
dk.KeyJSON = item.DeviceKey.KeyJSON
|
dk.KeyJSON = item.DeviceKey.KeyJSON
|
||||||
|
streamID = item.DeviceKey.StreamID
|
||||||
|
displayName := item.DeviceKey.DisplayName
|
||||||
dk.StreamID = streamID
|
dk.StreamID = streamID
|
||||||
if len(item.DeviceKey.DisplayName) > 0 {
|
if len(displayName) > 0 {
|
||||||
dk.DisplayName = item.DeviceKey.DisplayName
|
dk.DisplayName = displayName
|
||||||
}
|
}
|
||||||
// include the key if we want all keys (no device) or it was asked
|
// include the key if we want all keys (no device) or it was asked
|
||||||
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
|
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
|
||||||
|
|
|
||||||
|
|
@ -341,7 +341,7 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
|
||||||
ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
|
ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
|
||||||
) (map[string]json.RawMessage, error) {
|
) (map[string]json.RawMessage, error) {
|
||||||
var keyID string
|
var keyID string
|
||||||
var keyJSON string
|
// var keyJSON string
|
||||||
|
|
||||||
// "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1"
|
// "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1"
|
||||||
|
|
||||||
|
|
@ -360,14 +360,16 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
keyID = response[0].OneTimeKey.KeyID
|
||||||
|
keyJSONBytes := response[0].OneTimeKey.KeyJSON
|
||||||
err = deleteOneTimeKeyCore(s, ctx, response[0])
|
err = deleteOneTimeKeyCore(s, ctx, response[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if keyJSON == "" {
|
if keyID == "" {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return map[string]json.RawMessage{
|
return map[string]json.RawMessage{
|
||||||
algorithm + ":" + keyID: json.RawMessage(keyJSON),
|
algorithm + ":" + keyID: keyJSONBytes,
|
||||||
}, err
|
}, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ import (
|
||||||
|
|
||||||
// A Database is used to store room events and stream offsets.
|
// A Database is used to store room events and stream offsets.
|
||||||
type Database struct {
|
type Database struct {
|
||||||
shared.Database
|
database cosmosdbutil.Database
|
||||||
connection cosmosdbapi.CosmosConnection
|
connection cosmosdbapi.CosmosConnection
|
||||||
databaseName string
|
databaseName string
|
||||||
cosmosConfig cosmosdbapi.CosmosConfig
|
cosmosConfig cosmosdbapi.CosmosConfig
|
||||||
|
|
@ -33,38 +33,62 @@ type Database struct {
|
||||||
|
|
||||||
func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) {
|
func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) {
|
||||||
conn := cosmosdbutil.GetCosmosConnection(&dbProperties.ConnectionString)
|
conn := cosmosdbutil.GetCosmosConnection(&dbProperties.ConnectionString)
|
||||||
config := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString)
|
configCosmos := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString)
|
||||||
d := &Database{
|
result := &Database{
|
||||||
databaseName: "keyserver",
|
databaseName: "keyserver",
|
||||||
connection: conn,
|
connection: conn,
|
||||||
cosmosConfig: config,
|
cosmosConfig: configCosmos,
|
||||||
|
}
|
||||||
|
|
||||||
|
result.database = cosmosdbutil.Database{
|
||||||
|
Connection: conn,
|
||||||
|
CosmosConfig: configCosmos,
|
||||||
|
DatabaseName: result.databaseName,
|
||||||
}
|
}
|
||||||
|
|
||||||
// db, err := sqlutil.Open(dbProperties)
|
// db, err := sqlutil.Open(dbProperties)
|
||||||
// if err != nil {
|
// if err != nil {
|
||||||
// return nil, err
|
// return nil, err
|
||||||
// }
|
// }
|
||||||
otk, err := NewCosmosDBOneTimeKeysTable(d)
|
otk, err := NewCosmosDBOneTimeKeysTable(result)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
dk, err := NewCosmosDBDeviceKeysTable(d)
|
dk, err := NewCosmosDBDeviceKeysTable(result)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
kc, err := NewCosmosDBKeyChangesTable(d)
|
kc, err := NewCosmosDBKeyChangesTable(result)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
sdl, err := NewCosmosDBStaleDeviceListsTable(d)
|
sdl, err := NewCosmosDBStaleDeviceListsTable(result)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
csk, err := NewSqliteCrossSigningKeysTable(result)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
css, err := NewSqliteCrossSigningSigsTable(result)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := cosmosdbutil.NewExclusiveWriterFake()
|
||||||
|
storer := cosmosdbutil.PartitionOffsetStatements{}
|
||||||
|
if err = storer.Prepare(&result.database, writer, "keyserver"); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return &shared.Database{
|
return &shared.Database{
|
||||||
Writer: cosmosdbutil.NewExclusiveWriterFake(),
|
Writer: writer,
|
||||||
OneTimeKeysTable: otk,
|
OneTimeKeysTable: otk,
|
||||||
DeviceKeysTable: dk,
|
DeviceKeysTable: dk,
|
||||||
KeyChangesTable: kc,
|
KeyChangesTable: kc,
|
||||||
StaleDeviceListsTable: sdl,
|
StaleDeviceListsTable: sdl,
|
||||||
|
CrossSigningKeysTable: csk,
|
||||||
|
CrossSigningSigsTable: css,
|
||||||
|
PartitionStorer: &storer,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -61,8 +61,10 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error)
|
||||||
CrossSigningKeysTable: csk,
|
CrossSigningKeysTable: csk,
|
||||||
CrossSigningSigsTable: css,
|
CrossSigningSigsTable: css,
|
||||||
}
|
}
|
||||||
if err = d.PartitionOffsetStatements.Prepare(db, d.Writer, "keyserver"); err != nil {
|
storer := sqlutil.PartitionOffsetStatements{}
|
||||||
|
if err = storer.Prepare(db, d.Writer, "keyserver"); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
d.PartitionStorer = &storer
|
||||||
return d, nil
|
return d, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/keyserver/api"
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||||
|
|
@ -28,6 +29,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Database struct {
|
type Database struct {
|
||||||
|
internal.PartitionStorer
|
||||||
DB *sql.DB
|
DB *sql.DB
|
||||||
Writer sqlutil.Writer
|
Writer sqlutil.Writer
|
||||||
OneTimeKeysTable tables.OneTimeKeys
|
OneTimeKeysTable tables.OneTimeKeys
|
||||||
|
|
@ -36,7 +38,6 @@ type Database struct {
|
||||||
StaleDeviceListsTable tables.StaleDeviceLists
|
StaleDeviceListsTable tables.StaleDeviceLists
|
||||||
CrossSigningKeysTable tables.CrossSigningKeys
|
CrossSigningKeysTable tables.CrossSigningKeys
|
||||||
CrossSigningSigsTable tables.CrossSigningSigs
|
CrossSigningSigsTable tables.CrossSigningSigs
|
||||||
sqlutil.PartitionOffsetStatements
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
||||||
|
|
|
||||||
|
|
@ -59,8 +59,10 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error)
|
||||||
CrossSigningKeysTable: csk,
|
CrossSigningKeysTable: csk,
|
||||||
CrossSigningSigsTable: css,
|
CrossSigningSigsTable: css,
|
||||||
}
|
}
|
||||||
if err = d.PartitionOffsetStatements.Prepare(db, d.Writer, "keyserver"); err != nil {
|
storer := sqlutil.PartitionOffsetStatements{}
|
||||||
|
if err = storer.Prepare(db, d.Writer, "keyserver"); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
d.PartitionStorer = &storer
|
||||||
return d, nil
|
return d, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sort"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||||
|
|
@ -98,6 +99,14 @@ const bulkSelectStateEventByIDSQL = "" +
|
||||||
// ", c.mx_roomserver_event.event_state_key_nid " +
|
// ", c.mx_roomserver_event.event_state_key_nid " +
|
||||||
"asc"
|
"asc"
|
||||||
|
|
||||||
|
// "SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" +
|
||||||
|
// " WHERE event_nid IN ($1)"
|
||||||
|
// // Rest of query is built by BulkSelectStateEventByNID
|
||||||
|
const bulkSelectStateEventByNIDSQL = "" +
|
||||||
|
"select * from c where c._cn = @x1 " +
|
||||||
|
"and ARRAY_CONTAINS(@x2, c.mx_roomserver_event.event_nid) "
|
||||||
|
// Rest of query is built by BulkSelectStateEventByNID
|
||||||
|
|
||||||
// "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" +
|
// "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" +
|
||||||
// " WHERE event_id IN ($1)"
|
// " WHERE event_id IN ($1)"
|
||||||
const bulkSelectStateAtEventByIDSQL = "" +
|
const bulkSelectStateAtEventByIDSQL = "" +
|
||||||
|
|
@ -491,6 +500,83 @@ func (s *eventStatements) BulkSelectStateEventByID(
|
||||||
return results, err
|
return results, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
||||||
|
// If any of the requested events are missing from the database it returns a types.MissingEventError
|
||||||
|
func (s *eventStatements) BulkSelectStateEventByNID(
|
||||||
|
ctx context.Context, eventNIDs []types.EventNID,
|
||||||
|
stateKeyTuples []types.StateKeyTuple,
|
||||||
|
) ([]types.StateEntry, error) {
|
||||||
|
// "SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" +
|
||||||
|
// " WHERE event_nid IN ($1)"
|
||||||
|
// // Rest of query is built by BulkSelectStateEventByNID
|
||||||
|
tuples := stateKeyTupleSorter(stateKeyTuples)
|
||||||
|
sort.Sort(tuples)
|
||||||
|
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
|
||||||
|
// params := make([]interface{}, 0, len(eventNIDs)+len(eventTypeNIDArray)+len(eventStateKeyNIDArray))
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": eventNIDs,
|
||||||
|
}
|
||||||
|
// selectOrig := strings.Replace(bulkSelectStateEventByNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1)
|
||||||
|
selectOrig := bulkSelectStateEventByNIDSQL
|
||||||
|
// for _, v := range eventNIDs {
|
||||||
|
// params = append(params, v)
|
||||||
|
// }
|
||||||
|
if len(eventTypeNIDArray) > 0 {
|
||||||
|
// selectOrig += " AND event_type_nid IN " + sqlutil.QueryVariadicOffset(len(eventTypeNIDArray), len(params))
|
||||||
|
selectOrig += " and ARRAY_CONTAINS(@x3, c.mx_roomserver_event.event_type_nid) "
|
||||||
|
// for _, v := range eventTypeNIDArray {
|
||||||
|
// params = append(params, v)
|
||||||
|
// }
|
||||||
|
params["@x3"] = eventTypeNIDArray
|
||||||
|
}
|
||||||
|
if len(eventStateKeyNIDArray) > 0 {
|
||||||
|
// selectOrig += " AND event_state_key_nid IN " + sqlutil.QueryVariadicOffset(len(eventStateKeyNIDArray), len(params))
|
||||||
|
selectOrig += " and ARRAY_CONTAINS(@x4, c.mx_roomserver_event.event_state_key_nid) "
|
||||||
|
// for _, v := range eventStateKeyNIDArray {
|
||||||
|
// params = append(params, v)
|
||||||
|
// }
|
||||||
|
params["@x4"] = eventStateKeyNIDArray
|
||||||
|
}
|
||||||
|
// selectOrig += " ORDER BY event_type_nid, event_state_key_nid ASC"
|
||||||
|
//No Composite Index so just order by the 1st one
|
||||||
|
selectOrig += " order by c.mx_roomserver_event.event_type_nid asc "
|
||||||
|
// selectStmt, err := s.db.Prepare(selectOrig)
|
||||||
|
// if err != nil {
|
||||||
|
// return nil, fmt.Errorf("s.db.Prepare: %w", err)
|
||||||
|
// }
|
||||||
|
// rows, err := selectStmt.QueryContext(ctx, params...)
|
||||||
|
rows, err := queryEvent(s, ctx, selectOrig, params)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("selectStmt.QueryContext: %w", err)
|
||||||
|
}
|
||||||
|
// defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateEventByID: rows.close() failed")
|
||||||
|
// We know that we will only get as many results as event IDs
|
||||||
|
// because of the unique constraint on event IDs.
|
||||||
|
// So we can allocate an array of the correct size now.
|
||||||
|
// We might get fewer results than IDs so we adjust the length of the slice before returning it.
|
||||||
|
results := make([]types.StateEntry, len(eventNIDs))
|
||||||
|
i := 0
|
||||||
|
// for ; rows.Next(); i++ {
|
||||||
|
for _, item := range rows {
|
||||||
|
result := &results[i]
|
||||||
|
result.EventTypeNID = types.EventTypeNID(item.Event.EventTypeNID)
|
||||||
|
result.EventStateKeyNID = types.EventStateKeyNID(item.Event.EventStateKeyNID)
|
||||||
|
result.EventNID = types.EventNID(item.Event.EventNID)
|
||||||
|
// if err = rows.Scan(
|
||||||
|
// &result.EventTypeNID,
|
||||||
|
// &result.EventStateKeyNID,
|
||||||
|
// &result.EventNID,
|
||||||
|
// ); err != nil {
|
||||||
|
// return nil, err
|
||||||
|
// }
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
return results[:i], err
|
||||||
|
}
|
||||||
|
|
||||||
// bulkSelectStateAtEventByID lookups the state at a list of events by event ID.
|
// bulkSelectStateAtEventByID lookups the state at a list of events by event ID.
|
||||||
// If any of the requested events are missing from the database it returns a types.MissingEventError.
|
// If any of the requested events are missing from the database it returns a types.MissingEventError.
|
||||||
// If we do not have the state for any of the requested events it returns a types.MissingEventError.
|
// If we do not have the state for any of the requested events it returns a types.MissingEventError.
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@ import (
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||||
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
|
|
@ -160,6 +161,32 @@ var selectKnownUsersSQLDistinctRoom = "" +
|
||||||
"and c.mx_roomserver_membership.membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " " +
|
"and c.mx_roomserver_membership.membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " " +
|
||||||
"and contains(c.mx_roomserver_membership.event_state_key, @x3) "
|
"and contains(c.mx_roomserver_membership.event_state_key, @x3) "
|
||||||
|
|
||||||
|
// selectLocalServerInRoomSQL is an optimised case for checking if we, the local server,
|
||||||
|
// are in the room by using the target_local column of the membership table. Normally when
|
||||||
|
// we want to know if a server is in a room, we have to unmarshal the entire room state which
|
||||||
|
// is expensive. The presence of a single row from this query suggests we're still in the
|
||||||
|
// room, no rows returned suggests we aren't.
|
||||||
|
// "SELECT room_nid FROM roomserver_membership WHERE target_local = 1 AND membership_nid = $1 AND room_nid = $2 LIMIT 1"
|
||||||
|
const selectLocalServerInRoomSQL = "" +
|
||||||
|
"select top 1 * from c where c._cn = @x1 " +
|
||||||
|
" and c.mx_roomserver_membership.target_local = 1" +
|
||||||
|
" and c.mx_roomserver_membership.membership_nid = @x2" +
|
||||||
|
" and c.mx_roomserver_membership.room_nid = @x3"
|
||||||
|
|
||||||
|
// selectServerMembersInRoomSQL is an optimised case for checking for server members in a room.
|
||||||
|
// The JOIN is significantly leaner than the previous case of looking up event NIDs and reading the
|
||||||
|
// membership events from the database, as the JOIN query amounts to little more than two index
|
||||||
|
// scans which are very fast. The presence of a single row from this query suggests the server is
|
||||||
|
// in the room, no rows returned suggests they aren't.
|
||||||
|
// "SELECT room_nid FROM roomserver_membership" +
|
||||||
|
// " JOIN roomserver_event_state_keys ON roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" +
|
||||||
|
// " WHERE membership_nid = $1 AND room_nid = $2 AND event_state_key LIKE '%:' || $3 LIMIT 1"
|
||||||
|
const selectServerInRoomSQL = "" +
|
||||||
|
"select top 1 * from c where c._cn = @x1 " +
|
||||||
|
" and c.mx_roomserver_membership.membership_nid = @x2" +
|
||||||
|
" and c.mx_roomserver_membership.room_nid = @x3" +
|
||||||
|
" and contains(c.mx_roomserver_membership.target_nid, @x4) "
|
||||||
|
|
||||||
type membershipStatements struct {
|
type membershipStatements struct {
|
||||||
db *Database
|
db *Database
|
||||||
// insertMembershipStmt *sql.Stmt
|
// insertMembershipStmt *sql.Stmt
|
||||||
|
|
@ -172,6 +199,8 @@ type membershipStatements struct {
|
||||||
selectRoomsWithMembershipStmt string
|
selectRoomsWithMembershipStmt string
|
||||||
// updateMembershipStmt *sql.Stmt
|
// updateMembershipStmt *sql.Stmt
|
||||||
// selectKnownUsersStmt string
|
// selectKnownUsersStmt string
|
||||||
|
selectLocalServerInRoomStmt string
|
||||||
|
selectServerInRoomStmt string
|
||||||
// updateMembershipForgetRoomStmt *sql.Stmt
|
// updateMembershipForgetRoomStmt *sql.Stmt
|
||||||
tableName string
|
tableName string
|
||||||
}
|
}
|
||||||
|
|
@ -242,6 +271,8 @@ func NewCosmosDBMembershipTable(db *Database) (tables.Membership, error) {
|
||||||
// {&s.updateMembershipStmt, updateMembershipSQL},
|
// {&s.updateMembershipStmt, updateMembershipSQL},
|
||||||
s.selectRoomsWithMembershipStmt = selectRoomsWithMembershipSQL
|
s.selectRoomsWithMembershipStmt = selectRoomsWithMembershipSQL
|
||||||
// {&s.selectKnownUsersStmt, selectKnownUsersSQL},
|
// {&s.selectKnownUsersStmt, selectKnownUsersSQL},
|
||||||
|
s.selectLocalServerInRoomStmt = selectLocalServerInRoomSQL
|
||||||
|
s.selectServerInRoomStmt = selectServerInRoomSQL
|
||||||
// {&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom},
|
// {&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom},
|
||||||
// }.Prepare(db)
|
// }.Prepare(db)
|
||||||
|
|
||||||
|
|
@ -495,6 +526,91 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context,
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) {
|
||||||
|
// "SELECT room_nid FROM roomserver_membership WHERE target_local = 1 AND membership_nid = $1 AND room_nid = $2 LIMIT 1"
|
||||||
|
|
||||||
|
var nid types.RoomNID
|
||||||
|
// err := s.selectLocalServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid)
|
||||||
|
//
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": tables.MembershipStateJoin,
|
||||||
|
"@x3": roomNID,
|
||||||
|
}
|
||||||
|
response, err := queryMembership(s, ctx, s.selectLocalServerInRoomStmt, params)
|
||||||
|
if len(response) == 0 {
|
||||||
|
if err == cosmosdbutil.ErrNoRows {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
nid = types.RoomNID(response[0].Membership.RoomNID)
|
||||||
|
|
||||||
|
found := nid > 0
|
||||||
|
return found, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *membershipStatements) SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) {
|
||||||
|
var nid types.RoomNID
|
||||||
|
// "SELECT room_nid FROM roomserver_membership" +
|
||||||
|
// " JOIN roomserver_event_state_keys ON roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" +
|
||||||
|
// " WHERE membership_nid = $1 AND room_nid = $2 AND event_state_key LIKE '%:' || $3 LIMIT 1"
|
||||||
|
|
||||||
|
//First get the JOIN table
|
||||||
|
// SELECT event_state_key_nid FROM roomserver_event_state_keys
|
||||||
|
// WHERE event_state_key LIKE '%:' || $3 LIMIT 1
|
||||||
|
selectEventStateKeyNIDSQL := "" +
|
||||||
|
"select * from c where c._cn = @x1 " +
|
||||||
|
"and (endswith(c.mx_roomserver_event_state_keys.event_state_key, \":\") or c.mx_roomserver_event_state_keys.event_state_key = @x2) "
|
||||||
|
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": serverName,
|
||||||
|
}
|
||||||
|
|
||||||
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
|
||||||
|
var eventStateKeys []EventStateKeysCosmosData
|
||||||
|
|
||||||
|
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||||
|
var query = cosmosdbapi.GetQuery(selectEventStateKeyNIDSQL, params) //
|
||||||
|
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
query,
|
||||||
|
&eventStateKeys,
|
||||||
|
optionsQry)
|
||||||
|
|
||||||
|
eventStateKeyNids := []int64{}
|
||||||
|
for _, item := range eventStateKeys {
|
||||||
|
eventStateKeyNids = append(eventStateKeyNids, item.EventStateKeys.EventStateKeyNID)
|
||||||
|
}
|
||||||
|
|
||||||
|
//Now do the JOIN
|
||||||
|
// "SELECT room_nid FROM roomserver_membership" +
|
||||||
|
// " JOIN roomserver_event_state_keys ON roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" +
|
||||||
|
// " WHERE membership_nid = $1 AND room_nid = $2 AND event_state_key LIKE '%:' || $3 LIMIT 1"
|
||||||
|
|
||||||
|
// err := s.selectServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid)
|
||||||
|
params = map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": tables.MembershipStateJoin,
|
||||||
|
"@x3": roomNID,
|
||||||
|
"@x4": eventStateKeyNids,
|
||||||
|
}
|
||||||
|
response, err := queryMembership(s, ctx, s.selectServerInRoomStmt, params)
|
||||||
|
if len(response) == 0 {
|
||||||
|
if err == cosmosdbutil.ErrNoRows {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
nid = types.RoomNID(response[0].Membership.RoomNID)
|
||||||
|
return roomNID == nid, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) {
|
func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) {
|
||||||
|
|
||||||
// " SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
|
// " SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
|
||||||
|
|
|
||||||
13
roomserver/storage/cosmosdb/state_blob_seq.go
Normal file
13
roomserver/storage/cosmosdb/state_blob_seq.go
Normal file
|
|
@ -0,0 +1,13 @@
|
||||||
|
package cosmosdb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetNextStateBlockNID(s *stateBlockStatements, ctx context.Context) (int64, error) {
|
||||||
|
const docId = "stateblocknid_seq"
|
||||||
|
//1 insert start at 2
|
||||||
|
return cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 2)
|
||||||
|
}
|
||||||
|
|
@ -18,12 +18,13 @@ package cosmosdb
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
"sort"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||||
|
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
|
|
@ -31,19 +32,21 @@ import (
|
||||||
|
|
||||||
// const stateDataSchema = `
|
// const stateDataSchema = `
|
||||||
// CREATE TABLE IF NOT EXISTS roomserver_state_block (
|
// CREATE TABLE IF NOT EXISTS roomserver_state_block (
|
||||||
// state_block_nid INTEGER NOT NULL,
|
// -- The state snapshot NID that identifies this snapshot.
|
||||||
// event_type_nid INTEGER NOT NULL,
|
// state_block_nid INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
// event_state_key_nid INTEGER NOT NULL,
|
// -- The hash of the state block, which is used to enforce uniqueness. The hash is
|
||||||
// event_nid INTEGER NOT NULL,
|
// -- generated in Dendrite and passed through to the database, as a btree index over
|
||||||
// UNIQUE (state_block_nid, event_type_nid, event_state_key_nid)
|
// -- this column is cheap and fits within the maximum index size.
|
||||||
|
// state_block_hash BLOB UNIQUE,
|
||||||
|
// -- The event NIDs contained within the state block, encoded as JSON.
|
||||||
|
// event_nids TEXT NOT NULL DEFAULT '[]'
|
||||||
// );
|
// );
|
||||||
// `
|
// `
|
||||||
|
|
||||||
type StateBlockCosmos struct {
|
type StateBlockCosmos struct {
|
||||||
StateBlockNID int64 `json:"state_block_nid"`
|
StateBlockNID int64 `json:"state_block_nid"`
|
||||||
EventTypeNID int64 `json:"event_type_nid"`
|
StateBlockHash []byte `json:"state_block_hash"`
|
||||||
EventStateKeyNID int64 `json:"event_state_key_nid"`
|
EventNIDs []int64 `json:"event_nids"`
|
||||||
EventNID int64 `json:"event_nid"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type StateBlockCosmosMaxNID struct {
|
type StateBlockCosmosMaxNID struct {
|
||||||
|
|
@ -60,63 +63,29 @@ type StateBlockCosmosData struct {
|
||||||
StateBlock StateBlockCosmos `json:"mx_roomserver_state_block"`
|
StateBlock StateBlockCosmos `json:"mx_roomserver_state_block"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// const insertStateDataSQL = "" +
|
// Insert a new state block. If we conflict on the hash column then
|
||||||
// "INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" +
|
// we must perform an update so that the RETURNING statement returns the
|
||||||
// " VALUES ($1, $2, $3, $4)"
|
// ID of the row that we conflicted with, so that we can then refer to
|
||||||
|
// the original block.
|
||||||
|
// const insertStateDataSQL = `
|
||||||
|
// INSERT INTO roomserver_state_block (state_block_hash, event_nids)
|
||||||
|
// VALUES ($1, $2)
|
||||||
|
// ON CONFLICT (state_block_hash) DO UPDATE SET event_nids=$2
|
||||||
|
// RETURNING state_block_nid
|
||||||
|
// `
|
||||||
|
|
||||||
// SELECT IFNULL(MAX(state_block_nid), 0) + 1 FROM roomserver_state_block
|
// "SELECT state_block_nid, event_nids" +
|
||||||
const selectNextStateBlockNIDSQL = "" +
|
// " FROM roomserver_state_block WHERE state_block_nid IN ($1) ORDER BY state_block_nid ASC"
|
||||||
"select sub.maxinner != null ? sub.maxinner + 1 : 1 as maxstateblocknid " +
|
|
||||||
"from " +
|
|
||||||
"(select MAX(c.mx_roomserver_state_block.state_block_nid) maxinner from c where c._sid = @x1 and c._cn = @x2) as sub"
|
|
||||||
|
|
||||||
// Bulk state lookup by numeric state block ID.
|
|
||||||
// Sort by the state_block_nid, event_type_nid, event_state_key_nid
|
|
||||||
// This means that all the entries for a given state_block_nid will appear
|
|
||||||
// together in the list and those entries will sorted by event_type_nid
|
|
||||||
// and event_state_key_nid. This property makes it easier to merge two
|
|
||||||
// state data blocks together.
|
|
||||||
// "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
|
|
||||||
// " FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
|
|
||||||
// " ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
|
|
||||||
const bulkSelectStateBlockEntriesSQL = "" +
|
const bulkSelectStateBlockEntriesSQL = "" +
|
||||||
"select * from c where c._cn = @x1 " +
|
"select * from c where c._cn = @x1 " +
|
||||||
"and ARRAY_CONTAINS(@x2, c.mx_roomserver_state_block.state_block_nid) " +
|
"and ARRAY_CONTAINS(@x2, c.mx_roomserver_state_block.state_block_nid) " +
|
||||||
"order by c.mx_roomserver_state_block.state_block_nid " +
|
"order by c.mx_roomserver_state_block.state_block_nid "
|
||||||
// Cant do multi field order by - The order by query does not have a corresponding composite index that it can be served from
|
|
||||||
// ", c.mx_roomserver_state_block.event_type_nid " +
|
|
||||||
// ", c.mx_roomserver_state_block.event_state_key_nid " +
|
|
||||||
" asc"
|
|
||||||
|
|
||||||
// Bulk state lookup by numeric state block ID.
|
|
||||||
// Filters the rows in each block to the requested types and state keys.
|
|
||||||
// We would like to restrict to particular type state key pairs but we are
|
|
||||||
// restricted by the query language to pull the cross product of a list
|
|
||||||
// of types and a list state_keys. So we have to filter the result in the
|
|
||||||
// application to restrict it to the list of event types and state keys we
|
|
||||||
// actually wanted.
|
|
||||||
// "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
|
|
||||||
// " FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
|
|
||||||
// " AND event_type_nid IN ($2) AND event_state_key_nid IN ($3)" +
|
|
||||||
// " ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
|
|
||||||
const bulkSelectFilteredStateBlockEntriesSQL = "" +
|
|
||||||
"select * from c where c._cn = @x1 " +
|
|
||||||
"and ARRAY_CONTAINS(@x2, c.mx_roomserver_state_block.state_block_nid) " +
|
|
||||||
"and ARRAY_CONTAINS(@x3, c.mx_roomserver_state_block.event_type_nid) " +
|
|
||||||
"and ARRAY_CONTAINS(@x4, c.mx_roomserver_state_block.event_state_key_nid) " +
|
|
||||||
"order by c.mx_roomserver_state_block.state_block_nid " +
|
|
||||||
// Cant do multi field order by - The order by query does not have a corresponding composite index that it can be served from
|
|
||||||
// ", c.mx_roomserver_state_block.event_type_nid " +
|
|
||||||
// ", c.mx_roomserver_state_block.event_state_key_nid " +
|
|
||||||
"asc"
|
|
||||||
|
|
||||||
type stateBlockStatements struct {
|
type stateBlockStatements struct {
|
||||||
db *Database
|
db *Database
|
||||||
// insertStateDataStmt *sql.Stmt
|
// insertStateDataStmt *sql.Stmt
|
||||||
selectNextStateBlockNIDStmt string
|
bulkSelectStateBlockEntriesStmt string
|
||||||
bulkSelectStateBlockEntriesStmt string
|
tableName string
|
||||||
bulkSelectFilteredStateBlockEntriesStmt string
|
|
||||||
tableName string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func queryStateBlock(s *stateBlockStatements, ctx context.Context, qry string, params map[string]interface{}) ([]StateBlockCosmosData, error) {
|
func queryStateBlock(s *stateBlockStatements, ctx context.Context, qry string, params map[string]interface{}) ([]StateBlockCosmosData, error) {
|
||||||
|
|
@ -140,39 +109,107 @@ func queryStateBlock(s *stateBlockStatements, ctx context.Context, qry string, p
|
||||||
return response, nil
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getStateBlock(s *stateBlockStatements, ctx context.Context, pk string, docId string) (*StateBlockCosmosData, error) {
|
||||||
|
response := StateBlockCosmosData{}
|
||||||
|
err := cosmosdbapi.GetDocumentOrNil(
|
||||||
|
s.db.connection,
|
||||||
|
s.db.cosmosConfig,
|
||||||
|
ctx,
|
||||||
|
pk,
|
||||||
|
docId,
|
||||||
|
&response)
|
||||||
|
|
||||||
|
if response.Id == "" {
|
||||||
|
return nil, cosmosdbutil.ErrNoRows
|
||||||
|
}
|
||||||
|
|
||||||
|
return &response, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func setStateBlock(s *stateBlockStatements, ctx context.Context, item StateBlockCosmosData) (*StateBlockCosmosData, error) {
|
||||||
|
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(item.Pk, item.ETag)
|
||||||
|
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
item.Id,
|
||||||
|
&item,
|
||||||
|
optionsReplace)
|
||||||
|
return &item, ex
|
||||||
|
}
|
||||||
|
|
||||||
func NewCosmosDBStateBlockTable(db *Database) (tables.StateBlock, error) {
|
func NewCosmosDBStateBlockTable(db *Database) (tables.StateBlock, error) {
|
||||||
s := &stateBlockStatements{
|
s := &stateBlockStatements{
|
||||||
db: db,
|
db: db,
|
||||||
}
|
}
|
||||||
|
|
||||||
// return s, shared.StatementList{
|
// s.insertStateDataStmt = insertStateDataSQL
|
||||||
// {&s.insertStateDataStmt, insertStateDataSQL},
|
|
||||||
s.selectNextStateBlockNIDStmt = selectNextStateBlockNIDSQL
|
|
||||||
s.bulkSelectStateBlockEntriesStmt = bulkSelectStateBlockEntriesSQL
|
s.bulkSelectStateBlockEntriesStmt = bulkSelectStateBlockEntriesSQL
|
||||||
s.bulkSelectFilteredStateBlockEntriesStmt = bulkSelectFilteredStateBlockEntriesSQL
|
|
||||||
// }.Prepare(db)
|
|
||||||
s.tableName = "state_block"
|
s.tableName = "state_block"
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func inertStateBlockCore(s *stateBlockStatements, ctx context.Context, stateBlockNID types.StateBlockNID, entry types.StateEntry) error {
|
func (s *stateBlockStatements) BulkInsertStateData(
|
||||||
|
ctx context.Context,
|
||||||
|
txn *sql.Tx,
|
||||||
|
entries types.StateEntries,
|
||||||
|
) (id types.StateBlockNID, err error) {
|
||||||
|
// INSERT INTO roomserver_state_block (state_block_hash, event_nids)
|
||||||
|
// VALUES ($1, $2)
|
||||||
|
// ON CONFLICT (state_block_hash) DO UPDATE SET event_nids=$2
|
||||||
|
// RETURNING state_block_nid
|
||||||
|
|
||||||
// "INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" +
|
entries = entries[:util.SortAndUnique(entries)]
|
||||||
// " VALUES ($1, $2, $3, $4)"
|
nids := types.EventNIDs{} // zero slice to not store 'null' in the DB
|
||||||
data := StateBlockCosmos{
|
ids := []int64{}
|
||||||
EventNID: int64(entry.EventNID),
|
for _, e := range entries {
|
||||||
EventStateKeyNID: int64(entry.EventStateKeyNID),
|
nids = append(nids, e.EventNID)
|
||||||
EventTypeNID: int64(entry.EventTypeNID),
|
ids = append(ids, int64(e.EventNID))
|
||||||
StateBlockNID: int64(stateBlockNID),
|
|
||||||
}
|
}
|
||||||
|
// js, err := json.Marshal(nids)
|
||||||
|
// if err != nil {
|
||||||
|
// return 0, fmt.Errorf("json.Marshal: %w", err)
|
||||||
|
// }
|
||||||
|
// err = s.insertStateDataStmt.QueryRowContext(
|
||||||
|
// ctx, nids.Hash(), js,
|
||||||
|
// ).Scan(&id)
|
||||||
|
|
||||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
|
||||||
// UNIQUE (state_block_nid, event_type_nid, event_state_key_nid)
|
// state_block_hash BLOB UNIQUE,
|
||||||
docId := fmt.Sprintf("%d_%d_%d", data.StateBlockNID, data.EventTypeNID, data.EventStateKeyNID)
|
docId := hex.EncodeToString(nids.Hash())
|
||||||
|
|
||||||
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
|
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
|
||||||
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
|
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
|
||||||
|
|
||||||
|
//See if it exists
|
||||||
|
existing, err := getStateBlock(s, ctx, pk, cosmosDocId)
|
||||||
|
if err != nil {
|
||||||
|
if err != cosmosdbutil.ErrNoRows {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if existing != nil {
|
||||||
|
//if exists, just update and dont create a new seq
|
||||||
|
existing.StateBlock.EventNIDs = ids
|
||||||
|
_, err = setStateBlock(s, ctx, *existing)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return types.StateBlockNID(existing.StateBlock.StateBlockNID), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//Doesnt exist,create a new one
|
||||||
|
// state_block_nid INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
seq, err := GetNextStateBlockNID(s, ctx)
|
||||||
|
id = types.StateBlockNID(seq)
|
||||||
|
|
||||||
|
data := StateBlockCosmos{
|
||||||
|
StateBlockNID: seq,
|
||||||
|
StateBlockHash: nids.Hash(),
|
||||||
|
EventNIDs: ids,
|
||||||
|
}
|
||||||
|
|
||||||
var dbData = StateBlockCosmosData{
|
var dbData = StateBlockCosmosData{
|
||||||
Id: cosmosDocId,
|
Id: cosmosDocId,
|
||||||
Tn: s.db.cosmosConfig.TenantName,
|
Tn: s.db.cosmosConfig.TenantName,
|
||||||
|
|
@ -182,187 +219,72 @@ func inertStateBlockCore(s *stateBlockStatements, ctx context.Context, stateBloc
|
||||||
StateBlock: data,
|
StateBlock: data,
|
||||||
}
|
}
|
||||||
|
|
||||||
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
|
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
|
||||||
_, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||||
ctx,
|
ctx,
|
||||||
s.db.cosmosConfig.DatabaseName,
|
s.db.cosmosConfig.DatabaseName,
|
||||||
s.db.cosmosConfig.ContainerName,
|
s.db.cosmosConfig.ContainerName,
|
||||||
&dbData,
|
&dbData,
|
||||||
options)
|
options)
|
||||||
|
|
||||||
return err
|
return
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func getNextStateBlockNID(s *stateBlockStatements, ctx context.Context) (int64, error) {
|
|
||||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
|
||||||
var stateBlockNext []StateBlockCosmosMaxNID
|
|
||||||
params := map[string]interface{}{
|
|
||||||
"@x1": s.db.cosmosConfig.TenantName,
|
|
||||||
"@x2": dbCollectionName,
|
|
||||||
}
|
|
||||||
|
|
||||||
var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions()
|
|
||||||
var query = cosmosdbapi.GetQuery(s.selectNextStateBlockNIDStmt, params)
|
|
||||||
var _, err = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
|
||||||
ctx,
|
|
||||||
s.db.cosmosConfig.DatabaseName,
|
|
||||||
s.db.cosmosConfig.ContainerName,
|
|
||||||
query,
|
|
||||||
&stateBlockNext,
|
|
||||||
optionsQry)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return stateBlockNext[0].Max, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stateBlockStatements) BulkInsertStateData(
|
|
||||||
ctx context.Context, txn *sql.Tx,
|
|
||||||
entries []types.StateEntry,
|
|
||||||
) (types.StateBlockNID, error) {
|
|
||||||
if len(entries) == 0 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
nextID, errNextID := getNextStateBlockNID(s, ctx)
|
|
||||||
if errNextID != nil {
|
|
||||||
return 0, errNextID
|
|
||||||
}
|
|
||||||
|
|
||||||
stateBlockNID := types.StateBlockNID(nextID)
|
|
||||||
|
|
||||||
for _, entry := range entries {
|
|
||||||
err := inertStateBlockCore(s, ctx, stateBlockNID, entry)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return stateBlockNID, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stateBlockStatements) BulkSelectStateBlockEntries(
|
func (s *stateBlockStatements) BulkSelectStateBlockEntries(
|
||||||
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
|
ctx context.Context, stateBlockNIDs types.StateBlockNIDs,
|
||||||
) ([]types.StateEntryList, error) {
|
) ([][]types.EventNID, error) {
|
||||||
|
// "SELECT state_block_nid, event_nids" +
|
||||||
|
// " FROM roomserver_state_block WHERE state_block_nid IN ($1) ORDER BY state_block_nid ASC"
|
||||||
|
|
||||||
// "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
|
intfs := make([]interface{}, len(stateBlockNIDs))
|
||||||
// " FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
|
for i := range stateBlockNIDs {
|
||||||
// " ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
|
intfs[i] = int64(stateBlockNIDs[i])
|
||||||
|
}
|
||||||
|
// selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(intfs)), 1)
|
||||||
|
// selectStmt, err := s.db.Prepare(selectOrig)
|
||||||
|
// if err != nil {
|
||||||
|
// return nil, err
|
||||||
|
// }
|
||||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
var response []StateBlockCosmosData
|
|
||||||
params := map[string]interface{}{
|
params := map[string]interface{}{
|
||||||
"@x1": dbCollectionName,
|
"@x1": dbCollectionName,
|
||||||
"@x2": stateBlockNIDs,
|
"@x2": stateBlockNIDs,
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := queryStateBlock(s, ctx, s.bulkSelectStateBlockEntriesStmt, params)
|
// rows, err := selectStmt.QueryContext(ctx, intfs...)
|
||||||
|
rows, err := queryStateBlock(s, ctx, s.bulkSelectStateBlockEntriesStmt, params)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
// defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed")
|
||||||
|
|
||||||
results := make([]types.StateEntryList, len(stateBlockNIDs))
|
results := make([][]types.EventNID, len(stateBlockNIDs))
|
||||||
// current is a pointer to the StateEntryList to append the state entries to.
|
|
||||||
var current *types.StateEntryList
|
|
||||||
i := 0
|
i := 0
|
||||||
for _, item := range response {
|
// for ; rows.Next(); i++ {
|
||||||
entry := types.StateEntry{}
|
for _, item := range rows {
|
||||||
entry.EventTypeNID = types.EventTypeNID(item.StateBlock.EventTypeNID)
|
// var stateBlockNID types.StateBlockNID
|
||||||
entry.EventStateKeyNID = types.EventStateKeyNID(item.StateBlock.EventStateKeyNID)
|
// var result json.RawMessage
|
||||||
entry.EventNID = types.EventNID(item.StateBlock.EventNID)
|
// if err = rows.Scan(&stateBlockNID, &result); err != nil {
|
||||||
|
// return nil, err
|
||||||
if current == nil || types.StateBlockNID(item.StateBlock.StateBlockNID) != current.StateBlockNID {
|
// }
|
||||||
// The state entry row is for a different state data block to the current one.
|
r := []types.EventNID{}
|
||||||
// So we start appending to the next entry in the list.
|
// if err = json.Unmarshal(result, &r); err != nil {
|
||||||
current = &results[i]
|
// return nil, fmt.Errorf("json.Unmarshal: %w", err)
|
||||||
current.StateBlockNID = types.StateBlockNID(item.StateBlock.StateBlockNID)
|
// }
|
||||||
i++
|
for _, eventNID := range item.StateBlock.EventNIDs {
|
||||||
|
r = append(r, types.EventNID(eventNID))
|
||||||
}
|
}
|
||||||
current.StateEntries = append(current.StateEntries, entry)
|
results[i] = r
|
||||||
|
i++
|
||||||
}
|
}
|
||||||
|
// if err = rows.Err(); err != nil {
|
||||||
|
// return nil, err
|
||||||
|
// }
|
||||||
if i != len(stateBlockNIDs) {
|
if i != len(stateBlockNIDs) {
|
||||||
return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(stateBlockNIDs))
|
return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", len(results), len(stateBlockNIDs))
|
||||||
}
|
}
|
||||||
return results, nil
|
return results, err
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries(
|
|
||||||
ctx context.Context,
|
|
||||||
stateBlockNIDs []types.StateBlockNID,
|
|
||||||
stateKeyTuples []types.StateKeyTuple,
|
|
||||||
) ([]types.StateEntryList, error) {
|
|
||||||
tuples := stateKeyTupleSorter(stateKeyTuples)
|
|
||||||
// Sort the tuples so that we can run binary search against them as we filter the rows returned by the db.
|
|
||||||
sort.Sort(tuples)
|
|
||||||
|
|
||||||
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
|
|
||||||
// sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(stateBlockNIDs)), 1)
|
|
||||||
// sqlStatement = strings.Replace(sqlStatement, "($2)", sqlutil.QueryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1)
|
|
||||||
// sqlStatement = strings.Replace(sqlStatement, "($3)", sqlutil.QueryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1)
|
|
||||||
|
|
||||||
// "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
|
|
||||||
// " FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
|
|
||||||
// " AND event_type_nid IN ($2) AND event_state_key_nid IN ($3)" +
|
|
||||||
// " ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
|
|
||||||
|
|
||||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
|
||||||
var response []StateBlockCosmosData
|
|
||||||
params := map[string]interface{}{
|
|
||||||
"@x1": dbCollectionName,
|
|
||||||
"@x2": stateBlockNIDs,
|
|
||||||
"@x3": eventTypeNIDArray,
|
|
||||||
"@x4": eventStateKeyNIDArray,
|
|
||||||
}
|
|
||||||
|
|
||||||
response, err := queryStateBlock(s, ctx, s.bulkSelectFilteredStateBlockEntriesStmt, params)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var results []types.StateEntryList
|
|
||||||
var current types.StateEntryList
|
|
||||||
for _, item := range response {
|
|
||||||
var (
|
|
||||||
stateBlockNID int64
|
|
||||||
eventTypeNID int64
|
|
||||||
eventStateKeyNID int64
|
|
||||||
eventNID int64
|
|
||||||
entry types.StateEntry
|
|
||||||
)
|
|
||||||
stateBlockNID = item.StateBlock.StateBlockNID
|
|
||||||
eventTypeNID = item.StateBlock.EventTypeNID
|
|
||||||
eventStateKeyNID = item.StateBlock.EventStateKeyNID
|
|
||||||
eventNID = item.StateBlock.EventNID
|
|
||||||
entry.EventTypeNID = types.EventTypeNID(eventTypeNID)
|
|
||||||
entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
|
|
||||||
entry.EventNID = types.EventNID(eventNID)
|
|
||||||
|
|
||||||
// We can use binary search here because we sorted the tuples earlier
|
|
||||||
if !tuples.contains(entry.StateKeyTuple) {
|
|
||||||
// The select will return the cross product of types and state keys.
|
|
||||||
// So we need to check if type of the entry is in the list.
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if types.StateBlockNID(stateBlockNID) != current.StateBlockNID {
|
|
||||||
// The state entry row is for a different state data block to the current one.
|
|
||||||
// So we append the current entry to the results and start adding to a new one.
|
|
||||||
// The first time through the loop current will be empty.
|
|
||||||
if current.StateEntries != nil {
|
|
||||||
results = append(results, current)
|
|
||||||
}
|
|
||||||
current = types.StateEntryList{StateBlockNID: types.StateBlockNID(stateBlockNID)}
|
|
||||||
}
|
|
||||||
current.StateEntries = append(current.StateEntries, entry)
|
|
||||||
}
|
|
||||||
// Add the last entry to the list if it is not empty.
|
|
||||||
if current.StateEntries != nil {
|
|
||||||
results = append(results, current)
|
|
||||||
}
|
|
||||||
return results, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type stateKeyTupleSorter []types.StateKeyTuple
|
type stateKeyTupleSorter []types.StateKeyTuple
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
// const stateSnapshotSchema = `
|
// const stateSnapshotSchema = `
|
||||||
|
|
@ -34,10 +35,24 @@ import (
|
||||||
// );
|
// );
|
||||||
// `
|
// `
|
||||||
|
|
||||||
|
// CREATE TABLE IF NOT EXISTS roomserver_state_snapshots (
|
||||||
|
// -- The state snapshot NID that identifies this snapshot.
|
||||||
|
// state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
// -- The hash of the state snapshot, which is used to enforce uniqueness. The hash is
|
||||||
|
// -- generated in Dendrite and passed through to the database, as a btree index over
|
||||||
|
// -- this column is cheap and fits within the maximum index size.
|
||||||
|
// state_snapshot_hash BLOB UNIQUE,
|
||||||
|
// -- The room NID that the snapshot belongs to.
|
||||||
|
// room_nid INTEGER NOT NULL,
|
||||||
|
// -- The state blocks contained within this snapshot, encoded as JSON.
|
||||||
|
// state_block_nids TEXT NOT NULL DEFAULT '[]'
|
||||||
|
// );
|
||||||
|
|
||||||
type StateSnapshotCosmos struct {
|
type StateSnapshotCosmos struct {
|
||||||
StateSnapshotNID int64 `json:"state_snapshot_nid"`
|
StateSnapshotNID int64 `json:"state_snapshot_nid"`
|
||||||
RoomNID int64 `json:"room_nid"`
|
StateSnapshotHash []byte `json:"state_snapshot_hash"`
|
||||||
StateBlockNIDs []int64 `json:"state_block_nids"`
|
RoomNID int64 `json:"room_nid"`
|
||||||
|
StateBlockNIDs []int64 `json:"state_block_nids"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type StateSnapshotCosmosData struct {
|
type StateSnapshotCosmosData struct {
|
||||||
|
|
@ -51,8 +66,10 @@ type StateSnapshotCosmosData struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// const insertStateSQL = `
|
// const insertStateSQL = `
|
||||||
// INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids)
|
// INSERT INTO roomserver_state_snapshots (state_snapshot_hash, room_nid, state_block_nids)
|
||||||
// VALUES ($1, $2);`
|
// VALUES ($1, $2, $3)
|
||||||
|
// ON CONFLICT (state_snapshot_hash) DO UPDATE SET room_nid=$2
|
||||||
|
// RETURNING state_snapshot_nid
|
||||||
|
|
||||||
// Bulk state data NID lookup.
|
// Bulk state data NID lookup.
|
||||||
// Sorting by state_snapshot_nid means we can use binary search over the result
|
// Sorting by state_snapshot_nid means we can use binary search over the result
|
||||||
|
|
@ -101,20 +118,32 @@ func NewCosmosDBStateSnapshotTable(db *Database) (tables.StateSnapshot, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stateSnapshotStatements) InsertState(
|
func (s *stateSnapshotStatements) InsertState(
|
||||||
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID,
|
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs types.StateBlockNIDs,
|
||||||
) (stateNID types.StateSnapshotNID, err error) {
|
) (stateNID types.StateSnapshotNID, err error) {
|
||||||
|
|
||||||
// INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids)
|
// INSERT INTO roomserver_state_snapshots (state_snapshot_hash, room_nid, state_block_nids)
|
||||||
// VALUES ($1, $2);`
|
// VALUES ($1, $2, $3)
|
||||||
|
// ON CONFLICT (state_snapshot_hash) DO UPDATE SET room_nid=$2
|
||||||
|
// RETURNING state_snapshot_nid
|
||||||
stateSnapshotNIDSeq, seqErr := GetNextStateSnapshotNID(s, ctx)
|
stateSnapshotNIDSeq, seqErr := GetNextStateSnapshotNID(s, ctx)
|
||||||
if seqErr != nil {
|
if seqErr != nil {
|
||||||
return 0, seqErr
|
return 0, seqErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if stateBlockNIDs == nil {
|
||||||
|
stateBlockNIDs = []types.StateBlockNID{} // zero slice to not store 'null' in the DB
|
||||||
|
}
|
||||||
|
stateBlockNIDs = stateBlockNIDs[:util.SortAndUnique(stateBlockNIDs)]
|
||||||
|
// stateBlockNIDsJSON, err := json.Marshal(stateBlockNIDs)
|
||||||
|
// if err != nil {
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
|
||||||
data := StateSnapshotCosmos{
|
data := StateSnapshotCosmos{
|
||||||
RoomNID: int64(roomNID),
|
RoomNID: int64(roomNID),
|
||||||
StateBlockNIDs: mapFromStateBlockNIDArray(stateBlockNIDs),
|
StateSnapshotHash: stateBlockNIDs.Hash(),
|
||||||
StateSnapshotNID: int64(stateSnapshotNIDSeq),
|
StateBlockNIDs: mapFromStateBlockNIDArray(stateBlockNIDs),
|
||||||
|
StateSnapshotNID: int64(stateSnapshotNIDSeq),
|
||||||
}
|
}
|
||||||
|
|
||||||
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
@ -80,10 +79,13 @@ const selectSendToDeviceMessagesSQL = "" +
|
||||||
"and c.mx_syncapi_send_to_device.id <= @x5 " +
|
"and c.mx_syncapi_send_to_device.id <= @x5 " +
|
||||||
"order by c.mx_syncapi_send_to_device.id desc "
|
"order by c.mx_syncapi_send_to_device.id desc "
|
||||||
|
|
||||||
const deleteSendToDeviceMessagesSQL = `
|
// DELETE FROM syncapi_send_to_device
|
||||||
DELETE FROM syncapi_send_to_device
|
// WHERE user_id = $1 AND device_id = $2 AND id < $3
|
||||||
WHERE user_id = $1 AND device_id = $2 AND id < $3
|
const deleteSendToDeviceMessagesSQL = "" +
|
||||||
`
|
"select * from c where c._cn = @x1 " +
|
||||||
|
"and c.mx_syncapi_send_to_device.user_id = @x2 " +
|
||||||
|
"and c.mx_syncapi_send_to_device.device_id = @x3 " +
|
||||||
|
"and c.mx_syncapi_send_to_device.id < @x4 "
|
||||||
|
|
||||||
// "SELECT MAX(id) FROM syncapi_send_to_device"
|
// "SELECT MAX(id) FROM syncapi_send_to_device"
|
||||||
const selectMaxSendToDeviceIDSQL = "" +
|
const selectMaxSendToDeviceIDSQL = "" +
|
||||||
|
|
@ -93,7 +95,7 @@ type sendToDeviceStatements struct {
|
||||||
db *SyncServerDatasource
|
db *SyncServerDatasource
|
||||||
// insertSendToDeviceMessageStmt *sql.Stmt
|
// insertSendToDeviceMessageStmt *sql.Stmt
|
||||||
selectSendToDeviceMessagesStmt string
|
selectSendToDeviceMessagesStmt string
|
||||||
deleteSendToDeviceMessagesStmt *sql.Stmt
|
deleteSendToDeviceMessagesStmt string
|
||||||
selectMaxSendToDeviceIDStmt string
|
selectMaxSendToDeviceIDStmt string
|
||||||
tableName string
|
tableName string
|
||||||
}
|
}
|
||||||
|
|
@ -140,6 +142,21 @@ func querySendToDeviceNumber(s *sendToDeviceStatements, ctx context.Context, qry
|
||||||
return response, nil
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func deleteSendToDevice(s *sendToDeviceStatements, ctx context.Context, dbData SendToDeviceCosmosData) error {
|
||||||
|
var options = cosmosdbapi.GetDeleteDocumentOptions(dbData.Pk)
|
||||||
|
var _, err = cosmosdbapi.GetClient(s.db.connection).DeleteDocument(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
dbData.Id,
|
||||||
|
options)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func NewCosmosDBSendToDeviceTable(db *SyncServerDatasource) (tables.SendToDevice, error) {
|
func NewCosmosDBSendToDeviceTable(db *SyncServerDatasource) (tables.SendToDevice, error) {
|
||||||
s := &sendToDeviceStatements{
|
s := &sendToDeviceStatements{
|
||||||
db: db,
|
db: db,
|
||||||
|
|
@ -148,9 +165,7 @@ func NewCosmosDBSendToDeviceTable(db *SyncServerDatasource) (tables.SendToDevice
|
||||||
// return nil, err
|
// return nil, err
|
||||||
// }
|
// }
|
||||||
s.selectSendToDeviceMessagesStmt = selectSendToDeviceMessagesSQL
|
s.selectSendToDeviceMessagesStmt = selectSendToDeviceMessagesSQL
|
||||||
// if s.deleteSendToDeviceMessagesStmt, err = db.Prepare(deleteSendToDeviceMessagesSQL); err != nil {
|
s.deleteSendToDeviceMessagesStmt = deleteSendToDeviceMessagesSQL
|
||||||
// return nil, err
|
|
||||||
// }
|
|
||||||
s.selectMaxSendToDeviceIDStmt = selectMaxSendToDeviceIDSQL
|
s.selectMaxSendToDeviceIDStmt = selectMaxSendToDeviceIDSQL
|
||||||
s.tableName = "send_to_device"
|
s.tableName = "send_to_device"
|
||||||
return s, nil
|
return s, nil
|
||||||
|
|
@ -260,7 +275,29 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
|
||||||
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
|
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
|
||||||
ctx context.Context, txn *sql.Tx, userID, deviceID string, pos types.StreamPosition,
|
ctx context.Context, txn *sql.Tx, userID, deviceID string, pos types.StreamPosition,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = sqlutil.TxStmt(txn, s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, userID, deviceID, pos)
|
// DELETE FROM syncapi_send_to_device
|
||||||
|
// WHERE user_id = $1 AND device_id = $2 AND id < $3
|
||||||
|
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": userID,
|
||||||
|
"@x3": deviceID,
|
||||||
|
"@x4": pos,
|
||||||
|
}
|
||||||
|
|
||||||
|
// _, err = sqlutil.TxStmt(txn, s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, userID, deviceID, pos)
|
||||||
|
rows, err := querySendToDevice(s, ctx, s.deleteSendToDeviceMessagesStmt, params)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, item := range rows {
|
||||||
|
err = deleteSendToDevice(s, ctx, item)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
414
userapi/storage/accounts/cosmosdb/key_backup_table.go
Normal file
414
userapi/storage/accounts/cosmosdb/key_backup_table.go
Normal file
|
|
@ -0,0 +1,414 @@
|
||||||
|
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package cosmosdb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
// const keyBackupTableSchema = `
|
||||||
|
// CREATE TABLE IF NOT EXISTS account_e2e_room_keys (
|
||||||
|
// user_id TEXT NOT NULL,
|
||||||
|
// room_id TEXT NOT NULL,
|
||||||
|
// session_id TEXT NOT NULL,
|
||||||
|
|
||||||
|
// version TEXT NOT NULL,
|
||||||
|
// first_message_index INTEGER NOT NULL,
|
||||||
|
// forwarded_count INTEGER NOT NULL,
|
||||||
|
// is_verified BOOLEAN NOT NULL,
|
||||||
|
// session_data TEXT NOT NULL
|
||||||
|
// );
|
||||||
|
// CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version);
|
||||||
|
// CREATE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version);
|
||||||
|
// `
|
||||||
|
|
||||||
|
type KeyBackupCosmosData struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Pk string `json:"_pk"`
|
||||||
|
Tn string `json:"_sid"`
|
||||||
|
Cn string `json:"_cn"`
|
||||||
|
ETag string `json:"_etag"`
|
||||||
|
Timestamp int64 `json:"_ts"`
|
||||||
|
KeyBackup KeyBackupCosmos `json:"mx_userapi_account_e2e_room_keys"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type KeyBackupCosmos struct {
|
||||||
|
UserId string `json:"user_id"`
|
||||||
|
RoomId string `json:"room_id"`
|
||||||
|
SessionId string `json:"session_id"`
|
||||||
|
Version string `json:"vesion"`
|
||||||
|
FirstMessageIndex int `json:"first_message_index"`
|
||||||
|
ForwardedCount int `json:"forwarded_count"`
|
||||||
|
IsVerified bool `json:"is_verified"`
|
||||||
|
SessionData []byte `json:"session_data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type KeyBackupCosmosNumber struct {
|
||||||
|
Number int64 `json:"number"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// const insertBackupKeySQL = "" +
|
||||||
|
// "INSERT INTO account_e2e_room_keys(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " +
|
||||||
|
// "VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
|
||||||
|
|
||||||
|
// const updateBackupKeySQL = "" +
|
||||||
|
// "UPDATE account_e2e_room_keys SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " +
|
||||||
|
// "WHERE user_id=$5 AND room_id=$6 AND session_id=$7 AND version=$8"
|
||||||
|
|
||||||
|
// "SELECT COUNT(*) FROM account_e2e_room_keys WHERE user_id = $1 AND version = $2"
|
||||||
|
const countKeysSQL = "" +
|
||||||
|
"select count(c._ts) as number from c where c._cn = @x1 " +
|
||||||
|
"and c.mx_userapi_account_e2e_room_keys.user_id = @x2 " +
|
||||||
|
"and c.mx_userapi_account_e2e_room_keys.version = @x3 "
|
||||||
|
|
||||||
|
// "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
|
||||||
|
// "WHERE user_id = $1 AND version = $2"
|
||||||
|
const selectKeysSQL = "" +
|
||||||
|
"select * from c where c._cn = @x1 " +
|
||||||
|
"and c.mx_userapi_account_e2e_room_keys.user_id = @x2 " +
|
||||||
|
"and c.mx_userapi_account_e2e_room_keys.version = @x3 "
|
||||||
|
|
||||||
|
// "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
|
||||||
|
// "WHERE user_id = $1 AND version = $2 AND room_id = $3"
|
||||||
|
const selectKeysByRoomIDSQL = "" +
|
||||||
|
"select * from c where c._cn = @x1 " +
|
||||||
|
"and c.mx_userapi_account_e2e_room_keys.user_id = @x2 " +
|
||||||
|
"and c.mx_userapi_account_e2e_room_keys.version = @x3 " +
|
||||||
|
"and c.mx_userapi_account_e2e_room_keys.room_id = @x4 "
|
||||||
|
|
||||||
|
// "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
|
||||||
|
// "WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4"
|
||||||
|
const selectKeysByRoomIDAndSessionIDSQL = "" +
|
||||||
|
"select * from c where c._cn = @x1 " +
|
||||||
|
"and c.mx_userapi_account_e2e_room_keys.user_id = @x2 " +
|
||||||
|
"and c.mx_userapi_account_e2e_room_keys.version = @x3 " +
|
||||||
|
"and c.mx_userapi_account_e2e_room_keys.room_id = @x4 " +
|
||||||
|
"and c.mx_userapi_account_e2e_room_keys.session_id = @x5 "
|
||||||
|
|
||||||
|
type keyBackupStatements struct {
|
||||||
|
db *Database
|
||||||
|
// insertBackupKeyStmt *sql.Stmt
|
||||||
|
// updateBackupKeyStmt *sql.Stmt
|
||||||
|
countKeysStmt string
|
||||||
|
selectKeysStmt string
|
||||||
|
selectKeysByRoomIDStmt string
|
||||||
|
selectKeysByRoomIDAndSessionIDStmt string
|
||||||
|
tableName string
|
||||||
|
serverName gomatrixserverlib.ServerName
|
||||||
|
}
|
||||||
|
|
||||||
|
func queryKeyBackup(s *keyBackupStatements, ctx context.Context, qry string, params map[string]interface{}) ([]KeyBackupCosmosData, error) {
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
|
||||||
|
var response []KeyBackupCosmosData
|
||||||
|
|
||||||
|
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||||
|
var query = cosmosdbapi.GetQuery(qry, params)
|
||||||
|
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
query,
|
||||||
|
&response,
|
||||||
|
optionsQry)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func queryKeyBackupNumber(s *keyBackupStatements, ctx context.Context, qry string, params map[string]interface{}) ([]KeyBackupCosmosNumber, error) {
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
|
||||||
|
var response []KeyBackupCosmosNumber
|
||||||
|
|
||||||
|
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
|
||||||
|
var query = cosmosdbapi.GetQuery(qry, params)
|
||||||
|
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
query,
|
||||||
|
&response,
|
||||||
|
optionsQry)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getKeyBackup(s *keyBackupStatements, ctx context.Context, pk string, docId string) (*KeyBackupCosmosData, error) {
|
||||||
|
response := KeyBackupCosmosData{}
|
||||||
|
err := cosmosdbapi.GetDocumentOrNil(
|
||||||
|
s.db.connection,
|
||||||
|
s.db.cosmosConfig,
|
||||||
|
ctx,
|
||||||
|
pk,
|
||||||
|
docId,
|
||||||
|
&response)
|
||||||
|
|
||||||
|
if response.Id == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &response, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func setKeyBackup(s *keyBackupStatements, ctx context.Context, keyBackup KeyBackupCosmosData) (*KeyBackupCosmosData, error) {
|
||||||
|
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(keyBackup.Pk, keyBackup.ETag)
|
||||||
|
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
keyBackup.Id,
|
||||||
|
&keyBackup,
|
||||||
|
optionsReplace)
|
||||||
|
return &keyBackup, ex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *keyBackupStatements) prepare(db *Database, server gomatrixserverlib.ServerName) (err error) {
|
||||||
|
s.db = db
|
||||||
|
// s.insertBackupKeyStmt = insertBackupKeySQL
|
||||||
|
// s.updateBackupKeyStmt = updateBackupKeySQL
|
||||||
|
s.countKeysStmt = countKeysSQL
|
||||||
|
s.selectKeysStmt = selectKeysSQL
|
||||||
|
s.selectKeysByRoomIDStmt = selectKeysByRoomIDSQL
|
||||||
|
s.selectKeysByRoomIDAndSessionIDStmt = selectKeysByRoomIDAndSessionIDSQL
|
||||||
|
s.tableName = "account_e2e_room_keys"
|
||||||
|
s.serverName = server
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s keyBackupStatements) countKeys(
|
||||||
|
ctx context.Context, userID, version string,
|
||||||
|
) (count int64, err error) {
|
||||||
|
// "SELECT COUNT(*) FROM account_e2e_room_keys WHERE user_id = $1 AND version = $2"
|
||||||
|
// err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count)
|
||||||
|
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": userID,
|
||||||
|
"@x3": version,
|
||||||
|
}
|
||||||
|
rows, err := queryKeyBackupNumber(&s, ctx, s.countKeysStmt, params)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(rows) == 0 {
|
||||||
|
return -1, nil
|
||||||
|
}
|
||||||
|
// err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count)
|
||||||
|
count = rows[0].Number
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *keyBackupStatements) insertBackupKey(
|
||||||
|
ctx context.Context, userID, version string, key api.InternalKeyBackupSession,
|
||||||
|
) (err error) {
|
||||||
|
// "INSERT INTO account_e2e_room_keys(user_id, room_id, session_id, version, first_message_index, forwarded_count, is_verified, session_data) " +
|
||||||
|
// "VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
|
||||||
|
// _, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext(
|
||||||
|
// ctx, userID, key.RoomID, key.SessionID, version, key.FirstMessageIndex, key.ForwardedCount, key.IsVerified, string(key.SessionData),
|
||||||
|
// )
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
// CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version);
|
||||||
|
docId := fmt.Sprintf("%s_%s_%s_%s", userID, key.RoomID, key.SessionID, version)
|
||||||
|
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
|
||||||
|
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
|
||||||
|
|
||||||
|
data := KeyBackupCosmos{
|
||||||
|
UserId: userID,
|
||||||
|
RoomId: key.RoomID,
|
||||||
|
SessionId: key.SessionID,
|
||||||
|
Version: version,
|
||||||
|
FirstMessageIndex: key.FirstMessageIndex,
|
||||||
|
ForwardedCount: key.ForwardedCount,
|
||||||
|
IsVerified: key.IsVerified,
|
||||||
|
SessionData: key.SessionData,
|
||||||
|
}
|
||||||
|
|
||||||
|
dbData := &KeyBackupCosmosData{
|
||||||
|
Id: cosmosDocId,
|
||||||
|
Tn: s.db.cosmosConfig.TenantName,
|
||||||
|
Cn: dbCollectionName,
|
||||||
|
Pk: pk,
|
||||||
|
Timestamp: time.Now().Unix(),
|
||||||
|
KeyBackup: data,
|
||||||
|
}
|
||||||
|
|
||||||
|
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
|
||||||
|
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
&dbData,
|
||||||
|
options)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *keyBackupStatements) updateBackupKey(
|
||||||
|
ctx context.Context, userID, version string, key api.InternalKeyBackupSession,
|
||||||
|
) (err error) {
|
||||||
|
// "UPDATE account_e2e_room_keys SET first_message_index=$1, forwarded_count=$2, is_verified=$3, session_data=$4 " +
|
||||||
|
// "WHERE user_id=$5 AND room_id=$6 AND session_id=$7 AND version=$8"
|
||||||
|
// _, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext(
|
||||||
|
// ctx, key.FirstMessageIndex, key.ForwardedCount, key.IsVerified, string(key.SessionData), userID, key.RoomID, key.SessionID, version,
|
||||||
|
// )
|
||||||
|
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
// CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version);
|
||||||
|
docId := fmt.Sprintf("%s_%s_%s_%s", userID, key.RoomID, key.SessionID, version)
|
||||||
|
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
|
||||||
|
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
|
||||||
|
|
||||||
|
res, err := getKeyBackup(s, ctx, pk, cosmosDocId)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if res == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ctx, key.FirstMessageIndex, key.ForwardedCount, key.IsVerified, string(key.SessionData), userID, key.RoomID, key.SessionID, version,
|
||||||
|
res.KeyBackup.FirstMessageIndex = key.FirstMessageIndex
|
||||||
|
res.KeyBackup.ForwardedCount = key.ForwardedCount
|
||||||
|
res.KeyBackup.IsVerified = key.IsVerified
|
||||||
|
res.KeyBackup.SessionData = key.SessionData
|
||||||
|
|
||||||
|
_, err = setKeyBackup(s, ctx, *res)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *keyBackupStatements) selectKeys(
|
||||||
|
ctx context.Context, userID, version string,
|
||||||
|
) (map[string]map[string]api.KeyBackupSession, error) {
|
||||||
|
// "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
|
||||||
|
// "WHERE user_id = $1 AND version = $2"
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName)
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": userID,
|
||||||
|
"@x3": version,
|
||||||
|
}
|
||||||
|
rows, err := queryKeyBackup(s, ctx, s.selectKeysStmt, params)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(rows) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version)
|
||||||
|
return unpackKeys(ctx, &rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *keyBackupStatements) selectKeysByRoomID(
|
||||||
|
ctx context.Context, userID, version, roomID string,
|
||||||
|
) (map[string]map[string]api.KeyBackupSession, error) {
|
||||||
|
// "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
|
||||||
|
// "WHERE user_id = $1 AND version = $2 AND room_id = $3"
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName)
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": userID,
|
||||||
|
"@x3": version,
|
||||||
|
"@x4": roomID,
|
||||||
|
}
|
||||||
|
rows, err := queryKeyBackup(s, ctx, s.selectKeysByRoomIDStmt, params)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(rows) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
// rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return unpackKeys(ctx, &rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID(
|
||||||
|
ctx context.Context, userID, version, roomID, sessionID string,
|
||||||
|
) (map[string]map[string]api.KeyBackupSession, error) {
|
||||||
|
// "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
|
||||||
|
// "WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4"
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.db.openIDTokens.tableName)
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": dbCollectionName,
|
||||||
|
"@x2": userID,
|
||||||
|
"@x3": version,
|
||||||
|
"@x4": roomID,
|
||||||
|
"@x5": sessionID,
|
||||||
|
}
|
||||||
|
rows, err := queryKeyBackup(s, ctx, s.selectKeysByRoomIDAndSessionIDStmt, params)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(rows) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
// rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return unpackKeys(ctx, &rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
func unpackKeys(ctx context.Context, rows *[]KeyBackupCosmosData) (map[string]map[string]api.KeyBackupSession, error) {
|
||||||
|
result := make(map[string]map[string]api.KeyBackupSession)
|
||||||
|
for _, item := range *rows {
|
||||||
|
var key api.InternalKeyBackupSession
|
||||||
|
// room_id, session_id, first_message_index, forwarded_count, is_verified, session_data
|
||||||
|
var sessionDataStr string
|
||||||
|
// if err := rows.Scan(&key.RoomID, &key.SessionID, &key.FirstMessageIndex, &key.ForwardedCount, &key.IsVerified, &sessionDataStr); err != nil {
|
||||||
|
// return nil, err
|
||||||
|
// }
|
||||||
|
key.RoomID = item.KeyBackup.RoomId
|
||||||
|
key.SessionID = item.KeyBackup.SessionId
|
||||||
|
key.FirstMessageIndex = item.KeyBackup.FirstMessageIndex
|
||||||
|
key.ForwardedCount = item.KeyBackup.ForwardedCount
|
||||||
|
key.SessionData = json.RawMessage(sessionDataStr)
|
||||||
|
roomData := result[key.RoomID]
|
||||||
|
if roomData == nil {
|
||||||
|
roomData = make(map[string]api.KeyBackupSession)
|
||||||
|
}
|
||||||
|
roomData[key.SessionID] = key.KeyBackupSession
|
||||||
|
result[key.RoomID] = roomData
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
377
userapi/storage/accounts/cosmosdb/key_backup_version_table.go
Normal file
377
userapi/storage/accounts/cosmosdb/key_backup_version_table.go
Normal file
|
|
@ -0,0 +1,377 @@
|
||||||
|
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package cosmosdb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
|
||||||
|
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
// const keyBackupVersionTableSchema = `
|
||||||
|
// -- the metadata for each generation of encrypted e2e session backups
|
||||||
|
// CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions (
|
||||||
|
// user_id TEXT NOT NULL,
|
||||||
|
// -- this means no 2 users will ever have the same version of e2e session backups which strictly
|
||||||
|
// -- isn't necessary, but this is easy to do rather than SELECT MAX(version)+1.
|
||||||
|
// version INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
// algorithm TEXT NOT NULL,
|
||||||
|
// auth_data TEXT NOT NULL,
|
||||||
|
// etag TEXT NOT NULL,
|
||||||
|
// deleted INTEGER DEFAULT 0 NOT NULL
|
||||||
|
// );
|
||||||
|
|
||||||
|
// CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version);
|
||||||
|
// `
|
||||||
|
|
||||||
|
type KeyBackupVersionCosmosData struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Pk string `json:"_pk"`
|
||||||
|
Tn string `json:"_sid"`
|
||||||
|
Cn string `json:"_cn"`
|
||||||
|
ETag string `json:"_etag"`
|
||||||
|
Timestamp int64 `json:"_ts"`
|
||||||
|
KeyBackupVersion KeyBackupVersionCosmos `json:"mx_userapi_account_e2e_room_keys_versions"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type KeyBackupVersionCosmos struct {
|
||||||
|
UserId string `json:"user_id"`
|
||||||
|
Version int64 `json:"vesion"`
|
||||||
|
Algorithm string `json:"algorithm"`
|
||||||
|
AuthData []byte `json:"auth_data"`
|
||||||
|
Etag string `json:"etag"`
|
||||||
|
Deleted int `json:"deleted"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type KeyBackupVersionCosmosNumber struct {
|
||||||
|
Number int64 `json:"number"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// const insertKeyBackupSQL = "" +
|
||||||
|
// "INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version"
|
||||||
|
|
||||||
|
// const updateKeyBackupAuthDataSQL = "" +
|
||||||
|
// "UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3"
|
||||||
|
|
||||||
|
// const updateKeyBackupETagSQL = "" +
|
||||||
|
// "UPDATE account_e2e_room_keys_versions SET etag = $1 WHERE user_id = $2 AND version = $3"
|
||||||
|
|
||||||
|
// const deleteKeyBackupSQL = "" +
|
||||||
|
// "UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2"
|
||||||
|
|
||||||
|
// const selectKeyBackupSQL = "" +
|
||||||
|
// "SELECT algorithm, auth_data, etag, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2"
|
||||||
|
|
||||||
|
// "SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1"
|
||||||
|
const selectLatestVersionSQL = "" +
|
||||||
|
"select max(c.mx_userapi_account_e2e_room_keys_versions.version) as number from c where c._sid = @x1 and c._cn = @x2 " +
|
||||||
|
"and c.mx_userapi_account_e2e_room_keys_versions.user_id = @x3 "
|
||||||
|
|
||||||
|
type keyBackupVersionStatements struct {
|
||||||
|
db *Database
|
||||||
|
// insertKeyBackupStmt *sql.Stmt
|
||||||
|
// updateKeyBackupAuthDataStmt *sql.Stmt
|
||||||
|
// deleteKeyBackupStmt *sql.Stmt
|
||||||
|
// selectKeyBackupStmt *sql.Stmt
|
||||||
|
selectLatestVersionStmt string
|
||||||
|
// updateKeyBackupETagStmt *sql.Stmt
|
||||||
|
tableName string
|
||||||
|
serverName gomatrixserverlib.ServerName
|
||||||
|
}
|
||||||
|
|
||||||
|
func queryKeyBackupVersionNumber(s *keyBackupVersionStatements, ctx context.Context, qry string, params map[string]interface{}) ([]KeyBackupVersionCosmosNumber, error) {
|
||||||
|
var response []KeyBackupVersionCosmosNumber
|
||||||
|
|
||||||
|
var optionsQry = cosmosdbapi.GetQueryAllPartitionsDocumentsOptions()
|
||||||
|
var query = cosmosdbapi.GetQuery(qry, params)
|
||||||
|
var _, _ = cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
query,
|
||||||
|
&response,
|
||||||
|
optionsQry)
|
||||||
|
|
||||||
|
//WHen there is no data these GroupBy queries return errors
|
||||||
|
// if err != nil {
|
||||||
|
// return nil, err
|
||||||
|
// }
|
||||||
|
|
||||||
|
if len(response) == 0 {
|
||||||
|
return nil, cosmosdbutil.ErrNoRows
|
||||||
|
}
|
||||||
|
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getKeyBackupVersion(s *keyBackupVersionStatements, ctx context.Context, pk string, docId string) (*KeyBackupVersionCosmosData, error) {
|
||||||
|
response := KeyBackupVersionCosmosData{}
|
||||||
|
err := cosmosdbapi.GetDocumentOrNil(
|
||||||
|
s.db.connection,
|
||||||
|
s.db.cosmosConfig,
|
||||||
|
ctx,
|
||||||
|
pk,
|
||||||
|
docId,
|
||||||
|
&response)
|
||||||
|
|
||||||
|
if response.Id == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &response, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func setKeyBackupVersion(s *keyBackupVersionStatements, ctx context.Context, keyBackup KeyBackupVersionCosmosData) (*KeyBackupVersionCosmosData, error) {
|
||||||
|
var optionsReplace = cosmosdbapi.GetReplaceDocumentOptions(keyBackup.Pk, keyBackup.ETag)
|
||||||
|
var _, _, ex = cosmosdbapi.GetClient(s.db.connection).ReplaceDocument(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
keyBackup.Id,
|
||||||
|
&keyBackup,
|
||||||
|
optionsReplace)
|
||||||
|
return &keyBackup, ex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *keyBackupVersionStatements) prepare(db *Database, server gomatrixserverlib.ServerName) (err error) {
|
||||||
|
s.db = db
|
||||||
|
// s.insertKeyBackupStmt = insertKeyBackupSQL
|
||||||
|
// s.updateKeyBackupAuthDataStmt = updateKeyBackupAuthDataSQL
|
||||||
|
// s.deleteKeyBackupStmt = deleteKeyBackupSQL
|
||||||
|
// s.selectKeyBackupStmt = selectKeyBackupSQL
|
||||||
|
s.selectLatestVersionStmt = selectLatestVersionSQL
|
||||||
|
// s.updateKeyBackupETagStmt = updateKeyBackupETagSQL
|
||||||
|
s.tableName = "account_e2e_room_keys_versions"
|
||||||
|
s.serverName = server
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *keyBackupVersionStatements) insertKeyBackup(
|
||||||
|
ctx context.Context, userID, algorithm string, authData json.RawMessage, etag string,
|
||||||
|
) (version string, err error) {
|
||||||
|
// "INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data, etag) VALUES ($1, $2, $3, $4) RETURNING version"
|
||||||
|
var versionInt int64
|
||||||
|
// -- this means no 2 users will ever have the same version of e2e session backups which strictly
|
||||||
|
// -- isn't necessary, but this is easy to do rather than SELECT MAX(version)+1.
|
||||||
|
// version INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
versionInt, seqErr := GetNextKeyBackupVersionID(s, ctx)
|
||||||
|
if seqErr != nil {
|
||||||
|
return "", seqErr
|
||||||
|
}
|
||||||
|
// err = txn.Stmt(s.insertKeyBackupStmt).QueryRowContext(ctx, userID, algorithm, string(authData), etag).Scan(&versionInt)
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
// CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version);
|
||||||
|
docId := fmt.Sprintf("%s_%d", userID, versionInt)
|
||||||
|
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
|
||||||
|
|
||||||
|
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
|
||||||
|
|
||||||
|
data := KeyBackupVersionCosmos{
|
||||||
|
UserId: userID,
|
||||||
|
Version: versionInt,
|
||||||
|
Algorithm: algorithm,
|
||||||
|
AuthData: authData,
|
||||||
|
Etag: etag,
|
||||||
|
Deleted: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
dbData := &KeyBackupVersionCosmosData{
|
||||||
|
Id: cosmosDocId,
|
||||||
|
Tn: s.db.cosmosConfig.TenantName,
|
||||||
|
Cn: dbCollectionName,
|
||||||
|
Pk: pk,
|
||||||
|
Timestamp: time.Now().Unix(),
|
||||||
|
KeyBackupVersion: data,
|
||||||
|
}
|
||||||
|
|
||||||
|
var options = cosmosdbapi.GetCreateDocumentOptions(dbData.Pk)
|
||||||
|
_, _, err = cosmosdbapi.GetClient(s.db.connection).CreateDocument(
|
||||||
|
ctx,
|
||||||
|
s.db.cosmosConfig.DatabaseName,
|
||||||
|
s.db.cosmosConfig.ContainerName,
|
||||||
|
&dbData,
|
||||||
|
options)
|
||||||
|
|
||||||
|
return strconv.FormatInt(versionInt, 10), err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *keyBackupVersionStatements) updateKeyBackupAuthData(
|
||||||
|
ctx context.Context, userID, version string, authData json.RawMessage,
|
||||||
|
) error {
|
||||||
|
// "UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3"
|
||||||
|
versionInt, err := strconv.ParseInt(version, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid version")
|
||||||
|
}
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
// CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version);
|
||||||
|
docId := fmt.Sprintf("%s_%d", userID, versionInt)
|
||||||
|
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
|
||||||
|
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
|
||||||
|
|
||||||
|
res, err := getKeyBackupVersion(s, ctx, pk, cosmosDocId)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if res == nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// _, err = txn.Stmt(s.updateKeyBackupAuthDataStmt).ExecContext(ctx, string(authData), userID, versionInt)
|
||||||
|
res.KeyBackupVersion.AuthData = authData
|
||||||
|
|
||||||
|
_, err = setKeyBackupVersion(s, ctx, *res)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *keyBackupVersionStatements) updateKeyBackupETag(
|
||||||
|
ctx context.Context, userID, version, etag string,
|
||||||
|
) error {
|
||||||
|
// "UPDATE account_e2e_room_keys_versions SET etag = $1 WHERE user_id = $2 AND version = $3"
|
||||||
|
versionInt, err := strconv.ParseInt(version, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid version")
|
||||||
|
}
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
// CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version);
|
||||||
|
docId := fmt.Sprintf("%s_%d", userID, versionInt)
|
||||||
|
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
|
||||||
|
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
|
||||||
|
|
||||||
|
res, err := getKeyBackupVersion(s, ctx, pk, cosmosDocId)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if res == nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// _, err = txn.Stmt(s.updateKeyBackupETagStmt).ExecContext(ctx, etag, userID, versionInt)
|
||||||
|
res.KeyBackupVersion.Etag = etag
|
||||||
|
|
||||||
|
_, err = setKeyBackupVersion(s, ctx, *res)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *keyBackupVersionStatements) deleteKeyBackup(
|
||||||
|
ctx context.Context, userID, version string,
|
||||||
|
) (bool, error) {
|
||||||
|
// "UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2"
|
||||||
|
versionInt, err := strconv.ParseInt(version, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("invalid version")
|
||||||
|
}
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
// CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version);
|
||||||
|
docId := fmt.Sprintf("%s_%d", userID, versionInt)
|
||||||
|
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
|
||||||
|
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
|
||||||
|
|
||||||
|
res, err := getKeyBackupVersion(s, ctx, pk, cosmosDocId)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if res == nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// result, err := txn.Stmt(s.deleteKeyBackupStmt).ExecContext(ctx, userID, versionInt)
|
||||||
|
res.KeyBackupVersion.Deleted = 1
|
||||||
|
|
||||||
|
_, err = setKeyBackupVersion(s, ctx, *res)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *keyBackupVersionStatements) selectKeyBackup(
|
||||||
|
ctx context.Context, userID, version string,
|
||||||
|
) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) {
|
||||||
|
// "SELECT algorithm, auth_data, etag, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2"
|
||||||
|
var versionInt int64
|
||||||
|
if version == "" {
|
||||||
|
// var v *int64 // allows nulls
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
params := map[string]interface{}{
|
||||||
|
"@x1": s.db.cosmosConfig.TenantName,
|
||||||
|
"@x2": dbCollectionName,
|
||||||
|
"@x3": userID,
|
||||||
|
}
|
||||||
|
|
||||||
|
// err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
|
||||||
|
response, err1 := queryKeyBackupVersionNumber(s, ctx, s.selectLatestVersionStmt, params)
|
||||||
|
|
||||||
|
if err1 != nil {
|
||||||
|
if err == cosmosdbutil.ErrNoRows {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// if err = txn.Stmt(s.selectLatestVersionStmt).QueryRowContext(ctx, userID).Scan(&v); err != nil {
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
if response == nil || len(response) == 0 {
|
||||||
|
err = cosmosdbutil.ErrNoRows
|
||||||
|
versionInt = 0
|
||||||
|
return
|
||||||
|
}
|
||||||
|
versionInt = response[0].Number
|
||||||
|
} else {
|
||||||
|
if versionInt, err = strconv.ParseInt(version, 10, 64); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
versionResult = strconv.FormatInt(versionInt, 10)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
|
||||||
|
// CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version);
|
||||||
|
docId := fmt.Sprintf("%s_%d", userID, versionInt)
|
||||||
|
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.TenantName, dbCollectionName, docId)
|
||||||
|
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.TenantName, dbCollectionName)
|
||||||
|
|
||||||
|
res, err := getKeyBackupVersion(s, ctx, pk, cosmosDocId)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if res == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// var deletedInt int
|
||||||
|
// var authDataStr string
|
||||||
|
// err = txn.Stmt(s.selectKeyBackupStmt).QueryRowContext(ctx, userID, versionInt).Scan(&algorithm, &authDataStr, &etag, &deletedInt)
|
||||||
|
deleted = res.KeyBackupVersion.Deleted == 1
|
||||||
|
authData = res.KeyBackupVersion.AuthData
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
package cosmosdb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetNextKeyBackupVersionID(s *keyBackupVersionStatements, ctx context.Context) (int64, error) {
|
||||||
|
const docId = "id_seq"
|
||||||
|
return cosmosdbutil.GetNextSequence(ctx, s.db.connection, s.db.cosmosConfig, s.db.databaseName, s.tableName, docId, 1)
|
||||||
|
}
|
||||||
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
|
@ -45,6 +46,8 @@ type Database struct {
|
||||||
accountDatas accountDataStatements
|
accountDatas accountDataStatements
|
||||||
threepids threepidStatements
|
threepids threepidStatements
|
||||||
openIDTokens tokenStatements
|
openIDTokens tokenStatements
|
||||||
|
keyBackupVersions keyBackupVersionStatements
|
||||||
|
keyBackups keyBackupStatements
|
||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
bcryptCost int
|
bcryptCost int
|
||||||
openIDTokenLifetimeMS int64
|
openIDTokenLifetimeMS int64
|
||||||
|
|
@ -105,6 +108,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
|
||||||
if err = d.openIDTokens.prepare(d, serverName); err != nil {
|
if err = d.openIDTokens.prepare(d, serverName); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if err = d.keyBackupVersions.prepare(d, serverName); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err = d.keyBackups.prepare(d, serverName); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return d, nil
|
return d, nil
|
||||||
}
|
}
|
||||||
|
|
@ -419,3 +428,150 @@ func (d *Database) GetOpenIDTokenAttributes(
|
||||||
) (*api.OpenIDTokenAttributes, error) {
|
) (*api.OpenIDTokenAttributes, error) {
|
||||||
return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token)
|
return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Database) CreateKeyBackup(
|
||||||
|
ctx context.Context, userID, algorithm string, authData json.RawMessage,
|
||||||
|
) (version string, err error) {
|
||||||
|
// err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
|
version, err = d.keyBackupVersions.insertKeyBackup(ctx, userID, algorithm, authData, "")
|
||||||
|
return version, err
|
||||||
|
// })
|
||||||
|
// return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) UpdateKeyBackupAuthData(
|
||||||
|
ctx context.Context, userID, version string, authData json.RawMessage,
|
||||||
|
) (err error) {
|
||||||
|
// err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.keyBackupVersions.updateKeyBackupAuthData(ctx, userID, version, authData)
|
||||||
|
// })
|
||||||
|
// return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) DeleteKeyBackup(
|
||||||
|
ctx context.Context, userID, version string,
|
||||||
|
) (exists bool, err error) {
|
||||||
|
// err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
|
exists, err = d.keyBackupVersions.deleteKeyBackup(ctx, userID, version)
|
||||||
|
return
|
||||||
|
// })
|
||||||
|
// return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) GetKeyBackup(
|
||||||
|
ctx context.Context, userID, version string,
|
||||||
|
) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) {
|
||||||
|
// err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
|
versionResult, algorithm, authData, etag, deleted, err = d.keyBackupVersions.selectKeyBackup(ctx, userID, version)
|
||||||
|
return
|
||||||
|
// })
|
||||||
|
// return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) GetBackupKeys(
|
||||||
|
ctx context.Context, version, userID, filterRoomID, filterSessionID string,
|
||||||
|
) (result map[string]map[string]api.KeyBackupSession, err error) {
|
||||||
|
// err = d.writer.Do(d, nil, func(txn *sql.Tx) error {
|
||||||
|
if filterSessionID != "" {
|
||||||
|
result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, userID, version, filterRoomID, filterSessionID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if filterRoomID != "" {
|
||||||
|
result, err = d.keyBackups.selectKeysByRoomID(ctx, userID, version, filterRoomID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
result, err = d.keyBackups.selectKeys(ctx, userID, version)
|
||||||
|
return
|
||||||
|
// })
|
||||||
|
// return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) CountBackupKeys(
|
||||||
|
ctx context.Context, version, userID string,
|
||||||
|
) (count int64, err error) {
|
||||||
|
// err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
|
count, err = d.keyBackups.countKeys(ctx, userID, version)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
// })
|
||||||
|
// return
|
||||||
|
}
|
||||||
|
|
||||||
|
// nolint:nakedret
|
||||||
|
func (d *Database) UpsertBackupKeys(
|
||||||
|
ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession,
|
||||||
|
) (count int64, etag string, err error) {
|
||||||
|
// wrap the following logic in a txn to ensure we atomically upload keys
|
||||||
|
// err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
|
_, _, _, oldETag, deleted, err := d.keyBackupVersions.selectKeyBackup(ctx, userID, version)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if deleted {
|
||||||
|
err = fmt.Errorf("backup was deleted")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// pull out all keys for this (user_id, version)
|
||||||
|
existingKeys, err := d.keyBackups.selectKeys(ctx, userID, version)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
changed := false
|
||||||
|
// loop over all the new keys (which should be smaller than the set of backed up keys)
|
||||||
|
for _, newKey := range uploads {
|
||||||
|
// if we have a matching (room_id, session_id), we may need to update the key if it meets some rules, check them.
|
||||||
|
existingRoom := existingKeys[newKey.RoomID]
|
||||||
|
if existingRoom != nil {
|
||||||
|
existingSession, ok := existingRoom[newKey.SessionID]
|
||||||
|
if ok {
|
||||||
|
if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) {
|
||||||
|
err = d.keyBackups.updateBackupKey(ctx, userID, version, newKey)
|
||||||
|
changed = true
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("d.keyBackups.updateBackupKey: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// if we shouldn't replace the key we do nothing with it
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// if we're here, either the room or session are new, either way, we insert
|
||||||
|
err = d.keyBackups.insertBackupKey(ctx, userID, version, newKey)
|
||||||
|
changed = true
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("d.keyBackups.insertBackupKey: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
count, err = d.keyBackups.countKeys(ctx, userID, version)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if changed {
|
||||||
|
// update the etag
|
||||||
|
var newETag string
|
||||||
|
if oldETag == "" {
|
||||||
|
newETag = "1"
|
||||||
|
} else {
|
||||||
|
oldETagInt, err1 := strconv.ParseInt(oldETag, 10, 64)
|
||||||
|
if err1 != nil {
|
||||||
|
err = fmt.Errorf("failed to parse old etag: %s", err1)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
newETag = strconv.FormatInt(oldETagInt+1, 10)
|
||||||
|
}
|
||||||
|
etag = newETag
|
||||||
|
err = d.keyBackupVersions.updateKeyBackupETag(ctx, userID, version, newETag)
|
||||||
|
} else {
|
||||||
|
etag = oldETag
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
// })
|
||||||
|
// return
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue