From 34d1888c096287cedd4a5c1c37f1f1b68306bdb0 Mon Sep 17 00:00:00 2001 From: Till Faelligen Date: Tue, 10 May 2022 12:20:59 +0200 Subject: [PATCH] Add StateBlock tests Some optimizations --- .../storage/postgres/state_block_table.go | 20 ++--- roomserver/storage/postgres/storage.go | 4 +- .../storage/sqlite3/state_block_table.go | 16 ++-- roomserver/storage/sqlite3/storage.go | 4 +- .../storage/tables/state_block_table_test.go | 78 +++++++++++++++++++ 5 files changed, 100 insertions(+), 22 deletions(-) create mode 100644 roomserver/storage/tables/state_block_table_test.go diff --git a/roomserver/storage/postgres/state_block_table.go b/roomserver/storage/postgres/state_block_table.go index 73b04c92c..5af48f031 100644 --- a/roomserver/storage/postgres/state_block_table.go +++ b/roomserver/storage/postgres/state_block_table.go @@ -70,12 +70,12 @@ type stateBlockStatements struct { bulkSelectStateBlockEntriesStmt *sql.Stmt } -func createStateBlockTable(db *sql.DB) error { +func CreateStateBlockTable(db *sql.DB) error { _, err := db.Exec(stateDataSchema) return err } -func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) { +func PrepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) { s := &stateBlockStatements{} return s, sqlutil.StatementList{ @@ -89,9 +89,9 @@ func (s *stateBlockStatements) BulkInsertStateData( entries types.StateEntries, ) (id types.StateBlockNID, err error) { entries = entries[:util.SortAndUnique(entries)] - var nids types.EventNIDs - for _, e := range entries { - nids = append(nids, e.EventNID) + nids := make(types.EventNIDs, entries.Len()) + for i := range entries { + nids[i] = entries[i].EventNID } stmt := sqlutil.TxStmt(txn, s.insertStateDataStmt) err = stmt.QueryRowContext( @@ -112,15 +112,15 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries( results := make([][]types.EventNID, len(stateBlockNIDs)) i := 0 + var stateBlockNID types.StateBlockNID + var result pq.Int64Array for ; rows.Next(); i++ { - var stateBlockNID types.StateBlockNID - var result pq.Int64Array if err = rows.Scan(&stateBlockNID, &result); err != nil { return nil, err } - r := []types.EventNID{} - for _, e := range result { - r = append(r, types.EventNID(e)) + r := make([]types.EventNID, len(result)) + for x := range result { + r[x] = types.EventNID(result[x]) } results[i] = r } diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 93758dd78..2b83967dd 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -83,7 +83,7 @@ func (d *Database) create(db *sql.DB) error { if err := CreateRoomsTable(db); err != nil { return err } - if err := createStateBlockTable(db); err != nil { + if err := CreateStateBlockTable(db); err != nil { return err } if err := createStateSnapshotTable(db); err != nil { @@ -132,7 +132,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room if err != nil { return err } - stateBlock, err := prepareStateBlockTable(db) + stateBlock, err := PrepareStateBlockTable(db) if err != nil { return err } diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index 5f9b646bd..caf239a21 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -63,12 +63,12 @@ type stateBlockStatements struct { bulkSelectStateBlockEntriesStmt *sql.Stmt } -func createStateBlockTable(db *sql.DB) error { +func CreateStateBlockTable(db *sql.DB) error { _, err := db.Exec(stateDataSchema) return err } -func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) { +func PrepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) { s := &stateBlockStatements{ db: db, } @@ -84,9 +84,9 @@ func (s *stateBlockStatements) BulkInsertStateData( entries types.StateEntries, ) (id types.StateBlockNID, err error) { entries = entries[:util.SortAndUnique(entries)] - nids := types.EventNIDs{} // zero slice to not store 'null' in the DB - for _, e := range entries { - nids = append(nids, e.EventNID) + nids := make(types.EventNIDs, entries.Len()) + for i := range entries { + nids[i] = entries[i].EventNID } js, err := json.Marshal(nids) if err != nil { @@ -121,13 +121,13 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries( results := make([][]types.EventNID, len(stateBlockNIDs)) i := 0 + var stateBlockNID types.StateBlockNID + var result json.RawMessage + var r []types.EventNID for ; rows.Next(); i++ { - var stateBlockNID types.StateBlockNID - var result json.RawMessage if err = rows.Scan(&stateBlockNID, &result); err != nil { return nil, err } - r := []types.EventNID{} if err = json.Unmarshal(result, &r); err != nil { return nil, fmt.Errorf("json.Unmarshal: %w", err) } diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 0c63e8772..0c071ac12 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -92,7 +92,7 @@ func (d *Database) create(db *sql.DB) error { if err := CreateRoomsTable(db); err != nil { return err } - if err := createStateBlockTable(db); err != nil { + if err := CreateStateBlockTable(db); err != nil { return err } if err := createStateSnapshotTable(db); err != nil { @@ -141,7 +141,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room if err != nil { return err } - stateBlock, err := prepareStateBlockTable(db) + stateBlock, err := PrepareStateBlockTable(db) if err != nil { return err } diff --git a/roomserver/storage/tables/state_block_table_test.go b/roomserver/storage/tables/state_block_table_test.go new file mode 100644 index 000000000..8e6f1956e --- /dev/null +++ b/roomserver/storage/tables/state_block_table_test.go @@ -0,0 +1,78 @@ +package tables_test + +import ( + "context" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/postgres" + "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/stretchr/testify/assert" +) + +func mustCreateStateBlockTable(t *testing.T, dbType test.DBType) (tab tables.StateBlock, close func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + switch dbType { + case test.DBTypePostgres: + err = postgres.CreateStateBlockTable(db) + assert.NoError(t, err) + tab, err = postgres.PrepareStateBlockTable(db) + case test.DBTypeSQLite: + err = sqlite3.CreateStateBlockTable(db) + assert.NoError(t, err) + tab, err = sqlite3.PrepareStateBlockTable(db) + } + assert.NoError(t, err) + + return tab, close +} + +func TestStateBlockTable(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, close := mustCreateStateBlockTable(t, dbType) + defer close() + + // generate some dummy data + var entries types.StateEntries + for i := 0; i < 100; i++ { + entry := types.StateEntry{ + EventNID: types.EventNID(i), + } + entries = append(entries, entry) + } + stateBlockNID, err := tab.BulkInsertStateData(ctx, nil, entries) + assert.NoError(t, err) + assert.Equal(t, types.StateBlockNID(1), stateBlockNID) + + // generate a different hash, to get a new StateBlockNID + var entries2 types.StateEntries + for i := 100; i < 300; i++ { + entry := types.StateEntry{ + EventNID: types.EventNID(i), + } + entries2 = append(entries2, entry) + } + stateBlockNID, err = tab.BulkInsertStateData(ctx, nil, entries2) + assert.NoError(t, err) + assert.Equal(t, types.StateBlockNID(2), stateBlockNID) + + eventNIDs, err := tab.BulkSelectStateBlockEntries(ctx, nil, types.StateBlockNIDs{1, 2}) + assert.NoError(t, err) + assert.Equal(t, len(entries), len(eventNIDs[0])) + assert.Equal(t, len(entries2), len(eventNIDs[1])) + + // try to get a StateBlockNID which does not exist + _, err = tab.BulkSelectStateBlockEntries(ctx, nil, types.StateBlockNIDs{5}) + assert.Error(t, err) + }) +}