syncapi: add more tests; fix more bugs (#2338)

* syncapi: add more tests; fix more bugs

bugfixes:
 - The postgres impl of TopologyTable.SelectEventIDsInRange did not use the provided txn
 - The postgres impl of EventsTable.SelectEvents did not preserve the ordering of the input event IDs in the output events slice
 - The sqlite impl of EventsTable.SelectEvents did not use a bulk `IN ($1)` query.

Added tests:
 - `TestGetEventsInRangeWithTopologyToken`
 - `TestOutputRoomEventsTable`
 - `TestTopologyTable`

* -p 1 for now
This commit is contained in:
kegsay 2022-04-08 17:53:24 +01:00 committed by GitHub
parent 986d27a128
commit 6d25bd6ca5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 388 additions and 197 deletions

View file

@ -111,7 +111,7 @@ jobs:
key: ${{ runner.os }}-go${{ matrix.go }}-test-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go${{ matrix.go }}-test-${{ hashFiles('**/go.sum') }}
restore-keys: | restore-keys: |
${{ runner.os }}-go${{ matrix.go }}-test- ${{ runner.os }}-go${{ matrix.go }}-test-
- run: go test ./... - run: go test -p 1 ./...
env: env:
POSTGRES_HOST: localhost POSTGRES_HOST: localhost
POSTGRES_USER: postgres POSTGRES_USER: postgres

View file

@ -104,7 +104,7 @@ type Database interface {
// DeletePeek deletes all peeks for a given room by a given user // DeletePeek deletes all peeks for a given room by a given user
// Returns an error if there was a problem communicating with the database. // Returns an error if there was a problem communicating with the database.
DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error) DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error)
// GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. // GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. If backwardsOrdering is true, the most recent event must be first, else last.
GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error) GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error)
// EventPositionInTopology returns the depth and stream position of the given event. // EventPositionInTopology returns the depth and stream position of the given event.
EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error) EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error)

View file

@ -427,7 +427,7 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
// selectEvents returns the events for the given event IDs. If an event is // selectEvents returns the events for the given event IDs. If an event is
// missing from the database, it will be omitted. // missing from the database, it will be omitted.
func (s *outputRoomEventsStatements) SelectEvents( func (s *outputRoomEventsStatements) SelectEvents(
ctx context.Context, txn *sql.Tx, eventIDs []string, ctx context.Context, txn *sql.Tx, eventIDs []string, preserveOrder bool,
) ([]types.StreamEvent, error) { ) ([]types.StreamEvent, error) {
stmt := sqlutil.TxStmt(txn, s.selectEventsStmt) stmt := sqlutil.TxStmt(txn, s.selectEventsStmt)
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
@ -435,7 +435,25 @@ func (s *outputRoomEventsStatements) SelectEvents(
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
return rowsToStreamEvents(rows) streamEvents, err := rowsToStreamEvents(rows)
if err != nil {
return nil, err
}
if preserveOrder {
eventMap := make(map[string]types.StreamEvent)
for _, ev := range streamEvents {
eventMap[ev.EventID()] = ev
}
var returnEvents []types.StreamEvent
for _, eventID := range eventIDs {
ev, ok := eventMap[eventID]
if ok {
returnEvents = append(returnEvents, ev)
}
}
return returnEvents, nil
}
return streamEvents, nil
} }
func (s *outputRoomEventsStatements) DeleteEventsForRoom( func (s *outputRoomEventsStatements) DeleteEventsForRoom(

View file

@ -148,9 +148,9 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
// is requested or not. // is requested or not.
var stmt *sql.Stmt var stmt *sql.Stmt
if chronologicalOrder { if chronologicalOrder {
stmt = s.selectEventIDsInRangeASCStmt stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeASCStmt)
} else { } else {
stmt = s.selectEventIDsInRangeDESCStmt stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeDESCStmt)
} }
// Query the event IDs. // Query the event IDs.

View file

@ -150,7 +150,7 @@ func (d *Database) RoomReceiptsAfter(ctx context.Context, roomIDs []string, stre
// Returns an error if there was a problem talking with the database. // Returns an error if there was a problem talking with the database.
// Does not include any transaction IDs in the returned events. // Does not include any transaction IDs in the returned events.
func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs) streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -312,7 +312,7 @@ func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, e
// Check if we have all of the event's previous events. If an event is // Check if we have all of the event's previous events. If an event is
// missing, add it to the room's backward extremities. // missing, add it to the room's backward extremities.
prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs()) prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs(), false)
if err != nil { if err != nil {
return err return err
} }
@ -457,7 +457,7 @@ func (d *Database) GetEventsInTopologicalRange(
} }
// Retrieve the events' contents using their IDs. // Retrieve the events' contents using their IDs.
events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs) events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs, true)
return return
} }
@ -619,7 +619,7 @@ func (d *Database) fetchMissingStateEvents(
) ([]types.StreamEvent, error) { ) ([]types.StreamEvent, error) {
// Fetch from the events table first so we pick up the stream ID for the // Fetch from the events table first so we pick up the stream ID for the
// event. // event.
events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs) events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -51,13 +51,13 @@ const selectMaxAccountDataIDSQL = "" +
type accountDataStatements struct { type accountDataStatements struct {
db *sql.DB db *sql.DB
streamIDStatements *streamIDStatements streamIDStatements *StreamIDStatements
insertAccountDataStmt *sql.Stmt insertAccountDataStmt *sql.Stmt
selectMaxAccountDataIDStmt *sql.Stmt selectMaxAccountDataIDStmt *sql.Stmt
selectAccountDataInRangeStmt *sql.Stmt selectAccountDataInRangeStmt *sql.Stmt
} }
func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) { func NewSqliteAccountDataTable(db *sql.DB, streamID *StreamIDStatements) (tables.AccountData, error) {
s := &accountDataStatements{ s := &accountDataStatements{
db: db, db: db,
streamIDStatements: streamID, streamIDStatements: streamID,

View file

@ -90,7 +90,7 @@ const selectEventsWithEventIDsSQL = "" +
type currentRoomStateStatements struct { type currentRoomStateStatements struct {
db *sql.DB db *sql.DB
streamIDStatements *streamIDStatements streamIDStatements *StreamIDStatements
upsertRoomStateStmt *sql.Stmt upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt
deleteRoomStateForRoomStmt *sql.Stmt deleteRoomStateForRoomStmt *sql.Stmt
@ -100,7 +100,7 @@ type currentRoomStateStatements struct {
selectStateEventStmt *sql.Stmt selectStateEventStmt *sql.Stmt
} }
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) { func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) {
s := &currentRoomStateStatements{ s := &currentRoomStateStatements{
db: db, db: db,
streamIDStatements: streamID, streamIDStatements: streamID,

View file

@ -59,14 +59,14 @@ const selectMaxInviteIDSQL = "" +
type inviteEventsStatements struct { type inviteEventsStatements struct {
db *sql.DB db *sql.DB
streamIDStatements *streamIDStatements streamIDStatements *StreamIDStatements
insertInviteEventStmt *sql.Stmt insertInviteEventStmt *sql.Stmt
selectInviteEventsInRangeStmt *sql.Stmt selectInviteEventsInRangeStmt *sql.Stmt
deleteInviteEventStmt *sql.Stmt deleteInviteEventStmt *sql.Stmt
selectMaxInviteIDStmt *sql.Stmt selectMaxInviteIDStmt *sql.Stmt
} }
func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) { func NewSqliteInvitesTable(db *sql.DB, streamID *StreamIDStatements) (tables.Invites, error) {
s := &inviteEventsStatements{ s := &inviteEventsStatements{
db: db, db: db,
streamIDStatements: streamID, streamIDStatements: streamID,

View file

@ -58,7 +58,7 @@ const insertEventSQL = "" +
"ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $13)" "ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $13)"
const selectEventsSQL = "" + const selectEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = $1" "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id IN ($1)"
const selectRecentEventsSQL = "" + const selectRecentEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
@ -111,9 +111,8 @@ const selectContextAfterEventSQL = "" +
type outputRoomEventsStatements struct { type outputRoomEventsStatements struct {
db *sql.DB db *sql.DB
streamIDStatements *streamIDStatements streamIDStatements *StreamIDStatements
insertEventStmt *sql.Stmt insertEventStmt *sql.Stmt
selectEventsStmt *sql.Stmt
selectMaxEventIDStmt *sql.Stmt selectMaxEventIDStmt *sql.Stmt
updateEventJSONStmt *sql.Stmt updateEventJSONStmt *sql.Stmt
deleteEventsForRoomStmt *sql.Stmt deleteEventsForRoomStmt *sql.Stmt
@ -122,7 +121,7 @@ type outputRoomEventsStatements struct {
selectContextAfterEventStmt *sql.Stmt selectContextAfterEventStmt *sql.Stmt
} }
func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) { func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Events, error) {
s := &outputRoomEventsStatements{ s := &outputRoomEventsStatements{
db: db, db: db,
streamIDStatements: streamID, streamIDStatements: streamID,
@ -133,7 +132,6 @@ func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Even
} }
return s, sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertEventStmt, insertEventSQL}, {&s.insertEventStmt, insertEventSQL},
{&s.selectEventsStmt, selectEventsSQL},
{&s.selectMaxEventIDStmt, selectMaxEventIDSQL}, {&s.selectMaxEventIDStmt, selectMaxEventIDSQL},
{&s.updateEventJSONStmt, updateEventJSONSQL}, {&s.updateEventJSONStmt, updateEventJSONSQL},
{&s.deleteEventsForRoomStmt, deleteEventsForRoomSQL}, {&s.deleteEventsForRoomStmt, deleteEventsForRoomSQL},
@ -421,21 +419,43 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
// selectEvents returns the events for the given event IDs. If an event is // selectEvents returns the events for the given event IDs. If an event is
// missing from the database, it will be omitted. // missing from the database, it will be omitted.
func (s *outputRoomEventsStatements) SelectEvents( func (s *outputRoomEventsStatements) SelectEvents(
ctx context.Context, txn *sql.Tx, eventIDs []string, ctx context.Context, txn *sql.Tx, eventIDs []string, preserveOrder bool,
) ([]types.StreamEvent, error) { ) ([]types.StreamEvent, error) {
var returnEvents []types.StreamEvent iEventIDs := make([]interface{}, len(eventIDs))
stmt := sqlutil.TxStmt(txn, s.selectEventsStmt) for i := range eventIDs {
for _, eventID := range eventIDs { iEventIDs[i] = eventIDs[i]
rows, err := stmt.QueryContext(ctx, eventID)
if err != nil {
return nil, err
}
if streamEvents, err := rowsToStreamEvents(rows); err == nil {
returnEvents = append(returnEvents, streamEvents...)
}
internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
} }
return returnEvents, nil selectSQL := strings.Replace(selectEventsSQL, "($1)", sqlutil.QueryVariadic(len(eventIDs)), 1)
var rows *sql.Rows
var err error
if txn != nil {
rows, err = txn.QueryContext(ctx, selectSQL, iEventIDs...)
} else {
rows, err = s.db.QueryContext(ctx, selectSQL, iEventIDs...)
}
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
streamEvents, err := rowsToStreamEvents(rows)
if err != nil {
return nil, err
}
if preserveOrder {
var returnEvents []types.StreamEvent
eventMap := make(map[string]types.StreamEvent)
for _, ev := range streamEvents {
eventMap[ev.EventID()] = ev
}
for _, eventID := range eventIDs {
ev, ok := eventMap[eventID]
if ok {
returnEvents = append(returnEvents, ev)
}
}
return returnEvents, nil
}
return streamEvents, nil
} }
func (s *outputRoomEventsStatements) DeleteEventsForRoom( func (s *outputRoomEventsStatements) DeleteEventsForRoom(

View file

@ -66,7 +66,7 @@ const selectMaxPeekIDSQL = "" +
type peekStatements struct { type peekStatements struct {
db *sql.DB db *sql.DB
streamIDStatements *streamIDStatements streamIDStatements *StreamIDStatements
insertPeekStmt *sql.Stmt insertPeekStmt *sql.Stmt
deletePeekStmt *sql.Stmt deletePeekStmt *sql.Stmt
deletePeeksStmt *sql.Stmt deletePeeksStmt *sql.Stmt
@ -75,7 +75,7 @@ type peekStatements struct {
selectMaxPeekIDStmt *sql.Stmt selectMaxPeekIDStmt *sql.Stmt
} }
func NewSqlitePeeksTable(db *sql.DB, streamID *streamIDStatements) (tables.Peeks, error) { func NewSqlitePeeksTable(db *sql.DB, streamID *StreamIDStatements) (tables.Peeks, error) {
_, err := db.Exec(peeksSchema) _, err := db.Exec(peeksSchema)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -75,7 +75,7 @@ const selectPresenceAfter = "" +
type presenceStatements struct { type presenceStatements struct {
db *sql.DB db *sql.DB
streamIDStatements *streamIDStatements streamIDStatements *StreamIDStatements
upsertPresenceStmt *sql.Stmt upsertPresenceStmt *sql.Stmt
upsertPresenceFromSyncStmt *sql.Stmt upsertPresenceFromSyncStmt *sql.Stmt
selectPresenceForUsersStmt *sql.Stmt selectPresenceForUsersStmt *sql.Stmt
@ -83,7 +83,7 @@ type presenceStatements struct {
selectPresenceAfterStmt *sql.Stmt selectPresenceAfterStmt *sql.Stmt
} }
func NewSqlitePresenceTable(db *sql.DB, streamID *streamIDStatements) (*presenceStatements, error) { func NewSqlitePresenceTable(db *sql.DB, streamID *StreamIDStatements) (*presenceStatements, error) {
_, err := db.Exec(presenceSchema) _, err := db.Exec(presenceSchema)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -59,13 +59,13 @@ const selectMaxReceiptIDSQL = "" +
type receiptStatements struct { type receiptStatements struct {
db *sql.DB db *sql.DB
streamIDStatements *streamIDStatements streamIDStatements *StreamIDStatements
upsertReceipt *sql.Stmt upsertReceipt *sql.Stmt
selectRoomReceipts *sql.Stmt selectRoomReceipts *sql.Stmt
selectMaxReceiptID *sql.Stmt selectMaxReceiptID *sql.Stmt
} }
func NewSqliteReceiptsTable(db *sql.DB, streamID *streamIDStatements) (tables.Receipts, error) { func NewSqliteReceiptsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Receipts, error) {
_, err := db.Exec(receiptsSchema) _, err := db.Exec(receiptsSchema)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -32,12 +32,12 @@ const increaseStreamIDStmt = "" +
"UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1" + "UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1" +
" RETURNING stream_id" " RETURNING stream_id"
type streamIDStatements struct { type StreamIDStatements struct {
db *sql.DB db *sql.DB
increaseStreamIDStmt *sql.Stmt increaseStreamIDStmt *sql.Stmt
} }
func (s *streamIDStatements) prepare(db *sql.DB) (err error) { func (s *StreamIDStatements) Prepare(db *sql.DB) (err error) {
s.db = db s.db = db
_, err = db.Exec(streamIDTableSchema) _, err = db.Exec(streamIDTableSchema)
if err != nil { if err != nil {
@ -49,31 +49,31 @@ func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
return return
} }
func (s *streamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { func (s *StreamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
err = increaseStmt.QueryRowContext(ctx, "global").Scan(&pos) err = increaseStmt.QueryRowContext(ctx, "global").Scan(&pos)
return return
} }
func (s *streamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { func (s *StreamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
err = increaseStmt.QueryRowContext(ctx, "receipt").Scan(&pos) err = increaseStmt.QueryRowContext(ctx, "receipt").Scan(&pos)
return return
} }
func (s *streamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { func (s *StreamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
err = increaseStmt.QueryRowContext(ctx, "invite").Scan(&pos) err = increaseStmt.QueryRowContext(ctx, "invite").Scan(&pos)
return return
} }
func (s *streamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { func (s *StreamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
err = increaseStmt.QueryRowContext(ctx, "accountdata").Scan(&pos) err = increaseStmt.QueryRowContext(ctx, "accountdata").Scan(&pos)
return return
} }
func (s *streamIDStatements) nextPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { func (s *StreamIDStatements) nextPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
err = increaseStmt.QueryRowContext(ctx, "presence").Scan(&pos) err = increaseStmt.QueryRowContext(ctx, "presence").Scan(&pos)
return return

View file

@ -30,7 +30,7 @@ type SyncServerDatasource struct {
shared.Database shared.Database
db *sql.DB db *sql.DB
writer sqlutil.Writer writer sqlutil.Writer
streamID streamIDStatements streamID StreamIDStatements
} }
// NewDatabase creates a new sync server database // NewDatabase creates a new sync server database
@ -49,7 +49,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
} }
func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (err error) { func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (err error) {
if err = d.streamID.prepare(d.db); err != nil { if err = d.streamID.Prepare(d.db); err != nil {
return err return err
} }
accountData, err := NewSqliteAccountDataTable(d.db, &d.streamID) accountData, err := NewSqliteAccountDataTable(d.db, &d.streamID)

View file

@ -3,6 +3,7 @@ package storage_test
import ( import (
"context" "context"
"fmt" "fmt"
"reflect"
"testing" "testing"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
@ -38,7 +39,7 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver
if err != nil { if err != nil {
t.Fatalf("WriteEvent failed: %s", err) t.Fatalf("WriteEvent failed: %s", err)
} }
fmt.Println("Event ID", ev.EventID(), " spos=", pos, "depth=", ev.Depth()) t.Logf("Event ID %s spos=%v depth=%v", ev.EventID(), pos, ev.Depth())
positions = append(positions, pos) positions = append(positions, pos)
} }
return return
@ -46,7 +47,6 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver
func TestWriteEvents(t *testing.T) { func TestWriteEvents(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
t.Parallel()
alice := test.NewUser() alice := test.NewUser()
r := test.NewRoom(t, alice) r := test.NewRoom(t, alice)
db, close := MustCreateDatabase(t, dbType) db, close := MustCreateDatabase(t, dbType)
@ -61,84 +61,84 @@ func TestRecentEventsPDU(t *testing.T) {
db, close := MustCreateDatabase(t, dbType) db, close := MustCreateDatabase(t, dbType)
defer close() defer close()
alice := test.NewUser() alice := test.NewUser()
var filter gomatrixserverlib.RoomEventFilter // dummy room to make sure SQL queries are filtering on room ID
filter.Limit = 100 MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
// actual test room
r := test.NewRoom(t, alice) r := test.NewRoom(t, alice)
r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": "hi"}) r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": "hi"})
events := r.Events() events := r.Events()
positions := MustWriteEvents(t, db, events) positions := MustWriteEvents(t, db, events)
// dummy room to make sure SQL queries are filtering on room ID
MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
latest, err := db.MaxStreamPositionForPDUs(ctx) latest, err := db.MaxStreamPositionForPDUs(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to get MaxStreamPositionForPDUs: %s", err) t.Fatalf("failed to get MaxStreamPositionForPDUs: %s", err)
} }
testCases := []struct { testCases := []struct {
Name string Name string
From types.StreamPosition From types.StreamPosition
To types.StreamPosition To types.StreamPosition
WantEvents []*gomatrixserverlib.HeaderedEvent Limit int
WantLimited bool ReverseOrder bool
WantEvents []*gomatrixserverlib.HeaderedEvent
WantLimited bool
}{ }{
// The purpose of this test is to make sure that incremental syncs are including up to the latest events. // The purpose of this test is to make sure that incremental syncs are including up to the latest events.
// It's a basic sanity test that sync works. It creates a `since` token that is on the penultimate event. // It's a basic sanity test that sync works. It creates a streaming position that is on the penultimate event.
// It makes sure the response includes the final event. // It makes sure the response includes the final event.
{ {
Name: "IncrementalSync penultimate", Name: "penultimate",
From: positions[len(positions)-2], // pretend we are at the penultimate event From: positions[len(positions)-2], // pretend we are at the penultimate event
To: latest, To: latest,
Limit: 100,
WantEvents: events[len(events)-1:], WantEvents: events[len(events)-1:],
WantLimited: false, WantLimited: false,
}, },
/* // The purpose of this test is to check that limits can be applied and work.
// The purpose of this test is to check that passing a `numRecentEventsPerRoom` correctly limits the // This is critical for big rooms hence the test here.
// number of returned events. This is critical for big rooms hence the test here. {
{ Name: "limited",
Name: "IncrementalSync limited", From: 0,
DoSync: func() (*types.Response, error) { To: latest,
from := types.StreamingToken{ // pretend we are 10 events behind Limit: 1,
PDUPosition: positions[len(positions)-11], WantEvents: events[len(events)-1:],
} WantLimited: true,
res := types.NewResponse() },
// limit is set to 5 // The purpose of this test is to check that we can return every event with a high
return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false) // enough limit
}, {
// want the last 5 events, NOT the last 10. Name: "large limited",
WantTimeline: events[len(events)-5:], From: 0,
}, To: latest,
// The purpose of this test is to check that CompleteSync returns all the current state as well as Limit: 100,
// honouring the `numRecentEventsPerRoom` value WantEvents: events,
{ WantLimited: false,
Name: "CompleteSync limited", },
DoSync: func() (*types.Response, error) { // The purpose of this test is to check that we can return events in reverse order
res := types.NewResponse() {
// limit set to 5 Name: "reverse",
return db.CompleteSync(ctx, res, testUserDeviceA, 5) From: positions[len(positions)-3], // 2 events back
}, To: latest,
// want the last 5 events Limit: 100,
WantTimeline: events[len(events)-5:], ReverseOrder: true,
// want all state for the room WantEvents: test.Reversed(events[len(events)-2:]),
WantState: state, WantLimited: false,
}, },
// The purpose of this test is to check that CompleteSync can return everything with a high enough
// `numRecentEventsPerRoom`.
{
Name: "CompleteSync",
DoSync: func() (*types.Response, error) {
res := types.NewResponse()
return db.CompleteSync(ctx, res, testUserDeviceA, len(events)+1)
},
WantTimeline: events,
// We want no state at all as that field in /sync is the delta between the token (beginning of time)
// and the START of the timeline.
}, */
} }
for _, tc := range testCases { for i := range testCases {
tc := testCases[i]
t.Run(tc.Name, func(st *testing.T) { t.Run(tc.Name, func(st *testing.T) {
var filter gomatrixserverlib.RoomEventFilter
filter.Limit = tc.Limit
gotEvents, limited, err := db.RecentEvents(ctx, r.ID, types.Range{ gotEvents, limited, err := db.RecentEvents(ctx, r.ID, types.Range{
From: tc.From, From: tc.From,
To: tc.To, To: tc.To,
}, &filter, true, true) }, &filter, !tc.ReverseOrder, true)
if err != nil { if err != nil {
st.Fatalf("failed to do sync: %s", err) st.Fatalf("failed to do sync: %s", err)
} }
@ -148,100 +148,48 @@ func TestRecentEventsPDU(t *testing.T) {
if len(gotEvents) != len(tc.WantEvents) { if len(gotEvents) != len(tc.WantEvents) {
st.Errorf("got %d events, want %d", len(gotEvents), len(tc.WantEvents)) st.Errorf("got %d events, want %d", len(gotEvents), len(tc.WantEvents))
} }
for j := range gotEvents {
if !reflect.DeepEqual(gotEvents[j].JSON(), tc.WantEvents[j].JSON()) {
st.Errorf("event %d got %s want %s", j, string(gotEvents[j].JSON()), string(tc.WantEvents[j].JSON()))
}
}
}) })
} }
}) })
} }
/*
func TestGetEventsInRangeWithPrevBatch(t *testing.T) {
t.Parallel()
db := MustCreateDatabase(t)
events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
positions := MustWriteEvents(t, db, events)
latest, err := db.SyncPosition(ctx)
if err != nil {
t.Fatalf("failed to get SyncPosition: %s", err)
}
from := types.StreamingToken{
PDUPosition: positions[len(positions)-2],
}
res := types.NewResponse()
res, err = db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
if err != nil {
t.Fatalf("failed to IncrementalSync with latest token")
}
roomRes, ok := res.Rooms.Join[testRoomID]
if !ok {
t.Fatalf("IncrementalSync response missing room %s - response: %+v", testRoomID, res)
}
// returns the last event "Message 10"
assertEventsEqual(t, "IncrementalSync Timeline", false, roomRes.Timeline.Events, reversed(events[len(events)-1:]))
prev := roomRes.Timeline.PrevBatch.String()
if prev == "" {
t.Fatalf("IncrementalSync expected prev_batch token")
}
prevBatchToken, err := types.NewTopologyTokenFromString(prev)
if err != nil {
t.Fatalf("failed to NewTopologyTokenFromString : %s", err)
}
// backpaginate 5 messages starting at the latest position.
// head towards the beginning of time
to := types.TopologyToken{}
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &prevBatchToken, &to, testRoomID, 5, true)
if err != nil {
t.Fatalf("GetEventsInRange returned an error: %s", err)
}
gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll)
assertEventsEqual(t, "", true, gots, reversed(events[len(events)-6:len(events)-1]))
}
// The purpose of this test is to ensure that backfill does indeed go backwards, using a stream token.
func TestGetEventsInRangeWithStreamToken(t *testing.T) {
t.Parallel()
db := MustCreateDatabase(t)
events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
MustWriteEvents(t, db, events)
latest, err := db.SyncPosition(ctx)
if err != nil {
t.Fatalf("failed to get SyncPosition: %s", err)
}
// head towards the beginning of time
to := types.StreamingToken{}
// backpaginate 5 messages starting at the latest position.
paginatedEvents, err := db.GetEventsInStreamingRange(ctx, &latest, &to, testRoomID, 5, true)
if err != nil {
t.Fatalf("GetEventsInRange returned an error: %s", err)
}
gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll)
assertEventsEqual(t, "", true, gots, reversed(events[len(events)-5:]))
}
// The purpose of this test is to ensure that backfill does indeed go backwards, using a topology token // The purpose of this test is to ensure that backfill does indeed go backwards, using a topology token
func TestGetEventsInRangeWithTopologyToken(t *testing.T) { func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
t.Parallel() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db := MustCreateDatabase(t) db, close := MustCreateDatabase(t, dbType)
events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB) defer close()
MustWriteEvents(t, db, events) alice := test.NewUser()
from, err := db.MaxTopologicalPosition(ctx, testRoomID) r := test.NewRoom(t, alice)
if err != nil { for i := 0; i < 10; i++ {
t.Fatalf("failed to get MaxTopologicalPosition: %s", err) r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("hi %d", i)})
} }
// head towards the beginning of time events := r.Events()
to := types.TopologyToken{} _ = MustWriteEvents(t, db, events)
// backpaginate 5 messages starting at the latest position. from, err := db.MaxTopologicalPosition(ctx, r.ID)
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, testRoomID, 5, true) if err != nil {
if err != nil { t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
t.Fatalf("GetEventsInRange returned an error: %s", err) }
} t.Logf("max topo pos = %+v", from)
gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll) // head towards the beginning of time
assertEventsEqual(t, "", true, gots, reversed(events[len(events)-5:])) to := types.TopologyToken{}
// backpaginate 5 messages starting at the latest position.
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, 5, true)
if err != nil {
t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err)
}
gots := db.StreamEventsToEvents(nil, paginatedEvents)
test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:]))
})
} }
/*
// The purpose of this test is to make sure that backpagination returns all events, even if some events have the same depth. // The purpose of this test is to make sure that backpagination returns all events, even if some events have the same depth.
// For cases where events have the same depth, the streaming token should be used to tie break so events written via WriteEvent // For cases where events have the same depth, the streaming token should be used to tie break so events written via WriteEvent
// will appear FIRST when going backwards. This test creates a DAG like: // will appear FIRST when going backwards. This test creates a DAG like:
@ -651,12 +599,4 @@ func topologyTokenBefore(t *testing.T, db storage.Database, eventID string) *typ
tok.Decrement() tok.Decrement()
return &tok return &tok
} }
func reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent {
out := make([]*gomatrixserverlib.HeaderedEvent, len(in))
for i := 0; i < len(in); i++ {
out[i] = in[len(in)-i-1]
}
return out
}
*/ */

View file

@ -59,7 +59,7 @@ type Events interface {
SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
// SelectEarlyEvents returns the earliest events in the given room. // SelectEarlyEvents returns the earliest events in the given room.
SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error) SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error)
SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error) SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string, preserveOrder bool) ([]types.StreamEvent, error)
UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error
// DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely. // DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely.
DeleteEventsForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error) DeleteEventsForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error)

View file

@ -0,0 +1,82 @@
package tables_test
import (
"context"
"database/sql"
"fmt"
"reflect"
"testing"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage/postgres"
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/test"
)
func newOutputRoomEventsTable(t *testing.T, dbType test.DBType) (tables.Events, *sql.DB, func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
})
if err != nil {
t.Fatalf("failed to open db: %s", err)
}
var tab tables.Events
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresEventsTable(db)
case test.DBTypeSQLite:
var stream sqlite3.StreamIDStatements
if err = stream.Prepare(db); err != nil {
t.Fatalf("failed to prepare stream stmts: %s", err)
}
tab, err = sqlite3.NewSqliteEventsTable(db, &stream)
}
if err != nil {
t.Fatalf("failed to make new table: %s", err)
}
return tab, db, close
}
func TestOutputRoomEventsTable(t *testing.T) {
ctx := context.Background()
alice := test.NewUser()
room := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, db, close := newOutputRoomEventsTable(t, dbType)
defer close()
events := room.Events()
err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error {
for _, ev := range events {
_, err := tab.InsertEvent(ctx, txn, ev, nil, nil, nil, false)
if err != nil {
return fmt.Errorf("failed to InsertEvent: %s", err)
}
}
// order = 2,0,3,1
wantEventIDs := []string{
events[2].EventID(), events[0].EventID(), events[3].EventID(), events[1].EventID(),
}
gotEvents, err := tab.SelectEvents(ctx, txn, wantEventIDs, true)
if err != nil {
return fmt.Errorf("failed to SelectEvents: %s", err)
}
gotEventIDs := make([]string, len(gotEvents))
for i := range gotEvents {
gotEventIDs[i] = gotEvents[i].EventID()
}
if !reflect.DeepEqual(gotEventIDs, wantEventIDs) {
return fmt.Errorf("SelectEvents\ngot %v\n want %v", gotEventIDs, wantEventIDs)
}
return nil
})
if err != nil {
t.Fatalf("err: %s", err)
}
})
}

View file

@ -0,0 +1,91 @@
package tables_test
import (
"context"
"database/sql"
"fmt"
"testing"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage/postgres"
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/test"
)
func newTopologyTable(t *testing.T, dbType test.DBType) (tables.Topology, *sql.DB, func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
})
if err != nil {
t.Fatalf("failed to open db: %s", err)
}
var tab tables.Topology
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresTopologyTable(db)
case test.DBTypeSQLite:
tab, err = sqlite3.NewSqliteTopologyTable(db)
}
if err != nil {
t.Fatalf("failed to make new table: %s", err)
}
return tab, db, close
}
func TestTopologyTable(t *testing.T) {
ctx := context.Background()
alice := test.NewUser()
room := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, db, close := newTopologyTable(t, dbType)
defer close()
events := room.Events()
err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error {
var highestPos types.StreamPosition
for i, ev := range events {
topoPos, err := tab.InsertEventInTopology(ctx, txn, ev, types.StreamPosition(i))
if err != nil {
return fmt.Errorf("failed to InsertEventInTopology: %s", err)
}
// topo pos = depth, depth starts at 1, hence 1+i
if topoPos != types.StreamPosition(1+i) {
return fmt.Errorf("got topo pos %d want %d", topoPos, 1+i)
}
highestPos = topoPos + 1
}
// check ordering works without limit
eventIDs, err := tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, true)
if err != nil {
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
}
test.AssertEventIDsEqual(t, eventIDs, events[:])
eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, false)
if err != nil {
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
}
test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[:]))
// check ordering works with limit
eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, true)
if err != nil {
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
}
test.AssertEventIDsEqual(t, eventIDs, events[:3])
eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, false)
if err != nil {
return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
}
test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[len(events)-3:]))
return nil
})
if err != nil {
t.Fatalf("err: %s", err)
}
})
}

View file

@ -121,6 +121,7 @@ func WithAllDatabases(t *testing.T, testFn func(t *testing.T, db DBType)) {
for dbName, dbType := range dbs { for dbName, dbType := range dbs {
dbt := dbType dbt := dbType
t.Run(dbName, func(tt *testing.T) { t.Run(dbName, func(tt *testing.T) {
tt.Parallel()
testFn(tt, dbt) testFn(tt, dbt)
}) })
} }

View file

@ -15,7 +15,9 @@
package test package test
import ( import (
"bytes"
"crypto/ed25519" "crypto/ed25519"
"testing"
"time" "time"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -49,3 +51,40 @@ func WithUnsigned(unsigned interface{}) eventModifier {
e.unsigned = unsigned e.unsigned = unsigned
} }
} }
// Reverse a list of events
func Reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent {
out := make([]*gomatrixserverlib.HeaderedEvent, len(in))
for i := 0; i < len(in); i++ {
out[i] = in[len(in)-i-1]
}
return out
}
func AssertEventIDsEqual(t *testing.T, gotEventIDs []string, wants []*gomatrixserverlib.HeaderedEvent) {
t.Helper()
if len(gotEventIDs) != len(wants) {
t.Fatalf("length mismatch: got %d events, want %d", len(gotEventIDs), len(wants))
}
for i := range wants {
w := wants[i].EventID()
g := gotEventIDs[i]
if w != g {
t.Errorf("event at index %d mismatch:\ngot %s\n\nwant %s", i, string(g), string(w))
}
}
}
func AssertEventsEqual(t *testing.T, gots, wants []*gomatrixserverlib.HeaderedEvent) {
t.Helper()
if len(gots) != len(wants) {
t.Fatalf("length mismatch: got %d events, want %d", len(gots), len(wants))
}
for i := range wants {
w := wants[i].JSON()
g := gots[i].JSON()
if !bytes.Equal(w, g) {
t.Errorf("event at index %d mismatch:\ngot %s\n\nwant %s", i, string(g), string(w))
}
}
}