diff --git a/.github/workflows/dendrite.yml b/.github/workflows/dendrite.yml index c80a82d1d..4f337a866 100644 --- a/.github/workflows/dendrite.yml +++ b/.github/workflows/dendrite.yml @@ -73,6 +73,26 @@ jobs: timeout-minutes: 5 name: Unit tests (Go ${{ matrix.go }}) runs-on: ubuntu-latest + # Service containers to run with `container-job` + services: + # Label used to access the service container + postgres: + # Docker Hub image + image: postgres:13-alpine + # Provide the password for postgres + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: dendrite + ports: + # Maps tcp port 5432 on service container to the host + - 5432:5432 + # Set health checks to wait until postgres has started + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 strategy: fail-fast: false matrix: @@ -92,6 +112,11 @@ jobs: restore-keys: | ${{ runner.os }}-go${{ matrix.go }}-test- - run: go test ./... + env: + POSTGRES_HOST: localhost + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: dendrite # build Dendrite for linux with different architectures and go versions build: diff --git a/roomserver/internal/input/input_test.go b/roomserver/internal/input/input_test.go index 4fa966281..81c86ae38 100644 --- a/roomserver/internal/input/input_test.go +++ b/roomserver/internal/input/input_test.go @@ -2,7 +2,6 @@ package input_test import ( "context" - "fmt" "os" "testing" "time" @@ -12,30 +11,22 @@ import ( "github.com/matrix-org/dendrite/roomserver/internal/input" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" ) -func psqlConnectionString() config.DataSource { - user := os.Getenv("POSTGRES_USER") - if user == "" { - user = "dendrite" - } - dbName := os.Getenv("POSTGRES_DB") - if dbName == "" { - dbName = "dendrite" - } - connStr := fmt.Sprintf( - "user=%s dbname=%s sslmode=disable", user, dbName, - ) - password := os.Getenv("POSTGRES_PASSWORD") - if password != "" { - connStr += fmt.Sprintf(" password=%s", password) - } - host := os.Getenv("POSTGRES_HOST") - if host != "" { - connStr += fmt.Sprintf(" host=%s", host) - } - return config.DataSource(connStr) +var js nats.JetStreamContext +var jc *nats.Conn + +func TestMain(m *testing.M) { + var pc *process.ProcessContext + pc, js, jc = jetstream.PrepareForTests() + code := m.Run() + pc.ShutdownDendrite() + pc.WaitForComponentsToFinish() + os.Exit(code) } func TestSingleTransactionOnInput(t *testing.T) { @@ -63,7 +54,7 @@ func TestSingleTransactionOnInput(t *testing.T) { } db, err := storage.Open( &config.DatabaseOptions{ - ConnectionString: psqlConnectionString(), + ConnectionString: "", MaxOpenConnections: 1, MaxIdleConnections: 1, }, @@ -74,7 +65,9 @@ func TestSingleTransactionOnInput(t *testing.T) { t.SkipNow() } inputter := &input.Inputer{ - DB: db, + DB: db, + JetStream: js, + NATSClient: jc, } res := &api.InputRoomEventsResponse{} inputter.InputRoomEvents( diff --git a/setup/jetstream/nats.go b/setup/jetstream/nats.go index 4e4fe7a29..1c8a89e8d 100644 --- a/setup/jetstream/nats.go +++ b/setup/jetstream/nats.go @@ -13,12 +13,22 @@ import ( "github.com/sirupsen/logrus" natsserver "github.com/nats-io/nats-server/v2/server" + "github.com/nats-io/nats.go" natsclient "github.com/nats-io/nats.go" ) var natsServer *natsserver.Server var natsServerMutex sync.Mutex +func PrepareForTests() (*process.ProcessContext, nats.JetStreamContext, *nats.Conn) { + cfg := &config.Dendrite{} + cfg.Defaults(true) + cfg.Global.JetStream.InMemory = true + pc := process.NewProcessContext() + js, jc := Prepare(pc, &cfg.Global.JetStream) + return pc, js, jc +} + func Prepare(process *process.ProcessContext, cfg *config.JetStream) (natsclient.JetStreamContext, *natsclient.Conn) { // check if we need an in-process NATS Server if len(cfg.Addresses) != 0 { diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index 864322001..403b50eaa 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -1,121 +1,28 @@ package storage_test -// TODO: Fix these tests -/* import ( "context" - "crypto/ed25519" - "encoding/json" "fmt" - "os" "testing" - "time" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/storage" - "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" "github.com/matrix-org/dendrite/syncapi/types" - userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" ) -var ( - ctx = context.Background() - emptyStateKey = "" - testOrigin = gomatrixserverlib.ServerName("hollow.knight") - testRoomID = fmt.Sprintf("!hallownest:%s", testOrigin) - testUserIDA = fmt.Sprintf("@hornet:%s", testOrigin) - testUserIDB = fmt.Sprintf("@paleking:%s", testOrigin) - testUserDeviceA = userapi.Device{ - UserID: testUserIDA, - ID: "device_id_A", - DisplayName: "Device A", - } - testRoomVersion = gomatrixserverlib.RoomVersionV4 - testKeyID = gomatrixserverlib.KeyID("ed25519:storage_test") - testPrivateKey = ed25519.NewKeyFromSeed([]byte{ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, - }) -) +var ctx = context.Background() -func MustCreateEvent(t *testing.T, roomID string, prevs []*gomatrixserverlib.HeaderedEvent, b *gomatrixserverlib.EventBuilder) *gomatrixserverlib.HeaderedEvent { - b.RoomID = roomID - if prevs != nil { - prevIDs := make([]string, len(prevs)) - for i := range prevs { - prevIDs[i] = prevs[i].EventID() - } - b.PrevEvents = prevIDs - } - e, err := b.Build(time.Now(), testOrigin, testKeyID, testPrivateKey, testRoomVersion) - if err != nil { - t.Fatalf("failed to build event: %s", err) - } - return e.Headered(testRoomVersion) -} - -func MustCreateDatabase(t *testing.T) storage.Database { - dbname := fmt.Sprintf("test_%s.db", t.Name()) - if _, err := os.Stat(dbname); err == nil { - if err = os.Remove(dbname); err != nil { - t.Fatalf("tried to delete stale test database but failed: %s", err) - } - } - db, err := sqlite3.NewDatabase(&config.DatabaseOptions{ - ConnectionString: config.DataSource(fmt.Sprintf("file:%s", dbname)), +func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := storage.NewSyncServerDatasource(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), }) if err != nil { t.Fatalf("NewSyncServerDatasource returned %s", err) } - return db -} - -// Create a list of events which include a create event, join event and some messages. -func SimpleRoom(t *testing.T, roomID, userA, userB string) (msgs []*gomatrixserverlib.HeaderedEvent, state []*gomatrixserverlib.HeaderedEvent) { - var events []*gomatrixserverlib.HeaderedEvent - events = append(events, MustCreateEvent(t, roomID, nil, &gomatrixserverlib.EventBuilder{ - Content: []byte(fmt.Sprintf(`{"room_version":"4","creator":"%s"}`, userA)), - Type: "m.room.create", - StateKey: &emptyStateKey, - Sender: userA, - Depth: int64(len(events) + 1), - })) - state = append(state, events[len(events)-1]) - events = append(events, MustCreateEvent(t, roomID, []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ - Content: []byte(`{"membership":"join"}`), - Type: "m.room.member", - StateKey: &userA, - Sender: userA, - Depth: int64(len(events) + 1), - })) - state = append(state, events[len(events)-1]) - for i := 0; i < 10; i++ { - events = append(events, MustCreateEvent(t, roomID, []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ - Content: []byte(fmt.Sprintf(`{"body":"Message A %d"}`, i+1)), - Type: "m.room.message", - Sender: userA, - Depth: int64(len(events) + 1), - })) - } - events = append(events, MustCreateEvent(t, roomID, []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ - Content: []byte(`{"membership":"join"}`), - Type: "m.room.member", - StateKey: &userB, - Sender: userB, - Depth: int64(len(events) + 1), - })) - state = append(state, events[len(events)-1]) - for i := 0; i < 10; i++ { - events = append(events, MustCreateEvent(t, roomID, []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ - Content: []byte(fmt.Sprintf(`{"body":"Message B %d"}`, i+1)), - Type: "m.room.message", - Sender: userB, - Depth: int64(len(events) + 1), - })) - } - - return events, state + return db, close } func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserverlib.HeaderedEvent) (positions []types.StreamPosition) { @@ -138,111 +45,115 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver } func TestWriteEvents(t *testing.T) { - t.Parallel() - db := MustCreateDatabase(t) - events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB) - MustWriteEvents(t, db, events) + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + t.Parallel() + alice := test.NewUser() + r := test.NewRoom(t, alice) + db, close := MustCreateDatabase(t, dbType) + defer close() + MustWriteEvents(t, db, r.Events()) + }) } -// These tests assert basic functionality of the IncrementalSync and CompleteSync functions. -func TestSyncResponse(t *testing.T) { - t.Parallel() - db := MustCreateDatabase(t) - events, state := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB) - positions := MustWriteEvents(t, db, events) - latest, err := db.SyncPosition(ctx) - if err != nil { - t.Fatalf("failed to get SyncPosition: %s", err) - } +// These tests assert basic functionality of RecentEvents for PDUs +func TestRecentEventsPDU(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := MustCreateDatabase(t, dbType) + defer close() + alice := test.NewUser() + var filter gomatrixserverlib.RoomEventFilter + filter.Limit = 100 + r := test.NewRoom(t, alice) + r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": "hi"}) + events := r.Events() + positions := MustWriteEvents(t, db, events) + latest, err := db.MaxStreamPositionForPDUs(ctx) + if err != nil { + t.Fatalf("failed to get MaxStreamPositionForPDUs: %s", err) + } - testCases := []struct { - Name string - DoSync func() (*types.Response, error) - WantTimeline []*gomatrixserverlib.HeaderedEvent - WantState []*gomatrixserverlib.HeaderedEvent - }{ - // The purpose of this test is to make sure that incremental syncs are including up to the latest events. - // It's a basic sanity test that sync works. It creates a `since` token that is on the penultimate event. - // It makes sure the response includes the final event. - { - Name: "IncrementalSync penultimate", - DoSync: func() (*types.Response, error) { - from := types.StreamingToken{ // pretend we are at the penultimate event - PDUPosition: positions[len(positions)-2], - } - res := types.NewResponse() - return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false) + testCases := []struct { + Name string + From types.StreamPosition + To types.StreamPosition + WantEvents []*gomatrixserverlib.HeaderedEvent + WantLimited bool + }{ + // The purpose of this test is to make sure that incremental syncs are including up to the latest events. + // It's a basic sanity test that sync works. It creates a `since` token that is on the penultimate event. + // It makes sure the response includes the final event. + { + Name: "IncrementalSync penultimate", + From: positions[len(positions)-2], // pretend we are at the penultimate event + To: latest, + WantEvents: events[len(events)-1:], + WantLimited: false, }, - WantTimeline: events[len(events)-1:], - }, - // The purpose of this test is to check that passing a `numRecentEventsPerRoom` correctly limits the - // number of returned events. This is critical for big rooms hence the test here. - { - Name: "IncrementalSync limited", - DoSync: func() (*types.Response, error) { - from := types.StreamingToken{ // pretend we are 10 events behind - PDUPosition: positions[len(positions)-11], - } - res := types.NewResponse() - // limit is set to 5 - return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false) - }, - // want the last 5 events, NOT the last 10. - WantTimeline: events[len(events)-5:], - }, - // The purpose of this test is to check that CompleteSync returns all the current state as well as - // honouring the `numRecentEventsPerRoom` value - { - Name: "CompleteSync limited", - DoSync: func() (*types.Response, error) { - res := types.NewResponse() - // limit set to 5 - return db.CompleteSync(ctx, res, testUserDeviceA, 5) - }, - // want the last 5 events - WantTimeline: events[len(events)-5:], - // want all state for the room - WantState: state, - }, - // The purpose of this test is to check that CompleteSync can return everything with a high enough - // `numRecentEventsPerRoom`. - { - Name: "CompleteSync", - DoSync: func() (*types.Response, error) { - res := types.NewResponse() - return db.CompleteSync(ctx, res, testUserDeviceA, len(events)+1) - }, - WantTimeline: events, - // We want no state at all as that field in /sync is the delta between the token (beginning of time) - // and the START of the timeline. - }, - } + /* + // The purpose of this test is to check that passing a `numRecentEventsPerRoom` correctly limits the + // number of returned events. This is critical for big rooms hence the test here. + { + Name: "IncrementalSync limited", + DoSync: func() (*types.Response, error) { + from := types.StreamingToken{ // pretend we are 10 events behind + PDUPosition: positions[len(positions)-11], + } + res := types.NewResponse() + // limit is set to 5 + return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false) + }, + // want the last 5 events, NOT the last 10. + WantTimeline: events[len(events)-5:], + }, + // The purpose of this test is to check that CompleteSync returns all the current state as well as + // honouring the `numRecentEventsPerRoom` value + { + Name: "CompleteSync limited", + DoSync: func() (*types.Response, error) { + res := types.NewResponse() + // limit set to 5 + return db.CompleteSync(ctx, res, testUserDeviceA, 5) + }, + // want the last 5 events + WantTimeline: events[len(events)-5:], + // want all state for the room + WantState: state, + }, + // The purpose of this test is to check that CompleteSync can return everything with a high enough + // `numRecentEventsPerRoom`. + { + Name: "CompleteSync", + DoSync: func() (*types.Response, error) { + res := types.NewResponse() + return db.CompleteSync(ctx, res, testUserDeviceA, len(events)+1) + }, + WantTimeline: events, + // We want no state at all as that field in /sync is the delta between the token (beginning of time) + // and the START of the timeline. + }, */ + } - for _, tc := range testCases { - t.Run(tc.Name, func(st *testing.T) { - res, err := tc.DoSync() - if err != nil { - st.Fatalf("failed to do sync: %s", err) - } - next := types.StreamingToken{ - PDUPosition: latest.PDUPosition, - TypingPosition: latest.TypingPosition, - ReceiptPosition: latest.ReceiptPosition, - SendToDevicePosition: latest.SendToDevicePosition, - } - if res.NextBatch.String() != next.String() { - st.Errorf("NextBatch got %s want %s", res.NextBatch, next.String()) - } - roomRes, ok := res.Rooms.Join[testRoomID] - if !ok { - st.Fatalf("IncrementalSync response missing room %s - response: %+v", testRoomID, res) - } - assertEventsEqual(st, "state for "+testRoomID, false, roomRes.State.Events, tc.WantState) - assertEventsEqual(st, "timeline for "+testRoomID, false, roomRes.Timeline.Events, tc.WantTimeline) - }) - } + for _, tc := range testCases { + t.Run(tc.Name, func(st *testing.T) { + gotEvents, limited, err := db.RecentEvents(ctx, r.ID, types.Range{ + From: tc.From, + To: tc.To, + }, &filter, true, true) + if err != nil { + st.Fatalf("failed to do sync: %s", err) + } + if limited != tc.WantLimited { + st.Errorf("got limited=%v want %v", limited, tc.WantLimited) + } + if len(gotEvents) != len(tc.WantEvents) { + st.Errorf("got %d events, want %d", len(gotEvents), len(tc.WantEvents)) + } + }) + } + }) } +/* func TestGetEventsInRangeWithPrevBatch(t *testing.T) { t.Parallel() db := MustCreateDatabase(t) diff --git a/test/db.go b/test/db.go new file mode 100644 index 000000000..9deec0a89 --- /dev/null +++ b/test/db.go @@ -0,0 +1,127 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "database/sql" + "fmt" + "os" + "os/exec" + "os/user" + "testing" +) + +type DBType int + +var DBTypeSQLite DBType = 1 +var DBTypePostgres DBType = 2 + +var Quiet = false + +func createLocalDB(dbName string) string { + if !Quiet { + fmt.Println("Note: tests require a postgres install accessible to the current user") + } + createDB := exec.Command("createdb", dbName) + if !Quiet { + createDB.Stdout = os.Stdout + createDB.Stderr = os.Stderr + } + err := createDB.Run() + if err != nil && !Quiet { + fmt.Println("createLocalDB returned error:", err) + } + return dbName +} + +func currentUser() string { + user, err := user.Current() + if err != nil { + if !Quiet { + fmt.Println("cannot get current user: ", err) + } + os.Exit(2) + } + return user.Username +} + +// Prepare a sqlite or postgres connection string for testing. +// Returns the connection string to use and a close function which must be called when the test finishes. +// Calling this function twice will return the same database, which will have data from previous tests +// unless close() is called. +// TODO: namespace for concurrent package tests +func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, close func()) { + if dbType == DBTypeSQLite { + dbname := "dendrite_test.db" + return fmt.Sprintf("file:%s", dbname), func() { + err := os.Remove(dbname) + if err != nil { + t.Fatalf("failed to cleanup sqlite db '%s': %s", dbname, err) + } + } + } + + // Required vars: user and db + // We'll try to infer from the local env if they are missing + user := os.Getenv("POSTGRES_USER") + if user == "" { + user = currentUser() + } + dbName := os.Getenv("POSTGRES_DB") + if dbName == "" { + dbName = createLocalDB("dendrite_test") + } + connStr = fmt.Sprintf( + "user=%s dbname=%s sslmode=disable", + user, dbName, + ) + // optional vars, used in CI + password := os.Getenv("POSTGRES_PASSWORD") + if password != "" { + connStr += fmt.Sprintf(" password=%s", password) + } + host := os.Getenv("POSTGRES_HOST") + if host != "" { + connStr += fmt.Sprintf(" host=%s", host) + } + + return connStr, func() { + // Drop all tables on the database to get a fresh instance + db, err := sql.Open("postgres", connStr) + if err != nil { + t.Fatalf("failed to connect to postgres db '%s': %s", connStr, err) + } + _, err = db.Exec(`DROP SCHEMA public CASCADE; + CREATE SCHEMA public;`) + if err != nil { + t.Fatalf("failed to cleanup postgres db '%s': %s", connStr, err) + } + _ = db.Close() + } +} + +// Creates subtests with each known DBType +func WithAllDatabases(t *testing.T, testFn func(t *testing.T, db DBType)) { + dbs := map[string]DBType{ + "postgres": DBTypePostgres, + "sqlite": DBTypeSQLite, + } + for dbName, dbType := range dbs { + dbt := dbType + t.Run(dbName, func(tt *testing.T) { + testFn(tt, dbt) + }) + } +} diff --git a/test/event.go b/test/event.go new file mode 100644 index 000000000..487b09364 --- /dev/null +++ b/test/event.go @@ -0,0 +1,51 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "crypto/ed25519" + "time" + + "github.com/matrix-org/gomatrixserverlib" +) + +type eventMods struct { + originServerTS time.Time + origin gomatrixserverlib.ServerName + stateKey *string + unsigned interface{} + keyID gomatrixserverlib.KeyID + privKey ed25519.PrivateKey +} + +type eventModifier func(e *eventMods) + +func WithTimestamp(ts time.Time) eventModifier { + return func(e *eventMods) { + e.originServerTS = ts + } +} + +func WithStateKey(skey string) eventModifier { + return func(e *eventMods) { + e.stateKey = &skey + } +} + +func WithUnsigned(unsigned interface{}) eventModifier { + return func(e *eventMods) { + e.unsigned = unsigned + } +} diff --git a/test/room.go b/test/room.go new file mode 100644 index 000000000..619cb5c9a --- /dev/null +++ b/test/room.go @@ -0,0 +1,223 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "crypto/ed25519" + "encoding/json" + "fmt" + "sync/atomic" + "testing" + "time" + + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/gomatrixserverlib" +) + +type Preset int + +var ( + PresetNone Preset = 0 + PresetPrivateChat Preset = 1 + PresetPublicChat Preset = 2 + PresetTrustedPrivateChat Preset = 3 + + roomIDCounter = int64(0) + + testKeyID = gomatrixserverlib.KeyID("ed25519:test") + testPrivateKey = ed25519.NewKeyFromSeed([]byte{ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + }) +) + +type Room struct { + ID string + Version gomatrixserverlib.RoomVersion + preset Preset + creator *User + + authEvents gomatrixserverlib.AuthEvents + events []*gomatrixserverlib.HeaderedEvent +} + +// Create a new test room. Automatically creates the initial create events. +func NewRoom(t *testing.T, creator *User, modifiers ...roomModifier) *Room { + t.Helper() + counter := atomic.AddInt64(&roomIDCounter, 1) + + // set defaults then let roomModifiers override + r := &Room{ + ID: fmt.Sprintf("!%d:localhost", counter), + creator: creator, + authEvents: gomatrixserverlib.NewAuthEvents(nil), + preset: PresetPublicChat, + Version: gomatrixserverlib.RoomVersionV9, + } + for _, m := range modifiers { + m(t, r) + } + r.insertCreateEvents(t) + return r +} + +func (r *Room) insertCreateEvents(t *testing.T) { + t.Helper() + var joinRule gomatrixserverlib.JoinRuleContent + var hisVis gomatrixserverlib.HistoryVisibilityContent + plContent := eventutil.InitialPowerLevelsContent(r.creator.ID) + switch r.preset { + case PresetTrustedPrivateChat: + fallthrough + case PresetPrivateChat: + joinRule.JoinRule = "invite" + hisVis.HistoryVisibility = "shared" + case PresetPublicChat: + joinRule.JoinRule = "public" + hisVis.HistoryVisibility = "shared" + } + r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomCreate, map[string]interface{}{ + "creator": r.creator.ID, + "room_version": r.Version, + }, WithStateKey("")) + r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, WithStateKey(r.creator.ID)) + r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomPowerLevels, plContent, WithStateKey("")) + r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomJoinRules, joinRule, WithStateKey("")) + r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomHistoryVisibility, hisVis, WithStateKey("")) +} + +// Create an event in this room but do not insert it. Does not modify the room in any way (depth, fwd extremities, etc) so is thread-safe. +func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, content interface{}, mods ...eventModifier) *gomatrixserverlib.HeaderedEvent { + t.Helper() + depth := 1 + len(r.events) // depth starts at 1 + + // possible event modifiers (optional fields) + mod := &eventMods{} + for _, m := range mods { + m(mod) + } + + if mod.privKey == nil { + mod.privKey = testPrivateKey + } + if mod.keyID == "" { + mod.keyID = testKeyID + } + if mod.originServerTS.IsZero() { + mod.originServerTS = time.Now() + } + if mod.origin == "" { + mod.origin = gomatrixserverlib.ServerName("localhost") + } + + var unsigned gomatrixserverlib.RawJSON + var err error + if mod.unsigned != nil { + unsigned, err = json.Marshal(mod.unsigned) + if err != nil { + t.Fatalf("CreateEvent[%s]: failed to marshal unsigned field: %s", eventType, err) + } + } + + builder := &gomatrixserverlib.EventBuilder{ + Sender: creator.ID, + RoomID: r.ID, + Type: eventType, + StateKey: mod.stateKey, + Depth: int64(depth), + Unsigned: unsigned, + } + err = builder.SetContent(content) + if err != nil { + t.Fatalf("CreateEvent[%s]: failed to SetContent: %s", eventType, err) + } + if depth > 1 { + builder.PrevEvents = []gomatrixserverlib.EventReference{r.events[len(r.events)-1].EventReference()} + } + + eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) + if err != nil { + t.Fatalf("CreateEvent[%s]: failed to StateNeededForEventBuilder: %s", eventType, err) + } + refs, err := eventsNeeded.AuthEventReferences(&r.authEvents) + if err != nil { + t.Fatalf("CreateEvent[%s]: failed to AuthEventReferences: %s", eventType, err) + } + builder.AuthEvents = refs + ev, err := builder.Build( + mod.originServerTS, mod.origin, mod.keyID, + mod.privKey, r.Version, + ) + if err != nil { + t.Fatalf("CreateEvent[%s]: failed to build event: %s", eventType, err) + } + if err = gomatrixserverlib.Allowed(ev, &r.authEvents); err != nil { + t.Fatalf("CreateEvent[%s]: failed to verify event was allowed: %s", eventType, err) + } + return ev.Headered(r.Version) +} + +// Add a new event to this room DAG. Not thread-safe. +func (r *Room) InsertEvent(t *testing.T, he *gomatrixserverlib.HeaderedEvent) { + t.Helper() + // Add the event to the list of auth events + r.events = append(r.events, he) + if he.StateKey() != nil { + err := r.authEvents.AddEvent(he.Unwrap()) + if err != nil { + t.Fatalf("InsertEvent: failed to add event to auth events: %s", err) + } + } +} + +func (r *Room) Events() []*gomatrixserverlib.HeaderedEvent { + return r.events +} + +func (r *Room) CreateAndInsert(t *testing.T, creator *User, eventType string, content interface{}, mods ...eventModifier) *gomatrixserverlib.HeaderedEvent { + t.Helper() + he := r.CreateEvent(t, creator, eventType, content, mods...) + r.InsertEvent(t, he) + return he +} + +// All room modifiers are below + +type roomModifier func(t *testing.T, r *Room) + +func RoomPreset(p Preset) roomModifier { + return func(t *testing.T, r *Room) { + switch p { + case PresetPrivateChat: + fallthrough + case PresetPublicChat: + fallthrough + case PresetTrustedPrivateChat: + fallthrough + case PresetNone: + r.preset = p + default: + t.Errorf("invalid RoomPreset: %v", p) + } + } +} + +func RoomVersion(ver gomatrixserverlib.RoomVersion) roomModifier { + return func(t *testing.T, r *Room) { + r.Version = ver + } +} diff --git a/test/user.go b/test/user.go new file mode 100644 index 000000000..41a66e1c4 --- /dev/null +++ b/test/user.go @@ -0,0 +1,36 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "fmt" + "sync/atomic" +) + +var ( + userIDCounter = int64(0) +) + +type User struct { + ID string +} + +func NewUser() *User { + counter := atomic.AddInt64(&userIDCounter, 1) + u := &User{ + ID: fmt.Sprintf("@%d:localhost", counter), + } + return u +}