- Update Config for signing_key_server to use CosmosDB (#10)

- Implement keydb to use Cosmos
This commit is contained in:
alexfca 2021-05-31 13:19:48 +10:00 committed by GitHub
parent db08aa6250
commit e763a6feb9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 207 additions and 108 deletions

View file

@ -302,7 +302,7 @@ signing_key_server:
listen: http://localhost:7780 listen: http://localhost:7780
connect: http://localhost:7780 connect: http://localhost:7780
database: database:
connection_string: file:signingkeyserver.db connection_string: "cosmosdb:AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;DatabaseName=safezone_local;ContainerName=test.criticalarc.com;"
max_open_conns: 10 max_open_conns: 10
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1

View file

@ -18,9 +18,12 @@ package cosmosdb
import ( import (
"context" "context"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -30,8 +33,11 @@ import (
// A Database implements gomatrixserverlib.KeyDatabase and is used to store // A Database implements gomatrixserverlib.KeyDatabase and is used to store
// the public keys for other matrix servers. // the public keys for other matrix servers.
type Database struct { type Database struct {
writer sqlutil.Writer writer cosmosdbutil.Writer
statements serverKeyStatements statements serverKeyStatements
connection cosmosdbapi.CosmosConnection
databaseName string
cosmosConfig cosmosdbapi.CosmosConfig
} }
// NewDatabase prepares a new key database. // NewDatabase prepares a new key database.
@ -44,14 +50,16 @@ func NewDatabase(
serverKey ed25519.PublicKey, serverKey ed25519.PublicKey,
serverKeyID gomatrixserverlib.KeyID, serverKeyID gomatrixserverlib.KeyID,
) (*Database, error) { ) (*Database, error) {
db, err := sqlutil.Open(dbProperties) conn := cosmosdbutil.GetCosmosConnection(&dbProperties.ConnectionString)
if err != nil { configCosmos := cosmosdbutil.GetCosmosConfig(&dbProperties.ConnectionString)
return nil, err
}
d := &Database{ d := &Database{
writer: sqlutil.NewExclusiveWriter(), databaseName: "keydb",
writer: cosmosdbutil.NewExclusiveWriterFake(),
connection: conn,
cosmosConfig: configCosmos,
} }
err = d.statements.prepare(db, d.writer) err := d.statements.prepare(d, d.writer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -63,7 +71,7 @@ func NewDatabase(
// FetcherName implements KeyFetcher // FetcherName implements KeyFetcher
func (d Database) FetcherName() string { func (d Database) FetcherName() string {
return "SqliteKeyDatabase" return "CosmosDBKeyDatabase"
} }
// FetchKeys implements gomatrixserverlib.KeyDatabase // FetchKeys implements gomatrixserverlib.KeyDatabase

View file

@ -19,67 +19,104 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"time"
"github.com/matrix-org/dendrite/internal/cosmosdbapi"
"github.com/matrix-org/dendrite/internal/cosmosdbutil"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
const serverKeysSchema = ` // const serverKeysSchema = `
-- A cache of signing keys downloaded from remote servers. // -- A cache of signing keys downloaded from remote servers.
CREATE TABLE IF NOT EXISTS keydb_server_keys ( // CREATE TABLE IF NOT EXISTS keydb_server_keys (
-- The name of the matrix server the key is for. // -- The name of the matrix server the key is for.
server_name TEXT NOT NULL, // server_name TEXT NOT NULL,
-- The ID of the server key. // -- The ID of the server key.
server_key_id TEXT NOT NULL, // server_key_id TEXT NOT NULL,
-- Combined server name and key ID separated by the ASCII unit separator // -- Combined server name and key ID separated by the ASCII unit separator
-- to make it easier to run bulk queries. // -- to make it easier to run bulk queries.
server_name_and_key_id TEXT NOT NULL, // server_name_and_key_id TEXT NOT NULL,
-- When the key is valid until as a millisecond timestamp. // -- When the key is valid until as a millisecond timestamp.
-- 0 if this is an expired key (in which case expired_ts will be non-zero) // -- 0 if this is an expired key (in which case expired_ts will be non-zero)
valid_until_ts BIGINT NOT NULL, // valid_until_ts BIGINT NOT NULL,
-- When the key expired as a millisecond timestamp. // -- When the key expired as a millisecond timestamp.
-- 0 if this is an active key (in which case valid_until_ts will be non-zero) // -- 0 if this is an active key (in which case valid_until_ts will be non-zero)
expired_ts BIGINT NOT NULL, // expired_ts BIGINT NOT NULL,
-- The base64-encoded public key. // -- The base64-encoded public key.
server_key TEXT NOT NULL, // server_key TEXT NOT NULL,
UNIQUE (server_name, server_key_id) // UNIQUE (server_name, server_key_id)
); // );
CREATE INDEX IF NOT EXISTS keydb_server_name_and_key_id ON keydb_server_keys (server_name_and_key_id); // CREATE INDEX IF NOT EXISTS keydb_server_name_and_key_id ON keydb_server_keys (server_name_and_key_id);
` // `
const bulkSelectServerKeysSQL = "" + type ServerKeyCosmos struct {
"SELECT server_name, server_key_id, valid_until_ts, expired_ts, " + ServerName string `json:"server_name"`
" server_key FROM keydb_server_keys" + ServerKeyID string `json:"server_key_id"`
" WHERE server_name_and_key_id IN ($1)" ServerNameAndKeyID string `json:"server_name_and_key_id"`
ValidUntilTimestamp int64 `json:"valid_until_ts"`
const upsertServerKeysSQL = "" + ExpiredTimestamp int64 `json:"expired_ts"`
"INSERT INTO keydb_server_keys (server_name, server_key_id," + ServerKey string `json:"server_key"`
" server_name_and_key_id, valid_until_ts, expired_ts, server_key)" +
" VALUES ($1, $2, $3, $4, $5, $6)" +
" ON CONFLICT (server_name, server_key_id)" +
" DO UPDATE SET valid_until_ts = $4, expired_ts = $5, server_key = $6"
type serverKeyStatements struct {
db *sql.DB
writer sqlutil.Writer
bulkSelectServerKeysStmt *sql.Stmt
upsertServerKeysStmt *sql.Stmt
} }
func (s *serverKeyStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { type ServerKeyCosmosData struct {
Id string `json:"id"`
Pk string `json:"_pk"`
Cn string `json:"_cn"`
ETag string `json:"_etag"`
Timestamp int64 `json:"_ts"`
ServerKey ServerKeyCosmos `json:"mx_keydb_server_key"`
}
// "SELECT server_name, server_key_id, valid_until_ts, expired_ts, " +
// " server_key FROM keydb_server_keys" +
// " WHERE server_name_and_key_id IN ($1)"
const bulkSelectServerKeysSQL = "" +
"select * from c where c._cn = @x1 " +
"and ARRAY_CONTAINS(@x2, c.mx_keydb_server_key.server_name_and_key_id) "
// const upsertServerKeysSQL = "" +
// "INSERT INTO keydb_server_keys (server_name, server_key_id," +
// " server_name_and_key_id, valid_until_ts, expired_ts, server_key)" +
// " VALUES ($1, $2, $3, $4, $5, $6)" +
// " ON CONFLICT (server_name, server_key_id)" +
// " DO UPDATE SET valid_until_ts = $4, expired_ts = $5, server_key = $6"
type serverKeyStatements struct {
db *Database
writer cosmosdbutil.Writer
bulkSelectServerKeysStmt *sql.Stmt
// upsertServerKeysStmt *sql.Stmt
tableName string
}
func queryServerKey(s *serverKeyStatements, ctx context.Context, qry string, params map[string]interface{}) ([]ServerKeyCosmosData, error) {
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
var pk = cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
var response []ServerKeyCosmosData
var optionsQry = cosmosdbapi.GetQueryDocumentsOptions(pk)
var query = cosmosdbapi.GetQuery(qry, params)
_, err := cosmosdbapi.GetClient(s.db.connection).QueryDocuments(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
query,
&response,
optionsQry)
if err != nil {
return nil, err
}
return response, nil
}
func (s *serverKeyStatements) prepare(db *Database, writer sqlutil.Writer) (err error) {
s.db = db s.db = db
s.writer = writer s.writer = writer
_, err = db.Exec(serverKeysSchema) s.tableName = "server_keys"
if err != nil {
return
}
if s.bulkSelectServerKeysStmt, err = db.Prepare(bulkSelectServerKeysSQL); err != nil {
return
}
if s.upsertServerKeysStmt, err = db.Prepare(upsertServerKeysSQL); err != nil {
return
}
return return
} }
@ -92,46 +129,62 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request)) nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request))
} }
results := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, len(requests)) results := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, len(requests))
iKeyIDs := make([]interface{}, len(nameAndKeyIDs)) // iKeyIDs := make([]interface{}, len(nameAndKeyIDs))
for i, v := range nameAndKeyIDs { // for i, v := range nameAndKeyIDs {
iKeyIDs[i] = v // iKeyIDs[i] = v
// }
var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
params := map[string]interface{}{
"@x1": dbCollectionName,
"@x2": nameAndKeyIDs,
} }
err := sqlutil.RunLimitedVariablesQuery( // "SELECT server_name, server_key_id, valid_until_ts, expired_ts, " +
ctx, bulkSelectServerKeysSQL, s.db, iKeyIDs, sqlutil.SQLite3MaxVariables, // " server_key FROM keydb_server_keys" +
func(rows *sql.Rows) error { // " WHERE server_name_and_key_id IN ($1)"
for rows.Next() {
var serverName string // err := sqlutil.RunLimitedVariablesQuery(
var keyID string // ctx, bulkSelectServerKeysSQL, s.db, iKeyIDs, sqlutil.SQLite3MaxVariables,
var key string // func(rows *sql.Rows) error {
var validUntilTS int64
var expiredTS int64 rows, err := queryServerKey(s, ctx, bulkSelectServerKeysSQL, params)
if err := rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil {
return fmt.Errorf("bulkSelectServerKeys: %v", err)
}
r := gomatrixserverlib.PublicKeyLookupRequest{
ServerName: gomatrixserverlib.ServerName(serverName),
KeyID: gomatrixserverlib.KeyID(keyID),
}
vk := gomatrixserverlib.VerifyKey{}
err := vk.Key.Decode(key)
if err != nil {
return fmt.Errorf("bulkSelectServerKeys: %v", err)
}
results[r] = gomatrixserverlib.PublicKeyLookupResult{
VerifyKey: vk,
ValidUntilTS: gomatrixserverlib.Timestamp(validUntilTS),
ExpiredTS: gomatrixserverlib.Timestamp(expiredTS),
}
}
return nil
},
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return results, nil
for _, item := range rows {
var serverName string
var keyID string
var key string
var validUntilTS int64
var expiredTS int64
// if err := rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil {
// return fmt.Errorf("bulkSelectServerKeys: %v", err)
// }
serverName = item.ServerKey.ServerName
keyID = item.ServerKey.ServerKeyID
validUntilTS = item.ServerKey.ValidUntilTimestamp
expiredTS = item.ServerKey.ExpiredTimestamp
key = item.ServerKey.ServerKey
r := gomatrixserverlib.PublicKeyLookupRequest{
ServerName: gomatrixserverlib.ServerName(serverName),
KeyID: gomatrixserverlib.KeyID(keyID),
}
vk := gomatrixserverlib.VerifyKey{}
err := vk.Key.Decode(key)
if err != nil {
return nil, fmt.Errorf("bulkSelectServerKeys: %v", err)
}
results[r] = gomatrixserverlib.PublicKeyLookupResult{
VerifyKey: vk,
ValidUntilTS: gomatrixserverlib.Timestamp(validUntilTS),
ExpiredTS: gomatrixserverlib.Timestamp(expiredTS),
}
}
return results, err
} }
func (s *serverKeyStatements) upsertServerKeys( func (s *serverKeyStatements) upsertServerKeys(
@ -139,19 +192,57 @@ func (s *serverKeyStatements) upsertServerKeys(
request gomatrixserverlib.PublicKeyLookupRequest, request gomatrixserverlib.PublicKeyLookupRequest,
key gomatrixserverlib.PublicKeyLookupResult, key gomatrixserverlib.PublicKeyLookupResult,
) error { ) error {
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.upsertServerKeysStmt) // "INSERT INTO keydb_server_keys (server_name, server_key_id," +
_, err := stmt.ExecContext( // " server_name_and_key_id, valid_until_ts, expired_ts, server_key)" +
ctx, // " VALUES ($1, $2, $3, $4, $5, $6)" +
string(request.ServerName), // " ON CONFLICT (server_name, server_key_id)" +
string(request.KeyID), // " DO UPDATE SET valid_until_ts = $4, expired_ts = $5, server_key = $6"
nameAndKeyID(request),
key.ValidUntilTS, // stmt := sqlutil.TxStmt(txn, s.upsertServerKeysStmt)
key.ExpiredTS,
key.Key.Encode(), var dbCollectionName = cosmosdbapi.GetCollectionName(s.db.databaseName, s.tableName)
) // UNIQUE (server_name, server_key_id)
return err docId := fmt.Sprintf("%s_%s", string(request.ServerName), string(request.KeyID))
}) cosmosDocId := cosmosdbapi.GetDocumentId(s.db.cosmosConfig.ContainerName, dbCollectionName, docId)
pk := cosmosdbapi.GetPartitionKey(s.db.cosmosConfig.ContainerName, dbCollectionName)
data := ServerKeyCosmos{
ServerName: string(request.ServerName),
ServerKeyID: string(request.KeyID),
ServerNameAndKeyID: nameAndKeyID(request),
ValidUntilTimestamp: int64(key.ValidUntilTS),
ExpiredTimestamp: int64(key.ExpiredTS),
ServerKey: key.Key.Encode(),
}
dbData := &ServerKeyCosmosData{
Id: cosmosDocId,
Cn: dbCollectionName,
Pk: pk,
Timestamp: time.Now().Unix(),
ServerKey: data,
}
// _, err := stmt.ExecContext(
// ctx,
// string(request.ServerName),
// string(request.KeyID),
// nameAndKeyID(request),
// key.ValidUntilTS,
// key.ExpiredTS,
// key.Key.Encode(),
// )
var options = cosmosdbapi.GetUpsertDocumentOptions(dbData.Pk)
_, _, err := cosmosdbapi.GetClient(s.db.connection).CreateDocument(
ctx,
s.db.cosmosConfig.DatabaseName,
s.db.cosmosConfig.ContainerName,
&dbData,
options)
return err
} }
func nameAndKeyID(request gomatrixserverlib.PublicKeyLookupRequest) string { func nameAndKeyID(request gomatrixserverlib.PublicKeyLookupRequest) string {