Add regression test

This commit is contained in:
Kegan Dougal 2022-05-16 17:15:43 +01:00
parent f34b05d7d2
commit 8a95edad75
12 changed files with 203 additions and 60 deletions

View file

@ -4,6 +4,7 @@ import (
"context"
"crypto/ed25519"
"encoding/json"
"fmt"
"strings"
"testing"
"time"
@ -11,6 +12,7 @@ import (
"github.com/matrix-org/dendrite/federationapi"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/internal"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
@ -19,11 +21,13 @@ import (
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
)
type fedRoomserverAPI struct {
rsapi.FederationRoomserverAPI
inputRoomEvents func(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse)
inputRoomEvents func(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse)
queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error
}
// PerformJoin will call this function
@ -34,20 +38,47 @@ func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.Input
f.inputRoomEvents(ctx, req, res)
}
// keychange consumer calls this
func (f *fedRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error {
if f.queryRoomsForUser == nil {
return nil
}
return f.queryRoomsForUser(ctx, req, res)
}
// TODO: This struct isn't generic, only works for TestFederationAPIJoinThenKeyUpdate
type fedClient struct {
api.FederationClient
allowJoins []*test.Room
t *testing.T
keys map[gomatrixserverlib.ServerName]struct {
key ed25519.PrivateKey
keyID gomatrixserverlib.KeyID
}
t *testing.T
sentTxn bool
}
func (f *fedClient) GetServerKeys(ctx context.Context, matrixServer gomatrixserverlib.ServerName) (gomatrixserverlib.ServerKeys, error) {
fmt.Println("GetServerKeys:", matrixServer)
var keys gomatrixserverlib.ServerKeys
var keyID gomatrixserverlib.KeyID
var pkey ed25519.PrivateKey
for srv, data := range f.keys {
if srv == matrixServer {
pkey = data.key
keyID = data.keyID
break
}
}
if pkey == nil {
return keys, nil
}
keys.ServerName = matrixServer
keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(time.Now().Add(10 * time.Hour))
publicKey := test.PrivateKey.Public().(ed25519.PublicKey)
publicKey := pkey.Public().(ed25519.PublicKey)
keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{
test.KeyID: {
keyID: {
Key: gomatrixserverlib.Base64Bytes(publicKey),
},
}
@ -57,7 +88,7 @@ func (f *fedClient) GetServerKeys(ctx context.Context, matrixServer gomatrixserv
}
keys.Raw, err = gomatrixserverlib.SignJSON(
string(matrixServer), test.KeyID, test.PrivateKey, toSign,
string(matrixServer), keyID, pkey, toSign,
)
if err != nil {
return keys, err
@ -102,6 +133,16 @@ func (f *fedClient) SendJoin(ctx context.Context, s gomatrixserverlib.ServerName
return
}
func (f *fedClient) SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) {
for _, edu := range t.EDUs {
if edu.Type == gomatrixserverlib.MDeviceListUpdate {
f.sentTxn = true
}
}
f.t.Logf("got /send")
return
}
// Regression test to make sure that /send_join is updating the destination hosts synchronously and
// isn't relying on the roomserver.
func TestFederationAPIJoinThenKeyUpdate(t *testing.T) {
@ -112,13 +153,22 @@ func TestFederationAPIJoinThenKeyUpdate(t *testing.T) {
func testFederationAPIJoinThenKeyUpdate(t *testing.T, dbType test.DBType) {
base, close := testrig.CreateBaseDendrite(t, dbType)
base.Cfg.FederationAPI.PreferDirectFetch = true
defer close()
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
user := test.NewUser()
room := test.NewRoom(t, user)
joiningVia := gomatrixserverlib.ServerName("example.localhost")
serverA := gomatrixserverlib.ServerName("server.a")
serverAKeyID := gomatrixserverlib.KeyID("ed25519:servera")
serverAPrivKey := test.PrivateKeyA
creator := test.NewUser(t, test.WithSigningServer(serverA, serverAKeyID, serverAPrivKey))
myServer := base.Cfg.Global.ServerName
myServerKeyID := base.Cfg.Global.KeyID
myServerPrivKey := base.Cfg.Global.PrivateKey
joiningUser := test.NewUser(t, test.WithSigningServer(myServer, myServerKeyID, myServerPrivKey))
fmt.Printf("creator: %v joining user: %v\n", creator.ID, joiningUser.ID)
room := test.NewRoom(t, creator)
rsapi := &fedRoomserverAPI{
inputRoomEvents: func(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) {
@ -126,29 +176,75 @@ func testFederationAPIJoinThenKeyUpdate(t *testing.T, dbType test.DBType) {
t.Errorf("InputRoomEvents from PerformJoin MUST be synchronous")
}
},
queryRoomsForUser: func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error {
if req.UserID == joiningUser.ID && req.WantMembership == "join" {
res.RoomIDs = []string{room.ID}
return nil
}
return fmt.Errorf("unexpected queryRoomsForUser: %+v", *req)
},
}
fsapi := federationapi.NewInternalAPI(base, &fedClient{
fc := &fedClient{
allowJoins: []*test.Room{room},
t: t,
}, rsapi, base.Caches, nil, false)
keys: map[gomatrixserverlib.ServerName]struct {
key ed25519.PrivateKey
keyID gomatrixserverlib.KeyID
}{
serverA: {
key: serverAPrivKey,
keyID: serverAKeyID,
},
myServer: {
key: myServerPrivKey,
keyID: myServerKeyID,
},
},
}
fsapi := federationapi.NewInternalAPI(base, fc, rsapi, base.Caches, nil, false)
var resp api.PerformJoinResponse
fsapi.PerformJoin(context.Background(), &api.PerformJoinRequest{
RoomID: room.ID,
UserID: user.ID,
ServerNames: []gomatrixserverlib.ServerName{joiningVia},
UserID: joiningUser.ID,
ServerNames: []gomatrixserverlib.ServerName{serverA},
}, &resp)
if resp.JoinedVia != joiningVia {
t.Errorf("PerformJoin: joined via %v want %v", resp.JoinedVia, joiningVia)
if resp.JoinedVia != serverA {
t.Errorf("PerformJoin: joined via %v want %v", resp.JoinedVia, serverA)
}
if resp.LastError != nil {
t.Fatalf("PerformJoin: returned error: %+v", *resp.LastError)
}
// TODO:
// Inject a keyserver key change event and ensure we try to send it out
// Inject a keyserver key change event and ensure we try to send it out. If we don't, then the
// federationapi is incorrectly waiting for an output room event to arrive to update the joined
// hosts table.
key := keyapi.DeviceMessage{
Type: keyapi.TypeDeviceKeyUpdate,
DeviceKeys: &keyapi.DeviceKeys{
UserID: joiningUser.ID,
DeviceID: "MY_DEVICE",
DisplayName: "BLARGLE",
KeyJSON: []byte(`{}`),
},
}
b, err := json.Marshal(key)
if err != nil {
t.Fatalf("Failed to marshal device message: %s", err)
}
msg := &nats.Msg{
Subject: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputKeyChangeEvent),
Header: nats.Header{},
Data: b,
}
msg.Header.Set(jetstream.UserID, key.UserID)
testrig.MustPublishMsgs(t, jsctx, msg)
time.Sleep(500 * time.Millisecond)
if !fc.sentTxn {
t.Fatalf("did not send device list update")
}
}
// Tests that event IDs with '/' in them (escaped as %2F) are correctly passed to the right handler and don't 404.

View file

@ -39,7 +39,7 @@ func mustCreateEventsTable(t *testing.T, dbType test.DBType) (tables.Events, fun
}
func Test_EventsTable(t *testing.T) {
alice := test.NewUser()
alice := test.NewUser(t)
room := test.NewRoom(t, alice)
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {

View file

@ -38,7 +38,7 @@ func mustCreatePreviousEventsTable(t *testing.T, dbType test.DBType) (tab tables
func TestPreviousEventsTable(t *testing.T) {
ctx := context.Background()
alice := test.NewUser()
alice := test.NewUser(t)
room := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, close := mustCreatePreviousEventsTable(t, dbType)

View file

@ -38,7 +38,7 @@ func mustCreatePublishedTable(t *testing.T, dbType test.DBType) (tab tables.Publ
func TestPublishedTable(t *testing.T) {
ctx := context.Background()
alice := test.NewUser()
alice := test.NewUser(t)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, close := mustCreatePublishedTable(t, dbType)

View file

@ -47,7 +47,7 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver
func TestWriteEvents(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
alice := test.NewUser()
alice := test.NewUser(t)
r := test.NewRoom(t, alice)
db, close := MustCreateDatabase(t, dbType)
defer close()
@ -60,7 +60,7 @@ func TestRecentEventsPDU(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := MustCreateDatabase(t, dbType)
defer close()
alice := test.NewUser()
alice := test.NewUser(t)
// dummy room to make sure SQL queries are filtering on room ID
MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
@ -163,7 +163,7 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := MustCreateDatabase(t, dbType)
defer close()
alice := test.NewUser()
alice := test.NewUser(t)
r := test.NewRoom(t, alice)
for i := 0; i < 10; i++ {
r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("hi %d", i)})

View file

@ -45,7 +45,7 @@ func newOutputRoomEventsTable(t *testing.T, dbType test.DBType) (tables.Events,
func TestOutputRoomEventsTable(t *testing.T) {
ctx := context.Background()
alice := test.NewUser()
alice := test.NewUser(t)
room := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, db, close := newOutputRoomEventsTable(t, dbType)

View file

@ -40,7 +40,7 @@ func newTopologyTable(t *testing.T, dbType test.DBType) (tables.Topology, *sql.D
func TestTopologyTable(t *testing.T) {
ctx := context.Background()
alice := test.NewUser()
alice := test.NewUser(t)
room := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, db, close := newTopologyTable(t, dbType)

View file

@ -87,7 +87,7 @@ func TestSyncAPIAccessTokens(t *testing.T) {
}
func testSyncAccessTokens(t *testing.T, dbType test.DBType) {
user := test.NewUser()
user := test.NewUser(t)
room := test.NewRoom(t, user)
alice := userapi.Device{
ID: "ALICEID",
@ -174,7 +174,7 @@ func TestSyncAPICreateRoomSyncEarly(t *testing.T) {
}
func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) {
user := test.NewUser()
user := test.NewUser(t)
room := test.NewRoom(t, user)
alice := userapi.Device{
ID: "ALICEID",

View file

@ -64,6 +64,12 @@ func WithPrivateKey(pkey ed25519.PrivateKey) eventModifier {
}
}
func WithOrigin(origin gomatrixserverlib.ServerName) eventModifier {
return func(e *eventMods) {
e.origin = origin
}
}
// Reverse a list of events
func Reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent {
out := make([]*gomatrixserverlib.HeaderedEvent, len(in))

View file

@ -15,7 +15,6 @@
package test
import (
"crypto/ed25519"
"encoding/json"
"fmt"
"sync/atomic"
@ -35,12 +34,6 @@ var (
PresetTrustedPrivateChat Preset = 3
roomIDCounter = int64(0)
KeyID = gomatrixserverlib.KeyID("ed25519:test")
PrivateKey = ed25519.NewKeyFromSeed([]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
})
)
type Room struct {
@ -58,10 +51,11 @@ type Room struct {
func NewRoom(t *testing.T, creator *User, modifiers ...roomModifier) *Room {
t.Helper()
counter := atomic.AddInt64(&roomIDCounter, 1)
// set defaults then let roomModifiers override
if creator.srvName == "" {
t.Fatalf("NewRoom: creator doesn't belong to a server: %+v", *creator)
}
r := &Room{
ID: fmt.Sprintf("!%d:localhost", counter),
ID: fmt.Sprintf("!%d:%s", counter, creator.srvName),
creator: creator,
authEvents: gomatrixserverlib.NewAuthEvents(nil),
preset: PresetPublicChat,
@ -108,16 +102,21 @@ func (r *Room) insertCreateEvents(t *testing.T) {
joinRule.JoinRule = "public"
hisVis.HistoryVisibility = "shared"
}
var mods []eventModifier
if r.creator.keyID != "" && r.creator.privKey != nil {
mods = append(mods, WithKeyID(r.creator.keyID), WithPrivateKey(r.creator.privKey), WithOrigin(r.creator.srvName))
}
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomCreate, map[string]interface{}{
"creator": r.creator.ID,
"room_version": r.Version,
}, WithStateKey(""))
}, append(mods, WithStateKey(""))...)
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomMember, map[string]interface{}{
"membership": "join",
}, WithStateKey(r.creator.ID))
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomPowerLevels, plContent, WithStateKey(""))
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomJoinRules, joinRule, WithStateKey(""))
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomHistoryVisibility, hisVis, WithStateKey(""))
}, append(mods, WithStateKey(r.creator.ID))...)
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomPowerLevels, plContent, append(mods, WithStateKey(""))...)
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomJoinRules, joinRule, append(mods, WithStateKey(""))...)
r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomHistoryVisibility, hisVis, append(mods, WithStateKey(""))...)
}
// Create an event in this room but do not insert it. Does not modify the room in any way (depth, fwd extremities, etc) so is thread-safe.
@ -132,16 +131,16 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten
}
if mod.privKey == nil {
mod.privKey = PrivateKey
t.Fatalf("CreateEvent[%s]: missing private key", eventType)
}
if mod.keyID == "" {
mod.keyID = KeyID
t.Fatalf("CreateEvent[%s]: missing key ID", eventType)
}
if mod.originServerTS.IsZero() {
mod.originServerTS = time.Now()
}
if mod.origin == "" {
mod.origin = gomatrixserverlib.ServerName("localhost")
t.Fatalf("CreateEvent[%s]: missing origin", eventType)
}
var unsigned gomatrixserverlib.RawJSON

View file

@ -15,22 +15,64 @@
package test
import (
"crypto/ed25519"
"fmt"
"sync/atomic"
"testing"
"github.com/matrix-org/gomatrixserverlib"
)
var (
userIDCounter = int64(0)
serverName = gomatrixserverlib.ServerName("test")
keyID = gomatrixserverlib.KeyID("ed25519:test")
privateKey = ed25519.NewKeyFromSeed([]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
})
// private keys that tests can use
PrivateKeyA = ed25519.NewKeyFromSeed([]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 77,
})
PrivateKeyB = ed25519.NewKeyFromSeed([]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 66,
})
)
type User struct {
ID string
// key ID and private key of the server who has this user, if known.
keyID gomatrixserverlib.KeyID
privKey ed25519.PrivateKey
srvName gomatrixserverlib.ServerName
}
func NewUser() *User {
counter := atomic.AddInt64(&userIDCounter, 1)
u := &User{
ID: fmt.Sprintf("@%d:localhost", counter),
type UserOpt func(*User)
func WithSigningServer(srvName gomatrixserverlib.ServerName, keyID gomatrixserverlib.KeyID, privKey ed25519.PrivateKey) UserOpt {
return func(u *User) {
u.keyID = keyID
u.privKey = privKey
u.srvName = srvName
}
return u
}
func NewUser(t *testing.T, opts ...UserOpt) *User {
counter := atomic.AddInt64(&userIDCounter, 1)
var u User
for _, opt := range opts {
opt(&u)
}
if u.keyID == "" || u.srvName == "" || u.privKey == nil {
t.Logf("NewUser: missing signing server credentials; using default.")
WithSigningServer(serverName, keyID, privateKey)(&u)
}
u.ID = fmt.Sprintf("@%d:%s", counter, u.srvName)
t.Logf("NewUser: created user %s", u.ID)
return &u
}

View file

@ -43,7 +43,7 @@ func Test_AccountData(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
defer close()
alice := test.NewUser()
alice := test.NewUser(t)
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
@ -74,7 +74,7 @@ func Test_Accounts(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
defer close()
alice := test.NewUser()
alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
@ -128,7 +128,7 @@ func Test_Accounts(t *testing.T) {
}
func Test_Devices(t *testing.T) {
alice := test.NewUser()
alice := test.NewUser(t)
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
deviceID := util.RandomString(8)
@ -212,7 +212,7 @@ func Test_Devices(t *testing.T) {
}
func Test_KeyBackup(t *testing.T) {
alice := test.NewUser()
alice := test.NewUser(t)
room := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
@ -291,7 +291,7 @@ func Test_KeyBackup(t *testing.T) {
}
func Test_LoginToken(t *testing.T) {
alice := test.NewUser()
alice := test.NewUser(t)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
defer close()
@ -321,7 +321,7 @@ func Test_LoginToken(t *testing.T) {
}
func Test_OpenID(t *testing.T) {
alice := test.NewUser()
alice := test.NewUser(t)
token := util.RandomString(24)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
@ -341,7 +341,7 @@ func Test_OpenID(t *testing.T) {
}
func Test_Profile(t *testing.T) {
alice := test.NewUser()
alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
@ -379,7 +379,7 @@ func Test_Profile(t *testing.T) {
}
func Test_Pusher(t *testing.T) {
alice := test.NewUser()
alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
@ -430,7 +430,7 @@ func Test_Pusher(t *testing.T) {
}
func Test_ThreePID(t *testing.T) {
alice := test.NewUser()
alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
@ -467,7 +467,7 @@ func Test_ThreePID(t *testing.T) {
}
func Test_Notification(t *testing.T) {
alice := test.NewUser()
alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
room := test.NewRoom(t, alice)