diff --git a/.github/workflows/dendrite.yml b/.github/workflows/dendrite.yml index 4f337a866..8221bff96 100644 --- a/.github/workflows/dendrite.yml +++ b/.github/workflows/dendrite.yml @@ -111,7 +111,7 @@ jobs: key: ${{ runner.os }}-go${{ matrix.go }}-test-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-go${{ matrix.go }}-test- - - run: go test ./... + - run: go test -p 1 ./... env: POSTGRES_HOST: localhost POSTGRES_USER: postgres diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 841f67261..cf3fd5532 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -104,7 +104,7 @@ type Database interface { // DeletePeek deletes all peeks for a given room by a given user // Returns an error if there was a problem communicating with the database. DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error) - // GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. + // GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. If backwardsOrdering is true, the most recent event must be first, else last. GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error) // EventPositionInTopology returns the depth and stream position of the given event. EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error) diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index 14af6a949..a30e220ba 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -427,7 +427,7 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents( // selectEvents returns the events for the given event IDs. If an event is // missing from the database, it will be omitted. func (s *outputRoomEventsStatements) SelectEvents( - ctx context.Context, txn *sql.Tx, eventIDs []string, + ctx context.Context, txn *sql.Tx, eventIDs []string, preserveOrder bool, ) ([]types.StreamEvent, error) { stmt := sqlutil.TxStmt(txn, s.selectEventsStmt) rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) @@ -435,7 +435,25 @@ func (s *outputRoomEventsStatements) SelectEvents( return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed") - return rowsToStreamEvents(rows) + streamEvents, err := rowsToStreamEvents(rows) + if err != nil { + return nil, err + } + if preserveOrder { + eventMap := make(map[string]types.StreamEvent) + for _, ev := range streamEvents { + eventMap[ev.EventID()] = ev + } + var returnEvents []types.StreamEvent + for _, eventID := range eventIDs { + ev, ok := eventMap[eventID] + if ok { + returnEvents = append(returnEvents, ev) + } + } + return returnEvents, nil + } + return streamEvents, nil } func (s *outputRoomEventsStatements) DeleteEventsForRoom( diff --git a/syncapi/storage/postgres/output_room_events_topology_table.go b/syncapi/storage/postgres/output_room_events_topology_table.go index 626386ba0..90b3b0083 100644 --- a/syncapi/storage/postgres/output_room_events_topology_table.go +++ b/syncapi/storage/postgres/output_room_events_topology_table.go @@ -148,9 +148,9 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange( // is requested or not. var stmt *sql.Stmt if chronologicalOrder { - stmt = s.selectEventIDsInRangeASCStmt + stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeASCStmt) } else { - stmt = s.selectEventIDsInRangeDESCStmt + stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeDESCStmt) } // Query the event IDs. diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 1c45d5d9a..14db5795c 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -150,7 +150,7 @@ func (d *Database) RoomReceiptsAfter(ctx context.Context, roomIDs []string, stre // Returns an error if there was a problem talking with the database. // Does not include any transaction IDs in the returned events. func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { - streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs) + streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs, false) if err != nil { return nil, err } @@ -312,7 +312,7 @@ func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, e // Check if we have all of the event's previous events. If an event is // missing, add it to the room's backward extremities. - prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs()) + prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs(), false) if err != nil { return err } @@ -457,7 +457,7 @@ func (d *Database) GetEventsInTopologicalRange( } // Retrieve the events' contents using their IDs. - events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs) + events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs, true) return } @@ -619,7 +619,7 @@ func (d *Database) fetchMissingStateEvents( ) ([]types.StreamEvent, error) { // Fetch from the events table first so we pick up the stream ID for the // event. - events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs) + events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs, false) if err != nil { return nil, err } diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index 24c442240..5b2287e6d 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -51,13 +51,13 @@ const selectMaxAccountDataIDSQL = "" + type accountDataStatements struct { db *sql.DB - streamIDStatements *streamIDStatements + streamIDStatements *StreamIDStatements insertAccountDataStmt *sql.Stmt selectMaxAccountDataIDStmt *sql.Stmt selectAccountDataInRangeStmt *sql.Stmt } -func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) { +func NewSqliteAccountDataTable(db *sql.DB, streamID *StreamIDStatements) (tables.AccountData, error) { s := &accountDataStatements{ db: db, streamIDStatements: streamID, diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index 473aa49b0..464f32e04 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -90,7 +90,7 @@ const selectEventsWithEventIDsSQL = "" + type currentRoomStateStatements struct { db *sql.DB - streamIDStatements *streamIDStatements + streamIDStatements *StreamIDStatements upsertRoomStateStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt deleteRoomStateForRoomStmt *sql.Stmt @@ -100,7 +100,7 @@ type currentRoomStateStatements struct { selectStateEventStmt *sql.Stmt } -func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) { +func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) { s := ¤tRoomStateStatements{ db: db, streamIDStatements: streamID, diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go index 0a6823cc0..58ab8461e 100644 --- a/syncapi/storage/sqlite3/invites_table.go +++ b/syncapi/storage/sqlite3/invites_table.go @@ -59,14 +59,14 @@ const selectMaxInviteIDSQL = "" + type inviteEventsStatements struct { db *sql.DB - streamIDStatements *streamIDStatements + streamIDStatements *StreamIDStatements insertInviteEventStmt *sql.Stmt selectInviteEventsInRangeStmt *sql.Stmt deleteInviteEventStmt *sql.Stmt selectMaxInviteIDStmt *sql.Stmt } -func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) { +func NewSqliteInvitesTable(db *sql.DB, streamID *StreamIDStatements) (tables.Invites, error) { s := &inviteEventsStatements{ db: db, streamIDStatements: streamID, diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index acd959696..9da9d776e 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -58,7 +58,7 @@ const insertEventSQL = "" + "ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $13)" const selectEventsSQL = "" + - "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = $1" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id IN ($1)" const selectRecentEventsSQL = "" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + @@ -111,9 +111,8 @@ const selectContextAfterEventSQL = "" + type outputRoomEventsStatements struct { db *sql.DB - streamIDStatements *streamIDStatements + streamIDStatements *StreamIDStatements insertEventStmt *sql.Stmt - selectEventsStmt *sql.Stmt selectMaxEventIDStmt *sql.Stmt updateEventJSONStmt *sql.Stmt deleteEventsForRoomStmt *sql.Stmt @@ -122,7 +121,7 @@ type outputRoomEventsStatements struct { selectContextAfterEventStmt *sql.Stmt } -func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) { +func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Events, error) { s := &outputRoomEventsStatements{ db: db, streamIDStatements: streamID, @@ -133,7 +132,6 @@ func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Even } return s, sqlutil.StatementList{ {&s.insertEventStmt, insertEventSQL}, - {&s.selectEventsStmt, selectEventsSQL}, {&s.selectMaxEventIDStmt, selectMaxEventIDSQL}, {&s.updateEventJSONStmt, updateEventJSONSQL}, {&s.deleteEventsForRoomStmt, deleteEventsForRoomSQL}, @@ -421,21 +419,43 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents( // selectEvents returns the events for the given event IDs. If an event is // missing from the database, it will be omitted. func (s *outputRoomEventsStatements) SelectEvents( - ctx context.Context, txn *sql.Tx, eventIDs []string, + ctx context.Context, txn *sql.Tx, eventIDs []string, preserveOrder bool, ) ([]types.StreamEvent, error) { - var returnEvents []types.StreamEvent - stmt := sqlutil.TxStmt(txn, s.selectEventsStmt) - for _, eventID := range eventIDs { - rows, err := stmt.QueryContext(ctx, eventID) - if err != nil { - return nil, err - } - if streamEvents, err := rowsToStreamEvents(rows); err == nil { - returnEvents = append(returnEvents, streamEvents...) - } - internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed") + iEventIDs := make([]interface{}, len(eventIDs)) + for i := range eventIDs { + iEventIDs[i] = eventIDs[i] } - return returnEvents, nil + selectSQL := strings.Replace(selectEventsSQL, "($1)", sqlutil.QueryVariadic(len(eventIDs)), 1) + var rows *sql.Rows + var err error + if txn != nil { + rows, err = txn.QueryContext(ctx, selectSQL, iEventIDs...) + } else { + rows, err = s.db.QueryContext(ctx, selectSQL, iEventIDs...) + } + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed") + streamEvents, err := rowsToStreamEvents(rows) + if err != nil { + return nil, err + } + if preserveOrder { + var returnEvents []types.StreamEvent + eventMap := make(map[string]types.StreamEvent) + for _, ev := range streamEvents { + eventMap[ev.EventID()] = ev + } + for _, eventID := range eventIDs { + ev, ok := eventMap[eventID] + if ok { + returnEvents = append(returnEvents, ev) + } + } + return returnEvents, nil + } + return streamEvents, nil } func (s *outputRoomEventsStatements) DeleteEventsForRoom( diff --git a/syncapi/storage/sqlite3/peeks_table.go b/syncapi/storage/sqlite3/peeks_table.go index c93c82051..5ee86448c 100644 --- a/syncapi/storage/sqlite3/peeks_table.go +++ b/syncapi/storage/sqlite3/peeks_table.go @@ -66,7 +66,7 @@ const selectMaxPeekIDSQL = "" + type peekStatements struct { db *sql.DB - streamIDStatements *streamIDStatements + streamIDStatements *StreamIDStatements insertPeekStmt *sql.Stmt deletePeekStmt *sql.Stmt deletePeeksStmt *sql.Stmt @@ -75,7 +75,7 @@ type peekStatements struct { selectMaxPeekIDStmt *sql.Stmt } -func NewSqlitePeeksTable(db *sql.DB, streamID *streamIDStatements) (tables.Peeks, error) { +func NewSqlitePeeksTable(db *sql.DB, streamID *StreamIDStatements) (tables.Peeks, error) { _, err := db.Exec(peeksSchema) if err != nil { return nil, err diff --git a/syncapi/storage/sqlite3/presence_table.go b/syncapi/storage/sqlite3/presence_table.go index e7b78a705..00b16458d 100644 --- a/syncapi/storage/sqlite3/presence_table.go +++ b/syncapi/storage/sqlite3/presence_table.go @@ -75,7 +75,7 @@ const selectPresenceAfter = "" + type presenceStatements struct { db *sql.DB - streamIDStatements *streamIDStatements + streamIDStatements *StreamIDStatements upsertPresenceStmt *sql.Stmt upsertPresenceFromSyncStmt *sql.Stmt selectPresenceForUsersStmt *sql.Stmt @@ -83,7 +83,7 @@ type presenceStatements struct { selectPresenceAfterStmt *sql.Stmt } -func NewSqlitePresenceTable(db *sql.DB, streamID *streamIDStatements) (*presenceStatements, error) { +func NewSqlitePresenceTable(db *sql.DB, streamID *StreamIDStatements) (*presenceStatements, error) { _, err := db.Exec(presenceSchema) if err != nil { return nil, err diff --git a/syncapi/storage/sqlite3/receipt_table.go b/syncapi/storage/sqlite3/receipt_table.go index dea057719..bd778bf3c 100644 --- a/syncapi/storage/sqlite3/receipt_table.go +++ b/syncapi/storage/sqlite3/receipt_table.go @@ -59,13 +59,13 @@ const selectMaxReceiptIDSQL = "" + type receiptStatements struct { db *sql.DB - streamIDStatements *streamIDStatements + streamIDStatements *StreamIDStatements upsertReceipt *sql.Stmt selectRoomReceipts *sql.Stmt selectMaxReceiptID *sql.Stmt } -func NewSqliteReceiptsTable(db *sql.DB, streamID *streamIDStatements) (tables.Receipts, error) { +func NewSqliteReceiptsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Receipts, error) { _, err := db.Exec(receiptsSchema) if err != nil { return nil, err diff --git a/syncapi/storage/sqlite3/stream_id_table.go b/syncapi/storage/sqlite3/stream_id_table.go index faa2c41fe..71980b806 100644 --- a/syncapi/storage/sqlite3/stream_id_table.go +++ b/syncapi/storage/sqlite3/stream_id_table.go @@ -32,12 +32,12 @@ const increaseStreamIDStmt = "" + "UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1" + " RETURNING stream_id" -type streamIDStatements struct { +type StreamIDStatements struct { db *sql.DB increaseStreamIDStmt *sql.Stmt } -func (s *streamIDStatements) prepare(db *sql.DB) (err error) { +func (s *StreamIDStatements) Prepare(db *sql.DB) (err error) { s.db = db _, err = db.Exec(streamIDTableSchema) if err != nil { @@ -49,31 +49,31 @@ func (s *streamIDStatements) prepare(db *sql.DB) (err error) { return } -func (s *streamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { +func (s *StreamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) err = increaseStmt.QueryRowContext(ctx, "global").Scan(&pos) return } -func (s *streamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { +func (s *StreamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) err = increaseStmt.QueryRowContext(ctx, "receipt").Scan(&pos) return } -func (s *streamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { +func (s *StreamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) err = increaseStmt.QueryRowContext(ctx, "invite").Scan(&pos) return } -func (s *streamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { +func (s *StreamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) err = increaseStmt.QueryRowContext(ctx, "accountdata").Scan(&pos) return } -func (s *streamIDStatements) nextPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { +func (s *StreamIDStatements) nextPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) err = increaseStmt.QueryRowContext(ctx, "presence").Scan(&pos) return diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index 9d9d35988..dfc289482 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -30,7 +30,7 @@ type SyncServerDatasource struct { shared.Database db *sql.DB writer sqlutil.Writer - streamID streamIDStatements + streamID StreamIDStatements } // NewDatabase creates a new sync server database @@ -49,7 +49,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e } func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (err error) { - if err = d.streamID.prepare(d.db); err != nil { + if err = d.streamID.Prepare(d.db); err != nil { return err } accountData, err := NewSqliteAccountDataTable(d.db, &d.streamID) diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index 403b50eaa..4e1634ece 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -3,6 +3,7 @@ package storage_test import ( "context" "fmt" + "reflect" "testing" "github.com/matrix-org/dendrite/setup/config" @@ -38,7 +39,7 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver if err != nil { t.Fatalf("WriteEvent failed: %s", err) } - fmt.Println("Event ID", ev.EventID(), " spos=", pos, "depth=", ev.Depth()) + t.Logf("Event ID %s spos=%v depth=%v", ev.EventID(), pos, ev.Depth()) positions = append(positions, pos) } return @@ -46,7 +47,6 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver func TestWriteEvents(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - t.Parallel() alice := test.NewUser() r := test.NewRoom(t, alice) db, close := MustCreateDatabase(t, dbType) @@ -61,84 +61,84 @@ func TestRecentEventsPDU(t *testing.T) { db, close := MustCreateDatabase(t, dbType) defer close() alice := test.NewUser() - var filter gomatrixserverlib.RoomEventFilter - filter.Limit = 100 + // dummy room to make sure SQL queries are filtering on room ID + MustWriteEvents(t, db, test.NewRoom(t, alice).Events()) + + // actual test room r := test.NewRoom(t, alice) r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": "hi"}) events := r.Events() positions := MustWriteEvents(t, db, events) + + // dummy room to make sure SQL queries are filtering on room ID + MustWriteEvents(t, db, test.NewRoom(t, alice).Events()) + latest, err := db.MaxStreamPositionForPDUs(ctx) if err != nil { t.Fatalf("failed to get MaxStreamPositionForPDUs: %s", err) } testCases := []struct { - Name string - From types.StreamPosition - To types.StreamPosition - WantEvents []*gomatrixserverlib.HeaderedEvent - WantLimited bool + Name string + From types.StreamPosition + To types.StreamPosition + Limit int + ReverseOrder bool + WantEvents []*gomatrixserverlib.HeaderedEvent + WantLimited bool }{ // The purpose of this test is to make sure that incremental syncs are including up to the latest events. - // It's a basic sanity test that sync works. It creates a `since` token that is on the penultimate event. + // It's a basic sanity test that sync works. It creates a streaming position that is on the penultimate event. // It makes sure the response includes the final event. { - Name: "IncrementalSync penultimate", + Name: "penultimate", From: positions[len(positions)-2], // pretend we are at the penultimate event To: latest, + Limit: 100, WantEvents: events[len(events)-1:], WantLimited: false, }, - /* - // The purpose of this test is to check that passing a `numRecentEventsPerRoom` correctly limits the - // number of returned events. This is critical for big rooms hence the test here. - { - Name: "IncrementalSync limited", - DoSync: func() (*types.Response, error) { - from := types.StreamingToken{ // pretend we are 10 events behind - PDUPosition: positions[len(positions)-11], - } - res := types.NewResponse() - // limit is set to 5 - return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false) - }, - // want the last 5 events, NOT the last 10. - WantTimeline: events[len(events)-5:], - }, - // The purpose of this test is to check that CompleteSync returns all the current state as well as - // honouring the `numRecentEventsPerRoom` value - { - Name: "CompleteSync limited", - DoSync: func() (*types.Response, error) { - res := types.NewResponse() - // limit set to 5 - return db.CompleteSync(ctx, res, testUserDeviceA, 5) - }, - // want the last 5 events - WantTimeline: events[len(events)-5:], - // want all state for the room - WantState: state, - }, - // The purpose of this test is to check that CompleteSync can return everything with a high enough - // `numRecentEventsPerRoom`. - { - Name: "CompleteSync", - DoSync: func() (*types.Response, error) { - res := types.NewResponse() - return db.CompleteSync(ctx, res, testUserDeviceA, len(events)+1) - }, - WantTimeline: events, - // We want no state at all as that field in /sync is the delta between the token (beginning of time) - // and the START of the timeline. - }, */ + // The purpose of this test is to check that limits can be applied and work. + // This is critical for big rooms hence the test here. + { + Name: "limited", + From: 0, + To: latest, + Limit: 1, + WantEvents: events[len(events)-1:], + WantLimited: true, + }, + // The purpose of this test is to check that we can return every event with a high + // enough limit + { + Name: "large limited", + From: 0, + To: latest, + Limit: 100, + WantEvents: events, + WantLimited: false, + }, + // The purpose of this test is to check that we can return events in reverse order + { + Name: "reverse", + From: positions[len(positions)-3], // 2 events back + To: latest, + Limit: 100, + ReverseOrder: true, + WantEvents: test.Reversed(events[len(events)-2:]), + WantLimited: false, + }, } - for _, tc := range testCases { + for i := range testCases { + tc := testCases[i] t.Run(tc.Name, func(st *testing.T) { + var filter gomatrixserverlib.RoomEventFilter + filter.Limit = tc.Limit gotEvents, limited, err := db.RecentEvents(ctx, r.ID, types.Range{ From: tc.From, To: tc.To, - }, &filter, true, true) + }, &filter, !tc.ReverseOrder, true) if err != nil { st.Fatalf("failed to do sync: %s", err) } @@ -148,100 +148,48 @@ func TestRecentEventsPDU(t *testing.T) { if len(gotEvents) != len(tc.WantEvents) { st.Errorf("got %d events, want %d", len(gotEvents), len(tc.WantEvents)) } + for j := range gotEvents { + if !reflect.DeepEqual(gotEvents[j].JSON(), tc.WantEvents[j].JSON()) { + st.Errorf("event %d got %s want %s", j, string(gotEvents[j].JSON()), string(tc.WantEvents[j].JSON())) + } + } }) } }) } -/* -func TestGetEventsInRangeWithPrevBatch(t *testing.T) { - t.Parallel() - db := MustCreateDatabase(t) - events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB) - positions := MustWriteEvents(t, db, events) - latest, err := db.SyncPosition(ctx) - if err != nil { - t.Fatalf("failed to get SyncPosition: %s", err) - } - from := types.StreamingToken{ - PDUPosition: positions[len(positions)-2], - } - - res := types.NewResponse() - res, err = db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false) - if err != nil { - t.Fatalf("failed to IncrementalSync with latest token") - } - roomRes, ok := res.Rooms.Join[testRoomID] - if !ok { - t.Fatalf("IncrementalSync response missing room %s - response: %+v", testRoomID, res) - } - // returns the last event "Message 10" - assertEventsEqual(t, "IncrementalSync Timeline", false, roomRes.Timeline.Events, reversed(events[len(events)-1:])) - - prev := roomRes.Timeline.PrevBatch.String() - if prev == "" { - t.Fatalf("IncrementalSync expected prev_batch token") - } - prevBatchToken, err := types.NewTopologyTokenFromString(prev) - if err != nil { - t.Fatalf("failed to NewTopologyTokenFromString : %s", err) - } - // backpaginate 5 messages starting at the latest position. - // head towards the beginning of time - to := types.TopologyToken{} - paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &prevBatchToken, &to, testRoomID, 5, true) - if err != nil { - t.Fatalf("GetEventsInRange returned an error: %s", err) - } - gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll) - assertEventsEqual(t, "", true, gots, reversed(events[len(events)-6:len(events)-1])) -} - -// The purpose of this test is to ensure that backfill does indeed go backwards, using a stream token. -func TestGetEventsInRangeWithStreamToken(t *testing.T) { - t.Parallel() - db := MustCreateDatabase(t) - events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB) - MustWriteEvents(t, db, events) - latest, err := db.SyncPosition(ctx) - if err != nil { - t.Fatalf("failed to get SyncPosition: %s", err) - } - // head towards the beginning of time - to := types.StreamingToken{} - - // backpaginate 5 messages starting at the latest position. - paginatedEvents, err := db.GetEventsInStreamingRange(ctx, &latest, &to, testRoomID, 5, true) - if err != nil { - t.Fatalf("GetEventsInRange returned an error: %s", err) - } - gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll) - assertEventsEqual(t, "", true, gots, reversed(events[len(events)-5:])) -} - // The purpose of this test is to ensure that backfill does indeed go backwards, using a topology token func TestGetEventsInRangeWithTopologyToken(t *testing.T) { - t.Parallel() - db := MustCreateDatabase(t) - events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB) - MustWriteEvents(t, db, events) - from, err := db.MaxTopologicalPosition(ctx, testRoomID) - if err != nil { - t.Fatalf("failed to get MaxTopologicalPosition: %s", err) - } - // head towards the beginning of time - to := types.TopologyToken{} + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := MustCreateDatabase(t, dbType) + defer close() + alice := test.NewUser() + r := test.NewRoom(t, alice) + for i := 0; i < 10; i++ { + r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("hi %d", i)}) + } + events := r.Events() + _ = MustWriteEvents(t, db, events) - // backpaginate 5 messages starting at the latest position. - paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, testRoomID, 5, true) - if err != nil { - t.Fatalf("GetEventsInRange returned an error: %s", err) - } - gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll) - assertEventsEqual(t, "", true, gots, reversed(events[len(events)-5:])) + from, err := db.MaxTopologicalPosition(ctx, r.ID) + if err != nil { + t.Fatalf("failed to get MaxTopologicalPosition: %s", err) + } + t.Logf("max topo pos = %+v", from) + // head towards the beginning of time + to := types.TopologyToken{} + + // backpaginate 5 messages starting at the latest position. + paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, 5, true) + if err != nil { + t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err) + } + gots := db.StreamEventsToEvents(nil, paginatedEvents) + test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:])) + }) } +/* // The purpose of this test is to make sure that backpagination returns all events, even if some events have the same depth. // For cases where events have the same depth, the streaming token should be used to tie break so events written via WriteEvent // will appear FIRST when going backwards. This test creates a DAG like: @@ -651,12 +599,4 @@ func topologyTokenBefore(t *testing.T, db storage.Database, eventID string) *typ tok.Decrement() return &tok } - -func reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent { - out := make([]*gomatrixserverlib.HeaderedEvent, len(in)) - for i := 0; i < len(in); i++ { - out[i] = in[len(in)-i-1] - } - return out -} */ diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 8d368eec1..3cbeb0462 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -59,7 +59,7 @@ type Events interface { SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) // SelectEarlyEvents returns the earliest events in the given room. SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error) - SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error) + SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string, preserveOrder bool) ([]types.StreamEvent, error) UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error // DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely. DeleteEventsForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error) diff --git a/syncapi/storage/tables/output_room_events_test.go b/syncapi/storage/tables/output_room_events_test.go new file mode 100644 index 000000000..7a81ffcd2 --- /dev/null +++ b/syncapi/storage/tables/output_room_events_test.go @@ -0,0 +1,82 @@ +package tables_test + +import ( + "context" + "database/sql" + "fmt" + "reflect" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/syncapi/storage/postgres" + "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/test" +) + +func newOutputRoomEventsTable(t *testing.T, dbType test.DBType) (tables.Events, *sql.DB, func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }) + if err != nil { + t.Fatalf("failed to open db: %s", err) + } + + var tab tables.Events + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresEventsTable(db) + case test.DBTypeSQLite: + var stream sqlite3.StreamIDStatements + if err = stream.Prepare(db); err != nil { + t.Fatalf("failed to prepare stream stmts: %s", err) + } + tab, err = sqlite3.NewSqliteEventsTable(db, &stream) + } + if err != nil { + t.Fatalf("failed to make new table: %s", err) + } + return tab, db, close +} + +func TestOutputRoomEventsTable(t *testing.T) { + ctx := context.Background() + alice := test.NewUser() + room := test.NewRoom(t, alice) + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, db, close := newOutputRoomEventsTable(t, dbType) + defer close() + events := room.Events() + err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error { + for _, ev := range events { + _, err := tab.InsertEvent(ctx, txn, ev, nil, nil, nil, false) + if err != nil { + return fmt.Errorf("failed to InsertEvent: %s", err) + } + } + // order = 2,0,3,1 + wantEventIDs := []string{ + events[2].EventID(), events[0].EventID(), events[3].EventID(), events[1].EventID(), + } + gotEvents, err := tab.SelectEvents(ctx, txn, wantEventIDs, true) + if err != nil { + return fmt.Errorf("failed to SelectEvents: %s", err) + } + gotEventIDs := make([]string, len(gotEvents)) + for i := range gotEvents { + gotEventIDs[i] = gotEvents[i].EventID() + } + if !reflect.DeepEqual(gotEventIDs, wantEventIDs) { + return fmt.Errorf("SelectEvents\ngot %v\n want %v", gotEventIDs, wantEventIDs) + } + + return nil + }) + if err != nil { + t.Fatalf("err: %s", err) + } + }) +} diff --git a/syncapi/storage/tables/topology_test.go b/syncapi/storage/tables/topology_test.go new file mode 100644 index 000000000..b6ece0b0d --- /dev/null +++ b/syncapi/storage/tables/topology_test.go @@ -0,0 +1,91 @@ +package tables_test + +import ( + "context" + "database/sql" + "fmt" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/syncapi/storage/postgres" + "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/test" +) + +func newTopologyTable(t *testing.T, dbType test.DBType) (tables.Topology, *sql.DB, func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }) + if err != nil { + t.Fatalf("failed to open db: %s", err) + } + + var tab tables.Topology + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresTopologyTable(db) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSqliteTopologyTable(db) + } + if err != nil { + t.Fatalf("failed to make new table: %s", err) + } + return tab, db, close +} + +func TestTopologyTable(t *testing.T) { + ctx := context.Background() + alice := test.NewUser() + room := test.NewRoom(t, alice) + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, db, close := newTopologyTable(t, dbType) + defer close() + events := room.Events() + err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error { + var highestPos types.StreamPosition + for i, ev := range events { + topoPos, err := tab.InsertEventInTopology(ctx, txn, ev, types.StreamPosition(i)) + if err != nil { + return fmt.Errorf("failed to InsertEventInTopology: %s", err) + } + // topo pos = depth, depth starts at 1, hence 1+i + if topoPos != types.StreamPosition(1+i) { + return fmt.Errorf("got topo pos %d want %d", topoPos, 1+i) + } + highestPos = topoPos + 1 + } + // check ordering works without limit + eventIDs, err := tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, true) + if err != nil { + return fmt.Errorf("failed to SelectEventIDsInRange: %s", err) + } + test.AssertEventIDsEqual(t, eventIDs, events[:]) + eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, false) + if err != nil { + return fmt.Errorf("failed to SelectEventIDsInRange: %s", err) + } + test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[:])) + // check ordering works with limit + eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, true) + if err != nil { + return fmt.Errorf("failed to SelectEventIDsInRange: %s", err) + } + test.AssertEventIDsEqual(t, eventIDs, events[:3]) + eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, false) + if err != nil { + return fmt.Errorf("failed to SelectEventIDsInRange: %s", err) + } + test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[len(events)-3:])) + + return nil + }) + if err != nil { + t.Fatalf("err: %s", err) + } + }) +} diff --git a/test/db.go b/test/db.go index 9deec0a89..674fdf5c3 100644 --- a/test/db.go +++ b/test/db.go @@ -121,6 +121,7 @@ func WithAllDatabases(t *testing.T, testFn func(t *testing.T, db DBType)) { for dbName, dbType := range dbs { dbt := dbType t.Run(dbName, func(tt *testing.T) { + tt.Parallel() testFn(tt, dbt) }) } diff --git a/test/event.go b/test/event.go index 487b09364..b2e2805ba 100644 --- a/test/event.go +++ b/test/event.go @@ -15,7 +15,9 @@ package test import ( + "bytes" "crypto/ed25519" + "testing" "time" "github.com/matrix-org/gomatrixserverlib" @@ -49,3 +51,40 @@ func WithUnsigned(unsigned interface{}) eventModifier { e.unsigned = unsigned } } + +// Reverse a list of events +func Reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent { + out := make([]*gomatrixserverlib.HeaderedEvent, len(in)) + for i := 0; i < len(in); i++ { + out[i] = in[len(in)-i-1] + } + return out +} + +func AssertEventIDsEqual(t *testing.T, gotEventIDs []string, wants []*gomatrixserverlib.HeaderedEvent) { + t.Helper() + if len(gotEventIDs) != len(wants) { + t.Fatalf("length mismatch: got %d events, want %d", len(gotEventIDs), len(wants)) + } + for i := range wants { + w := wants[i].EventID() + g := gotEventIDs[i] + if w != g { + t.Errorf("event at index %d mismatch:\ngot %s\n\nwant %s", i, string(g), string(w)) + } + } +} + +func AssertEventsEqual(t *testing.T, gots, wants []*gomatrixserverlib.HeaderedEvent) { + t.Helper() + if len(gots) != len(wants) { + t.Fatalf("length mismatch: got %d events, want %d", len(gots), len(wants)) + } + for i := range wants { + w := wants[i].JSON() + g := gots[i].JSON() + if !bytes.Equal(w, g) { + t.Errorf("event at index %d mismatch:\ngot %s\n\nwant %s", i, string(g), string(w)) + } + } +}