Add keydb_server_keys table tests (#3270)

Also moves some of the variable declarations out of the loop to,
hopefully, reduce allocations.
This commit is contained in:
Till 2023-11-22 13:05:24 +01:00 committed by GitHub
parent 06e079abac
commit 210123bab5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 129 additions and 12 deletions

View file

@ -94,12 +94,14 @@ func (s *serverSigningKeyStatements) BulkSelectServerKeys(
} }
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectServerKeys: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectServerKeys: rows.close() failed")
results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{} results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{}
for rows.Next() {
var serverName string var serverName string
var keyID string var keyID string
var key string var key string
var validUntilTS int64 var validUntilTS int64
var expiredTS int64 var expiredTS int64
var vk gomatrixserverlib.VerifyKey
for rows.Next() {
if err = rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil { if err = rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil {
return nil, err return nil, err
} }
@ -107,7 +109,6 @@ func (s *serverSigningKeyStatements) BulkSelectServerKeys(
ServerName: spec.ServerName(serverName), ServerName: spec.ServerName(serverName),
KeyID: gomatrixserverlib.KeyID(keyID), KeyID: gomatrixserverlib.KeyID(keyID),
} }
vk := gomatrixserverlib.VerifyKey{}
err = vk.Key.Decode(key) err = vk.Key.Decode(key)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -98,12 +98,13 @@ func (s *serverSigningKeyStatements) BulkSelectServerKeys(
err := sqlutil.RunLimitedVariablesQuery( err := sqlutil.RunLimitedVariablesQuery(
ctx, bulkSelectServerSigningKeysSQL, s.db, iKeyIDs, sqlutil.SQLite3MaxVariables, ctx, bulkSelectServerSigningKeysSQL, s.db, iKeyIDs, sqlutil.SQLite3MaxVariables,
func(rows *sql.Rows) error { func(rows *sql.Rows) error {
for rows.Next() {
var serverName string var serverName string
var keyID string var keyID string
var key string var key string
var validUntilTS int64 var validUntilTS int64
var expiredTS int64 var expiredTS int64
var vk gomatrixserverlib.VerifyKey
for rows.Next() {
if err := rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil { if err := rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil {
return fmt.Errorf("bulkSelectServerKeys: %v", err) return fmt.Errorf("bulkSelectServerKeys: %v", err)
} }
@ -111,7 +112,6 @@ func (s *serverSigningKeyStatements) BulkSelectServerKeys(
ServerName: spec.ServerName(serverName), ServerName: spec.ServerName(serverName),
KeyID: gomatrixserverlib.KeyID(keyID), KeyID: gomatrixserverlib.KeyID(keyID),
} }
vk := gomatrixserverlib.VerifyKey{}
err := vk.Key.Decode(key) err := vk.Key.Decode(key)
if err != nil { if err != nil {
return fmt.Errorf("bulkSelectServerKeys: %v", err) return fmt.Errorf("bulkSelectServerKeys: %v", err)

View file

@ -0,0 +1,116 @@
package tables_test
import (
"context"
"testing"
"time"
"github.com/matrix-org/dendrite/federationapi/storage/postgres"
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
"github.com/matrix-org/dendrite/federationapi/storage/tables"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/stretchr/testify/assert"
)
func mustCreateServerKeyDB(t *testing.T, dbType test.DBType) (tables.FederationServerSigningKeys, func()) {
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
if err != nil {
t.Fatalf("failed to open database: %s", err)
}
var tab tables.FederationServerSigningKeys
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresServerSigningKeysTable(db)
case test.DBTypeSQLite:
tab, err = sqlite3.NewSQLiteServerSigningKeysTable(db)
}
if err != nil {
t.Fatalf("failed to create table: %s", err)
}
return tab, close
}
func TestServerKeysTable(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
ctx, cancel := context.WithCancel(context.Background())
tab, close := mustCreateServerKeyDB(t, dbType)
t.Cleanup(func() {
close()
cancel()
})
req := gomatrixserverlib.PublicKeyLookupRequest{
ServerName: "localhost",
KeyID: "ed25519:test",
}
expectedTimestamp := spec.AsTimestamp(time.Now().Add(time.Hour))
res := gomatrixserverlib.PublicKeyLookupResult{
VerifyKey: gomatrixserverlib.VerifyKey{Key: make(spec.Base64Bytes, 0)},
ExpiredTS: 0,
ValidUntilTS: expectedTimestamp,
}
// Insert the key
err := tab.UpsertServerKeys(ctx, nil, req, res)
assert.NoError(t, err)
selectKeys := map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp{
req: spec.AsTimestamp(time.Now()),
}
gotKeys, err := tab.BulkSelectServerKeys(ctx, nil, selectKeys)
assert.NoError(t, err)
// Now we should have a key for the req above
assert.NotNil(t, gotKeys[req])
assert.Equal(t, res, gotKeys[req])
// "Expire" the key by setting ExpireTS to a non-zero value and ValidUntilTS to 0
expectedTimestamp = spec.AsTimestamp(time.Now())
res.ExpiredTS = expectedTimestamp
res.ValidUntilTS = 0
// Update the key
err = tab.UpsertServerKeys(ctx, nil, req, res)
assert.NoError(t, err)
gotKeys, err = tab.BulkSelectServerKeys(ctx, nil, selectKeys)
assert.NoError(t, err)
// The key should be expired
assert.NotNil(t, gotKeys[req])
assert.Equal(t, res, gotKeys[req])
// Upsert a different key to validate querying multiple keys
req2 := gomatrixserverlib.PublicKeyLookupRequest{
ServerName: "notlocalhost",
KeyID: "ed25519:test2",
}
expectedTimestamp2 := spec.AsTimestamp(time.Now().Add(time.Hour))
res2 := gomatrixserverlib.PublicKeyLookupResult{
VerifyKey: gomatrixserverlib.VerifyKey{Key: make(spec.Base64Bytes, 0)},
ExpiredTS: 0,
ValidUntilTS: expectedTimestamp2,
}
err = tab.UpsertServerKeys(ctx, nil, req2, res2)
assert.NoError(t, err)
// Select multiple keys
selectKeys[req2] = spec.AsTimestamp(time.Now())
gotKeys, err = tab.BulkSelectServerKeys(ctx, nil, selectKeys)
assert.NoError(t, err)
// We now should receive two keys, one of which is expired
assert.Equal(t, 2, len(gotKeys))
assert.Equal(t, res2, gotKeys[req2])
assert.Equal(t, res, gotKeys[req])
})
}