From 454dc43e5679e971e843d9d97d976437815db9d8 Mon Sep 17 00:00:00 2001 From: Till Faelligen Date: Tue, 10 May 2022 13:17:57 +0200 Subject: [PATCH] Add State Snapshot tests Some optimization --- .../storage/postgres/state_snapshot_table.go | 10 +-- roomserver/storage/postgres/storage.go | 4 +- .../storage/sqlite3/state_snapshot_table.go | 10 +-- roomserver/storage/sqlite3/storage.go | 4 +- .../tables/state_snapshot_table_test.go | 79 +++++++++++++++++++ 5 files changed, 91 insertions(+), 16 deletions(-) create mode 100644 roomserver/storage/tables/state_snapshot_table_test.go diff --git a/roomserver/storage/postgres/state_snapshot_table.go b/roomserver/storage/postgres/state_snapshot_table.go index 8ed886030..a24b7f3f0 100644 --- a/roomserver/storage/postgres/state_snapshot_table.go +++ b/roomserver/storage/postgres/state_snapshot_table.go @@ -77,12 +77,12 @@ type stateSnapshotStatements struct { bulkSelectStateBlockNIDsStmt *sql.Stmt } -func createStateSnapshotTable(db *sql.DB) error { +func CreateStateSnapshotTable(db *sql.DB) error { _, err := db.Exec(stateSnapshotSchema) return err } -func prepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { +func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { s := &stateSnapshotStatements{} return s, sqlutil.StatementList{ @@ -95,12 +95,10 @@ func (s *stateSnapshotStatements) InsertState( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, nids types.StateBlockNIDs, ) (stateNID types.StateSnapshotNID, err error) { nids = nids[:util.SortAndUnique(nids)] - var id int64 - err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, nids.Hash(), int64(roomNID), stateBlockNIDsAsArray(nids)).Scan(&id) + err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, nids.Hash(), int64(roomNID), stateBlockNIDsAsArray(nids)).Scan(&stateNID) if err != nil { return 0, err } - stateNID = types.StateSnapshotNID(id) return } @@ -119,9 +117,9 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( defer rows.Close() // nolint: errcheck results := make([]types.StateBlockNIDList, len(stateNIDs)) i := 0 + var stateBlockNIDs pq.Int64Array for ; rows.Next(); i++ { result := &results[i] - var stateBlockNIDs pq.Int64Array if err = rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil { return nil, err } diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 2b83967dd..9eb7b8040 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -86,7 +86,7 @@ func (d *Database) create(db *sql.DB) error { if err := CreateStateBlockTable(db); err != nil { return err } - if err := createStateSnapshotTable(db); err != nil { + if err := CreateStateSnapshotTable(db); err != nil { return err } if err := createPrevEventsTable(db); err != nil { @@ -136,7 +136,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room if err != nil { return err } - stateSnapshot, err := prepareStateSnapshotTable(db) + stateSnapshot, err := PrepareStateSnapshotTable(db) if err != nil { return err } diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index 1f5e9ee3b..b8136b758 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -68,12 +68,12 @@ type stateSnapshotStatements struct { bulkSelectStateBlockNIDsStmt *sql.Stmt } -func createStateSnapshotTable(db *sql.DB) error { +func CreateStateSnapshotTable(db *sql.DB) error { _, err := db.Exec(stateSnapshotSchema) return err } -func prepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { +func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { s := &stateSnapshotStatements{ db: db, } @@ -96,12 +96,10 @@ func (s *stateSnapshotStatements) InsertState( return } insertStmt := sqlutil.TxStmt(txn, s.insertStateStmt) - var id int64 - err = insertStmt.QueryRowContext(ctx, stateBlockNIDs.Hash(), int64(roomNID), string(stateBlockNIDsJSON)).Scan(&id) + err = insertStmt.QueryRowContext(ctx, stateBlockNIDs.Hash(), int64(roomNID), string(stateBlockNIDsJSON)).Scan(&stateNID) if err != nil { return 0, err } - stateNID = types.StateSnapshotNID(id) return } @@ -127,9 +125,9 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockNIDs: rows.close() failed") results := make([]types.StateBlockNIDList, len(stateNIDs)) i := 0 + var stateBlockNIDsJSON string for ; rows.Next(); i++ { result := &results[i] - var stateBlockNIDsJSON string if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDsJSON); err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 0c071ac12..bed566367 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -95,7 +95,7 @@ func (d *Database) create(db *sql.DB) error { if err := CreateStateBlockTable(db); err != nil { return err } - if err := createStateSnapshotTable(db); err != nil { + if err := CreateStateSnapshotTable(db); err != nil { return err } if err := createPrevEventsTable(db); err != nil { @@ -145,7 +145,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room if err != nil { return err } - stateSnapshot, err := prepareStateSnapshotTable(db) + stateSnapshot, err := PrepareStateSnapshotTable(db) if err != nil { return err } diff --git a/roomserver/storage/tables/state_snapshot_table_test.go b/roomserver/storage/tables/state_snapshot_table_test.go new file mode 100644 index 000000000..a05fd6d87 --- /dev/null +++ b/roomserver/storage/tables/state_snapshot_table_test.go @@ -0,0 +1,79 @@ +package tables_test + +import ( + "context" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/postgres" + "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/stretchr/testify/assert" +) + +func mustCreateStateSnapshotTable(t *testing.T, dbType test.DBType) (tab tables.StateSnapshot, close func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + switch dbType { + case test.DBTypePostgres: + err = postgres.CreateStateSnapshotTable(db) + assert.NoError(t, err) + tab, err = postgres.PrepareStateSnapshotTable(db) + case test.DBTypeSQLite: + err = sqlite3.CreateStateSnapshotTable(db) + assert.NoError(t, err) + tab, err = sqlite3.PrepareStateSnapshotTable(db) + } + assert.NoError(t, err) + + return tab, close +} + +func TestStateSnapshotTable(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, close := mustCreateStateSnapshotTable(t, dbType) + defer close() + + // generate some dummy data + var stateBlockNIDs types.StateBlockNIDs + for i := 0; i < 100; i++ { + stateBlockNIDs = append(stateBlockNIDs, types.StateBlockNID(i)) + } + stateNID, err := tab.InsertState(ctx, nil, 1, stateBlockNIDs) + assert.NoError(t, err) + assert.Equal(t, types.StateSnapshotNID(1), stateNID) + + // verify ON CONFLICT; Note: this updates the sequence! + stateNID, err = tab.InsertState(ctx, nil, 1, stateBlockNIDs) + assert.NoError(t, err) + assert.Equal(t, types.StateSnapshotNID(1), stateNID) + + // create a second snapshot + var stateBlockNIDs2 types.StateBlockNIDs + for i := 100; i < 150; i++ { + stateBlockNIDs2 = append(stateBlockNIDs2, types.StateBlockNID(i)) + } + + stateNID, err = tab.InsertState(ctx, nil, 1, stateBlockNIDs2) + assert.NoError(t, err) + // StateSnapshotNID is now 3, since the DO UPDATE SET statement incremented the sequence + assert.Equal(t, types.StateSnapshotNID(3), stateNID) + + nidLists, err := tab.BulkSelectStateBlockNIDs(ctx, nil, []types.StateSnapshotNID{1, 3}) + assert.NoError(t, err) + assert.Equal(t, stateBlockNIDs, types.StateBlockNIDs(nidLists[0].StateBlockNIDs)) + assert.Equal(t, stateBlockNIDs2, types.StateBlockNIDs(nidLists[1].StateBlockNIDs)) + + // check we get an error if the state snapshot does not exist + _, err = tab.BulkSelectStateBlockNIDs(ctx, nil, []types.StateSnapshotNID{2}) + assert.Error(t, err) + }) +}