- Update Config for signing_key_server to use CosmosDB (#10)

- Implement keydb to use Cosmos
This commit is contained in:
alexfca 2021-05-31 13:19:48 +10:00 committed by GitHub
parent db08aa6250
commit e763a6feb9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 207 additions and 108 deletions

View file

@ -302,7 +302,7 @@ signing_key_server:
listen: http://localhost:7780 listen: http://localhost:7780
connect: http://localhost:7780 connect: http://localhost:7780
database: database:
connection_string: file:signingkeyserver.db connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=test.criticalarc.com;"
max_open_conns: 10 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1

View file

@ -18,9 +18,12 @@ package cosmosdb
import ( import (
"context" "context"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -30,8 +33,11 @@ import (
// A Database implements gomatrixserverlib.KeyDatabase and is used to store // A Database implements gomatrixserverlib.KeyDatabase and is used to store
// the public keys for other matrix servers. // the public keys for other matrix servers.
type Database struct { type Database struct {
writer sqlutil.Writer writer cosmosdbutil.Writer
statements serverKeyStatements statements serverKeyStatements
connection cosmosdbapi.CosmosConnection
databaseName string
cosmosConfig cosmosdbapi.CosmosConfig
} }
// NewDatabase prepares a new key database. // NewDatabase prepares a new key database.
@ -44,14 +50,16 @@ func NewDatabase(
serverKey ed25519.PublicKey, serverKey ed25519.PublicKey,
serverKeyID gomatrixserverlib.KeyID, serverKeyID gomatrixserverlib.KeyID,
) (*Database, error) { ) (*Database, error) {
db, err := sqlutil.Open(dbProperties) conn := cosmosdbutil.GetCosmosConnection(&dbProperties.ConnectionString)
if err != nil { configCosmos := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString)
return nil, err
}
d := &Database{ d := &Database{
writer: sqlutil.NewExclusiveWriter(), databaseName: "keydb",
writer: cosmosdbutil.NewExclusiveWriterFake(),
connection: conn,
cosmosConfig: configCosmos,
} }
err = d.statements.prepare(db, d.writer) err := d.statements.prepare(d, d.writer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -63,7 +71,7 @@ func NewDatabase(
// FetcherName implements KeyFetcher // FetcherName implements KeyFetcher
func (d Database) FetcherName() string { func (d Database) FetcherName() string {
return "SqliteKeyDatabase" return "CosmosDBKeyDatabase"
} }
// FetchKeys implements gomatrixserverlib.KeyDatabase // FetchKeys implements gomatrixserverlib.KeyDatabase

View file

@ -19,67 +19,104 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
const serverKeysSchema = ` // const serverKeysSchema = `
-- A cache of signing keys downloaded from remote servers. // -- A cache of signing keys downloaded from remote servers.
CREATE TABLE IF NOT EXISTS keydb_server_keys ( // CREATE TABLE IF NOT EXISTS keydb_server_keys (
-- The name of the matrix server the key is for. // -- The name of the matrix server the key is for.
server_name TEXT NOT NULL, // server_name TEXT NOT NULL,
-- The ID of the server key. // -- The ID of the server key.
server_key_id TEXT NOT NULL, // server_key_id TEXT NOT NULL,
-- Combined server name and key ID separated by the ASCII unit separator // -- Combined server name and key ID separated by the ASCII unit separator
-- to make it easier to run bulk queries. // -- to make it easier to run bulk queries.
server_name_and_key_id TEXT NOT NULL, // server_name_and_key_id TEXT NOT NULL,
-- When the key is valid until as a millisecond timestamp. // -- When the key is valid until as a millisecond timestamp.
-- 0 if this is an expired key (in which case expired_ts will be non-zero) // -- 0 if this is an expired key (in which case expired_ts will be non-zero)
valid_until_ts BIGINT NOT NULL, // valid_until_ts BIGINT NOT NULL,
-- When the key expired as a millisecond timestamp. // -- When the key expired as a millisecond timestamp.
-- 0 if this is an active key (in which case valid_until_ts will be non-zero) // -- 0 if this is an active key (in which case valid_until_ts will be non-zero)
expired_ts BIGINT NOT NULL, // expired_ts BIGINT NOT NULL,
-- The base64-encoded public key. // -- The base64-encoded public key.
server_key TEXT NOT NULL, // server_key TEXT NOT NULL,
UNIQUE (server_name, server_key_id) // UNIQUE (server_name, server_key_id)
); // );
CREATE INDEX IF NOT EXISTS keydb_server_name_and_key_id ON keydb_server_keys (server_name_and_key_id); // CREATE INDEX IF NOT EXISTS keydb_server_name_and_key_id ON keydb_server_keys (server_name_and_key_id);
` // `
const bulkSelectServerKeysSQL = "" + type ServerKeyCosmos struct {
"SELECT server_name, server_key_id, valid_until_ts, expired_ts, " + ServerName string `json:"server_name"`
" server_key FROM keydb_server_keys" + ServerKeyID string `json:"server_key_id"`
" WHERE server_name_and_key_id IN ($1)" ServerNameAndKeyID string `json:"server_name_and_key_id"`
ValidUntilTimestamp int64 `json:"valid_until_ts"`
const upsertServerKeysSQL = "" + ExpiredTimestamp int64 `json:"expired_ts"`
"INSERT INTO keydb_server_keys (server_name, server_key_id," + ServerKey string `json:"server_key"`
" server_name_and_key_id, valid_until_ts, expired_ts, server_key)" +
" VALUES ($1, $2, $3, $4, $5, $6)" +
" ON CONFLICT (server_name, server_key_id)" +
" DO UPDATE SET valid_until_ts = $4, expired_ts = $5, server_key = $6"
type serverKeyStatements struct {
db *sql.DB
writer sqlutil.Writer
bulkSelectServerKeysStmt *sql.Stmt
upsertServerKeysStmt *sql.Stmt
} }
func (s *serverKeyStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { type ServerKeyCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
ServerKey ServerKeyCosmos `json:"mx_keydb_server_key"`
}
// "SELECT server_name, server_key_id, valid_until_ts, expired_ts, " +
// " server_key FROM keydb_server_keys" +
// " WHERE server_name_and_key_id IN ($1)"
const bulkSelectServerKeysSQL = "" +
"select * from c where c._cn = @x1 " +
"and ARRAY_CONTAINS(@x2, c.mx_keydb_server_key.server_name_and_key_id) "
// const upsertServerKeysSQL = "" +
// "INSERT INTO keydb_server_keys (server_name, server_key_id," +
// " server_name_and_key_id, valid_until_ts, expired_ts, server_key)" +
// " VALUES ($1, $2, $3, $4, $5, $6)" +
// " ON CONFLICT (server_name, server_key_id)" +
// " DO UPDATE SET valid_until_ts = $4, expired_ts = $5, server_key = $6"
type serverKeyStatements struct {
db *Database
writer cosmosdbutil.Writer
bulkSelectServerKeysStmt *sql.Stmt
// upsertServerKeysStmt *sql.Stmt
tableName string
}
func queryServerKey(s *serverKeyStatements, ctx context.Context, qry string, params map[string]interface{}) ([]ServerKeyCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []ServerKeyCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func (s *serverKeyStatements) prepare(db *Database, writer sqlutil.Writer) (err error) {
s.db = db s.db = db
s.writer = writer s.writer = writer
_, err = db.Exec(serverKeysSchema) s.tableName = "server_keys"
if err != nil {
return
}
if s.bulkSelectServerKeysStmt, err = db.Prepare(bulkSelectServerKeysSQL); err != nil {
return
}
if s.upsertServerKeysStmt, err = db.Prepare(upsertServerKeysSQL); err != nil {
return
}
return return
} }
@ -92,23 +129,45 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request)) nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request))
} }
results := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, len(requests)) results := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, len(requests))
iKeyIDs := make([]interface{}, len(nameAndKeyIDs)) // iKeyIDs := make([]interface{}, len(nameAndKeyIDs))
for i, v := range nameAndKeyIDs { // for i, v := range nameAndKeyIDs {
iKeyIDs[i] = v // iKeyIDs[i] = v
// }
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": nameAndKeyIDs,
} }
err := sqlutil.RunLimitedVariablesQuery( // "SELECT server_name, server_key_id, valid_until_ts, expired_ts, " +
ctx, bulkSelectServerKeysSQL, s.db, iKeyIDs, sqlutil.SQLite3MaxVariables, // " server_key FROM keydb_server_keys" +
func(rows *sql.Rows) error { // " WHERE server_name_and_key_id IN ($1)"
for rows.Next() {
// err := sqlutil.RunLimitedVariablesQuery(
// ctx, bulkSelectServerKeysSQL, s.db, iKeyIDs, sqlutil.SQLite3MaxVariables,
// func(rows *sql.Rows) error {
rows, err := queryServerKey(s, ctx, bulkSelectServerKeysSQL, params)
if err != nil {
return nil, err
}
for _, item := range rows {
var serverName string var serverName string
var keyID string var keyID string
var key string var key string
var validUntilTS int64 var validUntilTS int64
var expiredTS int64 var expiredTS int64
if err := rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil { // if err := rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil {
return fmt.Errorf("bulkSelectServerKeys: %v", err) // return fmt.Errorf("bulkSelectServerKeys: %v", err)
} // }
serverName = item.ServerKey.ServerName
keyID = item.ServerKey.ServerKeyID
validUntilTS = item.ServerKey.ValidUntilTimestamp
expiredTS = item.ServerKey.ExpiredTimestamp
key = item.ServerKey.ServerKey
r := gomatrixserverlib.PublicKeyLookupRequest{ r := gomatrixserverlib.PublicKeyLookupRequest{
ServerName: gomatrixserverlib.ServerName(serverName), ServerName: gomatrixserverlib.ServerName(serverName),
KeyID: gomatrixserverlib.KeyID(keyID), KeyID: gomatrixserverlib.KeyID(keyID),
@ -116,7 +175,7 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
vk := gomatrixserverlib.VerifyKey{} vk := gomatrixserverlib.VerifyKey{}
err := vk.Key.Decode(key) err := vk.Key.Decode(key)
if err != nil { if err != nil {
return fmt.Errorf("bulkSelectServerKeys: %v", err) return nil, fmt.Errorf("bulkSelectServerKeys: %v", err)
} }
results[r] = gomatrixserverlib.PublicKeyLookupResult{ results[r] = gomatrixserverlib.PublicKeyLookupResult{
VerifyKey: vk, VerifyKey: vk,
@ -124,14 +183,8 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
ExpiredTS: gomatrixserverlib.Timestamp(expiredTS), ExpiredTS: gomatrixserverlib.Timestamp(expiredTS),
} }
} }
return nil
},
)
if err != nil { return results, err
return nil, err
}
return results, nil
} }
func (s *serverKeyStatements) upsertServerKeys( func (s *serverKeyStatements) upsertServerKeys(
@ -139,19 +192,57 @@ func (s *serverKeyStatements) upsertServerKeys(
request gomatrixserverlib.PublicKeyLookupRequest, request gomatrixserverlib.PublicKeyLookupRequest,
key gomatrixserverlib.PublicKeyLookupResult, key gomatrixserverlib.PublicKeyLookupResult,
) error { ) error {
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.upsertServerKeysStmt) // "INSERT INTO keydb_server_keys (server_name, server_key_id," +
_, err := stmt.ExecContext( // " server_name_and_key_id, valid_until_ts, expired_ts, server_key)" +
// " VALUES ($1, $2, $3, $4, $5, $6)" +
// " ON CONFLICT (server_name, server_key_id)" +
// " DO UPDATE SET valid_until_ts = $4, expired_ts = $5, server_key = $6"
// stmt := sqlutil.TxStmt(txn, s.upsertServerKeysStmt)
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
// UNIQUE (server_name, server_key_id)
docId := fmt.Sprintf("%s_%s", string(request.ServerName), string(request.KeyID))
cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
data := ServerKeyCosmos{
ServerName: string(request.ServerName),
ServerKeyID: string(request.KeyID),
ServerNameAndKeyID: nameAndKeyID(request),
ValidUntilTimestamp: int64(key.ValidUntilTS),
ExpiredTimestamp: int64(key.ExpiredTS),
ServerKey: key.Key.Encode(),
}
dbData := &ServerKeyCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
ServerKey: data,
}
// _, err := stmt.ExecContext(
// ctx,
// string(request.ServerName),
// string(request.KeyID),
// nameAndKeyID(request),
// key.ValidUntilTS,
// key.ExpiredTS,
// key.Key.Encode(),
// )
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
_, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx, ctx,
string(request.ServerName), s.db.cosmosConfig.DatabaseName,
string(request.KeyID), s.db.cosmosConfig.ContainerName,
nameAndKeyID(request), &dbData,
key.ValidUntilTS, options)
key.ExpiredTS,
key.Key.Encode(),
)
return err return err
})
} }
func nameAndKeyID(request gomatrixserverlib.PublicKeyLookupRequest) string { func nameAndKeyID(request gomatrixserverlib.PublicKeyLookupRequest) string {