diff --git a/dendrite-config-cosmosdb.yaml b/dendrite-config-cosmosdb.yaml index 014a2ec88..70b6bdc25 100644 --- a/dendrite-config-cosmosdb.yaml +++ b/dendrite-config-cosmosdb.yaml @@ -302,7 +302,7 @@ signing_key_server: listen: http://localhost:7780 connect: http://localhost:7780 database: - connection_string: file:signingkeyserver.db + connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=test.criticalarc.com;" max_open_conns: 10 max_idle_conns: 2 conn_max_lifetime: -1 diff --git a/signingkeyserver/storage/cosmosdb/keydb.go b/signingkeyserver/storage/cosmosdb/keydb.go index 0f4371bce..5ffc0e09d 100644 --- a/signingkeyserver/storage/cosmosdb/keydb.go +++ b/signingkeyserver/storage/cosmosdb/keydb.go @@ -18,9 +18,12 @@ package cosmosdb import ( "context" + "github.com/matrix-org/dendrite/internal/cosmosdbapi" + + "github.com/matrix-org/dendrite/internal/cosmosdbutil" + "golang.org/x/crypto/ed25519" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" @@ -30,8 +33,11 @@ import ( // A Database implements gomatrixserverlib.KeyDatabase and is used to store // the public keys for other matrix servers. type Database struct { - writer sqlutil.Writer - statements serverKeyStatements + writer cosmosdbutil.Writer + statements serverKeyStatements + connection cosmosdbapi.CosmosConnection + databaseName string + cosmosConfig cosmosdbapi.CosmosConfig } // NewDatabase prepares a new key database. @@ -44,14 +50,16 @@ func NewDatabase( serverKey ed25519.PublicKey, serverKeyID gomatrixserverlib.KeyID, ) (*Database, error) { - db, err := sqlutil.Open(dbProperties) - if err != nil { - return nil, err - } + conn := cosmosdbutil.GetCosmosConnection(&dbProperties.ConnectionString) + configCosmos := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString) + d := &Database{ - writer: sqlutil.NewExclusiveWriter(), + databaseName: "keydb", + writer: cosmosdbutil.NewExclusiveWriterFake(), + connection: conn, + cosmosConfig: configCosmos, } - err = d.statements.prepare(db, d.writer) + err := d.statements.prepare(d, d.writer) if err != nil { return nil, err } @@ -63,7 +71,7 @@ func NewDatabase( // FetcherName implements KeyFetcher func (d Database) FetcherName() string { - return "SqliteKeyDatabase" + return "CosmosDBKeyDatabase" } // FetchKeys implements gomatrixserverlib.KeyDatabase diff --git a/signingkeyserver/storage/cosmosdb/server_key_table.go b/signingkeyserver/storage/cosmosdb/server_key_table.go index e30de0a12..79bb4560c 100644 --- a/signingkeyserver/storage/cosmosdb/server_key_table.go +++ b/signingkeyserver/storage/cosmosdb/server_key_table.go @@ -19,67 +19,104 @@ import ( "context" "database/sql" "fmt" + "time" + + "github.com/matrix-org/dendrite/internal/cosmosdbapi" + "github.com/matrix-org/dendrite/internal/cosmosdbutil" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" ) -const serverKeysSchema = ` --- A cache of signing keys downloaded from remote servers. -CREATE TABLE IF NOT EXISTS keydb_server_keys ( - -- The name of the matrix server the key is for. - server_name TEXT NOT NULL, - -- The ID of the server key. - server_key_id TEXT NOT NULL, - -- Combined server name and key ID separated by the ASCII unit separator - -- to make it easier to run bulk queries. - server_name_and_key_id TEXT NOT NULL, - -- When the key is valid until as a millisecond timestamp. - -- 0 if this is an expired key (in which case expired_ts will be non-zero) - valid_until_ts BIGINT NOT NULL, - -- When the key expired as a millisecond timestamp. - -- 0 if this is an active key (in which case valid_until_ts will be non-zero) - expired_ts BIGINT NOT NULL, - -- The base64-encoded public key. - server_key TEXT NOT NULL, - UNIQUE (server_name, server_key_id) -); +// const serverKeysSchema = ` +// -- A cache of signing keys downloaded from remote servers. +// CREATE TABLE IF NOT EXISTS keydb_server_keys ( +// -- The name of the matrix server the key is for. +// server_name TEXT NOT NULL, +// -- The ID of the server key. +// server_key_id TEXT NOT NULL, +// -- Combined server name and key ID separated by the ASCII unit separator +// -- to make it easier to run bulk queries. +// server_name_and_key_id TEXT NOT NULL, +// -- When the key is valid until as a millisecond timestamp. +// -- 0 if this is an expired key (in which case expired_ts will be non-zero) +// valid_until_ts BIGINT NOT NULL, +// -- When the key expired as a millisecond timestamp. +// -- 0 if this is an active key (in which case valid_until_ts will be non-zero) +// expired_ts BIGINT NOT NULL, +// -- The base64-encoded public key. +// server_key TEXT NOT NULL, +// UNIQUE (server_name, server_key_id) +// ); -CREATE INDEX IF NOT EXISTS keydb_server_name_and_key_id ON keydb_server_keys (server_name_and_key_id); -` +// CREATE INDEX IF NOT EXISTS keydb_server_name_and_key_id ON keydb_server_keys (server_name_and_key_id); +// ` -const bulkSelectServerKeysSQL = "" + - "SELECT server_name, server_key_id, valid_until_ts, expired_ts, " + - " server_key FROM keydb_server_keys" + - " WHERE server_name_and_key_id IN ($1)" - -const upsertServerKeysSQL = "" + - "INSERT INTO keydb_server_keys (server_name, server_key_id," + - " server_name_and_key_id, valid_until_ts, expired_ts, server_key)" + - " VALUES ($1, $2, $3, $4, $5, $6)" + - " ON CONFLICT (server_name, server_key_id)" + - " DO UPDATE SET valid_until_ts = $4, expired_ts = $5, server_key = $6" - -type serverKeyStatements struct { - db *sql.DB - writer sqlutil.Writer - bulkSelectServerKeysStmt *sql.Stmt - upsertServerKeysStmt *sql.Stmt +type ServerKeyCosmos struct { + ServerName string `json:"server_name"` + ServerKeyID string `json:"server_key_id"` + ServerNameAndKeyID string `json:"server_name_and_key_id"` + ValidUntilTimestamp int64 `json:"valid_until_ts"` + ExpiredTimestamp int64 `json:"expired_ts"` + ServerKey string `json:"server_key"` } -func (s *serverKeyStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { +type ServerKeyCosmosData struct { + Id string `json:"id"` + Pk string `json:"_pk"` + Cn string `json:"_cn"` + ETag string `json:"_etag"` + Timestamp int64 `json:"_ts"` + ServerKey ServerKeyCosmos `json:"mx_keydb_server_key"` +} + +// "SELECT server_name, server_key_id, valid_until_ts, expired_ts, " + +// " server_key FROM keydb_server_keys" + +// " WHERE server_name_and_key_id IN ($1)" +const bulkSelectServerKeysSQL = "" + + "select * from c where c._cn = @x1 " + + "and ARRAY_CONTAINS(@x2, c.mx_keydb_server_key.server_name_and_key_id) " + +// const upsertServerKeysSQL = "" + +// "INSERT INTO keydb_server_keys (server_name, server_key_id," + +// " server_name_and_key_id, valid_until_ts, expired_ts, server_key)" + +// " VALUES ($1, $2, $3, $4, $5, $6)" + +// " ON CONFLICT (server_name, server_key_id)" + +// " DO UPDATE SET valid_until_ts = $4, expired_ts = $5, server_key = $6" + +type serverKeyStatements struct { + db *Database + writer cosmosdbutil.Writer + bulkSelectServerKeysStmt *sql.Stmt + // upsertServerKeysStmt *sql.Stmt + tableName string +} + +func queryServerKey(s *serverKeyStatements, ctx context.Context, qry string, params map[string]interface{}) ([]ServerKeyCosmosData, error) { + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + var response []ServerKeyCosmosData + + var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk) + var query = cosmosdbapi.GetQuery(qry, params) + _, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + query, + &response, + optionsQry) + + if err != nil { + return nil, err + } + return response, nil +} + +func (s *serverKeyStatements) prepare(db *Database, writer sqlutil.Writer) (err error) { s.db = db s.writer = writer - _, err = db.Exec(serverKeysSchema) - if err != nil { - return - } - if s.bulkSelectServerKeysStmt, err = db.Prepare(bulkSelectServerKeysSQL); err != nil { - return - } - if s.upsertServerKeysStmt, err = db.Prepare(upsertServerKeysSQL); err != nil { - return - } + s.tableName = "server_keys" return } @@ -92,46 +129,62 @@ func (s *serverKeyStatements) bulkSelectServerKeys( nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request)) } results := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, len(requests)) - iKeyIDs := make([]interface{}, len(nameAndKeyIDs)) - for i, v := range nameAndKeyIDs { - iKeyIDs[i] = v + // iKeyIDs := make([]interface{}, len(nameAndKeyIDs)) + // for i, v := range nameAndKeyIDs { + // iKeyIDs[i] = v + // } + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + params := map[string]interface{}{ + "@x1": dbCollectionName, + "@x2": nameAndKeyIDs, } - err := sqlutil.RunLimitedVariablesQuery( - ctx, bulkSelectServerKeysSQL, s.db, iKeyIDs, sqlutil.SQLite3MaxVariables, - func(rows *sql.Rows) error { - for rows.Next() { - var serverName string - var keyID string - var key string - var validUntilTS int64 - var expiredTS int64 - if err := rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil { - return fmt.Errorf("bulkSelectServerKeys: %v", err) - } - r := gomatrixserverlib.PublicKeyLookupRequest{ - ServerName: gomatrixserverlib.ServerName(serverName), - KeyID: gomatrixserverlib.KeyID(keyID), - } - vk := gomatrixserverlib.VerifyKey{} - err := vk.Key.Decode(key) - if err != nil { - return fmt.Errorf("bulkSelectServerKeys: %v", err) - } - results[r] = gomatrixserverlib.PublicKeyLookupResult{ - VerifyKey: vk, - ValidUntilTS: gomatrixserverlib.Timestamp(validUntilTS), - ExpiredTS: gomatrixserverlib.Timestamp(expiredTS), - } - } - return nil - }, - ) + // "SELECT server_name, server_key_id, valid_until_ts, expired_ts, " + + // " server_key FROM keydb_server_keys" + + // " WHERE server_name_and_key_id IN ($1)" + + // err := sqlutil.RunLimitedVariablesQuery( + // ctx, bulkSelectServerKeysSQL, s.db, iKeyIDs, sqlutil.SQLite3MaxVariables, + // func(rows *sql.Rows) error { + + rows, err := queryServerKey(s, ctx, bulkSelectServerKeysSQL, params) if err != nil { return nil, err } - return results, nil + + for _, item := range rows { + var serverName string + var keyID string + var key string + var validUntilTS int64 + var expiredTS int64 + // if err := rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil { + // return fmt.Errorf("bulkSelectServerKeys: %v", err) + // } + serverName = item.ServerKey.ServerName + keyID = item.ServerKey.ServerKeyID + validUntilTS = item.ServerKey.ValidUntilTimestamp + expiredTS = item.ServerKey.ExpiredTimestamp + key = item.ServerKey.ServerKey + r := gomatrixserverlib.PublicKeyLookupRequest{ + ServerName: gomatrixserverlib.ServerName(serverName), + KeyID: gomatrixserverlib.KeyID(keyID), + } + vk := gomatrixserverlib.VerifyKey{} + err := vk.Key.Decode(key) + if err != nil { + return nil, fmt.Errorf("bulkSelectServerKeys: %v", err) + } + results[r] = gomatrixserverlib.PublicKeyLookupResult{ + VerifyKey: vk, + ValidUntilTS: gomatrixserverlib.Timestamp(validUntilTS), + ExpiredTS: gomatrixserverlib.Timestamp(expiredTS), + } + } + + return results, err } func (s *serverKeyStatements) upsertServerKeys( @@ -139,19 +192,57 @@ func (s *serverKeyStatements) upsertServerKeys( request gomatrixserverlib.PublicKeyLookupRequest, key gomatrixserverlib.PublicKeyLookupResult, ) error { - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.upsertServerKeysStmt) - _, err := stmt.ExecContext( - ctx, - string(request.ServerName), - string(request.KeyID), - nameAndKeyID(request), - key.ValidUntilTS, - key.ExpiredTS, - key.Key.Encode(), - ) - return err - }) + + // "INSERT INTO keydb_server_keys (server_name, server_key_id," + + // " server_name_and_key_id, valid_until_ts, expired_ts, server_key)" + + // " VALUES ($1, $2, $3, $4, $5, $6)" + + // " ON CONFLICT (server_name, server_key_id)" + + // " DO UPDATE SET valid_until_ts = $4, expired_ts = $5, server_key = $6" + + // stmt := sqlutil.TxStmt(txn, s.upsertServerKeysStmt) + + var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName) + // UNIQUE (server_name, server_key_id) + docId := fmt.Sprintf("%s_%s", string(request.ServerName), string(request.KeyID)) + cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId) + pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName) + + data := ServerKeyCosmos{ + ServerName: string(request.ServerName), + ServerKeyID: string(request.KeyID), + ServerNameAndKeyID: nameAndKeyID(request), + ValidUntilTimestamp: int64(key.ValidUntilTS), + ExpiredTimestamp: int64(key.ExpiredTS), + ServerKey: key.Key.Encode(), + } + + dbData := &ServerKeyCosmosData{ + Id: cosmosDocId, + Cn: dbCollectionName, + Pk: pk, + Timestamp: time.Now().Unix(), + ServerKey: data, + } + + // _, err := stmt.ExecContext( + // ctx, + // string(request.ServerName), + // string(request.KeyID), + // nameAndKeyID(request), + // key.ValidUntilTS, + // key.ExpiredTS, + // key.Key.Encode(), + // ) + + var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk) + _, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument( + ctx, + s.db.cosmosConfig.DatabaseName, + s.db.cosmosConfig.ContainerName, + &dbData, + options) + + return err } func nameAndKeyID(request gomatrixserverlib.PublicKeyLookupRequest) string {