Store keys rather than json in the keydatabase

Rather than storing the raw JSON returned from a /keys/v1/query call in the
table, store the key itself.

This makes keydb.Database implement the updated KeyDatabase interface.
This commit is contained in:
Richard van der Hoff 2017-11-10 22:03:11 +00:00
parent 9a9724bcb9
commit 9e98cb3740
3 changed files with 36 additions and 29 deletions

View file

@ -48,14 +48,14 @@ func NewDatabase(dataSourceName string) (*Database, error) {
func (d *Database) FetchKeys( func (d *Database) FetchKeys(
ctx context.Context, ctx context.Context,
requests map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.Timestamp, requests map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.Timestamp,
) (map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys, error) { ) (map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
return d.statements.bulkSelectServerKeys(ctx, requests) return d.statements.bulkSelectServerKeys(ctx, requests)
} }
// StoreKeys implements gomatrixserverlib.KeyDatabase // StoreKeys implements gomatrixserverlib.KeyDatabase
func (d *Database) StoreKeys( func (d *Database) StoreKeys(
ctx context.Context, ctx context.Context,
keyMap map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys, keyMap map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.PublicKeyLookupResult,
) error { ) error {
// TODO: Inserting all the keys within a single transaction may // TODO: Inserting all the keys within a single transaction may
// be more efficient since the transaction overhead can be quite // be more efficient since the transaction overhead can be quite

View file

@ -17,14 +17,13 @@ package keydb
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
const serverKeysSchema = ` const serverKeysSchema = `
-- A cache of server keys downloaded from remote servers. -- A cache of signing keys downloaded from remote servers.
CREATE TABLE IF NOT EXISTS keydb_server_keys ( CREATE TABLE IF NOT EXISTS keydb_server_keys (
-- The name of the matrix server the key is for. -- The name of the matrix server the key is for.
server_name TEXT NOT NULL, server_name TEXT NOT NULL,
@ -33,10 +32,14 @@ CREATE TABLE IF NOT EXISTS keydb_server_keys (
-- Combined server name and key ID separated by the ASCII unit separator -- Combined server name and key ID separated by the ASCII unit separator
-- to make it easier to run bulk queries. -- to make it easier to run bulk queries.
server_name_and_key_id TEXT NOT NULL, server_name_and_key_id TEXT NOT NULL,
-- When the keys are valid until as a millisecond timestamp. -- When the key is valid until as a millisecond timestamp.
-- 0 if this is an expired key (in which case expired_ts will be non-zero)
valid_until_ts BIGINT NOT NULL, valid_until_ts BIGINT NOT NULL,
-- The raw JSON for the server key. -- When the key expired as a millisecond timestamp.
server_key_json TEXT NOT NULL, -- 0 if this is an active key (in which case valid_until_ts will be non-zero)
expired_ts BIGINT NOT NULL,
-- The base64-encoded public key.
server_key TEXT NOT NULL,
CONSTRAINT keydb_server_keys_unique UNIQUE (server_name, server_key_id) CONSTRAINT keydb_server_keys_unique UNIQUE (server_name, server_key_id)
); );
@ -44,15 +47,16 @@ CREATE INDEX IF NOT EXISTS keydb_server_name_and_key_id ON keydb_server_keys (se
` `
const bulkSelectServerKeysSQL = "" + const bulkSelectServerKeysSQL = "" +
"SELECT server_name, server_key_id, server_key_json FROM keydb_server_keys" + "SELECT server_name, server_key_id, valid_until_ts, expired_ts, " +
" server_key FROM keydb_server_keys" +
" WHERE server_name_and_key_id = ANY($1)" " WHERE server_name_and_key_id = ANY($1)"
const upsertServerKeysSQL = "" + const upsertServerKeysSQL = "" +
"INSERT INTO keydb_server_keys (server_name, server_key_id," + "INSERT INTO keydb_server_keys (server_name, server_key_id," +
" server_name_and_key_id, valid_until_ts, server_key_json)" + " server_name_and_key_id, valid_until_ts, expired_ts, server_key)" +
" VALUES ($1, $2, $3, $4, $5)" + " VALUES ($1, $2, $3, $4, $5, $6)" +
" ON CONFLICT ON CONSTRAINT keydb_server_keys_unique" + " ON CONFLICT ON CONSTRAINT keydb_server_keys_unique" +
" DO UPDATE SET valid_until_ts = $4, server_key_json = $5" " DO UPDATE SET valid_until_ts = $4, expired_ts = $5, server_key = $6"
type serverKeyStatements struct { type serverKeyStatements struct {
bulkSelectServerKeysStmt *sql.Stmt bulkSelectServerKeysStmt *sql.Stmt
@ -76,7 +80,7 @@ func (s *serverKeyStatements) prepare(db *sql.DB) (err error) {
func (s *serverKeyStatements) bulkSelectServerKeys( func (s *serverKeyStatements) bulkSelectServerKeys(
ctx context.Context, ctx context.Context,
requests map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.Timestamp, requests map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.Timestamp,
) (map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys, error) { ) (map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
var nameAndKeyIDs []string var nameAndKeyIDs []string
for request := range requests { for request := range requests {
nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request)) nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request))
@ -87,23 +91,30 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
return nil, err return nil, err
} }
defer rows.Close() // nolint: errcheck defer rows.Close() // nolint: errcheck
results := map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys{} results := map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.PublicKeyLookupResult{}
for rows.Next() { for rows.Next() {
var serverName string var serverName string
var keyID string var keyID string
var keyJSON []byte var key string
if err := rows.Scan(&serverName, &keyID, &keyJSON); err != nil { var validUntilTS int64
return nil, err var expiredTS int64
} if err = rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil {
var serverKeys gomatrixserverlib.ServerKeys
if err := json.Unmarshal(keyJSON, &serverKeys); err != nil {
return nil, err return nil, err
} }
r := gomatrixserverlib.PublicKeyRequest{ r := gomatrixserverlib.PublicKeyRequest{
ServerName: gomatrixserverlib.ServerName(serverName), ServerName: gomatrixserverlib.ServerName(serverName),
KeyID: gomatrixserverlib.KeyID(keyID), KeyID: gomatrixserverlib.KeyID(keyID),
} }
results[r] = serverKeys vk := gomatrixserverlib.VerifyKey{}
err = vk.Key.Decode(key)
if err != nil {
return nil, err
}
results[r] = gomatrixserverlib.PublicKeyLookupResult{
VerifyKey: vk,
ValidUntilTS: gomatrixserverlib.Timestamp(validUntilTS),
ExpiredTS: gomatrixserverlib.Timestamp(expiredTS),
}
} }
return results, nil return results, nil
} }
@ -111,19 +122,16 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
func (s *serverKeyStatements) upsertServerKeys( func (s *serverKeyStatements) upsertServerKeys(
ctx context.Context, ctx context.Context,
request gomatrixserverlib.PublicKeyRequest, request gomatrixserverlib.PublicKeyRequest,
keys gomatrixserverlib.ServerKeys, key gomatrixserverlib.PublicKeyLookupResult,
) error { ) error {
keyJSON, err := json.Marshal(keys) _, err := s.upsertServerKeysStmt.ExecContext(
if err != nil {
return err
}
_, err = s.upsertServerKeysStmt.ExecContext(
ctx, ctx,
string(request.ServerName), string(request.ServerName),
string(request.KeyID), string(request.KeyID),
nameAndKeyID(request), nameAndKeyID(request),
int64(keys.ValidUntilTS), key.ValidUntilTS,
keyJSON, key.ExpiredTS,
key.Key.Encode(),
) )
return err return err
} }

View file

@ -38,7 +38,6 @@ func localKeys(cfg config.Dendrite, validUntil time.Time) (*gomatrixserverlib.Se
var keys gomatrixserverlib.ServerKeys var keys gomatrixserverlib.ServerKeys
keys.ServerName = cfg.Matrix.ServerName keys.ServerName = cfg.Matrix.ServerName
keys.FromServer = cfg.Matrix.ServerName
publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey) publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey)