Use new testrig for key changes tests (#2552)
* Use new testrig for tests * Log the error message
This commit is contained in:
parent
43147bd654
commit
f29cdb26f6
|
@ -1,36 +1,26 @@
|
||||||
package storage
|
package storage_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
|
||||||
"log"
|
|
||||||
"os"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/keyserver/api"
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/keyserver/storage"
|
||||||
"github.com/matrix-org/dendrite/keyserver/types"
|
"github.com/matrix-org/dendrite/keyserver/types"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/test"
|
||||||
|
"github.com/matrix-org/dendrite/test/testrig"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ctx = context.Background()
|
var ctx = context.Background()
|
||||||
|
|
||||||
func MustCreateDatabase(t *testing.T) (Database, func()) {
|
func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
|
||||||
tmpfile, err := ioutil.TempFile("", "keyserver_storage_test")
|
base, close := testrig.CreateBaseDendrite(t, dbType)
|
||||||
|
db, err := storage.NewDatabase(base, &base.Cfg.KeyServer.Database)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
t.Fatalf("failed to create new database: %v", err)
|
||||||
}
|
|
||||||
t.Logf("Database %s", tmpfile.Name())
|
|
||||||
db, err := NewDatabase(nil, &config.DatabaseOptions{
|
|
||||||
ConnectionString: config.DataSource(fmt.Sprintf("file://%s", tmpfile.Name())),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to NewDatabase: %s", err)
|
|
||||||
}
|
|
||||||
return db, func() {
|
|
||||||
os.Remove(tmpfile.Name())
|
|
||||||
}
|
}
|
||||||
|
return db, close
|
||||||
}
|
}
|
||||||
|
|
||||||
func MustNotError(t *testing.T, err error) {
|
func MustNotError(t *testing.T, err error) {
|
||||||
|
@ -42,151 +32,159 @@ func MustNotError(t *testing.T, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestKeyChanges(t *testing.T) {
|
func TestKeyChanges(t *testing.T) {
|
||||||
db, clean := MustCreateDatabase(t)
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
defer clean()
|
db, clean := MustCreateDatabase(t, dbType)
|
||||||
_, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
defer clean()
|
||||||
MustNotError(t, err)
|
_, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||||
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
|
MustNotError(t, err)
|
||||||
MustNotError(t, err)
|
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
|
||||||
deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost")
|
MustNotError(t, err)
|
||||||
MustNotError(t, err)
|
deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost")
|
||||||
userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest)
|
MustNotError(t, err)
|
||||||
if err != nil {
|
userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest)
|
||||||
t.Fatalf("Failed to KeyChanges: %s", err)
|
if err != nil {
|
||||||
}
|
t.Fatalf("Failed to KeyChanges: %s", err)
|
||||||
if latest != deviceChangeIDC {
|
}
|
||||||
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC)
|
if latest != deviceChangeIDC {
|
||||||
}
|
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC)
|
||||||
if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) {
|
}
|
||||||
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) {
|
||||||
}
|
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestKeyChangesNoDupes(t *testing.T) {
|
func TestKeyChangesNoDupes(t *testing.T) {
|
||||||
db, clean := MustCreateDatabase(t)
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
defer clean()
|
db, clean := MustCreateDatabase(t, dbType)
|
||||||
deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
defer clean()
|
||||||
MustNotError(t, err)
|
deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||||
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
MustNotError(t, err)
|
||||||
MustNotError(t, err)
|
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||||
if deviceChangeIDA == deviceChangeIDB {
|
MustNotError(t, err)
|
||||||
t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA)
|
if deviceChangeIDA == deviceChangeIDB {
|
||||||
}
|
t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA)
|
||||||
deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
}
|
||||||
MustNotError(t, err)
|
deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||||
userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest)
|
MustNotError(t, err)
|
||||||
if err != nil {
|
userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest)
|
||||||
t.Fatalf("Failed to KeyChanges: %s", err)
|
if err != nil {
|
||||||
}
|
t.Fatalf("Failed to KeyChanges: %s", err)
|
||||||
if latest != deviceChangeID {
|
}
|
||||||
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID)
|
if latest != deviceChangeID {
|
||||||
}
|
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID)
|
||||||
if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) {
|
}
|
||||||
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) {
|
||||||
}
|
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestKeyChangesUpperLimit(t *testing.T) {
|
func TestKeyChangesUpperLimit(t *testing.T) {
|
||||||
db, clean := MustCreateDatabase(t)
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
defer clean()
|
db, clean := MustCreateDatabase(t, dbType)
|
||||||
deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
defer clean()
|
||||||
MustNotError(t, err)
|
deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||||
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
|
MustNotError(t, err)
|
||||||
MustNotError(t, err)
|
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
|
||||||
_, err = db.StoreKeyChange(ctx, "@charlie:localhost")
|
MustNotError(t, err)
|
||||||
MustNotError(t, err)
|
_, err = db.StoreKeyChange(ctx, "@charlie:localhost")
|
||||||
userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB)
|
MustNotError(t, err)
|
||||||
if err != nil {
|
userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB)
|
||||||
t.Fatalf("Failed to KeyChanges: %s", err)
|
if err != nil {
|
||||||
}
|
t.Fatalf("Failed to KeyChanges: %s", err)
|
||||||
if latest != deviceChangeIDB {
|
}
|
||||||
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB)
|
if latest != deviceChangeIDB {
|
||||||
}
|
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB)
|
||||||
if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) {
|
}
|
||||||
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) {
|
||||||
}
|
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user,
|
// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user,
|
||||||
// and that they are returned correctly when querying for device keys.
|
// and that they are returned correctly when querying for device keys.
|
||||||
func TestDeviceKeysStreamIDGeneration(t *testing.T) {
|
func TestDeviceKeysStreamIDGeneration(t *testing.T) {
|
||||||
var err error
|
var err error
|
||||||
db, clean := MustCreateDatabase(t)
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
defer clean()
|
db, clean := MustCreateDatabase(t, dbType)
|
||||||
alice := "@alice:TestDeviceKeysStreamIDGeneration"
|
defer clean()
|
||||||
bob := "@bob:TestDeviceKeysStreamIDGeneration"
|
alice := "@alice:TestDeviceKeysStreamIDGeneration"
|
||||||
msgs := []api.DeviceMessage{
|
bob := "@bob:TestDeviceKeysStreamIDGeneration"
|
||||||
{
|
msgs := []api.DeviceMessage{
|
||||||
Type: api.TypeDeviceKeyUpdate,
|
{
|
||||||
DeviceKeys: &api.DeviceKeys{
|
Type: api.TypeDeviceKeyUpdate,
|
||||||
DeviceID: "AAA",
|
DeviceKeys: &api.DeviceKeys{
|
||||||
UserID: alice,
|
DeviceID: "AAA",
|
||||||
KeyJSON: []byte(`{"key":"v1"}`),
|
UserID: alice,
|
||||||
|
KeyJSON: []byte(`{"key":"v1"}`),
|
||||||
|
},
|
||||||
|
// StreamID: 1
|
||||||
},
|
},
|
||||||
// StreamID: 1
|
{
|
||||||
},
|
Type: api.TypeDeviceKeyUpdate,
|
||||||
{
|
DeviceKeys: &api.DeviceKeys{
|
||||||
Type: api.TypeDeviceKeyUpdate,
|
DeviceID: "AAA",
|
||||||
DeviceKeys: &api.DeviceKeys{
|
UserID: bob,
|
||||||
DeviceID: "AAA",
|
KeyJSON: []byte(`{"key":"v1"}`),
|
||||||
UserID: bob,
|
},
|
||||||
KeyJSON: []byte(`{"key":"v1"}`),
|
// StreamID: 1 as this is a different user
|
||||||
},
|
},
|
||||||
// StreamID: 1 as this is a different user
|
{
|
||||||
},
|
Type: api.TypeDeviceKeyUpdate,
|
||||||
{
|
DeviceKeys: &api.DeviceKeys{
|
||||||
Type: api.TypeDeviceKeyUpdate,
|
DeviceID: "another_device",
|
||||||
DeviceKeys: &api.DeviceKeys{
|
UserID: alice,
|
||||||
DeviceID: "another_device",
|
KeyJSON: []byte(`{"key":"v1"}`),
|
||||||
UserID: alice,
|
},
|
||||||
KeyJSON: []byte(`{"key":"v1"}`),
|
// StreamID: 2 as this is a 2nd device key
|
||||||
},
|
},
|
||||||
// StreamID: 2 as this is a 2nd device key
|
|
||||||
},
|
|
||||||
}
|
|
||||||
MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
|
|
||||||
if msgs[0].StreamID != 1 {
|
|
||||||
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID)
|
|
||||||
}
|
|
||||||
if msgs[1].StreamID != 1 {
|
|
||||||
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID)
|
|
||||||
}
|
|
||||||
if msgs[2].StreamID != 2 {
|
|
||||||
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// updating a device sets the next stream ID for that user
|
|
||||||
msgs = []api.DeviceMessage{
|
|
||||||
{
|
|
||||||
Type: api.TypeDeviceKeyUpdate,
|
|
||||||
DeviceKeys: &api.DeviceKeys{
|
|
||||||
DeviceID: "AAA",
|
|
||||||
UserID: alice,
|
|
||||||
KeyJSON: []byte(`{"key":"v2"}`),
|
|
||||||
},
|
|
||||||
// StreamID: 3
|
|
||||||
},
|
|
||||||
}
|
|
||||||
MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
|
|
||||||
if msgs[0].StreamID != 3 {
|
|
||||||
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Querying for device keys returns the latest stream IDs
|
|
||||||
msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}, false)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("DeviceKeysForUser returned error: %s", err)
|
|
||||||
}
|
|
||||||
wantStreamIDs := map[string]int64{
|
|
||||||
"AAA": 3,
|
|
||||||
"another_device": 2,
|
|
||||||
}
|
|
||||||
if len(msgs) != len(wantStreamIDs) {
|
|
||||||
t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs))
|
|
||||||
}
|
|
||||||
for _, m := range msgs {
|
|
||||||
if m.StreamID != wantStreamIDs[m.DeviceID] {
|
|
||||||
t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID])
|
|
||||||
}
|
}
|
||||||
}
|
MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
|
||||||
|
if msgs[0].StreamID != 1 {
|
||||||
|
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID)
|
||||||
|
}
|
||||||
|
if msgs[1].StreamID != 1 {
|
||||||
|
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID)
|
||||||
|
}
|
||||||
|
if msgs[2].StreamID != 2 {
|
||||||
|
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// updating a device sets the next stream ID for that user
|
||||||
|
msgs = []api.DeviceMessage{
|
||||||
|
{
|
||||||
|
Type: api.TypeDeviceKeyUpdate,
|
||||||
|
DeviceKeys: &api.DeviceKeys{
|
||||||
|
DeviceID: "AAA",
|
||||||
|
UserID: alice,
|
||||||
|
KeyJSON: []byte(`{"key":"v2"}`),
|
||||||
|
},
|
||||||
|
// StreamID: 3
|
||||||
|
},
|
||||||
|
}
|
||||||
|
MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
|
||||||
|
if msgs[0].StreamID != 3 {
|
||||||
|
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Querying for device keys returns the latest stream IDs
|
||||||
|
msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("DeviceKeysForUser returned error: %s", err)
|
||||||
|
}
|
||||||
|
wantStreamIDs := map[string]int64{
|
||||||
|
"AAA": 3,
|
||||||
|
"another_device": 2,
|
||||||
|
}
|
||||||
|
if len(msgs) != len(wantStreamIDs) {
|
||||||
|
t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs))
|
||||||
|
}
|
||||||
|
for _, m := range msgs {
|
||||||
|
if m.StreamID != wantStreamIDs[m.DeviceID] {
|
||||||
|
t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue