From 05607d6b8734738bd5c32288e3d0ef8e827d11d0 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Mon, 16 May 2022 19:33:16 +0200 Subject: [PATCH 1/7] Add roomserver tests (3/4) (#2447) * Add Room Aliases tests * Add Rooms table test * Move StateKeyTuplerSorter to the types package * Add StateBlock tests Some optimizations * Add State Snapshot tests Some optimization * Return []int64 and convert to pq.Int64Array for postgres * Move []types.EventNID back to rows.Next() * Update tests, rename SelectRoomIDs --- roomserver/storage/postgres/events_table.go | 6 +- .../storage/postgres/room_aliases_table.go | 6 +- roomserver/storage/postgres/rooms_table.go | 16 +-- .../storage/postgres/state_block_table.go | 53 ++------ .../postgres/state_block_table_test.go | 86 ------------ .../storage/postgres/state_snapshot_table.go | 10 +- roomserver/storage/postgres/storage.go | 16 +-- roomserver/storage/shared/storage.go | 2 +- roomserver/storage/sqlite3/events_table.go | 4 +- .../storage/sqlite3/room_aliases_table.go | 6 +- roomserver/storage/sqlite3/rooms_table.go | 16 +-- .../storage/sqlite3/state_block_table.go | 49 ++----- .../storage/sqlite3/state_block_table_test.go | 86 ------------ .../storage/sqlite3/state_snapshot_table.go | 10 +- roomserver/storage/sqlite3/storage.go | 16 +-- roomserver/storage/tables/interface.go | 2 +- .../storage/tables/room_aliases_table_test.go | 96 +++++++++++++ roomserver/storage/tables/rooms_table_test.go | 128 ++++++++++++++++++ .../storage/tables/state_block_table_test.go | 92 +++++++++++++ .../tables/state_snapshot_table_test.go | 86 ++++++++++++ roomserver/types/types.go | 33 +++++ roomserver/types/types_test.go | 64 +++++++++ 22 files changed, 570 insertions(+), 313 deletions(-) delete mode 100644 roomserver/storage/postgres/state_block_table_test.go delete mode 100644 roomserver/storage/sqlite3/state_block_table_test.go create mode 100644 roomserver/storage/tables/room_aliases_table_test.go create mode 100644 roomserver/storage/tables/rooms_table_test.go create mode 100644 roomserver/storage/tables/state_block_table_test.go create mode 100644 roomserver/storage/tables/state_snapshot_table_test.go diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index 86d226ce7..a4d05756d 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -264,11 +264,11 @@ func (s *eventStatements) BulkSelectStateEventByNID( ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntry, error) { - tuples := stateKeyTupleSorter(stateKeyTuples) + tuples := types.StateKeyTupleSorter(stateKeyTuples) sort.Sort(tuples) - eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() + eventTypeNIDArray, eventStateKeyNIDArray := tuples.TypesAndStateKeysAsArrays() stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByNIDStmt) - rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray) + rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), pq.Int64Array(eventTypeNIDArray), pq.Int64Array(eventStateKeyNIDArray)) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/room_aliases_table.go b/roomserver/storage/postgres/room_aliases_table.go index d13df8e7f..a84929f61 100644 --- a/roomserver/storage/postgres/room_aliases_table.go +++ b/roomserver/storage/postgres/room_aliases_table.go @@ -61,12 +61,12 @@ type roomAliasesStatements struct { deleteRoomAliasStmt *sql.Stmt } -func createRoomAliasesTable(db *sql.DB) error { +func CreateRoomAliasesTable(db *sql.DB) error { _, err := db.Exec(roomAliasesSchema) return err } -func prepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { +func PrepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { s := &roomAliasesStatements{} return s, sqlutil.StatementList{ @@ -108,8 +108,8 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID( defer internal.CloseAndLogIfError(ctx, rows, "selectAliasesFromRoomID: rows.close() failed") var aliases []string + var alias string for rows.Next() { - var alias string if err = rows.Scan(&alias); err != nil { return nil, err } diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go index b2685084d..24362af74 100644 --- a/roomserver/storage/postgres/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -95,12 +95,12 @@ type roomStatements struct { bulkSelectRoomNIDsStmt *sql.Stmt } -func createRoomsTable(db *sql.DB) error { +func CreateRoomsTable(db *sql.DB) error { _, err := db.Exec(roomsSchema) return err } -func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) { +func PrepareRoomsTable(db *sql.DB) (tables.Rooms, error) { s := &roomStatements{} return s, sqlutil.StatementList{ @@ -117,7 +117,7 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) { }.Prepare(db) } -func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) { +func (s *roomStatements) SelectRoomIDsWithEvents(ctx context.Context, txn *sql.Tx) ([]string, error) { stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt) rows, err := stmt.QueryContext(ctx) if err != nil { @@ -125,8 +125,8 @@ func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]stri } defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed") var roomIDs []string + var roomID string for rows.Next() { - var roomID string if err = rows.Scan(&roomID); err != nil { return nil, err } @@ -231,9 +231,9 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs( } defer internal.CloseAndLogIfError(ctx, rows, "selectRoomVersionsForRoomNIDsStmt: rows.close() failed") result := make(map[types.RoomNID]gomatrixserverlib.RoomVersion) + var roomNID types.RoomNID + var roomVersion gomatrixserverlib.RoomVersion for rows.Next() { - var roomNID types.RoomNID - var roomVersion gomatrixserverlib.RoomVersion if err = rows.Scan(&roomNID, &roomVersion); err != nil { return nil, err } @@ -254,8 +254,8 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roo } defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed") var roomIDs []string + var roomID string for rows.Next() { - var roomID string if err = rows.Scan(&roomID); err != nil { return nil, err } @@ -276,8 +276,8 @@ func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, ro } defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed") var roomNIDs []types.RoomNID + var roomNID types.RoomNID for rows.Next() { - var roomNID types.RoomNID if err = rows.Scan(&roomNID); err != nil { return nil, err } diff --git a/roomserver/storage/postgres/state_block_table.go b/roomserver/storage/postgres/state_block_table.go index 6f8f9e1b5..5af48f031 100644 --- a/roomserver/storage/postgres/state_block_table.go +++ b/roomserver/storage/postgres/state_block_table.go @@ -19,7 +19,6 @@ import ( "context" "database/sql" "fmt" - "sort" "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" @@ -71,12 +70,12 @@ type stateBlockStatements struct { bulkSelectStateBlockEntriesStmt *sql.Stmt } -func createStateBlockTable(db *sql.DB) error { +func CreateStateBlockTable(db *sql.DB) error { _, err := db.Exec(stateDataSchema) return err } -func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) { +func PrepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) { s := &stateBlockStatements{} return s, sqlutil.StatementList{ @@ -90,9 +89,9 @@ func (s *stateBlockStatements) BulkInsertStateData( entries types.StateEntries, ) (id types.StateBlockNID, err error) { entries = entries[:util.SortAndUnique(entries)] - var nids types.EventNIDs - for _, e := range entries { - nids = append(nids, e.EventNID) + nids := make(types.EventNIDs, entries.Len()) + for i := range entries { + nids[i] = entries[i].EventNID } stmt := sqlutil.TxStmt(txn, s.insertStateDataStmt) err = stmt.QueryRowContext( @@ -113,15 +112,15 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries( results := make([][]types.EventNID, len(stateBlockNIDs)) i := 0 + var stateBlockNID types.StateBlockNID + var result pq.Int64Array for ; rows.Next(); i++ { - var stateBlockNID types.StateBlockNID - var result pq.Int64Array if err = rows.Scan(&stateBlockNID, &result); err != nil { return nil, err } - r := []types.EventNID{} - for _, e := range result { - r = append(r, types.EventNID(e)) + r := make([]types.EventNID, len(result)) + for x := range result { + r[x] = types.EventNID(result[x]) } results[i] = r } @@ -141,35 +140,3 @@ func stateBlockNIDsAsArray(stateBlockNIDs []types.StateBlockNID) pq.Int64Array { } return pq.Int64Array(nids) } - -type stateKeyTupleSorter []types.StateKeyTuple - -func (s stateKeyTupleSorter) Len() int { return len(s) } -func (s stateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) } -func (s stateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } - -// Check whether a tuple is in the list. Assumes that the list is sorted. -func (s stateKeyTupleSorter) contains(value types.StateKeyTuple) bool { - i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) }) - return i < len(s) && s[i] == value -} - -// List the unique eventTypeNIDs and eventStateKeyNIDs. -// Assumes that the list is sorted. -func (s stateKeyTupleSorter) typesAndStateKeysAsArrays() (eventTypeNIDs pq.Int64Array, eventStateKeyNIDs pq.Int64Array) { - eventTypeNIDs = make(pq.Int64Array, len(s)) - eventStateKeyNIDs = make(pq.Int64Array, len(s)) - for i := range s { - eventTypeNIDs[i] = int64(s[i].EventTypeNID) - eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID) - } - eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))] - eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))] - return -} - -type int64Sorter []int64 - -func (s int64Sorter) Len() int { return len(s) } -func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] } -func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } diff --git a/roomserver/storage/postgres/state_block_table_test.go b/roomserver/storage/postgres/state_block_table_test.go deleted file mode 100644 index a0e2ec952..000000000 --- a/roomserver/storage/postgres/state_block_table_test.go +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "sort" - "testing" - - "github.com/matrix-org/dendrite/roomserver/types" -) - -func TestStateKeyTupleSorter(t *testing.T) { - input := stateKeyTupleSorter{ - {EventTypeNID: 1, EventStateKeyNID: 2}, - {EventTypeNID: 1, EventStateKeyNID: 4}, - {EventTypeNID: 2, EventStateKeyNID: 2}, - {EventTypeNID: 1, EventStateKeyNID: 1}, - } - want := []types.StateKeyTuple{ - {EventTypeNID: 1, EventStateKeyNID: 1}, - {EventTypeNID: 1, EventStateKeyNID: 2}, - {EventTypeNID: 1, EventStateKeyNID: 4}, - {EventTypeNID: 2, EventStateKeyNID: 2}, - } - doNotWant := []types.StateKeyTuple{ - {EventTypeNID: 0, EventStateKeyNID: 0}, - {EventTypeNID: 1, EventStateKeyNID: 3}, - {EventTypeNID: 2, EventStateKeyNID: 1}, - {EventTypeNID: 3, EventStateKeyNID: 1}, - } - wantTypeNIDs := []int64{1, 2} - wantStateKeyNIDs := []int64{1, 2, 4} - - // Sort the input and check it's in the right order. - sort.Sort(input) - gotTypeNIDs, gotStateKeyNIDs := input.typesAndStateKeysAsArrays() - - for i := range want { - if input[i] != want[i] { - t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i]) - } - - if !input.contains(want[i]) { - t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i]) - } - } - - for i := range doNotWant { - if input.contains(doNotWant[i]) { - t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i]) - } - } - - if len(wantTypeNIDs) != len(gotTypeNIDs) { - t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) - } - - for i := range wantTypeNIDs { - if wantTypeNIDs[i] != gotTypeNIDs[i] { - t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) - } - } - - if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) { - t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs) - } - - for i := range wantStateKeyNIDs { - if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] { - t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) - } - } -} diff --git a/roomserver/storage/postgres/state_snapshot_table.go b/roomserver/storage/postgres/state_snapshot_table.go index 8ed886030..a24b7f3f0 100644 --- a/roomserver/storage/postgres/state_snapshot_table.go +++ b/roomserver/storage/postgres/state_snapshot_table.go @@ -77,12 +77,12 @@ type stateSnapshotStatements struct { bulkSelectStateBlockNIDsStmt *sql.Stmt } -func createStateSnapshotTable(db *sql.DB) error { +func CreateStateSnapshotTable(db *sql.DB) error { _, err := db.Exec(stateSnapshotSchema) return err } -func prepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { +func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { s := &stateSnapshotStatements{} return s, sqlutil.StatementList{ @@ -95,12 +95,10 @@ func (s *stateSnapshotStatements) InsertState( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, nids types.StateBlockNIDs, ) (stateNID types.StateSnapshotNID, err error) { nids = nids[:util.SortAndUnique(nids)] - var id int64 - err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, nids.Hash(), int64(roomNID), stateBlockNIDsAsArray(nids)).Scan(&id) + err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, nids.Hash(), int64(roomNID), stateBlockNIDsAsArray(nids)).Scan(&stateNID) if err != nil { return 0, err } - stateNID = types.StateSnapshotNID(id) return } @@ -119,9 +117,9 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( defer rows.Close() // nolint: errcheck results := make([]types.StateBlockNIDList, len(stateNIDs)) i := 0 + var stateBlockNIDs pq.Int64Array for ; rows.Next(); i++ { result := &results[i] - var stateBlockNIDs pq.Int64Array if err = rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil { return nil, err } diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 88df72009..70ea4d8ba 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -80,19 +80,19 @@ func (d *Database) create(db *sql.DB) error { if err := CreateEventsTable(db); err != nil { return err } - if err := createRoomsTable(db); err != nil { + if err := CreateRoomsTable(db); err != nil { return err } - if err := createStateBlockTable(db); err != nil { + if err := CreateStateBlockTable(db); err != nil { return err } - if err := createStateSnapshotTable(db); err != nil { + if err := CreateStateSnapshotTable(db); err != nil { return err } if err := CreatePrevEventsTable(db); err != nil { return err } - if err := createRoomAliasesTable(db); err != nil { + if err := CreateRoomAliasesTable(db); err != nil { return err } if err := CreateInvitesTable(db); err != nil { @@ -128,15 +128,15 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room if err != nil { return err } - rooms, err := prepareRoomsTable(db) + rooms, err := PrepareRoomsTable(db) if err != nil { return err } - stateBlock, err := prepareStateBlockTable(db) + stateBlock, err := PrepareStateBlockTable(db) if err != nil { return err } - stateSnapshot, err := prepareStateSnapshotTable(db) + stateSnapshot, err := PrepareStateSnapshotTable(db) if err != nil { return err } @@ -144,7 +144,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room if err != nil { return err } - roomAliases, err := prepareRoomAliasesTable(db) + roomAliases, err := PrepareRoomAliasesTable(db) if err != nil { return err } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 252e94c7e..cc4a9fff5 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -1216,7 +1216,7 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin // GetKnownRooms returns a list of all rooms we know about. func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) { - return d.RoomsTable.SelectRoomIDs(ctx, nil) + return d.RoomsTable.SelectRoomIDsWithEvents(ctx, nil) } // ForgetRoom sets a users room to forgotten diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index feb06150a..1dda34c36 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -247,9 +247,9 @@ func (s *eventStatements) BulkSelectStateEventByNID( ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntry, error) { - tuples := stateKeyTupleSorter(stateKeyTuples) + tuples := types.StateKeyTupleSorter(stateKeyTuples) sort.Sort(tuples) - eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() + eventTypeNIDArray, eventStateKeyNIDArray := tuples.TypesAndStateKeysAsArrays() params := make([]interface{}, 0, len(eventNIDs)+len(eventTypeNIDArray)+len(eventStateKeyNIDArray)) selectOrig := strings.Replace(bulkSelectStateEventByNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1) for _, v := range eventNIDs { diff --git a/roomserver/storage/sqlite3/room_aliases_table.go b/roomserver/storage/sqlite3/room_aliases_table.go index 7c7bead95..3bdbbaa35 100644 --- a/roomserver/storage/sqlite3/room_aliases_table.go +++ b/roomserver/storage/sqlite3/room_aliases_table.go @@ -63,12 +63,12 @@ type roomAliasesStatements struct { deleteRoomAliasStmt *sql.Stmt } -func createRoomAliasesTable(db *sql.DB) error { +func CreateRoomAliasesTable(db *sql.DB) error { _, err := db.Exec(roomAliasesSchema) return err } -func prepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { +func PrepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { s := &roomAliasesStatements{ db: db, } @@ -113,8 +113,8 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID( defer internal.CloseAndLogIfError(ctx, rows, "selectAliasesFromRoomID: rows.close() failed") + var alias string for rows.Next() { - var alias string if err = rows.Scan(&alias); err != nil { return } diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index cd60c6785..03ad4b3d0 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -86,12 +86,12 @@ type roomStatements struct { selectRoomIDsStmt *sql.Stmt } -func createRoomsTable(db *sql.DB) error { +func CreateRoomsTable(db *sql.DB) error { _, err := db.Exec(roomsSchema) return err } -func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) { +func PrepareRoomsTable(db *sql.DB) (tables.Rooms, error) { s := &roomStatements{ db: db, } @@ -108,7 +108,7 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) { }.Prepare(db) } -func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) { +func (s *roomStatements) SelectRoomIDsWithEvents(ctx context.Context, txn *sql.Tx) ([]string, error) { stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt) rows, err := stmt.QueryContext(ctx) if err != nil { @@ -116,8 +116,8 @@ func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]stri } defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed") var roomIDs []string + var roomID string for rows.Next() { - var roomID string if err = rows.Scan(&roomID); err != nil { return nil, err } @@ -241,9 +241,9 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs( } defer internal.CloseAndLogIfError(ctx, rows, "selectRoomVersionsForRoomNIDsStmt: rows.close() failed") result := make(map[types.RoomNID]gomatrixserverlib.RoomVersion) + var roomNID types.RoomNID + var roomVersion gomatrixserverlib.RoomVersion for rows.Next() { - var roomNID types.RoomNID - var roomVersion gomatrixserverlib.RoomVersion if err = rows.Scan(&roomNID, &roomVersion); err != nil { return nil, err } @@ -270,8 +270,8 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roo } defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed") var roomIDs []string + var roomID string for rows.Next() { - var roomID string if err = rows.Scan(&roomID); err != nil { return nil, err } @@ -298,8 +298,8 @@ func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, ro } defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed") var roomNIDs []types.RoomNID + var roomNID types.RoomNID for rows.Next() { - var roomNID types.RoomNID if err = rows.Scan(&roomNID); err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index 3c829cdcd..4e67d4da1 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -20,7 +20,6 @@ import ( "database/sql" "encoding/json" "fmt" - "sort" "strings" "github.com/matrix-org/dendrite/internal" @@ -64,12 +63,12 @@ type stateBlockStatements struct { bulkSelectStateBlockEntriesStmt *sql.Stmt } -func createStateBlockTable(db *sql.DB) error { +func CreateStateBlockTable(db *sql.DB) error { _, err := db.Exec(stateDataSchema) return err } -func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) { +func PrepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) { s := &stateBlockStatements{ db: db, } @@ -85,9 +84,9 @@ func (s *stateBlockStatements) BulkInsertStateData( entries types.StateEntries, ) (id types.StateBlockNID, err error) { entries = entries[:util.SortAndUnique(entries)] - nids := types.EventNIDs{} // zero slice to not store 'null' in the DB - for _, e := range entries { - nids = append(nids, e.EventNID) + nids := make(types.EventNIDs, entries.Len()) + for i := range entries { + nids[i] = entries[i].EventNID } js, err := json.Marshal(nids) if err != nil { @@ -122,13 +121,13 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries( results := make([][]types.EventNID, len(stateBlockNIDs)) i := 0 + var stateBlockNID types.StateBlockNID + var result json.RawMessage for ; rows.Next(); i++ { - var stateBlockNID types.StateBlockNID - var result json.RawMessage if err = rows.Scan(&stateBlockNID, &result); err != nil { return nil, err } - r := []types.EventNID{} + var r []types.EventNID if err = json.Unmarshal(result, &r); err != nil { return nil, fmt.Errorf("json.Unmarshal: %w", err) } @@ -142,35 +141,3 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries( } return results, err } - -type stateKeyTupleSorter []types.StateKeyTuple - -func (s stateKeyTupleSorter) Len() int { return len(s) } -func (s stateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) } -func (s stateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } - -// Check whether a tuple is in the list. Assumes that the list is sorted. -func (s stateKeyTupleSorter) contains(value types.StateKeyTuple) bool { - i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) }) - return i < len(s) && s[i] == value -} - -// List the unique eventTypeNIDs and eventStateKeyNIDs. -// Assumes that the list is sorted. -func (s stateKeyTupleSorter) typesAndStateKeysAsArrays() (eventTypeNIDs []int64, eventStateKeyNIDs []int64) { - eventTypeNIDs = make([]int64, len(s)) - eventStateKeyNIDs = make([]int64, len(s)) - for i := range s { - eventTypeNIDs[i] = int64(s[i].EventTypeNID) - eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID) - } - eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))] - eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))] - return -} - -type int64Sorter []int64 - -func (s int64Sorter) Len() int { return len(s) } -func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] } -func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } diff --git a/roomserver/storage/sqlite3/state_block_table_test.go b/roomserver/storage/sqlite3/state_block_table_test.go deleted file mode 100644 index 98439f5c0..000000000 --- a/roomserver/storage/sqlite3/state_block_table_test.go +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "sort" - "testing" - - "github.com/matrix-org/dendrite/roomserver/types" -) - -func TestStateKeyTupleSorter(t *testing.T) { - input := stateKeyTupleSorter{ - {EventTypeNID: 1, EventStateKeyNID: 2}, - {EventTypeNID: 1, EventStateKeyNID: 4}, - {EventTypeNID: 2, EventStateKeyNID: 2}, - {EventTypeNID: 1, EventStateKeyNID: 1}, - } - want := []types.StateKeyTuple{ - {EventTypeNID: 1, EventStateKeyNID: 1}, - {EventTypeNID: 1, EventStateKeyNID: 2}, - {EventTypeNID: 1, EventStateKeyNID: 4}, - {EventTypeNID: 2, EventStateKeyNID: 2}, - } - doNotWant := []types.StateKeyTuple{ - {EventTypeNID: 0, EventStateKeyNID: 0}, - {EventTypeNID: 1, EventStateKeyNID: 3}, - {EventTypeNID: 2, EventStateKeyNID: 1}, - {EventTypeNID: 3, EventStateKeyNID: 1}, - } - wantTypeNIDs := []int64{1, 2} - wantStateKeyNIDs := []int64{1, 2, 4} - - // Sort the input and check it's in the right order. - sort.Sort(input) - gotTypeNIDs, gotStateKeyNIDs := input.typesAndStateKeysAsArrays() - - for i := range want { - if input[i] != want[i] { - t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i]) - } - - if !input.contains(want[i]) { - t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i]) - } - } - - for i := range doNotWant { - if input.contains(doNotWant[i]) { - t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i]) - } - } - - if len(wantTypeNIDs) != len(gotTypeNIDs) { - t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) - } - - for i := range wantTypeNIDs { - if wantTypeNIDs[i] != gotTypeNIDs[i] { - t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) - } - } - - if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) { - t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs) - } - - for i := range wantStateKeyNIDs { - if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] { - t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) - } - } -} diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index 1f5e9ee3b..b8136b758 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -68,12 +68,12 @@ type stateSnapshotStatements struct { bulkSelectStateBlockNIDsStmt *sql.Stmt } -func createStateSnapshotTable(db *sql.DB) error { +func CreateStateSnapshotTable(db *sql.DB) error { _, err := db.Exec(stateSnapshotSchema) return err } -func prepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { +func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { s := &stateSnapshotStatements{ db: db, } @@ -96,12 +96,10 @@ func (s *stateSnapshotStatements) InsertState( return } insertStmt := sqlutil.TxStmt(txn, s.insertStateStmt) - var id int64 - err = insertStmt.QueryRowContext(ctx, stateBlockNIDs.Hash(), int64(roomNID), string(stateBlockNIDsJSON)).Scan(&id) + err = insertStmt.QueryRowContext(ctx, stateBlockNIDs.Hash(), int64(roomNID), string(stateBlockNIDsJSON)).Scan(&stateNID) if err != nil { return 0, err } - stateNID = types.StateSnapshotNID(id) return } @@ -127,9 +125,9 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockNIDs: rows.close() failed") results := make([]types.StateBlockNIDList, len(stateNIDs)) i := 0 + var stateBlockNIDsJSON string for ; rows.Next(); i++ { result := &results[i] - var stateBlockNIDsJSON string if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDsJSON); err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index a4e32d528..8325fdad5 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -89,19 +89,19 @@ func (d *Database) create(db *sql.DB) error { if err := CreateEventsTable(db); err != nil { return err } - if err := createRoomsTable(db); err != nil { + if err := CreateRoomsTable(db); err != nil { return err } - if err := createStateBlockTable(db); err != nil { + if err := CreateStateBlockTable(db); err != nil { return err } - if err := createStateSnapshotTable(db); err != nil { + if err := CreateStateSnapshotTable(db); err != nil { return err } if err := CreatePrevEventsTable(db); err != nil { return err } - if err := createRoomAliasesTable(db); err != nil { + if err := CreateRoomAliasesTable(db); err != nil { return err } if err := CreateInvitesTable(db); err != nil { @@ -137,15 +137,15 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room if err != nil { return err } - rooms, err := prepareRoomsTable(db) + rooms, err := PrepareRoomsTable(db) if err != nil { return err } - stateBlock, err := prepareStateBlockTable(db) + stateBlock, err := PrepareStateBlockTable(db) if err != nil { return err } - stateSnapshot, err := prepareStateSnapshotTable(db) + stateSnapshot, err := PrepareStateSnapshotTable(db) if err != nil { return err } @@ -153,7 +153,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room if err != nil { return err } - roomAliases, err := prepareRoomAliasesTable(db) + roomAliases, err := PrepareRoomAliasesTable(db) if err != nil { return err } diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 95609787a..116e11c4e 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -72,7 +72,7 @@ type Rooms interface { UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error SelectRoomVersionsForRoomNIDs(ctx context.Context, txn *sql.Tx, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) - SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) + SelectRoomIDsWithEvents(ctx context.Context, txn *sql.Tx) ([]string, error) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error) } diff --git a/roomserver/storage/tables/room_aliases_table_test.go b/roomserver/storage/tables/room_aliases_table_test.go new file mode 100644 index 000000000..8fb57d5a4 --- /dev/null +++ b/roomserver/storage/tables/room_aliases_table_test.go @@ -0,0 +1,96 @@ +package tables_test + +import ( + "context" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/postgres" + "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/stretchr/testify/assert" +) + +func mustCreateRoomAliasesTable(t *testing.T, dbType test.DBType) (tab tables.RoomAliases, close func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + switch dbType { + case test.DBTypePostgres: + err = postgres.CreateRoomAliasesTable(db) + assert.NoError(t, err) + tab, err = postgres.PrepareRoomAliasesTable(db) + case test.DBTypeSQLite: + err = sqlite3.CreateRoomAliasesTable(db) + assert.NoError(t, err) + tab, err = sqlite3.PrepareRoomAliasesTable(db) + } + assert.NoError(t, err) + + return tab, close +} + +func TestRoomAliasesTable(t *testing.T) { + alice := test.NewUser() + room := test.NewRoom(t, alice) + room2 := test.NewRoom(t, alice) + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, close := mustCreateRoomAliasesTable(t, dbType) + defer close() + alias, alias2, alias3 := "#alias:localhost", "#alias2:localhost", "#alias3:localhost" + // insert aliases + err := tab.InsertRoomAlias(ctx, nil, alias, room.ID, alice.ID) + assert.NoError(t, err) + + err = tab.InsertRoomAlias(ctx, nil, alias2, room.ID, alice.ID) + assert.NoError(t, err) + + err = tab.InsertRoomAlias(ctx, nil, alias3, room2.ID, alice.ID) + assert.NoError(t, err) + + // verify we can get the roomID for the alias + roomID, err := tab.SelectRoomIDFromAlias(ctx, nil, alias) + assert.NoError(t, err) + assert.Equal(t, room.ID, roomID) + + // .. and the creator + creator, err := tab.SelectCreatorIDFromAlias(ctx, nil, alias) + assert.NoError(t, err) + assert.Equal(t, alice.ID, creator) + + creator, err = tab.SelectCreatorIDFromAlias(ctx, nil, "#doesntexist:localhost") + assert.NoError(t, err) + assert.Equal(t, "", creator) + + roomID, err = tab.SelectRoomIDFromAlias(ctx, nil, "#doesntexist:localhost") + assert.NoError(t, err) + assert.Equal(t, "", roomID) + + // get all aliases for a room + aliases, err := tab.SelectAliasesFromRoomID(ctx, nil, room.ID) + assert.NoError(t, err) + assert.Equal(t, []string{alias, alias2}, aliases) + + // delete an alias and verify it's deleted + err = tab.DeleteRoomAlias(ctx, nil, alias2) + assert.NoError(t, err) + + aliases, err = tab.SelectAliasesFromRoomID(ctx, nil, room.ID) + assert.NoError(t, err) + assert.Equal(t, []string{alias}, aliases) + + // deleting the same alias should be a no-op + err = tab.DeleteRoomAlias(ctx, nil, alias2) + assert.NoError(t, err) + + // Delete non-existent alias should be a no-op + err = tab.DeleteRoomAlias(ctx, nil, "#doesntexist:localhost") + assert.NoError(t, err) + }) +} diff --git a/roomserver/storage/tables/rooms_table_test.go b/roomserver/storage/tables/rooms_table_test.go new file mode 100644 index 000000000..9872fb800 --- /dev/null +++ b/roomserver/storage/tables/rooms_table_test.go @@ -0,0 +1,128 @@ +package tables_test + +import ( + "context" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/postgres" + "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/util" + "github.com/stretchr/testify/assert" +) + +func mustCreateRoomsTable(t *testing.T, dbType test.DBType) (tab tables.Rooms, close func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + switch dbType { + case test.DBTypePostgres: + err = postgres.CreateRoomsTable(db) + assert.NoError(t, err) + tab, err = postgres.PrepareRoomsTable(db) + case test.DBTypeSQLite: + err = sqlite3.CreateRoomsTable(db) + assert.NoError(t, err) + tab, err = sqlite3.PrepareRoomsTable(db) + } + assert.NoError(t, err) + + return tab, close +} + +func TestRoomsTable(t *testing.T) { + alice := test.NewUser() + room := test.NewRoom(t, alice) + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, close := mustCreateRoomsTable(t, dbType) + defer close() + + wantRoomNID, err := tab.InsertRoomNID(ctx, nil, room.ID, room.Version) + assert.NoError(t, err) + + // Create dummy room + _, err = tab.InsertRoomNID(ctx, nil, util.RandomString(16), room.Version) + assert.NoError(t, err) + + gotRoomNID, err := tab.SelectRoomNID(ctx, nil, room.ID) + assert.NoError(t, err) + assert.Equal(t, wantRoomNID, gotRoomNID) + + // Ensure non existent roomNID errors + roomNID, err := tab.SelectRoomNID(ctx, nil, "!doesnotexist:localhost") + assert.Error(t, err) + assert.Equal(t, types.RoomNID(0), roomNID) + + roomInfo, err := tab.SelectRoomInfo(ctx, nil, room.ID) + assert.NoError(t, err) + assert.Equal(t, &types.RoomInfo{ + RoomNID: wantRoomNID, + RoomVersion: room.Version, + StateSnapshotNID: 0, + IsStub: true, // there are no latestEventNIDs + }, roomInfo) + + roomInfo, err = tab.SelectRoomInfo(ctx, nil, "!doesnotexist:localhost") + assert.NoError(t, err) + assert.Nil(t, roomInfo) + + // There are no rooms with latestEventNIDs yet + roomIDs, err := tab.SelectRoomIDsWithEvents(ctx, nil) + assert.NoError(t, err) + assert.Equal(t, 0, len(roomIDs)) + + roomVersions, err := tab.SelectRoomVersionsForRoomNIDs(ctx, nil, []types.RoomNID{wantRoomNID, 1337}) + assert.NoError(t, err) + assert.Equal(t, roomVersions[wantRoomNID], room.Version) + // Room does not exist + _, ok := roomVersions[1337] + assert.False(t, ok) + + roomIDs, err = tab.BulkSelectRoomIDs(ctx, nil, []types.RoomNID{wantRoomNID, 1337}) + assert.NoError(t, err) + assert.Equal(t, []string{room.ID}, roomIDs) + + roomNIDs, err := tab.BulkSelectRoomNIDs(ctx, nil, []string{room.ID, "!doesnotexist:localhost"}) + assert.NoError(t, err) + assert.Equal(t, []types.RoomNID{wantRoomNID}, roomNIDs) + + wantEventNIDs := []types.EventNID{1, 2, 3} + lastEventSentNID := types.EventNID(3) + stateSnapshotNID := types.StateSnapshotNID(1) + // make the room "usable" + err = tab.UpdateLatestEventNIDs(ctx, nil, wantRoomNID, wantEventNIDs, lastEventSentNID, stateSnapshotNID) + assert.NoError(t, err) + + roomInfo, err = tab.SelectRoomInfo(ctx, nil, room.ID) + assert.NoError(t, err) + assert.Equal(t, &types.RoomInfo{ + RoomNID: wantRoomNID, + RoomVersion: room.Version, + StateSnapshotNID: 1, + IsStub: false, + }, roomInfo) + + eventNIDs, snapshotNID, err := tab.SelectLatestEventNIDs(ctx, nil, wantRoomNID) + assert.NoError(t, err) + assert.Equal(t, wantEventNIDs, eventNIDs) + assert.Equal(t, types.StateSnapshotNID(1), snapshotNID) + + // Again, doesn't exist + _, _, err = tab.SelectLatestEventNIDs(ctx, nil, 1337) + assert.Error(t, err) + + eventNIDs, eventNID, snapshotNID, err := tab.SelectLatestEventsNIDsForUpdate(ctx, nil, wantRoomNID) + assert.NoError(t, err) + assert.Equal(t, wantEventNIDs, eventNIDs) + assert.Equal(t, types.EventNID(3), eventNID) + assert.Equal(t, types.StateSnapshotNID(1), snapshotNID) + }) +} diff --git a/roomserver/storage/tables/state_block_table_test.go b/roomserver/storage/tables/state_block_table_test.go new file mode 100644 index 000000000..de0b420bc --- /dev/null +++ b/roomserver/storage/tables/state_block_table_test.go @@ -0,0 +1,92 @@ +package tables_test + +import ( + "context" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/postgres" + "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/stretchr/testify/assert" +) + +func mustCreateStateBlockTable(t *testing.T, dbType test.DBType) (tab tables.StateBlock, close func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + switch dbType { + case test.DBTypePostgres: + err = postgres.CreateStateBlockTable(db) + assert.NoError(t, err) + tab, err = postgres.PrepareStateBlockTable(db) + case test.DBTypeSQLite: + err = sqlite3.CreateStateBlockTable(db) + assert.NoError(t, err) + tab, err = sqlite3.PrepareStateBlockTable(db) + } + assert.NoError(t, err) + + return tab, close +} + +func TestStateBlockTable(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, close := mustCreateStateBlockTable(t, dbType) + defer close() + + // generate some dummy data + var entries types.StateEntries + for i := 0; i < 100; i++ { + entry := types.StateEntry{ + EventNID: types.EventNID(i), + } + entries = append(entries, entry) + } + stateBlockNID, err := tab.BulkInsertStateData(ctx, nil, entries) + assert.NoError(t, err) + assert.Equal(t, types.StateBlockNID(1), stateBlockNID) + + // generate a different hash, to get a new StateBlockNID + var entries2 types.StateEntries + for i := 100; i < 300; i++ { + entry := types.StateEntry{ + EventNID: types.EventNID(i), + } + entries2 = append(entries2, entry) + } + stateBlockNID, err = tab.BulkInsertStateData(ctx, nil, entries2) + assert.NoError(t, err) + assert.Equal(t, types.StateBlockNID(2), stateBlockNID) + + eventNIDs, err := tab.BulkSelectStateBlockEntries(ctx, nil, types.StateBlockNIDs{1, 2}) + assert.NoError(t, err) + assert.Equal(t, len(entries), len(eventNIDs[0])) + assert.Equal(t, len(entries2), len(eventNIDs[1])) + + // try to get a StateBlockNID which does not exist + _, err = tab.BulkSelectStateBlockEntries(ctx, nil, types.StateBlockNIDs{5}) + assert.Error(t, err) + + // This should return an error, since we can only retrieve 1 StateBlock + _, err = tab.BulkSelectStateBlockEntries(ctx, nil, types.StateBlockNIDs{1, 5}) + assert.Error(t, err) + + for i := 0; i < 65555; i++ { + entry := types.StateEntry{ + EventNID: types.EventNID(i), + } + entries2 = append(entries2, entry) + } + stateBlockNID, err = tab.BulkInsertStateData(ctx, nil, entries2) + assert.NoError(t, err) + assert.Equal(t, types.StateBlockNID(3), stateBlockNID) + }) +} diff --git a/roomserver/storage/tables/state_snapshot_table_test.go b/roomserver/storage/tables/state_snapshot_table_test.go new file mode 100644 index 000000000..dcdb5d8f1 --- /dev/null +++ b/roomserver/storage/tables/state_snapshot_table_test.go @@ -0,0 +1,86 @@ +package tables_test + +import ( + "context" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/postgres" + "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/stretchr/testify/assert" +) + +func mustCreateStateSnapshotTable(t *testing.T, dbType test.DBType) (tab tables.StateSnapshot, close func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + switch dbType { + case test.DBTypePostgres: + err = postgres.CreateStateSnapshotTable(db) + assert.NoError(t, err) + tab, err = postgres.PrepareStateSnapshotTable(db) + case test.DBTypeSQLite: + err = sqlite3.CreateStateSnapshotTable(db) + assert.NoError(t, err) + tab, err = sqlite3.PrepareStateSnapshotTable(db) + } + assert.NoError(t, err) + + return tab, close +} + +func TestStateSnapshotTable(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, close := mustCreateStateSnapshotTable(t, dbType) + defer close() + + // generate some dummy data + var stateBlockNIDs types.StateBlockNIDs + for i := 0; i < 100; i++ { + stateBlockNIDs = append(stateBlockNIDs, types.StateBlockNID(i)) + } + stateNID, err := tab.InsertState(ctx, nil, 1, stateBlockNIDs) + assert.NoError(t, err) + assert.Equal(t, types.StateSnapshotNID(1), stateNID) + + // verify ON CONFLICT; Note: this updates the sequence! + stateNID, err = tab.InsertState(ctx, nil, 1, stateBlockNIDs) + assert.NoError(t, err) + assert.Equal(t, types.StateSnapshotNID(1), stateNID) + + // create a second snapshot + var stateBlockNIDs2 types.StateBlockNIDs + for i := 100; i < 150; i++ { + stateBlockNIDs2 = append(stateBlockNIDs2, types.StateBlockNID(i)) + } + + stateNID, err = tab.InsertState(ctx, nil, 1, stateBlockNIDs2) + assert.NoError(t, err) + // StateSnapshotNID is now 3, since the DO UPDATE SET statement incremented the sequence + assert.Equal(t, types.StateSnapshotNID(3), stateNID) + + nidLists, err := tab.BulkSelectStateBlockNIDs(ctx, nil, []types.StateSnapshotNID{1, 3}) + assert.NoError(t, err) + assert.Equal(t, stateBlockNIDs, types.StateBlockNIDs(nidLists[0].StateBlockNIDs)) + assert.Equal(t, stateBlockNIDs2, types.StateBlockNIDs(nidLists[1].StateBlockNIDs)) + + // check we get an error if the state snapshot does not exist + _, err = tab.BulkSelectStateBlockNIDs(ctx, nil, []types.StateSnapshotNID{2}) + assert.Error(t, err) + + // create a second snapshot + for i := 0; i < 65555; i++ { + stateBlockNIDs2 = append(stateBlockNIDs2, types.StateBlockNID(i)) + } + _, err = tab.InsertState(ctx, nil, 1, stateBlockNIDs2) + assert.NoError(t, err) + }) +} diff --git a/roomserver/types/types.go b/roomserver/types/types.go index ce4e5fd1e..62695aaee 100644 --- a/roomserver/types/types.go +++ b/roomserver/types/types.go @@ -21,6 +21,7 @@ import ( "strings" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" "golang.org/x/crypto/blake2b" ) @@ -97,6 +98,38 @@ func (a StateKeyTuple) LessThan(b StateKeyTuple) bool { return a.EventStateKeyNID < b.EventStateKeyNID } +type StateKeyTupleSorter []StateKeyTuple + +func (s StateKeyTupleSorter) Len() int { return len(s) } +func (s StateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) } +func (s StateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +// Check whether a tuple is in the list. Assumes that the list is sorted. +func (s StateKeyTupleSorter) contains(value StateKeyTuple) bool { + i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) }) + return i < len(s) && s[i] == value +} + +// List the unique eventTypeNIDs and eventStateKeyNIDs. +// Assumes that the list is sorted. +func (s StateKeyTupleSorter) TypesAndStateKeysAsArrays() (eventTypeNIDs []int64, eventStateKeyNIDs []int64) { + eventTypeNIDs = make([]int64, len(s)) + eventStateKeyNIDs = make([]int64, len(s)) + for i := range s { + eventTypeNIDs[i] = int64(s[i].EventTypeNID) + eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID) + } + eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))] + eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))] + return +} + +type int64Sorter []int64 + +func (s int64Sorter) Len() int { return len(s) } +func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] } +func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + // A StateEntry is an entry in the room state of a matrix room. type StateEntry struct { StateKeyTuple diff --git a/roomserver/types/types_test.go b/roomserver/types/types_test.go index b1e84b821..a26b80f74 100644 --- a/roomserver/types/types_test.go +++ b/roomserver/types/types_test.go @@ -1,6 +1,7 @@ package types import ( + "sort" "testing" ) @@ -24,3 +25,66 @@ func TestDeduplicateStateEntries(t *testing.T) { } } } + +func TestStateKeyTupleSorter(t *testing.T) { + input := StateKeyTupleSorter{ + {EventTypeNID: 1, EventStateKeyNID: 2}, + {EventTypeNID: 1, EventStateKeyNID: 4}, + {EventTypeNID: 2, EventStateKeyNID: 2}, + {EventTypeNID: 1, EventStateKeyNID: 1}, + } + want := []StateKeyTuple{ + {EventTypeNID: 1, EventStateKeyNID: 1}, + {EventTypeNID: 1, EventStateKeyNID: 2}, + {EventTypeNID: 1, EventStateKeyNID: 4}, + {EventTypeNID: 2, EventStateKeyNID: 2}, + } + doNotWant := []StateKeyTuple{ + {EventTypeNID: 0, EventStateKeyNID: 0}, + {EventTypeNID: 1, EventStateKeyNID: 3}, + {EventTypeNID: 2, EventStateKeyNID: 1}, + {EventTypeNID: 3, EventStateKeyNID: 1}, + } + wantTypeNIDs := []int64{1, 2} + wantStateKeyNIDs := []int64{1, 2, 4} + + // Sort the input and check it's in the right order. + sort.Sort(input) + gotTypeNIDs, gotStateKeyNIDs := input.TypesAndStateKeysAsArrays() + + for i := range want { + if input[i] != want[i] { + t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i]) + } + + if !input.contains(want[i]) { + t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i]) + } + } + + for i := range doNotWant { + if input.contains(doNotWant[i]) { + t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i]) + } + } + + if len(wantTypeNIDs) != len(gotTypeNIDs) { + t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) + } + + for i := range wantTypeNIDs { + if wantTypeNIDs[i] != gotTypeNIDs[i] { + t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) + } + } + + if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) { + t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs) + } + + for i := range wantStateKeyNIDs { + if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] { + t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) + } + } +} From cd82460513d5abf04e56c01667d56499d4c354be Mon Sep 17 00:00:00 2001 From: kegsay Date: Tue, 17 May 2022 10:45:50 +0100 Subject: [PATCH 2/7] Add docs which explain how to calculate coverage (#2468) --- docs/coverage.md | 84 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 docs/coverage.md diff --git a/docs/coverage.md b/docs/coverage.md new file mode 100644 index 000000000..7a3b7cb9e --- /dev/null +++ b/docs/coverage.md @@ -0,0 +1,84 @@ +--- +title: Coverage +parent: Development +permalink: /development/coverage +--- + +To generate a test coverage report for Sytest, a small patch needs to be applied to the Sytest repository to compile and use the instrumented binary: +```patch +diff --git a/lib/SyTest/Homeserver/Dendrite.pm b/lib/SyTest/Homeserver/Dendrite.pm +index 8f0e209c..ad057e52 100644 +--- a/lib/SyTest/Homeserver/Dendrite.pm ++++ b/lib/SyTest/Homeserver/Dendrite.pm +@@ -337,7 +337,7 @@ sub _start_monolith + + $output->diag( "Starting monolith server" ); + my @command = ( +- $self->{bindir} . '/dendrite-monolith-server', ++ $self->{bindir} . '/dendrite-monolith-server', '--test.coverprofile=' . $self->{hs_dir} . '/integrationcover.log', "DEVEL", + '--config', $self->{paths}{config}, + '--http-bind-address', $self->{bind_host} . ':' . $self->unsecure_port, + '--https-bind-address', $self->{bind_host} . ':' . $self->secure_port, +diff --git a/scripts/dendrite_sytest.sh b/scripts/dendrite_sytest.sh +index f009332b..7ea79869 100755 +--- a/scripts/dendrite_sytest.sh ++++ b/scripts/dendrite_sytest.sh +@@ -34,7 +34,8 @@ export GOBIN=/tmp/bin + echo >&2 "--- Building dendrite from source" + cd /src + mkdir -p $GOBIN +-go install -v ./cmd/dendrite-monolith-server ++# go install -v ./cmd/dendrite-monolith-server ++go test -c -cover -covermode=atomic -o $GOBIN/dendrite-monolith-server -coverpkg "github.com/matrix-org/..." ./cmd/dendrite-monolith-server + go install -v ./cmd/generate-keys + cd - + ``` + + Then run Sytest. This will generate a new file `integrationcover.log` in each server's directory e.g `server-0/integrationcover.log`. To parse it, + ensure your working directory is under the Dendrite repository then run: + ```bash + go tool cover -func=/path/to/server-0/integrationcover.log + ``` + which will produce an output like: + ``` + ... + github.com/matrix-org/util/json.go:83: NewJSONRequestHandler 100.0% +github.com/matrix-org/util/json.go:90: Protect 57.1% +github.com/matrix-org/util/json.go:110: RequestWithLogging 100.0% +github.com/matrix-org/util/json.go:132: MakeJSONAPI 70.0% +github.com/matrix-org/util/json.go:151: respond 61.5% +github.com/matrix-org/util/json.go:180: WithCORSOptions 0.0% +github.com/matrix-org/util/json.go:191: SetCORSHeaders 100.0% +github.com/matrix-org/util/json.go:202: RandomString 100.0% +github.com/matrix-org/util/json.go:210: init 100.0% +github.com/matrix-org/util/unique.go:13: Unique 91.7% +github.com/matrix-org/util/unique.go:48: SortAndUnique 100.0% +github.com/matrix-org/util/unique.go:55: UniqueStrings 100.0% +total: (statements) 53.7% +``` +The total coverage for this run is the last line at the bottom. However, this value is misleading because Dendrite can run in many different configurations, +which will never be tested in a single test run (e.g sqlite or postgres, monolith or polylith). To get a more accurate value, additional processing is required +to remove packages which will never be tested and extension MSCs: +```bash +# These commands are all similar but change which package paths are _removed_ from the output. + +# For Postgres (monolith) +go tool cover -func=/path/to/server-0/integrationcover.log | grep 'github.com/matrix-org/dendrite' | grep -Ev 'inthttp|sqlite|setup/mscs|api_trace' > coverage.txt + +# For Postgres (polylith) +go tool cover -func=/path/to/server-0/integrationcover.log | grep 'github.com/matrix-org/dendrite' | grep -Ev 'sqlite|setup/mscs|api_trace' > coverage.txt + +# For SQLite (monolith) +go tool cover -func=/path/to/server-0/integrationcover.log | grep 'github.com/matrix-org/dendrite' | grep -Ev 'inthttp|postgres|setup/mscs|api_trace' > coverage.txt + +# For SQLite (polylith) +go tool cover -func=/path/to/server-0/integrationcover.log | grep 'github.com/matrix-org/dendrite' | grep -Ev 'postgres|setup/mscs|api_trace' > coverage.txt +``` + +A total value can then be calculated using: +```bash +cat coverage.txt | awk -F '\t+' '{x = x + $3} END {print x/NR}' +``` + + +We currently do not have a way to combine Sytest/Complement/Unit Tests into a single coverage report. \ No newline at end of file From 6de29c1cd23d218f04d2e570932db8967d6adc4f Mon Sep 17 00:00:00 2001 From: kegsay Date: Tue, 17 May 2022 13:23:35 +0100 Subject: [PATCH 3/7] bugfix: E2EE device keys could sometimes not be sent to remote servers (#2466) * Fix flakey sytest 'Local device key changes get to remote servers' * Debug logs * Remove internal/test and use /test only Remove a lot of ancient code too. * Use FederationRoomserverAPI in more places * Use more interfaces in federationapi; begin adding regression test * Linting * Add regression test * Unbreak tests * ALL THE LOGS * Fix a race condition which could cause events to not be sent to servers If a new room event which rewrites state arrives, we remove all joined hosts then re-calculate them. This wasn't done in a transaction so for a brief period we would have no joined hosts. During this interim, key change events which arrive would not be sent to destination servers. This would sporadically fail on sytest. * Unbreak new tests * Linting --- cmd/generate-keys/main.go | 2 +- federationapi/api/api.go | 40 ++- federationapi/consumers/keychange.go | 8 +- federationapi/consumers/roomserver.go | 29 +-- federationapi/federationapi.go | 4 +- federationapi/federationapi_test.go | 236 +++++++++++++++++- federationapi/internal/api.go | 8 +- federationapi/internal/perform.go | 18 +- federationapi/queue/destinationqueue.go | 31 +-- federationapi/queue/queue.go | 9 +- federationapi/routing/query.go | 2 +- federationapi/routing/routing.go | 2 +- federationapi/routing/send.go | 2 +- federationapi/routing/send_test.go | 2 +- federationapi/routing/threepid.go | 16 +- federationapi/storage/interface.go | 3 +- .../storage/postgres/joined_hosts_table.go | 5 + federationapi/storage/shared/storage.go | 26 +- internal/caching/cache_typing_test.go | 2 +- internal/test/client.go | 158 ------------ internal/test/kafka.go | 76 ------ internal/test/server.go | 152 ----------- keyserver/internal/device_list_update.go | 4 +- keyserver/internal/internal.go | 2 +- keyserver/keyserver.go | 2 +- roomserver/api/api.go | 1 + roomserver/internal/input/input_test.go | 4 +- roomserver/internal/query/query_test.go | 2 +- .../storage/tables/events_table_test.go | 2 +- .../tables/previous_events_table_test.go | 2 +- .../storage/tables/published_table_test.go | 2 +- .../storage/tables/room_aliases_table_test.go | 2 +- roomserver/storage/tables/rooms_table_test.go | 2 +- syncapi/storage/storage_test.go | 6 +- .../storage/tables/output_room_events_test.go | 2 +- syncapi/storage/tables/topology_test.go | 2 +- syncapi/syncapi_test.go | 15 +- test/event.go | 18 ++ test/http.go | 47 ++++ {internal/test => test}/keyring.go | 0 internal/test/config.go => test/keys.go | 84 ------- test/room.go | 66 +++-- {internal/test => test}/slice.go | 0 test/{ => testrig}/base.go | 11 +- test/{ => testrig}/jetstream.go | 2 +- test/user.go | 52 +++- userapi/storage/storage_test.go | 20 +- userapi/userapi_test.go | 3 +- 48 files changed, 566 insertions(+), 618 deletions(-) delete mode 100644 internal/test/client.go delete mode 100644 internal/test/kafka.go delete mode 100644 internal/test/server.go rename {internal/test => test}/keyring.go (100%) rename internal/test/config.go => test/keys.go (61%) rename {internal/test => test}/slice.go (100%) rename test/{ => testrig}/base.go (92%) rename test/{ => testrig}/jetstream.go (98%) diff --git a/cmd/generate-keys/main.go b/cmd/generate-keys/main.go index bddf219dc..8acd28be0 100644 --- a/cmd/generate-keys/main.go +++ b/cmd/generate-keys/main.go @@ -20,7 +20,7 @@ import ( "log" "os" - "github.com/matrix-org/dendrite/internal/test" + "github.com/matrix-org/dendrite/test" ) const usage = `Usage: %s diff --git a/federationapi/api/api.go b/federationapi/api/api.go index fc25194e0..53d4701f3 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -12,12 +12,16 @@ import ( // FederationInternalAPI is used to query information from the federation sender. type FederationInternalAPI interface { - FederationClient + gomatrixserverlib.FederatedStateClient + KeyserverFederationAPI gomatrixserverlib.KeyDatabase ClientFederationAPI RoomserverFederationAPI QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error + LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error) + MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) + MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) // Broadcasts an EDU to all servers in rooms we are joined to. Used in the yggdrasil demos. PerformBroadcastEDU( @@ -60,17 +64,43 @@ type RoomserverFederationAPI interface { LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) } -// FederationClient is a subset of gomatrixserverlib.FederationClient functions which the fedsender +// KeyserverFederationAPI is a subset of gomatrixserverlib.FederationClient functions which the keyserver // implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in // this interface are of type FederationClientError -type FederationClient interface { - gomatrixserverlib.FederatedStateClient +type KeyserverFederationAPI interface { GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (res gomatrixserverlib.RespUserDevices, err error) ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error) QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error) +} + +// an interface for gmsl.FederationClient - contains functions called by federationapi only. +type FederationClient interface { + gomatrixserverlib.KeyClient + SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) + + // Perform operations + LookupRoomAlias(ctx context.Context, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error) + Peek(ctx context.Context, s gomatrixserverlib.ServerName, roomID, peekID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespPeek, err error) + MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) + SendJoin(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) + MakeLeave(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string) (res gomatrixserverlib.RespMakeLeave, err error) + SendLeave(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (err error) + SendInviteV2(ctx context.Context, s gomatrixserverlib.ServerName, request gomatrixserverlib.InviteV2Request) (res gomatrixserverlib.RespInviteV2, err error) + + GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) + + GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) + GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (gomatrixserverlib.RespUserDevices, error) + ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (gomatrixserverlib.RespClaimKeys, error) + QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (gomatrixserverlib.RespQueryKeys, error) + Backfill(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string) (res gomatrixserverlib.Transaction, err error) MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) - LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error) + + ExchangeThirdPartyInvite(ctx context.Context, s gomatrixserverlib.ServerName, builder gomatrixserverlib.EventBuilder) (err error) + LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespState, err error) + LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) + LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) } // FederationClientError is returned from FederationClient methods in the event of a problem. diff --git a/federationapi/consumers/keychange.go b/federationapi/consumers/keychange.go index 0ece18e97..95c9a7fdd 100644 --- a/federationapi/consumers/keychange.go +++ b/federationapi/consumers/keychange.go @@ -39,7 +39,7 @@ type KeyChangeConsumer struct { db storage.Database queues *queue.OutgoingQueues serverName gomatrixserverlib.ServerName - rsAPI roomserverAPI.RoomserverInternalAPI + rsAPI roomserverAPI.FederationRoomserverAPI topic string } @@ -50,7 +50,7 @@ func NewKeyChangeConsumer( js nats.JetStreamContext, queues *queue.OutgoingQueues, store storage.Database, - rsAPI roomserverAPI.RoomserverInternalAPI, + rsAPI roomserverAPI.FederationRoomserverAPI, ) *KeyChangeConsumer { return &KeyChangeConsumer{ ctx: process.Context(), @@ -120,6 +120,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool { logger.WithError(err).Error("failed to calculate joined rooms for user") return true } + logrus.Infof("DEBUG: %v joined rooms for user %v", queryRes.RoomIDs, m.UserID) // send this key change to all servers who share rooms with this user. destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) if err != nil { @@ -128,6 +129,9 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool { } if len(destinations) == 0 { + logger.WithField("num_rooms", len(queryRes.RoomIDs)).Debug("user is in no federated rooms") + destinations, err = t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, false) + logrus.Infof("GetJoinedHostsForRooms exclude self=false -> %v %v", destinations, err) return true } // Pack the EDU and marshal it diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index 80317ee69..7a0816ff2 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/federationapi/queue" @@ -36,7 +37,7 @@ import ( type OutputRoomEventConsumer struct { ctx context.Context cfg *config.FederationAPI - rsAPI api.RoomserverInternalAPI + rsAPI api.FederationRoomserverAPI jetstream nats.JetStreamContext durable string db storage.Database @@ -51,7 +52,7 @@ func NewOutputRoomEventConsumer( js nats.JetStreamContext, queues *queue.OutgoingQueues, store storage.Database, - rsAPI api.RoomserverInternalAPI, + rsAPI api.FederationRoomserverAPI, ) *OutputRoomEventConsumer { return &OutputRoomEventConsumer{ ctx: process.Context(), @@ -89,15 +90,7 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) switch output.Type { case api.OutputTypeNewRoomEvent: ev := output.NewRoomEvent.Event - - if output.NewRoomEvent.RewritesState { - if err := s.db.PurgeRoomState(s.ctx, ev.RoomID()); err != nil { - log.WithError(err).Errorf("roomserver output log: purge room state failure") - return false - } - } - - if err := s.processMessage(*output.NewRoomEvent); err != nil { + if err := s.processMessage(*output.NewRoomEvent, output.NewRoomEvent.RewritesState); err != nil { // panic rather than continue with an inconsistent database log.WithFields(log.Fields{ "event_id": ev.EventID(), @@ -145,7 +138,7 @@ func (s *OutputRoomEventConsumer) processInboundPeek(orp api.OutputNewInboundPee // processMessage updates the list of currently joined hosts in the room // and then sends the event to the hosts that were joined before the event. -func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) error { +func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rewritesState bool) error { addsStateEvents, missingEventIDs := ore.NeededStateEventIDs() // Ask the roomserver and add in the rest of the results into the set. @@ -164,7 +157,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) err addsStateEvents = append(addsStateEvents, eventsRes.Events...) } - addsJoinedHosts, err := joinedHostsFromEvents(gomatrixserverlib.UnwrapEventHeaders(addsStateEvents)) + addsJoinedHosts, err := JoinedHostsFromEvents(gomatrixserverlib.UnwrapEventHeaders(addsStateEvents)) if err != nil { return err } @@ -173,13 +166,13 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) err // expressed as a delta against the current state. // TODO(#290): handle EventIDMismatchError and recover the current state by // talking to the roomserver + logrus.Infof("room %s adds joined hosts: %v removes %v", ore.Event.RoomID(), addsJoinedHosts, ore.RemovesStateEventIDs) oldJoinedHosts, err := s.db.UpdateRoom( s.ctx, ore.Event.RoomID(), - ore.LastSentEventID, - ore.Event.EventID(), addsJoinedHosts, ore.RemovesStateEventIDs, + rewritesState, // if we're re-writing state, nuke all joined hosts before adding ) if err != nil { return err @@ -238,7 +231,7 @@ func (s *OutputRoomEventConsumer) joinedHostsAtEvent( return nil, err } - combinedAddsJoinedHosts, err := joinedHostsFromEvents(combinedAddsEvents) + combinedAddsJoinedHosts, err := JoinedHostsFromEvents(combinedAddsEvents) if err != nil { return nil, err } @@ -284,10 +277,10 @@ func (s *OutputRoomEventConsumer) joinedHostsAtEvent( return result, nil } -// joinedHostsFromEvents turns a list of state events into a list of joined hosts. +// JoinedHostsFromEvents turns a list of state events into a list of joined hosts. // This errors if one of the events was invalid. // It should be impossible for an invalid event to get this far in the pipeline. -func joinedHostsFromEvents(evs []*gomatrixserverlib.Event) ([]types.JoinedHost, error) { +func JoinedHostsFromEvents(evs []*gomatrixserverlib.Event) ([]types.JoinedHost, error) { var joinedHosts []types.JoinedHost for _, ev := range evs { if ev.Type() != "m.room.member" || ev.StateKey() == nil { diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index bec9ac777..ff159beea 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -93,8 +93,8 @@ func AddPublicRoutes( // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( base *base.BaseDendrite, - federation *gomatrixserverlib.FederationClient, - rsAPI roomserverAPI.RoomserverInternalAPI, + federation api.FederationClient, + rsAPI roomserverAPI.FederationRoomserverAPI, caches *caching.Caches, keyRing *gomatrixserverlib.KeyRing, resetBlacklist bool, diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index eedebc6cd..ae244c566 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -3,18 +3,250 @@ package federationapi_test import ( "context" "crypto/ed25519" + "encoding/json" + "fmt" "strings" "testing" + "time" "github.com/matrix-org/dendrite/federationapi" + "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/internal" - "github.com/matrix-org/dendrite/internal/test" + keyapi "github.com/matrix-org/dendrite/keyserver/api" + rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" ) +type fedRoomserverAPI struct { + rsapi.FederationRoomserverAPI + inputRoomEvents func(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) + queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error +} + +// PerformJoin will call this function +func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) { + if f.inputRoomEvents == nil { + return + } + f.inputRoomEvents(ctx, req, res) +} + +// keychange consumer calls this +func (f *fedRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error { + if f.queryRoomsForUser == nil { + return nil + } + return f.queryRoomsForUser(ctx, req, res) +} + +// TODO: This struct isn't generic, only works for TestFederationAPIJoinThenKeyUpdate +type fedClient struct { + api.FederationClient + allowJoins []*test.Room + keys map[gomatrixserverlib.ServerName]struct { + key ed25519.PrivateKey + keyID gomatrixserverlib.KeyID + } + t *testing.T + sentTxn bool +} + +func (f *fedClient) GetServerKeys(ctx context.Context, matrixServer gomatrixserverlib.ServerName) (gomatrixserverlib.ServerKeys, error) { + fmt.Println("GetServerKeys:", matrixServer) + var keys gomatrixserverlib.ServerKeys + var keyID gomatrixserverlib.KeyID + var pkey ed25519.PrivateKey + for srv, data := range f.keys { + if srv == matrixServer { + pkey = data.key + keyID = data.keyID + break + } + } + if pkey == nil { + return keys, nil + } + + keys.ServerName = matrixServer + keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(time.Now().Add(10 * time.Hour)) + publicKey := pkey.Public().(ed25519.PublicKey) + keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{ + keyID: { + Key: gomatrixserverlib.Base64Bytes(publicKey), + }, + } + toSign, err := json.Marshal(keys.ServerKeyFields) + if err != nil { + return keys, err + } + + keys.Raw, err = gomatrixserverlib.SignJSON( + string(matrixServer), keyID, pkey, toSign, + ) + if err != nil { + return keys, err + } + + return keys, nil +} + +func (f *fedClient) MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) { + for _, r := range f.allowJoins { + if r.ID == roomID { + res.RoomVersion = r.Version + res.JoinEvent = gomatrixserverlib.EventBuilder{ + Sender: userID, + RoomID: roomID, + Type: "m.room.member", + StateKey: &userID, + Content: gomatrixserverlib.RawJSON([]byte(`{"membership":"join"}`)), + PrevEvents: r.ForwardExtremities(), + } + var needed gomatrixserverlib.StateNeeded + needed, err = gomatrixserverlib.StateNeededForEventBuilder(&res.JoinEvent) + if err != nil { + f.t.Errorf("StateNeededForEventBuilder: %v", err) + return + } + res.JoinEvent.AuthEvents = r.MustGetAuthEventRefsForEvent(f.t, needed) + return + } + } + return +} +func (f *fedClient) SendJoin(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) { + for _, r := range f.allowJoins { + if r.ID == event.RoomID() { + r.InsertEvent(f.t, event.Headered(r.Version)) + f.t.Logf("Join event: %v", event.EventID()) + res.StateEvents = gomatrixserverlib.NewEventJSONsFromHeaderedEvents(r.CurrentState()) + res.AuthEvents = gomatrixserverlib.NewEventJSONsFromHeaderedEvents(r.Events()) + } + } + return +} + +func (f *fedClient) SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) { + for _, edu := range t.EDUs { + if edu.Type == gomatrixserverlib.MDeviceListUpdate { + f.sentTxn = true + } + } + f.t.Logf("got /send") + return +} + +// Regression test to make sure that /send_join is updating the destination hosts synchronously and +// isn't relying on the roomserver. +func TestFederationAPIJoinThenKeyUpdate(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + testFederationAPIJoinThenKeyUpdate(t, dbType) + }) +} + +func testFederationAPIJoinThenKeyUpdate(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + base.Cfg.FederationAPI.PreferDirectFetch = true + defer close() + jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) + defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) + + serverA := gomatrixserverlib.ServerName("server.a") + serverAKeyID := gomatrixserverlib.KeyID("ed25519:servera") + serverAPrivKey := test.PrivateKeyA + creator := test.NewUser(t, test.WithSigningServer(serverA, serverAKeyID, serverAPrivKey)) + + myServer := base.Cfg.Global.ServerName + myServerKeyID := base.Cfg.Global.KeyID + myServerPrivKey := base.Cfg.Global.PrivateKey + joiningUser := test.NewUser(t, test.WithSigningServer(myServer, myServerKeyID, myServerPrivKey)) + fmt.Printf("creator: %v joining user: %v\n", creator.ID, joiningUser.ID) + room := test.NewRoom(t, creator) + + rsapi := &fedRoomserverAPI{ + inputRoomEvents: func(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) { + if req.Asynchronous { + t.Errorf("InputRoomEvents from PerformJoin MUST be synchronous") + } + }, + queryRoomsForUser: func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error { + if req.UserID == joiningUser.ID && req.WantMembership == "join" { + res.RoomIDs = []string{room.ID} + return nil + } + return fmt.Errorf("unexpected queryRoomsForUser: %+v", *req) + }, + } + fc := &fedClient{ + allowJoins: []*test.Room{room}, + t: t, + keys: map[gomatrixserverlib.ServerName]struct { + key ed25519.PrivateKey + keyID gomatrixserverlib.KeyID + }{ + serverA: { + key: serverAPrivKey, + keyID: serverAKeyID, + }, + myServer: { + key: myServerPrivKey, + keyID: myServerKeyID, + }, + }, + } + fsapi := federationapi.NewInternalAPI(base, fc, rsapi, base.Caches, nil, false) + + var resp api.PerformJoinResponse + fsapi.PerformJoin(context.Background(), &api.PerformJoinRequest{ + RoomID: room.ID, + UserID: joiningUser.ID, + ServerNames: []gomatrixserverlib.ServerName{serverA}, + }, &resp) + if resp.JoinedVia != serverA { + t.Errorf("PerformJoin: joined via %v want %v", resp.JoinedVia, serverA) + } + if resp.LastError != nil { + t.Fatalf("PerformJoin: returned error: %+v", *resp.LastError) + } + + // Inject a keyserver key change event and ensure we try to send it out. If we don't, then the + // federationapi is incorrectly waiting for an output room event to arrive to update the joined + // hosts table. + key := keyapi.DeviceMessage{ + Type: keyapi.TypeDeviceKeyUpdate, + DeviceKeys: &keyapi.DeviceKeys{ + UserID: joiningUser.ID, + DeviceID: "MY_DEVICE", + DisplayName: "BLARGLE", + KeyJSON: []byte(`{}`), + }, + } + b, err := json.Marshal(key) + if err != nil { + t.Fatalf("Failed to marshal device message: %s", err) + } + + msg := &nats.Msg{ + Subject: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputKeyChangeEvent), + Header: nats.Header{}, + Data: b, + } + msg.Header.Set(jetstream.UserID, key.UserID) + + testrig.MustPublishMsgs(t, jsctx, msg) + time.Sleep(500 * time.Millisecond) + if !fc.sentTxn { + t.Fatalf("did not send device list update") + } +} + // Tests that event IDs with '/' in them (escaped as %2F) are correctly passed to the right handler and don't 404. // Relevant for v3 rooms and a cause of flakey sytests as the IDs are randomly generated. func TestRoomsV3URLEscapeDoNot404(t *testing.T) { @@ -86,7 +318,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) { } gerr, ok := err.(gomatrix.HTTPError) if !ok { - t.Errorf("failed to cast response error as gomatrix.HTTPError") + t.Errorf("failed to cast response error as gomatrix.HTTPError: %s", err) continue } t.Logf("Error: %+v", gerr) diff --git a/federationapi/internal/api.go b/federationapi/internal/api.go index 4e9fa8410..14056eafc 100644 --- a/federationapi/internal/api.go +++ b/federationapi/internal/api.go @@ -25,8 +25,8 @@ type FederationInternalAPI struct { db storage.Database cfg *config.FederationAPI statistics *statistics.Statistics - rsAPI roomserverAPI.RoomserverInternalAPI - federation *gomatrixserverlib.FederationClient + rsAPI roomserverAPI.FederationRoomserverAPI + federation api.FederationClient keyRing *gomatrixserverlib.KeyRing queues *queue.OutgoingQueues joins sync.Map // joins currently in progress @@ -34,8 +34,8 @@ type FederationInternalAPI struct { func NewFederationInternalAPI( db storage.Database, cfg *config.FederationAPI, - rsAPI roomserverAPI.RoomserverInternalAPI, - federation *gomatrixserverlib.FederationClient, + rsAPI roomserverAPI.FederationRoomserverAPI, + federation api.FederationClient, statistics *statistics.Statistics, caches *caching.Caches, queues *queue.OutgoingQueues, diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index 577cb70e0..7ccd68ef0 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -8,6 +8,7 @@ import ( "time" "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/federationapi/consumers" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/version" "github.com/matrix-org/gomatrix" @@ -235,6 +236,21 @@ func (r *FederationInternalAPI) performJoinUsingServer( return fmt.Errorf("respSendJoin.Check: %w", err) } + // We need to immediately update our list of joined hosts for this room now as we are technically + // joined. We must do this synchronously: we cannot rely on the roomserver output events as they + // will happen asyncly. If we don't update this table, you can end up with bad failure modes like + // joining a room, waiting for 200 OK then changing device keys and have those keys not be sent + // to other servers (this was a cause of a flakey sytest "Local device key changes get to remote servers") + // The events are trusted now as we performed auth checks above. + joinedHosts, err := consumers.JoinedHostsFromEvents(respState.StateEvents.TrustedEvents(respMakeJoin.RoomVersion, false)) + if err != nil { + return fmt.Errorf("JoinedHostsFromEvents: failed to get joined hosts: %s", err) + } + logrus.WithField("hosts", joinedHosts).WithField("room", roomID).Info("Joined federated room with hosts") + if _, err = r.db.UpdateRoom(context.Background(), roomID, joinedHosts, nil, true); err != nil { + return fmt.Errorf("UpdatedRoom: failed to update room with joined hosts: %s", err) + } + // If we successfully performed a send_join above then the other // server now thinks we're a part of the room. Send the newly // returned state to the roomserver to update our local view. @@ -650,7 +666,7 @@ func setDefaultRoomVersionFromJoinEvent(joinEvent gomatrixserverlib.EventBuilder // FederatedAuthProvider is an auth chain provider which fetches events from the server provided func federatedAuthProvider( - ctx context.Context, federation *gomatrixserverlib.FederationClient, + ctx context.Context, federation api.FederationClient, keyRing gomatrixserverlib.JSONVerifier, server gomatrixserverlib.ServerName, ) gomatrixserverlib.AuthChainProvider { // A list of events that we have retried, if they were not included in diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index 747940403..b6edec5da 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -21,6 +21,7 @@ import ( "sync" "time" + fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/federationapi/storage/shared" @@ -49,21 +50,21 @@ type destinationQueue struct { db storage.Database process *process.ProcessContext signing *SigningInfo - rsAPI api.RoomserverInternalAPI - client *gomatrixserverlib.FederationClient // federation client - origin gomatrixserverlib.ServerName // origin of requests - destination gomatrixserverlib.ServerName // destination of requests - running atomic.Bool // is the queue worker running? - backingOff atomic.Bool // true if we're backing off - overflowed atomic.Bool // the queues exceed maxPDUsInMemory/maxEDUsInMemory, so we should consult the database for more - statistics *statistics.ServerStatistics // statistics about this remote server - transactionIDMutex sync.Mutex // protects transactionID - transactionID gomatrixserverlib.TransactionID // last transaction ID if retrying, or "" if last txn was successful - notify chan struct{} // interrupts idle wait pending PDUs/EDUs - pendingPDUs []*queuedPDU // PDUs waiting to be sent - pendingEDUs []*queuedEDU // EDUs waiting to be sent - pendingMutex sync.RWMutex // protects pendingPDUs and pendingEDUs - interruptBackoff chan bool // interrupts backoff + rsAPI api.FederationRoomserverAPI + client fedapi.FederationClient // federation client + origin gomatrixserverlib.ServerName // origin of requests + destination gomatrixserverlib.ServerName // destination of requests + running atomic.Bool // is the queue worker running? + backingOff atomic.Bool // true if we're backing off + overflowed atomic.Bool // the queues exceed maxPDUsInMemory/maxEDUsInMemory, so we should consult the database for more + statistics *statistics.ServerStatistics // statistics about this remote server + transactionIDMutex sync.Mutex // protects transactionID + transactionID gomatrixserverlib.TransactionID // last transaction ID if retrying, or "" if last txn was successful + notify chan struct{} // interrupts idle wait pending PDUs/EDUs + pendingPDUs []*queuedPDU // PDUs waiting to be sent + pendingEDUs []*queuedEDU // EDUs waiting to be sent + pendingMutex sync.RWMutex // protects pendingPDUs and pendingEDUs + interruptBackoff chan bool // interrupts backoff } // Send event adds the event to the pending queue for the destination. diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go index d152886f5..4c25c4ce6 100644 --- a/federationapi/queue/queue.go +++ b/federationapi/queue/queue.go @@ -26,6 +26,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" + fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/federationapi/storage/shared" @@ -39,9 +40,9 @@ type OutgoingQueues struct { db storage.Database process *process.ProcessContext disabled bool - rsAPI api.RoomserverInternalAPI + rsAPI api.FederationRoomserverAPI origin gomatrixserverlib.ServerName - client *gomatrixserverlib.FederationClient + client fedapi.FederationClient statistics *statistics.Statistics signing *SigningInfo queuesMutex sync.Mutex // protects the below @@ -85,8 +86,8 @@ func NewOutgoingQueues( process *process.ProcessContext, disabled bool, origin gomatrixserverlib.ServerName, - client *gomatrixserverlib.FederationClient, - rsAPI api.RoomserverInternalAPI, + client fedapi.FederationClient, + rsAPI api.FederationRoomserverAPI, statistics *statistics.Statistics, signing *SigningInfo, ) *OutgoingQueues { diff --git a/federationapi/routing/query.go b/federationapi/routing/query.go index 707b7b019..316c61a14 100644 --- a/federationapi/routing/query.go +++ b/federationapi/routing/query.go @@ -30,7 +30,7 @@ import ( // RoomAliasToID converts the queried alias into a room ID and returns it func RoomAliasToID( httpReq *http.Request, - federation *gomatrixserverlib.FederationClient, + federation federationAPI.FederationClient, cfg *config.FederationAPI, rsAPI roomserverAPI.FederationRoomserverAPI, senderAPI federationAPI.FederationInternalAPI, diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index 9f95ed07e..e25f9866e 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -54,7 +54,7 @@ func Setup( rsAPI roomserverAPI.FederationRoomserverAPI, fsAPI *fedInternal.FederationInternalAPI, keys gomatrixserverlib.JSONVerifier, - federation *gomatrixserverlib.FederationClient, + federation federationAPI.FederationClient, userAPI userapi.FederationUserAPI, keyAPI keyserverAPI.FederationKeyAPI, mscCfg *config.MSCs, diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 55a113675..c25dabce9 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -85,7 +85,7 @@ func Send( rsAPI api.FederationRoomserverAPI, keyAPI keyapi.FederationKeyAPI, keys gomatrixserverlib.JSONVerifier, - federation *gomatrixserverlib.FederationClient, + federation federationAPI.FederationClient, mu *internal.MutexByRoom, servers federationAPI.ServersInRoomProvider, producer *producers.SyncAPIProducer, diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index 011d4e342..a111580c7 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -8,8 +8,8 @@ import ( "time" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" ) diff --git a/federationapi/routing/threepid.go b/federationapi/routing/threepid.go index 16f245cee..ccde9168e 100644 --- a/federationapi/routing/threepid.go +++ b/federationapi/routing/threepid.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" + federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" @@ -57,7 +58,7 @@ var ( func CreateInvitesFrom3PIDInvites( req *http.Request, rsAPI api.FederationRoomserverAPI, cfg *config.FederationAPI, - federation *gomatrixserverlib.FederationClient, + federation federationAPI.FederationClient, userAPI userapi.FederationUserAPI, ) util.JSONResponse { var body invites @@ -107,7 +108,7 @@ func ExchangeThirdPartyInvite( roomID string, rsAPI api.FederationRoomserverAPI, cfg *config.FederationAPI, - federation *gomatrixserverlib.FederationClient, + federation federationAPI.FederationClient, ) util.JSONResponse { var builder gomatrixserverlib.EventBuilder if err := json.Unmarshal(request.Content(), &builder); err != nil { @@ -165,7 +166,12 @@ func ExchangeThirdPartyInvite( // Ask the requesting server to sign the newly created event so we know it // acknowledged it - signedEvent, err := federation.SendInvite(httpReq.Context(), request.Origin(), event) + inviteReq, err := gomatrixserverlib.NewInviteV2Request(event.Headered(verRes.RoomVersion), nil) + if err != nil { + util.GetLogger(httpReq.Context()).WithError(err).Error("failed to make invite v2 request") + return jsonerror.InternalServerError() + } + signedEvent, err := federation.SendInviteV2(httpReq.Context(), request.Origin(), inviteReq) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed") return jsonerror.InternalServerError() @@ -205,7 +211,7 @@ func ExchangeThirdPartyInvite( func createInviteFrom3PIDInvite( ctx context.Context, rsAPI api.FederationRoomserverAPI, cfg *config.FederationAPI, - inv invite, federation *gomatrixserverlib.FederationClient, + inv invite, federation federationAPI.FederationClient, userAPI userapi.FederationUserAPI, ) (*gomatrixserverlib.Event, error) { verReq := api.QueryRoomVersionForRoomRequest{RoomID: inv.RoomID} @@ -335,7 +341,7 @@ func buildMembershipEvent( // them responded with an error. func sendToRemoteServer( ctx context.Context, inv invite, - federation *gomatrixserverlib.FederationClient, _ *config.FederationAPI, + federation federationAPI.FederationClient, _ *config.FederationAPI, builder gomatrixserverlib.EventBuilder, ) (err error) { remoteServers := make([]gomatrixserverlib.ServerName, 2) diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index e3038651b..29254948b 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -25,13 +25,12 @@ import ( type Database interface { gomatrixserverlib.KeyDatabase - UpdateRoom(ctx context.Context, roomID, oldEventID, newEventID string, addHosts []types.JoinedHost, removeHosts []string) (joinedHosts []types.JoinedHost, err error) + UpdateRoom(ctx context.Context, roomID string, addHosts []types.JoinedHost, removeHosts []string, purgeRoomFirst bool) (joinedHosts []types.JoinedHost, err error) GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) // GetJoinedHostsForRooms returns the complete set of servers in the rooms given. GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error) - PurgeRoomState(ctx context.Context, roomID string) error StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) diff --git a/federationapi/storage/postgres/joined_hosts_table.go b/federationapi/storage/postgres/joined_hosts_table.go index 5c95b72a8..bb6f6bfa3 100644 --- a/federationapi/storage/postgres/joined_hosts_table.go +++ b/federationapi/storage/postgres/joined_hosts_table.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" ) const joinedHostsSchema = ` @@ -111,6 +112,7 @@ func (s *joinedHostsStatements) InsertJoinedHosts( roomID, eventID string, serverName gomatrixserverlib.ServerName, ) error { + logrus.Debugf("FederationJoinedHosts: INSERT %v %v %v", roomID, eventID, serverName) stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt) _, err := stmt.ExecContext(ctx, roomID, eventID, serverName) return err @@ -119,6 +121,7 @@ func (s *joinedHostsStatements) InsertJoinedHosts( func (s *joinedHostsStatements) DeleteJoinedHosts( ctx context.Context, txn *sql.Tx, eventIDs []string, ) error { + logrus.Debugf("FederationJoinedHosts: DELETE WITH EVENTS %v", eventIDs) stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt) _, err := stmt.ExecContext(ctx, pq.StringArray(eventIDs)) return err @@ -127,6 +130,7 @@ func (s *joinedHostsStatements) DeleteJoinedHosts( func (s *joinedHostsStatements) DeleteJoinedHostsForRoom( ctx context.Context, txn *sql.Tx, roomID string, ) error { + logrus.Debugf("FederationJoinedHosts: DELETE ALL IN ROOM %v", roomID) stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsForRoomStmt) _, err := stmt.ExecContext(ctx, roomID) return err @@ -207,6 +211,7 @@ func joinedHostsFromStmt( ServerName: gomatrixserverlib.ServerName(serverName), }) } + logrus.Debugf("FederationJoinedHosts: SELECT %v => %+v", roomID, result) return result, rows.Err() } diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index 160c7f6fa..a00d782f1 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -63,11 +63,21 @@ func (r *Receipt) String() string { // this isn't a duplicate message. func (d *Database) UpdateRoom( ctx context.Context, - roomID, oldEventID, newEventID string, + roomID string, addHosts []types.JoinedHost, removeHosts []string, + purgeRoomFirst bool, ) (joinedHosts []types.JoinedHost, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + if purgeRoomFirst { + // If the event is a create event then we'll delete all of the existing + // data for the room. The only reason that a create event would be replayed + // to us in this way is if we're about to receive the entire room state. + if err = d.FederationJoinedHosts.DeleteJoinedHostsForRoom(ctx, txn, roomID); err != nil { + return fmt.Errorf("d.FederationJoinedHosts.DeleteJoinedHosts: %w", err) + } + } + joinedHosts, err = d.FederationJoinedHosts.SelectJoinedHostsWithTx(ctx, txn, roomID) if err != nil { return err @@ -138,20 +148,6 @@ func (d *Database) StoreJSON( }, nil } -func (d *Database) PurgeRoomState( - ctx context.Context, roomID string, -) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - // If the event is a create event then we'll delete all of the existing - // data for the room. The only reason that a create event would be replayed - // to us in this way is if we're about to receive the entire room state. - if err := d.FederationJoinedHosts.DeleteJoinedHostsForRoom(ctx, txn, roomID); err != nil { - return fmt.Errorf("d.FederationJoinedHosts.DeleteJoinedHosts: %w", err) - } - return nil - }) -} - func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationBlacklist.InsertBlacklist(context.TODO(), txn, serverName) diff --git a/internal/caching/cache_typing_test.go b/internal/caching/cache_typing_test.go index c03d89bc3..2cef32d3e 100644 --- a/internal/caching/cache_typing_test.go +++ b/internal/caching/cache_typing_test.go @@ -20,7 +20,7 @@ import ( "testing" "time" - "github.com/matrix-org/dendrite/internal/test" + "github.com/matrix-org/dendrite/test" ) func TestEDUCache(t *testing.T) { diff --git a/internal/test/client.go b/internal/test/client.go deleted file mode 100644 index a38540ac9..000000000 --- a/internal/test/client.go +++ /dev/null @@ -1,158 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package test - -import ( - "crypto/tls" - "fmt" - "io" - "io/ioutil" - "net/http" - "sync" - "time" - - "github.com/matrix-org/gomatrixserverlib" -) - -// Request contains the information necessary to issue a request and test its result -type Request struct { - Req *http.Request - WantedBody string - WantedStatusCode int - LastErr *LastRequestErr -} - -// LastRequestErr is a synchronised error wrapper -// Useful for obtaining the last error from a set of requests -type LastRequestErr struct { - sync.Mutex - Err error -} - -// Set sets the error -func (r *LastRequestErr) Set(err error) { - r.Lock() - defer r.Unlock() - r.Err = err -} - -// Get gets the error -func (r *LastRequestErr) Get() error { - r.Lock() - defer r.Unlock() - return r.Err -} - -// CanonicalJSONInput canonicalises a slice of JSON strings -// Useful for test input -func CanonicalJSONInput(jsonData []string) []string { - for i := range jsonData { - jsonBytes, err := gomatrixserverlib.CanonicalJSON([]byte(jsonData[i])) - if err != nil && err != io.EOF { - panic(err) - } - jsonData[i] = string(jsonBytes) - } - return jsonData -} - -// Do issues a request and checks the status code and body of the response -func (r *Request) Do() (err error) { - client := &http.Client{ - Timeout: 5 * time.Second, - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - }, - } - res, err := client.Do(r.Req) - if err != nil { - return err - } - defer (func() { err = res.Body.Close() })() - - if res.StatusCode != r.WantedStatusCode { - return fmt.Errorf("incorrect status code. Expected: %d Got: %d", r.WantedStatusCode, res.StatusCode) - } - - if r.WantedBody != "" { - resBytes, err := ioutil.ReadAll(res.Body) - if err != nil { - return err - } - jsonBytes, err := gomatrixserverlib.CanonicalJSON(resBytes) - if err != nil { - return err - } - if string(jsonBytes) != r.WantedBody { - return fmt.Errorf("returned wrong bytes. Expected:\n%s\n\nGot:\n%s", r.WantedBody, string(jsonBytes)) - } - } - - return nil -} - -// DoUntilSuccess blocks and repeats the same request until the response returns the desired status code and body. -// It then closes the given channel and returns. -func (r *Request) DoUntilSuccess(done chan error) { - r.LastErr = &LastRequestErr{} - for { - if err := r.Do(); err != nil { - r.LastErr.Set(err) - time.Sleep(1 * time.Second) // don't tightloop - continue - } - close(done) - return - } -} - -// Run repeatedly issues a request until success, error or a timeout is reached -func (r *Request) Run(label string, timeout time.Duration, serverCmdChan chan error) { - fmt.Printf("==TESTING== %v (timeout: %v)\n", label, timeout) - done := make(chan error, 1) - - // We need to wait for the server to: - // - have connected to the database - // - have created the tables - // - be listening on the given port - go r.DoUntilSuccess(done) - - // wait for one of: - // - the test to pass (done channel is closed) - // - the server to exit with an error (error sent on serverCmdChan) - // - our test timeout to expire - // We don't need to clean up since the main() function handles that in the event we panic - select { - case <-time.After(timeout): - fmt.Printf("==TESTING== %v TIMEOUT\n", label) - if reqErr := r.LastErr.Get(); reqErr != nil { - fmt.Println("Last /sync request error:") - fmt.Println(reqErr) - } - panic(fmt.Sprintf("%v server timed out", label)) - case err := <-serverCmdChan: - if err != nil { - fmt.Println("=============================================================================================") - fmt.Printf("%v server failed to run. If failing with 'pq: password authentication failed for user' try:", label) - fmt.Println(" export PGHOST=/var/run/postgresql") - fmt.Println("=============================================================================================") - panic(err) - } - case <-done: - fmt.Printf("==TESTING== %v PASSED\n", label) - } -} diff --git a/internal/test/kafka.go b/internal/test/kafka.go deleted file mode 100644 index cbf246304..000000000 --- a/internal/test/kafka.go +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package test - -import ( - "io" - "os/exec" - "path/filepath" - "strings" -) - -// KafkaExecutor executes kafka scripts. -type KafkaExecutor struct { - // The location of Zookeeper. Typically this is `localhost:2181`. - ZookeeperURI string - // The directory where Kafka is installed to. Used to locate kafka scripts. - KafkaDirectory string - // The location of the Kafka logs. Typically this is `localhost:9092`. - KafkaURI string - // Where stdout and stderr should be written to. Typically this is `os.Stderr`. - OutputWriter io.Writer -} - -// CreateTopic creates a new kafka topic. This is created with a single partition. -func (e *KafkaExecutor) CreateTopic(topic string) error { - cmd := exec.Command( - filepath.Join(e.KafkaDirectory, "bin", "kafka-topics.sh"), - "--create", - "--zookeeper", e.ZookeeperURI, - "--replication-factor", "1", - "--partitions", "1", - "--topic", topic, - ) - cmd.Stdout = e.OutputWriter - cmd.Stderr = e.OutputWriter - return cmd.Run() -} - -// WriteToTopic writes data to a kafka topic. -func (e *KafkaExecutor) WriteToTopic(topic string, data []string) error { - cmd := exec.Command( - filepath.Join(e.KafkaDirectory, "bin", "kafka-console-producer.sh"), - "--broker-list", e.KafkaURI, - "--topic", topic, - ) - cmd.Stdout = e.OutputWriter - cmd.Stderr = e.OutputWriter - cmd.Stdin = strings.NewReader(strings.Join(data, "\n")) - return cmd.Run() -} - -// DeleteTopic deletes a given kafka topic if it exists. -func (e *KafkaExecutor) DeleteTopic(topic string) error { - cmd := exec.Command( - filepath.Join(e.KafkaDirectory, "bin", "kafka-topics.sh"), - "--delete", - "--if-exists", - "--zookeeper", e.ZookeeperURI, - "--topic", topic, - ) - cmd.Stderr = e.OutputWriter - cmd.Stdout = e.OutputWriter - return cmd.Run() -} diff --git a/internal/test/server.go b/internal/test/server.go deleted file mode 100644 index ca14ea1bf..000000000 --- a/internal/test/server.go +++ /dev/null @@ -1,152 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package test - -import ( - "context" - "fmt" - "net" - "net/http" - "os" - "os/exec" - "path/filepath" - "strings" - "sync" - "testing" - - "github.com/matrix-org/dendrite/setup/config" -) - -// Defaulting allows assignment of string variables with a fallback default value -// Useful for use with os.Getenv() for example -func Defaulting(value, defaultValue string) string { - if value == "" { - value = defaultValue - } - return value -} - -// CreateDatabase creates a new database, dropping it first if it exists -func CreateDatabase(command string, args []string, database string) error { - cmd := exec.Command(command, args...) - cmd.Stdin = strings.NewReader( - fmt.Sprintf("DROP DATABASE IF EXISTS %s; CREATE DATABASE %s;", database, database), - ) - // Send stdout and stderr to our stderr so that we see error messages from - // the psql process - cmd.Stdout = os.Stderr - cmd.Stderr = os.Stderr - return cmd.Run() -} - -// CreateBackgroundCommand creates an executable command -// The Cmd being executed is returned. A channel is also returned, -// which will have any termination errors sent down it, followed immediately by the channel being closed. -func CreateBackgroundCommand(command string, args []string) (*exec.Cmd, chan error) { - cmd := exec.Command(command, args...) - cmd.Stderr = os.Stderr - cmd.Stdout = os.Stderr - - if err := cmd.Start(); err != nil { - panic("failed to start server: " + err.Error()) - } - cmdChan := make(chan error, 1) - go func() { - cmdChan <- cmd.Wait() - close(cmdChan) - }() - return cmd, cmdChan -} - -// InitDatabase creates the database and config file needed for the server to run -func InitDatabase(postgresDatabase, postgresContainerName string, databases []string) { - if len(databases) > 0 { - var dbCmd string - var dbArgs []string - if postgresContainerName == "" { - dbCmd = "psql" - dbArgs = []string{postgresDatabase} - } else { - dbCmd = "docker" - dbArgs = []string{ - "exec", "-i", postgresContainerName, "psql", "-U", "postgres", postgresDatabase, - } - } - for _, database := range databases { - if err := CreateDatabase(dbCmd, dbArgs, database); err != nil { - panic(err) - } - } - } -} - -// StartProxy creates a reverse proxy -func StartProxy(bindAddr string, cfg *config.Dendrite) (*exec.Cmd, chan error) { - proxyArgs := []string{ - "--bind-address", bindAddr, - "--sync-api-server-url", "http://" + string(cfg.SyncAPI.InternalAPI.Connect), - "--client-api-server-url", "http://" + string(cfg.ClientAPI.InternalAPI.Connect), - "--media-api-server-url", "http://" + string(cfg.MediaAPI.InternalAPI.Connect), - "--tls-cert", "server.crt", - "--tls-key", "server.key", - } - return CreateBackgroundCommand( - filepath.Join(filepath.Dir(os.Args[0]), "client-api-proxy"), - proxyArgs, - ) -} - -// ListenAndServe will listen on a random high-numbered port and attach the given router. -// Returns the base URL to send requests to. Call `cancel` to shutdown the server, which will block until it has closed. -func ListenAndServe(t *testing.T, router http.Handler, useTLS bool) (apiURL string, cancel func()) { - listener, err := net.Listen("tcp", ":0") - if err != nil { - t.Fatalf("failed to listen: %s", err) - } - port := listener.Addr().(*net.TCPAddr).Port - srv := http.Server{} - - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - srv.Handler = router - var err error - if useTLS { - certFile := filepath.Join(os.TempDir(), "dendrite.cert") - keyFile := filepath.Join(os.TempDir(), "dendrite.key") - err = NewTLSKey(keyFile, certFile) - if err != nil { - t.Logf("failed to generate tls key/cert: %s", err) - return - } - err = srv.ServeTLS(listener, certFile, keyFile) - } else { - err = srv.Serve(listener) - } - if err != nil && err != http.ErrServerClosed { - t.Logf("Listen failed: %s", err) - } - }() - - secure := "" - if useTLS { - secure = "s" - } - return fmt.Sprintf("http%s://localhost:%d", secure, port), func() { - _ = srv.Shutdown(context.Background()) - wg.Wait() - } -} diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index 561c9a163..23f3e1a67 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -84,7 +84,7 @@ type DeviceListUpdater struct { db DeviceListUpdaterDatabase api DeviceListUpdaterAPI producer KeyChangeProducer - fedClient fedsenderapi.FederationClient + fedClient fedsenderapi.KeyserverFederationAPI workerChans []chan gomatrixserverlib.ServerName // When device lists are stale for a user, they get inserted into this map with a channel which `Update` will @@ -127,7 +127,7 @@ type KeyChangeProducer interface { // NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale. func NewDeviceListUpdater( db DeviceListUpdaterDatabase, api DeviceListUpdaterAPI, producer KeyChangeProducer, - fedClient fedsenderapi.FederationClient, numWorkers int, + fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int, ) *DeviceListUpdater { return &DeviceListUpdater{ userIDToMutex: make(map[string]*sync.Mutex), diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index be71e5750..f8d0d69c3 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -37,7 +37,7 @@ import ( type KeyInternalAPI struct { DB storage.Database ThisServer gomatrixserverlib.ServerName - FedClient fedsenderapi.FederationClient + FedClient fedsenderapi.KeyserverFederationAPI UserAPI userapi.KeyserverUserAPI Producer *producers.KeyChange Updater *DeviceListUpdater diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index 47d7f57f9..3ffd3ba1e 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -37,7 +37,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) { // NewInternalAPI returns a concerete implementation of the internal API. Callers // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( - base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.FederationClient, + base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.KeyserverFederationAPI, ) api.KeyInternalAPI { js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) diff --git a/roomserver/api/api.go b/roomserver/api/api.go index cbb4cebca..80e7aed64 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -183,6 +183,7 @@ type FederationRoomserverAPI interface { QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error // Query whether a server is allowed to see an event QueryServerAllowedToSeeEvent(ctx context.Context, req *QueryServerAllowedToSeeEventRequest, res *QueryServerAllowedToSeeEventResponse) error + QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error PerformInvite(ctx context.Context, req *PerformInviteRequest, res *PerformInviteResponse) error // Query a given amount (or less) of events prior to a given set of events. diff --git a/roomserver/internal/input/input_test.go b/roomserver/internal/input/input_test.go index a95c13550..7c65f9eac 100644 --- a/roomserver/internal/input/input_test.go +++ b/roomserver/internal/input/input_test.go @@ -12,7 +12,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" ) @@ -22,7 +22,7 @@ var jc *nats.Conn func TestMain(m *testing.M) { var b *base.BaseDendrite - b, js, jc = test.Base(nil) + b, js, jc = testrig.Base(nil) code := m.Run() b.ShutdownDendrite() b.WaitForComponentsToFinish() diff --git a/roomserver/internal/query/query_test.go b/roomserver/internal/query/query_test.go index ba5bb9f55..03627ea97 100644 --- a/roomserver/internal/query/query_test.go +++ b/roomserver/internal/query/query_test.go @@ -19,8 +19,8 @@ import ( "encoding/json" "testing" - "github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" ) diff --git a/roomserver/storage/tables/events_table_test.go b/roomserver/storage/tables/events_table_test.go index d5d699c4c..6f72a59b5 100644 --- a/roomserver/storage/tables/events_table_test.go +++ b/roomserver/storage/tables/events_table_test.go @@ -39,7 +39,7 @@ func mustCreateEventsTable(t *testing.T, dbType test.DBType) (tables.Events, fun } func Test_EventsTable(t *testing.T) { - alice := test.NewUser() + alice := test.NewUser(t) room := test.NewRoom(t, alice) ctx := context.Background() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { diff --git a/roomserver/storage/tables/previous_events_table_test.go b/roomserver/storage/tables/previous_events_table_test.go index 96d7bfed0..63d540696 100644 --- a/roomserver/storage/tables/previous_events_table_test.go +++ b/roomserver/storage/tables/previous_events_table_test.go @@ -38,7 +38,7 @@ func mustCreatePreviousEventsTable(t *testing.T, dbType test.DBType) (tab tables func TestPreviousEventsTable(t *testing.T) { ctx := context.Background() - alice := test.NewUser() + alice := test.NewUser(t) room := test.NewRoom(t, alice) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { tab, close := mustCreatePreviousEventsTable(t, dbType) diff --git a/roomserver/storage/tables/published_table_test.go b/roomserver/storage/tables/published_table_test.go index 87662ed4c..fff6dc186 100644 --- a/roomserver/storage/tables/published_table_test.go +++ b/roomserver/storage/tables/published_table_test.go @@ -38,7 +38,7 @@ func mustCreatePublishedTable(t *testing.T, dbType test.DBType) (tab tables.Publ func TestPublishedTable(t *testing.T) { ctx := context.Background() - alice := test.NewUser() + alice := test.NewUser(t) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { tab, close := mustCreatePublishedTable(t, dbType) diff --git a/roomserver/storage/tables/room_aliases_table_test.go b/roomserver/storage/tables/room_aliases_table_test.go index 8fb57d5a4..624d92ae6 100644 --- a/roomserver/storage/tables/room_aliases_table_test.go +++ b/roomserver/storage/tables/room_aliases_table_test.go @@ -36,7 +36,7 @@ func mustCreateRoomAliasesTable(t *testing.T, dbType test.DBType) (tab tables.Ro } func TestRoomAliasesTable(t *testing.T) { - alice := test.NewUser() + alice := test.NewUser(t) room := test.NewRoom(t, alice) room2 := test.NewRoom(t, alice) ctx := context.Background() diff --git a/roomserver/storage/tables/rooms_table_test.go b/roomserver/storage/tables/rooms_table_test.go index 9872fb800..0a02369a1 100644 --- a/roomserver/storage/tables/rooms_table_test.go +++ b/roomserver/storage/tables/rooms_table_test.go @@ -38,7 +38,7 @@ func mustCreateRoomsTable(t *testing.T, dbType test.DBType) (tab tables.Rooms, c } func TestRoomsTable(t *testing.T) { - alice := test.NewUser() + alice := test.NewUser(t) room := test.NewRoom(t, alice) ctx := context.Background() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index 1150c2f3d..563c92e34 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -47,7 +47,7 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver func TestWriteEvents(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - alice := test.NewUser() + alice := test.NewUser(t) r := test.NewRoom(t, alice) db, close := MustCreateDatabase(t, dbType) defer close() @@ -60,7 +60,7 @@ func TestRecentEventsPDU(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := MustCreateDatabase(t, dbType) defer close() - alice := test.NewUser() + alice := test.NewUser(t) // dummy room to make sure SQL queries are filtering on room ID MustWriteEvents(t, db, test.NewRoom(t, alice).Events()) @@ -163,7 +163,7 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := MustCreateDatabase(t, dbType) defer close() - alice := test.NewUser() + alice := test.NewUser(t) r := test.NewRoom(t, alice) for i := 0; i < 10; i++ { r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("hi %d", i)}) diff --git a/syncapi/storage/tables/output_room_events_test.go b/syncapi/storage/tables/output_room_events_test.go index 8bbf879d4..69bbd04c9 100644 --- a/syncapi/storage/tables/output_room_events_test.go +++ b/syncapi/storage/tables/output_room_events_test.go @@ -45,7 +45,7 @@ func newOutputRoomEventsTable(t *testing.T, dbType test.DBType) (tables.Events, func TestOutputRoomEventsTable(t *testing.T) { ctx := context.Background() - alice := test.NewUser() + alice := test.NewUser(t) room := test.NewRoom(t, alice) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { tab, db, close := newOutputRoomEventsTable(t, dbType) diff --git a/syncapi/storage/tables/topology_test.go b/syncapi/storage/tables/topology_test.go index 2334aae2e..f4f75bdf3 100644 --- a/syncapi/storage/tables/topology_test.go +++ b/syncapi/storage/tables/topology_test.go @@ -40,7 +40,7 @@ func newTopologyTable(t *testing.T, dbType test.DBType) (tables.Topology, *sql.D func TestTopologyTable(t *testing.T) { ctx := context.Background() - alice := test.NewUser() + alice := test.NewUser(t) room := test.NewRoom(t, alice) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { tab, db, close := newTopologyTable(t, dbType) diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index d3d898394..5ecfd8772 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -15,6 +15,7 @@ import ( "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" @@ -86,7 +87,7 @@ func TestSyncAPIAccessTokens(t *testing.T) { } func testSyncAccessTokens(t *testing.T, dbType test.DBType) { - user := test.NewUser() + user := test.NewUser(t) room := test.NewRoom(t, user) alice := userapi.Device{ ID: "ALICEID", @@ -96,14 +97,14 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) { AccountType: userapi.AccountTypeUser, } - base, close := test.CreateBaseDendrite(t, dbType) + base, close := testrig.CreateBaseDendrite(t, dbType) defer close() jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) msgs := toNATSMsgs(t, base, room.Events()) AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{}) - test.MustPublishMsgs(t, jsctx, msgs...) + testrig.MustPublishMsgs(t, jsctx, msgs...) testCases := []struct { name string @@ -173,7 +174,7 @@ func TestSyncAPICreateRoomSyncEarly(t *testing.T) { } func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) { - user := test.NewUser() + user := test.NewUser(t) room := test.NewRoom(t, user) alice := userapi.Device{ ID: "ALICEID", @@ -183,7 +184,7 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) { AccountType: userapi.AccountTypeUser, } - base, close := test.CreateBaseDendrite(t, dbType) + base, close := testrig.CreateBaseDendrite(t, dbType) defer close() jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) @@ -198,7 +199,7 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) { sinceTokens := make([]string, len(msgs)) AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{}) for i, msg := range msgs { - test.MustPublishMsgs(t, jsctx, msg) + testrig.MustPublishMsgs(t, jsctx, msg) time.Sleep(100 * time.Millisecond) w := httptest.NewRecorder() base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{ @@ -262,7 +263,7 @@ func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input []*gomatrixserverli if ev.StateKey() != nil { addsStateIDs = append(addsStateIDs, ev.EventID()) } - result[i] = test.NewOutputEventMsg(t, base, ev.RoomID(), api.OutputEvent{ + result[i] = testrig.NewOutputEventMsg(t, base, ev.RoomID(), api.OutputEvent{ Type: rsapi.OutputTypeNewRoomEvent, NewRoomEvent: &rsapi.OutputNewRoomEvent{ Event: ev, diff --git a/test/event.go b/test/event.go index 40cb8f0e1..73fc656bd 100644 --- a/test/event.go +++ b/test/event.go @@ -52,6 +52,24 @@ func WithUnsigned(unsigned interface{}) eventModifier { } } +func WithKeyID(keyID gomatrixserverlib.KeyID) eventModifier { + return func(e *eventMods) { + e.keyID = keyID + } +} + +func WithPrivateKey(pkey ed25519.PrivateKey) eventModifier { + return func(e *eventMods) { + e.privKey = pkey + } +} + +func WithOrigin(origin gomatrixserverlib.ServerName) eventModifier { + return func(e *eventMods) { + e.origin = origin + } +} + // Reverse a list of events func Reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent { out := make([]*gomatrixserverlib.HeaderedEvent, len(in)) diff --git a/test/http.go b/test/http.go index a458a3385..37b3648f8 100644 --- a/test/http.go +++ b/test/http.go @@ -2,10 +2,15 @@ package test import ( "bytes" + "context" "encoding/json" + "fmt" "io" + "net" "net/http" "net/url" + "path/filepath" + "sync" "testing" ) @@ -43,3 +48,45 @@ func NewRequest(t *testing.T, method, path string, opts ...HTTPRequestOpt) *http } return req } + +// ListenAndServe will listen on a random high-numbered port and attach the given router. +// Returns the base URL to send requests to. Call `cancel` to shutdown the server, which will block until it has closed. +func ListenAndServe(t *testing.T, router http.Handler, withTLS bool) (apiURL string, cancel func()) { + listener, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatalf("failed to listen: %s", err) + } + port := listener.Addr().(*net.TCPAddr).Port + srv := http.Server{} + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + srv.Handler = router + var err error + if withTLS { + certFile := filepath.Join(t.TempDir(), "dendrite.cert") + keyFile := filepath.Join(t.TempDir(), "dendrite.key") + err = NewTLSKey(keyFile, certFile) + if err != nil { + t.Errorf("failed to make TLS key: %s", err) + return + } + err = srv.ServeTLS(listener, certFile, keyFile) + } else { + err = srv.Serve(listener) + } + if err != nil && err != http.ErrServerClosed { + t.Logf("Listen failed: %s", err) + } + }() + s := "" + if withTLS { + s = "s" + } + return fmt.Sprintf("http%s://localhost:%d", s, port), func() { + _ = srv.Shutdown(context.Background()) + wg.Wait() + } +} diff --git a/internal/test/keyring.go b/test/keyring.go similarity index 100% rename from internal/test/keyring.go rename to test/keyring.go diff --git a/internal/test/config.go b/test/keys.go similarity index 61% rename from internal/test/config.go rename to test/keys.go index d8e0c4531..75e3800e0 100644 --- a/internal/test/config.go +++ b/test/keys.go @@ -25,103 +25,19 @@ import ( "io/ioutil" "math/big" "os" - "path/filepath" "strings" "time" - - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" - "gopkg.in/yaml.v2" ) const ( - // ConfigFile is the name of the config file for a server. - ConfigFile = "dendrite.yaml" // ServerKeyFile is the name of the file holding the matrix server private key. ServerKeyFile = "server_key.pem" // TLSCertFile is the name of the file holding the TLS certificate used for federation. TLSCertFile = "tls_cert.pem" // TLSKeyFile is the name of the file holding the TLS key used for federation. TLSKeyFile = "tls_key.pem" - // MediaDir is the name of the directory used to store media. - MediaDir = "media" ) -// MakeConfig makes a config suitable for running integration tests. -// Generates new matrix and TLS keys for the server. -func MakeConfig(configDir, kafkaURI, database, host string, startPort int) (*config.Dendrite, int, error) { - var cfg config.Dendrite - cfg.Defaults(true) - - port := startPort - assignAddress := func() config.HTTPAddress { - result := config.HTTPAddress(fmt.Sprintf("http://%s:%d", host, port)) - port++ - return result - } - - serverKeyPath := filepath.Join(configDir, ServerKeyFile) - tlsCertPath := filepath.Join(configDir, TLSKeyFile) - tlsKeyPath := filepath.Join(configDir, TLSCertFile) - mediaBasePath := filepath.Join(configDir, MediaDir) - - if err := NewMatrixKey(serverKeyPath); err != nil { - return nil, 0, err - } - - if err := NewTLSKey(tlsKeyPath, tlsCertPath); err != nil { - return nil, 0, err - } - - cfg.Version = config.Version - - cfg.Global.ServerName = gomatrixserverlib.ServerName(assignAddress()) - cfg.Global.PrivateKeyPath = config.Path(serverKeyPath) - - cfg.MediaAPI.BasePath = config.Path(mediaBasePath) - - cfg.Global.JetStream.Addresses = []string{kafkaURI} - - // TODO: Use different databases for the different schemas. - // Using the same database for every schema currently works because - // the table names are globally unique. But we might not want to - // rely on that in the future. - cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(database) - cfg.FederationAPI.Database.ConnectionString = config.DataSource(database) - cfg.KeyServer.Database.ConnectionString = config.DataSource(database) - cfg.MediaAPI.Database.ConnectionString = config.DataSource(database) - cfg.RoomServer.Database.ConnectionString = config.DataSource(database) - cfg.SyncAPI.Database.ConnectionString = config.DataSource(database) - cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(database) - - cfg.AppServiceAPI.InternalAPI.Listen = assignAddress() - cfg.FederationAPI.InternalAPI.Listen = assignAddress() - cfg.KeyServer.InternalAPI.Listen = assignAddress() - cfg.MediaAPI.InternalAPI.Listen = assignAddress() - cfg.RoomServer.InternalAPI.Listen = assignAddress() - cfg.SyncAPI.InternalAPI.Listen = assignAddress() - cfg.UserAPI.InternalAPI.Listen = assignAddress() - - cfg.AppServiceAPI.InternalAPI.Connect = cfg.AppServiceAPI.InternalAPI.Listen - cfg.FederationAPI.InternalAPI.Connect = cfg.FederationAPI.InternalAPI.Listen - cfg.KeyServer.InternalAPI.Connect = cfg.KeyServer.InternalAPI.Listen - cfg.MediaAPI.InternalAPI.Connect = cfg.MediaAPI.InternalAPI.Listen - cfg.RoomServer.InternalAPI.Connect = cfg.RoomServer.InternalAPI.Listen - cfg.SyncAPI.InternalAPI.Connect = cfg.SyncAPI.InternalAPI.Listen - cfg.UserAPI.InternalAPI.Connect = cfg.UserAPI.InternalAPI.Listen - - return &cfg, port, nil -} - -// WriteConfig writes the config file to the directory. -func WriteConfig(cfg *config.Dendrite, configDir string) error { - data, err := yaml.Marshal(cfg) - if err != nil { - return err - } - return ioutil.WriteFile(filepath.Join(configDir, ConfigFile), data, 0666) -} - // NewMatrixKey generates a new ed25519 matrix server key and writes it to a file. func NewMatrixKey(matrixKeyPath string) (err error) { var data [35]byte diff --git a/test/room.go b/test/room.go index 619cb5c9a..6ae403b3f 100644 --- a/test/room.go +++ b/test/room.go @@ -15,7 +15,6 @@ package test import ( - "crypto/ed25519" "encoding/json" "fmt" "sync/atomic" @@ -35,12 +34,6 @@ var ( PresetTrustedPrivateChat Preset = 3 roomIDCounter = int64(0) - - testKeyID = gomatrixserverlib.KeyID("ed25519:test") - testPrivateKey = ed25519.NewKeyFromSeed([]byte{ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, - }) ) type Room struct { @@ -49,22 +42,25 @@ type Room struct { preset Preset creator *User - authEvents gomatrixserverlib.AuthEvents - events []*gomatrixserverlib.HeaderedEvent + authEvents gomatrixserverlib.AuthEvents + currentState map[string]*gomatrixserverlib.HeaderedEvent + events []*gomatrixserverlib.HeaderedEvent } // Create a new test room. Automatically creates the initial create events. func NewRoom(t *testing.T, creator *User, modifiers ...roomModifier) *Room { t.Helper() counter := atomic.AddInt64(&roomIDCounter, 1) - - // set defaults then let roomModifiers override + if creator.srvName == "" { + t.Fatalf("NewRoom: creator doesn't belong to a server: %+v", *creator) + } r := &Room{ - ID: fmt.Sprintf("!%d:localhost", counter), - creator: creator, - authEvents: gomatrixserverlib.NewAuthEvents(nil), - preset: PresetPublicChat, - Version: gomatrixserverlib.RoomVersionV9, + ID: fmt.Sprintf("!%d:%s", counter, creator.srvName), + creator: creator, + authEvents: gomatrixserverlib.NewAuthEvents(nil), + preset: PresetPublicChat, + Version: gomatrixserverlib.RoomVersionV9, + currentState: make(map[string]*gomatrixserverlib.HeaderedEvent), } for _, m := range modifiers { m(t, r) @@ -73,6 +69,24 @@ func NewRoom(t *testing.T, creator *User, modifiers ...roomModifier) *Room { return r } +func (r *Room) MustGetAuthEventRefsForEvent(t *testing.T, needed gomatrixserverlib.StateNeeded) []gomatrixserverlib.EventReference { + t.Helper() + a, err := needed.AuthEventReferences(&r.authEvents) + if err != nil { + t.Fatalf("MustGetAuthEvents: %v", err) + } + return a +} + +func (r *Room) ForwardExtremities() []string { + if len(r.events) == 0 { + return nil + } + return []string{ + r.events[len(r.events)-1].EventID(), + } +} + func (r *Room) insertCreateEvents(t *testing.T) { t.Helper() var joinRule gomatrixserverlib.JoinRuleContent @@ -88,6 +102,7 @@ func (r *Room) insertCreateEvents(t *testing.T) { joinRule.JoinRule = "public" hisVis.HistoryVisibility = "shared" } + r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomCreate, map[string]interface{}{ "creator": r.creator.ID, "room_version": r.Version, @@ -112,16 +127,16 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten } if mod.privKey == nil { - mod.privKey = testPrivateKey + mod.privKey = creator.privKey } if mod.keyID == "" { - mod.keyID = testKeyID + mod.keyID = creator.keyID } if mod.originServerTS.IsZero() { mod.originServerTS = time.Now() } if mod.origin == "" { - mod.origin = gomatrixserverlib.ServerName("localhost") + mod.origin = creator.srvName } var unsigned gomatrixserverlib.RawJSON @@ -174,13 +189,14 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten // Add a new event to this room DAG. Not thread-safe. func (r *Room) InsertEvent(t *testing.T, he *gomatrixserverlib.HeaderedEvent) { t.Helper() - // Add the event to the list of auth events + // Add the event to the list of auth/state events r.events = append(r.events, he) if he.StateKey() != nil { err := r.authEvents.AddEvent(he.Unwrap()) if err != nil { t.Fatalf("InsertEvent: failed to add event to auth events: %s", err) } + r.currentState[he.Type()+" "+*he.StateKey()] = he } } @@ -188,6 +204,16 @@ func (r *Room) Events() []*gomatrixserverlib.HeaderedEvent { return r.events } +func (r *Room) CurrentState() []*gomatrixserverlib.HeaderedEvent { + events := make([]*gomatrixserverlib.HeaderedEvent, len(r.currentState)) + i := 0 + for _, e := range r.currentState { + events[i] = e + i++ + } + return events +} + func (r *Room) CreateAndInsert(t *testing.T, creator *User, eventType string, content interface{}, mods ...eventModifier) *gomatrixserverlib.HeaderedEvent { t.Helper() he := r.CreateEvent(t, creator, eventType, content, mods...) diff --git a/internal/test/slice.go b/test/slice.go similarity index 100% rename from internal/test/slice.go rename to test/slice.go diff --git a/test/base.go b/test/testrig/base.go similarity index 92% rename from test/base.go rename to test/testrig/base.go index 664442c03..facb49f3e 100644 --- a/test/base.go +++ b/test/testrig/base.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package test +package testrig import ( "errors" @@ -24,22 +24,23 @@ import ( "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" "github.com/nats-io/nats.go" ) -func CreateBaseDendrite(t *testing.T, dbType DBType) (*base.BaseDendrite, func()) { +func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, func()) { var cfg config.Dendrite cfg.Defaults(false) cfg.Global.JetStream.InMemory = true switch dbType { - case DBTypePostgres: + case test.DBTypePostgres: cfg.Global.Defaults(true) // autogen a signing key cfg.MediaAPI.Defaults(true) // autogen a media path // use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use // the file system event with InMemory=true :( cfg.Global.JetStream.TopicPrefix = fmt.Sprintf("Test_%d_", dbType) - connStr, close := PrepareDBConnectionString(t, dbType) + connStr, close := test.PrepareDBConnectionString(t, dbType) cfg.Global.DatabaseOptions = config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), MaxOpenConnections: 10, @@ -47,7 +48,7 @@ func CreateBaseDendrite(t *testing.T, dbType DBType) (*base.BaseDendrite, func() ConnMaxLifetimeSeconds: 60, } return base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics), close - case DBTypeSQLite: + case test.DBTypeSQLite: cfg.Defaults(true) // sets a sqlite db per component // use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use // the file system event with InMemory=true :( diff --git a/test/jetstream.go b/test/testrig/jetstream.go similarity index 98% rename from test/jetstream.go rename to test/testrig/jetstream.go index 488c22beb..74cf95062 100644 --- a/test/jetstream.go +++ b/test/testrig/jetstream.go @@ -1,4 +1,4 @@ -package test +package testrig import ( "encoding/json" diff --git a/test/user.go b/test/user.go index 41a66e1c4..0020098a5 100644 --- a/test/user.go +++ b/test/user.go @@ -15,22 +15,64 @@ package test import ( + "crypto/ed25519" "fmt" "sync/atomic" + "testing" + + "github.com/matrix-org/gomatrixserverlib" ) var ( userIDCounter = int64(0) + + serverName = gomatrixserverlib.ServerName("test") + keyID = gomatrixserverlib.KeyID("ed25519:test") + privateKey = ed25519.NewKeyFromSeed([]byte{ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + }) + + // private keys that tests can use + PrivateKeyA = ed25519.NewKeyFromSeed([]byte{ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 77, + }) + PrivateKeyB = ed25519.NewKeyFromSeed([]byte{ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 66, + }) ) type User struct { ID string + // key ID and private key of the server who has this user, if known. + keyID gomatrixserverlib.KeyID + privKey ed25519.PrivateKey + srvName gomatrixserverlib.ServerName } -func NewUser() *User { - counter := atomic.AddInt64(&userIDCounter, 1) - u := &User{ - ID: fmt.Sprintf("@%d:localhost", counter), +type UserOpt func(*User) + +func WithSigningServer(srvName gomatrixserverlib.ServerName, keyID gomatrixserverlib.KeyID, privKey ed25519.PrivateKey) UserOpt { + return func(u *User) { + u.keyID = keyID + u.privKey = privKey + u.srvName = srvName } - return u +} + +func NewUser(t *testing.T, opts ...UserOpt) *User { + counter := atomic.AddInt64(&userIDCounter, 1) + var u User + for _, opt := range opts { + opt(&u) + } + if u.keyID == "" || u.srvName == "" || u.privKey == nil { + t.Logf("NewUser: missing signing server credentials; using default.") + WithSigningServer(serverName, keyID, privateKey)(&u) + } + u.ID = fmt.Sprintf("@%d:%s", counter, u.srvName) + t.Logf("NewUser: created user %s", u.ID) + return &u } diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index 5683fe067..5bee880d3 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -43,7 +43,7 @@ func Test_AccountData(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateDatabase(t, dbType) defer close() - alice := test.NewUser() + alice := test.NewUser(t) localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) @@ -74,7 +74,7 @@ func Test_Accounts(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateDatabase(t, dbType) defer close() - alice := test.NewUser() + alice := test.NewUser(t) aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) @@ -128,7 +128,7 @@ func Test_Accounts(t *testing.T) { } func Test_Devices(t *testing.T) { - alice := test.NewUser() + alice := test.NewUser(t) localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) deviceID := util.RandomString(8) @@ -212,7 +212,7 @@ func Test_Devices(t *testing.T) { } func Test_KeyBackup(t *testing.T) { - alice := test.NewUser() + alice := test.NewUser(t) room := test.NewRoom(t, alice) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { @@ -291,7 +291,7 @@ func Test_KeyBackup(t *testing.T) { } func Test_LoginToken(t *testing.T) { - alice := test.NewUser() + alice := test.NewUser(t) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateDatabase(t, dbType) defer close() @@ -321,7 +321,7 @@ func Test_LoginToken(t *testing.T) { } func Test_OpenID(t *testing.T) { - alice := test.NewUser() + alice := test.NewUser(t) token := util.RandomString(24) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { @@ -341,7 +341,7 @@ func Test_OpenID(t *testing.T) { } func Test_Profile(t *testing.T) { - alice := test.NewUser() + alice := test.NewUser(t) aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) @@ -379,7 +379,7 @@ func Test_Profile(t *testing.T) { } func Test_Pusher(t *testing.T) { - alice := test.NewUser() + alice := test.NewUser(t) aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) @@ -430,7 +430,7 @@ func Test_Pusher(t *testing.T) { } func Test_ThreePID(t *testing.T) { - alice := test.NewUser() + alice := test.NewUser(t) aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) @@ -467,7 +467,7 @@ func Test_ThreePID(t *testing.T) { } func Test_Notification(t *testing.T) { - alice := test.NewUser() + alice := test.NewUser(t) aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) room := test.NewRoom(t, alice) diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index e614765a2..40e37c5d6 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -24,7 +24,6 @@ import ( "github.com/gorilla/mux" "github.com/matrix-org/dendrite/internal/httputil" - internalTest "github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/dendrite/userapi/inthttp" @@ -135,7 +134,7 @@ func TestQueryProfile(t *testing.T) { t.Run("HTTP API", func(t *testing.T) { router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() userapi.AddInternalRoutes(router, userAPI) - apiURL, cancel := internalTest.ListenAndServe(t, router, false) + apiURL, cancel := test.ListenAndServe(t, router, false) defer cancel() httpAPI, err := inthttp.NewUserAPIClient(apiURL, &http.Client{}) if err != nil { From ac92e047728efc3d50d6dddbe392ca44afd63a38 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Tue, 17 May 2022 13:31:48 +0100 Subject: [PATCH 4/7] Remove debug logging --- federationapi/consumers/keychange.go | 5 +---- federationapi/consumers/roomserver.go | 2 -- federationapi/storage/postgres/joined_hosts_table.go | 5 ----- 3 files changed, 1 insertion(+), 11 deletions(-) diff --git a/federationapi/consumers/keychange.go b/federationapi/consumers/keychange.go index 95c9a7fdd..6d3cf0e46 100644 --- a/federationapi/consumers/keychange.go +++ b/federationapi/consumers/keychange.go @@ -120,7 +120,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool { logger.WithError(err).Error("failed to calculate joined rooms for user") return true } - logrus.Infof("DEBUG: %v joined rooms for user %v", queryRes.RoomIDs, m.UserID) + // send this key change to all servers who share rooms with this user. destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) if err != nil { @@ -129,9 +129,6 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool { } if len(destinations) == 0 { - logger.WithField("num_rooms", len(queryRes.RoomIDs)).Debug("user is in no federated rooms") - destinations, err = t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, false) - logrus.Infof("GetJoinedHostsForRooms exclude self=false -> %v %v", destinations, err) return true } // Pack the EDU and marshal it diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index 7a0816ff2..e50ec66ad 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -21,7 +21,6 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" - "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/federationapi/queue" @@ -166,7 +165,6 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew // expressed as a delta against the current state. // TODO(#290): handle EventIDMismatchError and recover the current state by // talking to the roomserver - logrus.Infof("room %s adds joined hosts: %v removes %v", ore.Event.RoomID(), addsJoinedHosts, ore.RemovesStateEventIDs) oldJoinedHosts, err := s.db.UpdateRoom( s.ctx, ore.Event.RoomID(), diff --git a/federationapi/storage/postgres/joined_hosts_table.go b/federationapi/storage/postgres/joined_hosts_table.go index bb6f6bfa3..5c95b72a8 100644 --- a/federationapi/storage/postgres/joined_hosts_table.go +++ b/federationapi/storage/postgres/joined_hosts_table.go @@ -24,7 +24,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" - "github.com/sirupsen/logrus" ) const joinedHostsSchema = ` @@ -112,7 +111,6 @@ func (s *joinedHostsStatements) InsertJoinedHosts( roomID, eventID string, serverName gomatrixserverlib.ServerName, ) error { - logrus.Debugf("FederationJoinedHosts: INSERT %v %v %v", roomID, eventID, serverName) stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt) _, err := stmt.ExecContext(ctx, roomID, eventID, serverName) return err @@ -121,7 +119,6 @@ func (s *joinedHostsStatements) InsertJoinedHosts( func (s *joinedHostsStatements) DeleteJoinedHosts( ctx context.Context, txn *sql.Tx, eventIDs []string, ) error { - logrus.Debugf("FederationJoinedHosts: DELETE WITH EVENTS %v", eventIDs) stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt) _, err := stmt.ExecContext(ctx, pq.StringArray(eventIDs)) return err @@ -130,7 +127,6 @@ func (s *joinedHostsStatements) DeleteJoinedHosts( func (s *joinedHostsStatements) DeleteJoinedHostsForRoom( ctx context.Context, txn *sql.Tx, roomID string, ) error { - logrus.Debugf("FederationJoinedHosts: DELETE ALL IN ROOM %v", roomID) stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsForRoomStmt) _, err := stmt.ExecContext(ctx, roomID) return err @@ -211,7 +207,6 @@ func joinedHostsFromStmt( ServerName: gomatrixserverlib.ServerName(serverName), }) } - logrus.Debugf("FederationJoinedHosts: SELECT %v => %+v", roomID, result) return result, rows.Err() } From b3162755a9053bbb30a83f00928ff0a0852ad32e Mon Sep 17 00:00:00 2001 From: kegsay Date: Tue, 17 May 2022 15:53:08 +0100 Subject: [PATCH 5/7] bugfix: fix race condition when updating presence via /sync (#2470) * bugfix: fix race condition when updating presence via /sync Previously when presence is updated via /sync, we would send the presence update asyncly via NATS. This created a race condition: - If the presence update is processed quickly, the /sync which triggered the presence update would see an online presence. - If the presence update was processed slowly, the /sync which triggered the presence update would see an offline presence. This is the root cause behind the flakey sytest: 'User sees their own presence in a sync'. The fix is to ensure we update the database/advance the stream position synchronously for local users. * Bugfix for test --- syncapi/consumers/presence.go | 25 +++++++++------ syncapi/sync/requestpool.go | 15 ++++++++- syncapi/sync/requestpool_test.go | 8 +++++ syncapi/syncapi.go | 20 ++++++------ syncapi/syncapi_test.go | 55 ++++++++++++++++++++++++++++++++ 5 files changed, 103 insertions(+), 20 deletions(-) diff --git a/syncapi/consumers/presence.go b/syncapi/consumers/presence.go index 388c08ff4..bfd72d604 100644 --- a/syncapi/consumers/presence.go +++ b/syncapi/consumers/presence.go @@ -138,9 +138,12 @@ func (s *PresenceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { presence := msg.Header.Get("presence") timestamp := msg.Header.Get("last_active_ts") fromSync, _ := strconv.ParseBool(msg.Header.Get("from_sync")) - logrus.Debugf("syncAPI received presence event: %+v", msg.Header) + if fromSync { // do not process local presence changes; we already did this synchronously. + return true + } + ts, err := strconv.Atoi(timestamp) if err != nil { return true @@ -151,15 +154,19 @@ func (s *PresenceConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { newMsg := msg.Header.Get("status_msg") statusMsg = &newMsg } - // OK is already checked, so no need to do it again + // already checked, so no need to check error p, _ := types.PresenceFromString(presence) - pos, err := s.db.UpdatePresence(ctx, userID, p, statusMsg, gomatrixserverlib.Timestamp(ts), fromSync) - if err != nil { - return true - } - - s.stream.Advance(pos) - s.notifier.OnNewPresence(types.StreamingToken{PresencePosition: pos}, userID) + s.EmitPresence(ctx, userID, p, statusMsg, ts, fromSync) return true } + +func (s *PresenceConsumer) EmitPresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, ts int, fromSync bool) { + pos, err := s.db.UpdatePresence(ctx, userID, presence, statusMsg, gomatrixserverlib.Timestamp(ts), fromSync) + if err != nil { + logrus.WithError(err).WithField("user", userID).WithField("presence", presence).Warn("failed to updated presence for user") + return + } + s.stream.Advance(pos) + s.notifier.OnNewPresence(types.StreamingToken{PresencePosition: pos}, userID) +} diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index fdf46cdde..ad151f70b 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -53,19 +53,24 @@ type RequestPool struct { streams *streams.Streams Notifier *notifier.Notifier producer PresencePublisher + consumer PresenceConsumer } type PresencePublisher interface { SendPresence(userID string, presence types.Presence, statusMsg *string) error } +type PresenceConsumer interface { + EmitPresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, ts int, fromSync bool) +} + // NewRequestPool makes a new RequestPool func NewRequestPool( db storage.Database, cfg *config.SyncAPI, userAPI userapi.SyncUserAPI, keyAPI keyapi.SyncKeyAPI, rsAPI roomserverAPI.SyncRoomserverAPI, streams *streams.Streams, notifier *notifier.Notifier, - producer PresencePublisher, enableMetrics bool, + producer PresencePublisher, consumer PresenceConsumer, enableMetrics bool, ) *RequestPool { if enableMetrics { prometheus.MustRegister( @@ -83,6 +88,7 @@ func NewRequestPool( streams: streams, Notifier: notifier, producer: producer, + consumer: consumer, } go rp.cleanLastSeen() go rp.cleanPresence(db, time.Minute*5) @@ -160,6 +166,13 @@ func (rp *RequestPool) updatePresence(db storage.Presence, presence string, user logrus.WithError(err).Error("Unable to publish presence message from sync") return } + + // now synchronously update our view of the world. It's critical we do this before calculating + // the /sync response else we may not return presence: online immediately. + rp.consumer.EmitPresence( + context.Background(), userID, presenceID, newPresence.ClientFields.StatusMsg, + int(gomatrixserverlib.AsTimestamp(time.Now())), true, + ) } func (rp *RequestPool) updateLastSeen(req *http.Request, device *userapi.Device) { diff --git a/syncapi/sync/requestpool_test.go b/syncapi/sync/requestpool_test.go index 5e52bc7c9..0c7209521 100644 --- a/syncapi/sync/requestpool_test.go +++ b/syncapi/sync/requestpool_test.go @@ -38,6 +38,12 @@ func (d dummyDB) MaxStreamPositionForPresence(ctx context.Context) (types.Stream return 0, nil } +type dummyConsumer struct{} + +func (d dummyConsumer) EmitPresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, ts int, fromSync bool) { + +} + func TestRequestPool_updatePresence(t *testing.T) { type args struct { presence string @@ -45,6 +51,7 @@ func TestRequestPool_updatePresence(t *testing.T) { sleep time.Duration } publisher := &dummyPublisher{} + consumer := &dummyConsumer{} syncMap := sync.Map{} tests := []struct { @@ -101,6 +108,7 @@ func TestRequestPool_updatePresence(t *testing.T) { rp := &RequestPool{ presence: &syncMap, producer: publisher, + consumer: consumer, cfg: &config.SyncAPI{ Matrix: &config.Global{ JetStream: config.JetStream{ diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index d8bacb2da..92db18d56 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -64,8 +64,17 @@ func AddPublicRoutes( Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputPresenceEvent), JetStream: js, } + presenceConsumer := consumers.NewPresenceConsumer( + base.ProcessContext, cfg, js, natsClient, syncDB, + notifier, streams.PresenceStreamProvider, + userAPI, + ) - requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, keyAPI, rsAPI, streams, notifier, federationPresenceProducer, base.EnableMetrics) + requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, keyAPI, rsAPI, streams, notifier, federationPresenceProducer, presenceConsumer, base.EnableMetrics) + + if err = presenceConsumer.Start(); err != nil { + logrus.WithError(err).Panicf("failed to start presence consumer") + } userAPIStreamEventProducer := &producers.UserAPIStreamEventProducer{ JetStream: js, @@ -131,15 +140,6 @@ func AddPublicRoutes( logrus.WithError(err).Panicf("failed to start receipts consumer") } - presenceConsumer := consumers.NewPresenceConsumer( - base.ProcessContext, cfg, js, natsClient, syncDB, - notifier, streams.PresenceStreamProvider, - userAPI, - ) - if err = presenceConsumer.Start(); err != nil { - logrus.WithError(err).Panicf("failed to start presence consumer") - } - routing.Setup( base.PublicClientAPIMux, requestPool, syncDB, userAPI, rsAPI, cfg, base.Caches, diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index 5ecfd8772..3ce7c64b7 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -19,6 +19,7 @@ import ( userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" + "github.com/tidwall/gjson" ) type syncRoomserverAPI struct { @@ -256,6 +257,60 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) { } } +// Test that if we hit /sync we get back presence: online, regardless of whether messages get delivered +// via NATS. Regression test for a flakey test "User sees their own presence in a sync" +func TestSyncAPIUpdatePresenceImmediately(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + testSyncAPIUpdatePresenceImmediately(t, dbType) + }) +} + +func testSyncAPIUpdatePresenceImmediately(t *testing.T, dbType test.DBType) { + user := test.NewUser(t) + alice := userapi.Device{ + ID: "ALICEID", + UserID: user.ID, + AccessToken: "ALICE_BEARER_TOKEN", + DisplayName: "Alice", + AccountType: userapi.AccountTypeUser, + } + + base, close := testrig.CreateBaseDendrite(t, dbType) + base.Cfg.Global.Presence.EnableOutbound = true + base.Cfg.Global.Presence.EnableInbound = true + defer close() + + jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) + defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, &syncKeyAPI{}) + w := httptest.NewRecorder() + base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{ + "access_token": alice.AccessToken, + "timeout": "0", + "set_presence": "online", + }))) + if w.Code != 200 { + t.Fatalf("got HTTP %d want %d", w.Code, 200) + } + var res types.Response + if err := json.NewDecoder(w.Body).Decode(&res); err != nil { + t.Errorf("failed to decode response body: %s", err) + } + if len(res.Presence.Events) != 1 { + t.Fatalf("expected 1 presence events, got: %+v", res.Presence.Events) + } + if res.Presence.Events[0].Sender != alice.UserID { + t.Errorf("sender: got %v want %v", res.Presence.Events[0].Sender, alice.UserID) + } + if res.Presence.Events[0].Type != "m.presence" { + t.Errorf("type: got %v want %v", res.Presence.Events[0].Type, "m.presence") + } + if gjson.ParseBytes(res.Presence.Events[0].Content).Get("presence").Str != "online" { + t.Errorf("content: not online, got %v", res.Presence.Events[0].Content) + } + +} + func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input []*gomatrixserverlib.HeaderedEvent) []*nats.Msg { result := make([]*nats.Msg, len(input)) for i, ev := range input { From f321a7d55ea75e6a5276cd88eddcbbc82ceeaaeb Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Wed, 18 May 2022 15:17:23 +0200 Subject: [PATCH 6/7] Really SKIP_NODB (#2472) * Really SKIP_NODB * Use fatalError in createLocalDB * Check if createdb exists * Revert change * Remove !Quiet --- test/db.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/test/db.go b/test/db.go index a1754cd08..c7cb919f6 100644 --- a/test/db.go +++ b/test/db.go @@ -44,8 +44,9 @@ func fatalError(t *testing.T, format string, args ...interface{}) { } func createLocalDB(t *testing.T, dbName string) { - if !Quiet { - t.Log("Note: tests require a postgres install accessible to the current user") + if _, err := exec.LookPath("createdb"); err != nil { + fatalError(t, "Note: tests require a postgres install accessible to the current user") + return } createDB := exec.Command("createdb", dbName) if !Quiet { @@ -63,6 +64,9 @@ func createRemoteDB(t *testing.T, dbName, user, connStr string) { if err != nil { fatalError(t, "failed to open postgres conn with connstr=%s : %s", connStr, err) } + if err = db.Ping(); err != nil { + fatalError(t, "failed to open postgres conn with connstr=%s : %s", connStr, err) + } _, err = db.Exec(fmt.Sprintf(`CREATE DATABASE %s;`, dbName)) if err != nil { pqErr, ok := err.(*pq.Error) From 21dd5a7176e52d018b91854db273424e4430af7b Mon Sep 17 00:00:00 2001 From: kegsay Date: Thu, 19 May 2022 09:00:56 +0100 Subject: [PATCH 7/7] syncapi: don't return early for no-op incremental syncs (#2473) * syncapi: don't return early for no-op incremental syncs Comments explain why, but basically it's an inefficient use of bandwidth and some sytests rely on /sync to block. * Honour timeouts * Actually return a response with timeout=0 --- syncapi/sync/requestpool.go | 246 ++++++++++++++++++++---------------- syncapi/types/types.go | 13 ++ 2 files changed, 149 insertions(+), 110 deletions(-) diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index ad151f70b..7b9526b53 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -251,125 +251,151 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. waitingSyncRequests.Inc() defer waitingSyncRequests.Dec() - currentPos := rp.Notifier.CurrentPosition() + // loop until we get some data + for { + startTime := time.Now() + currentPos := rp.Notifier.CurrentPosition() - if !rp.shouldReturnImmediately(syncReq, currentPos) { - timer := time.NewTimer(syncReq.Timeout) // case of timeout=0 is handled above - defer timer.Stop() + // if the since token matches the current positions, wait via the notifier + if !rp.shouldReturnImmediately(syncReq, currentPos) { + timer := time.NewTimer(syncReq.Timeout) // case of timeout=0 is handled above + defer timer.Stop() - userStreamListener := rp.Notifier.GetListener(*syncReq) - defer userStreamListener.Close() + userStreamListener := rp.Notifier.GetListener(*syncReq) + defer userStreamListener.Close() - giveup := func() util.JSONResponse { - syncReq.Log.Debugln("Responding to sync since client gave up or timeout was reached") - syncReq.Response.NextBatch = syncReq.Since - // We should always try to include OTKs in sync responses, otherwise clients might upload keys - // even if that's not required. See also: - // https://github.com/matrix-org/synapse/blob/29f06704b8871a44926f7c99e73cf4a978fb8e81/synapse/rest/client/sync.py#L276-L281 - // Only try to get OTKs if the context isn't already done. - if syncReq.Context.Err() == nil { - err = internal.DeviceOTKCounts(syncReq.Context, rp.keyAPI, syncReq.Device.UserID, syncReq.Device.ID, syncReq.Response) - if err != nil && err != context.Canceled { - syncReq.Log.WithError(err).Warn("failed to get OTK counts") + giveup := func() util.JSONResponse { + syncReq.Log.Debugln("Responding to sync since client gave up or timeout was reached") + syncReq.Response.NextBatch = syncReq.Since + // We should always try to include OTKs in sync responses, otherwise clients might upload keys + // even if that's not required. See also: + // https://github.com/matrix-org/synapse/blob/29f06704b8871a44926f7c99e73cf4a978fb8e81/synapse/rest/client/sync.py#L276-L281 + // Only try to get OTKs if the context isn't already done. + if syncReq.Context.Err() == nil { + err = internal.DeviceOTKCounts(syncReq.Context, rp.keyAPI, syncReq.Device.UserID, syncReq.Device.ID, syncReq.Response) + if err != nil && err != context.Canceled { + syncReq.Log.WithError(err).Warn("failed to get OTK counts") + } + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: syncReq.Response, } } - return util.JSONResponse{ - Code: http.StatusOK, - JSON: syncReq.Response, + + select { + case <-syncReq.Context.Done(): // Caller gave up + return giveup() + + case <-timer.C: // Timeout reached + return giveup() + + case <-userStreamListener.GetNotifyChannel(syncReq.Since): + syncReq.Log.Debugln("Responding to sync after wake-up") + currentPos.ApplyUpdates(userStreamListener.GetSyncPosition()) + } + } else { + syncReq.Log.WithField("currentPos", currentPos).Debugln("Responding to sync immediately") + } + + if syncReq.Since.IsEmpty() { + // Complete sync + syncReq.Response.NextBatch = types.StreamingToken{ + PDUPosition: rp.streams.PDUStreamProvider.CompleteSync( + syncReq.Context, syncReq, + ), + TypingPosition: rp.streams.TypingStreamProvider.CompleteSync( + syncReq.Context, syncReq, + ), + ReceiptPosition: rp.streams.ReceiptStreamProvider.CompleteSync( + syncReq.Context, syncReq, + ), + InvitePosition: rp.streams.InviteStreamProvider.CompleteSync( + syncReq.Context, syncReq, + ), + SendToDevicePosition: rp.streams.SendToDeviceStreamProvider.CompleteSync( + syncReq.Context, syncReq, + ), + AccountDataPosition: rp.streams.AccountDataStreamProvider.CompleteSync( + syncReq.Context, syncReq, + ), + NotificationDataPosition: rp.streams.NotificationDataStreamProvider.CompleteSync( + syncReq.Context, syncReq, + ), + DeviceListPosition: rp.streams.DeviceListStreamProvider.CompleteSync( + syncReq.Context, syncReq, + ), + PresencePosition: rp.streams.PresenceStreamProvider.CompleteSync( + syncReq.Context, syncReq, + ), + } + } else { + // Incremental sync + syncReq.Response.NextBatch = types.StreamingToken{ + PDUPosition: rp.streams.PDUStreamProvider.IncrementalSync( + syncReq.Context, syncReq, + syncReq.Since.PDUPosition, currentPos.PDUPosition, + ), + TypingPosition: rp.streams.TypingStreamProvider.IncrementalSync( + syncReq.Context, syncReq, + syncReq.Since.TypingPosition, currentPos.TypingPosition, + ), + ReceiptPosition: rp.streams.ReceiptStreamProvider.IncrementalSync( + syncReq.Context, syncReq, + syncReq.Since.ReceiptPosition, currentPos.ReceiptPosition, + ), + InvitePosition: rp.streams.InviteStreamProvider.IncrementalSync( + syncReq.Context, syncReq, + syncReq.Since.InvitePosition, currentPos.InvitePosition, + ), + SendToDevicePosition: rp.streams.SendToDeviceStreamProvider.IncrementalSync( + syncReq.Context, syncReq, + syncReq.Since.SendToDevicePosition, currentPos.SendToDevicePosition, + ), + AccountDataPosition: rp.streams.AccountDataStreamProvider.IncrementalSync( + syncReq.Context, syncReq, + syncReq.Since.AccountDataPosition, currentPos.AccountDataPosition, + ), + NotificationDataPosition: rp.streams.NotificationDataStreamProvider.IncrementalSync( + syncReq.Context, syncReq, + syncReq.Since.NotificationDataPosition, currentPos.NotificationDataPosition, + ), + DeviceListPosition: rp.streams.DeviceListStreamProvider.IncrementalSync( + syncReq.Context, syncReq, + syncReq.Since.DeviceListPosition, currentPos.DeviceListPosition, + ), + PresencePosition: rp.streams.PresenceStreamProvider.IncrementalSync( + syncReq.Context, syncReq, + syncReq.Since.PresencePosition, currentPos.PresencePosition, + ), + } + // it's possible for there to be no updates for this user even though since < current pos, + // e.g busy servers with a quiet user. In this scenario, we don't want to return a no-op + // response immediately, so let's try this again but pretend they bumped their since token. + // If the incremental sync was processed very quickly then we expect the next loop to block + // with a notifier, but if things are slow it's entirely possible that currentPos is no + // longer the current position so we will hit this code path again. We need to do this and + // not return a no-op response because: + // - It's an inefficient use of bandwidth. + // - Some sytests which test 'waking up' sync rely on some sync requests to block, which + // they weren't always doing, resulting in flakey tests. + if !syncReq.Response.HasUpdates() { + syncReq.Since = currentPos + // do not loop again if the ?timeout= is 0 as that means "return immediately" + if syncReq.Timeout > 0 { + syncReq.Timeout = syncReq.Timeout - time.Since(startTime) + if syncReq.Timeout < 0 { + syncReq.Timeout = 0 + } + continue + } } } - select { - case <-syncReq.Context.Done(): // Caller gave up - return giveup() - - case <-timer.C: // Timeout reached - return giveup() - - case <-userStreamListener.GetNotifyChannel(syncReq.Since): - syncReq.Log.Debugln("Responding to sync after wake-up") - currentPos.ApplyUpdates(userStreamListener.GetSyncPosition()) + return util.JSONResponse{ + Code: http.StatusOK, + JSON: syncReq.Response, } - } else { - syncReq.Log.WithField("currentPos", currentPos).Debugln("Responding to sync immediately") - } - - if syncReq.Since.IsEmpty() { - // Complete sync - syncReq.Response.NextBatch = types.StreamingToken{ - PDUPosition: rp.streams.PDUStreamProvider.CompleteSync( - syncReq.Context, syncReq, - ), - TypingPosition: rp.streams.TypingStreamProvider.CompleteSync( - syncReq.Context, syncReq, - ), - ReceiptPosition: rp.streams.ReceiptStreamProvider.CompleteSync( - syncReq.Context, syncReq, - ), - InvitePosition: rp.streams.InviteStreamProvider.CompleteSync( - syncReq.Context, syncReq, - ), - SendToDevicePosition: rp.streams.SendToDeviceStreamProvider.CompleteSync( - syncReq.Context, syncReq, - ), - AccountDataPosition: rp.streams.AccountDataStreamProvider.CompleteSync( - syncReq.Context, syncReq, - ), - NotificationDataPosition: rp.streams.NotificationDataStreamProvider.CompleteSync( - syncReq.Context, syncReq, - ), - DeviceListPosition: rp.streams.DeviceListStreamProvider.CompleteSync( - syncReq.Context, syncReq, - ), - PresencePosition: rp.streams.PresenceStreamProvider.CompleteSync( - syncReq.Context, syncReq, - ), - } - } else { - // Incremental sync - syncReq.Response.NextBatch = types.StreamingToken{ - PDUPosition: rp.streams.PDUStreamProvider.IncrementalSync( - syncReq.Context, syncReq, - syncReq.Since.PDUPosition, currentPos.PDUPosition, - ), - TypingPosition: rp.streams.TypingStreamProvider.IncrementalSync( - syncReq.Context, syncReq, - syncReq.Since.TypingPosition, currentPos.TypingPosition, - ), - ReceiptPosition: rp.streams.ReceiptStreamProvider.IncrementalSync( - syncReq.Context, syncReq, - syncReq.Since.ReceiptPosition, currentPos.ReceiptPosition, - ), - InvitePosition: rp.streams.InviteStreamProvider.IncrementalSync( - syncReq.Context, syncReq, - syncReq.Since.InvitePosition, currentPos.InvitePosition, - ), - SendToDevicePosition: rp.streams.SendToDeviceStreamProvider.IncrementalSync( - syncReq.Context, syncReq, - syncReq.Since.SendToDevicePosition, currentPos.SendToDevicePosition, - ), - AccountDataPosition: rp.streams.AccountDataStreamProvider.IncrementalSync( - syncReq.Context, syncReq, - syncReq.Since.AccountDataPosition, currentPos.AccountDataPosition, - ), - NotificationDataPosition: rp.streams.NotificationDataStreamProvider.IncrementalSync( - syncReq.Context, syncReq, - syncReq.Since.NotificationDataPosition, currentPos.NotificationDataPosition, - ), - DeviceListPosition: rp.streams.DeviceListStreamProvider.IncrementalSync( - syncReq.Context, syncReq, - syncReq.Since.DeviceListPosition, currentPos.DeviceListPosition, - ), - PresencePosition: rp.streams.PresenceStreamProvider.IncrementalSync( - syncReq.Context, syncReq, - syncReq.Since.PresencePosition, currentPos.PresencePosition, - ), - } - } - - return util.JSONResponse{ - Code: http.StatusOK, - JSON: syncReq.Response, } } diff --git a/syncapi/types/types.go b/syncapi/types/types.go index ba6b4f8cd..159fa08b6 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -350,6 +350,19 @@ type Response struct { DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"` } +func (r *Response) HasUpdates() bool { + // purposefully exclude DeviceListsOTKCount as we always include them + return (len(r.AccountData.Events) > 0 || + len(r.Presence.Events) > 0 || + len(r.Rooms.Invite) > 0 || + len(r.Rooms.Join) > 0 || + len(r.Rooms.Leave) > 0 || + len(r.Rooms.Peek) > 0 || + len(r.ToDevice.Events) > 0 || + len(r.DeviceLists.Changed) > 0 || + len(r.DeviceLists.Left) > 0) +} + // NewResponse creates an empty response with initialised maps. func NewResponse() *Response { res := Response{}