Add State Snapshot tests

Some optimization
This commit is contained in:
Till Faelligen 2022-05-10 13:17:57 +02:00
parent 34d1888c09
commit 454dc43e56
5 changed files with 91 additions and 16 deletions

View file

@ -77,12 +77,12 @@ type stateSnapshotStatements struct {
bulkSelectStateBlockNIDsStmt *sql.Stmt
}
func createStateSnapshotTable(db *sql.DB) error {
func CreateStateSnapshotTable(db *sql.DB) error {
_, err := db.Exec(stateSnapshotSchema)
return err
}
func prepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
s := &stateSnapshotStatements{}
return s, sqlutil.StatementList{
@ -95,12 +95,10 @@ func (s *stateSnapshotStatements) InsertState(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, nids types.StateBlockNIDs,
) (stateNID types.StateSnapshotNID, err error) {
nids = nids[:util.SortAndUnique(nids)]
var id int64
err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, nids.Hash(), int64(roomNID), stateBlockNIDsAsArray(nids)).Scan(&id)
err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, nids.Hash(), int64(roomNID), stateBlockNIDsAsArray(nids)).Scan(&stateNID)
if err != nil {
return 0, err
}
stateNID = types.StateSnapshotNID(id)
return
}
@ -119,9 +117,9 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
defer rows.Close() // nolint: errcheck
results := make([]types.StateBlockNIDList, len(stateNIDs))
i := 0
var stateBlockNIDs pq.Int64Array
for ; rows.Next(); i++ {
result := &results[i]
var stateBlockNIDs pq.Int64Array
if err = rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil {
return nil, err
}

View file

@ -86,7 +86,7 @@ func (d *Database) create(db *sql.DB) error {
if err := CreateStateBlockTable(db); err != nil {
return err
}
if err := createStateSnapshotTable(db); err != nil {
if err := CreateStateSnapshotTable(db); err != nil {
return err
}
if err := createPrevEventsTable(db); err != nil {
@ -136,7 +136,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
if err != nil {
return err
}
stateSnapshot, err := prepareStateSnapshotTable(db)
stateSnapshot, err := PrepareStateSnapshotTable(db)
if err != nil {
return err
}

View file

@ -68,12 +68,12 @@ type stateSnapshotStatements struct {
bulkSelectStateBlockNIDsStmt *sql.Stmt
}
func createStateSnapshotTable(db *sql.DB) error {
func CreateStateSnapshotTable(db *sql.DB) error {
_, err := db.Exec(stateSnapshotSchema)
return err
}
func prepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
s := &stateSnapshotStatements{
db: db,
}
@ -96,12 +96,10 @@ func (s *stateSnapshotStatements) InsertState(
return
}
insertStmt := sqlutil.TxStmt(txn, s.insertStateStmt)
var id int64
err = insertStmt.QueryRowContext(ctx, stateBlockNIDs.Hash(), int64(roomNID), string(stateBlockNIDsJSON)).Scan(&id)
err = insertStmt.QueryRowContext(ctx, stateBlockNIDs.Hash(), int64(roomNID), string(stateBlockNIDsJSON)).Scan(&stateNID)
if err != nil {
return 0, err
}
stateNID = types.StateSnapshotNID(id)
return
}
@ -127,9 +125,9 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockNIDs: rows.close() failed")
results := make([]types.StateBlockNIDList, len(stateNIDs))
i := 0
var stateBlockNIDsJSON string
for ; rows.Next(); i++ {
result := &results[i]
var stateBlockNIDsJSON string
if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDsJSON); err != nil {
return nil, err
}

View file

@ -95,7 +95,7 @@ func (d *Database) create(db *sql.DB) error {
if err := CreateStateBlockTable(db); err != nil {
return err
}
if err := createStateSnapshotTable(db); err != nil {
if err := CreateStateSnapshotTable(db); err != nil {
return err
}
if err := createPrevEventsTable(db); err != nil {
@ -145,7 +145,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
if err != nil {
return err
}
stateSnapshot, err := prepareStateSnapshotTable(db)
stateSnapshot, err := PrepareStateSnapshotTable(db)
if err != nil {
return err
}

View file

@ -0,0 +1,79 @@
package tables_test
import (
"context"
"testing"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/postgres"
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/stretchr/testify/assert"
)
func mustCreateStateSnapshotTable(t *testing.T, dbType test.DBType) (tab tables.StateSnapshot, close func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
assert.NoError(t, err)
switch dbType {
case test.DBTypePostgres:
err = postgres.CreateStateSnapshotTable(db)
assert.NoError(t, err)
tab, err = postgres.PrepareStateSnapshotTable(db)
case test.DBTypeSQLite:
err = sqlite3.CreateStateSnapshotTable(db)
assert.NoError(t, err)
tab, err = sqlite3.PrepareStateSnapshotTable(db)
}
assert.NoError(t, err)
return tab, close
}
func TestStateSnapshotTable(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, close := mustCreateStateSnapshotTable(t, dbType)
defer close()
// generate some dummy data
var stateBlockNIDs types.StateBlockNIDs
for i := 0; i < 100; i++ {
stateBlockNIDs = append(stateBlockNIDs, types.StateBlockNID(i))
}
stateNID, err := tab.InsertState(ctx, nil, 1, stateBlockNIDs)
assert.NoError(t, err)
assert.Equal(t, types.StateSnapshotNID(1), stateNID)
// verify ON CONFLICT; Note: this updates the sequence!
stateNID, err = tab.InsertState(ctx, nil, 1, stateBlockNIDs)
assert.NoError(t, err)
assert.Equal(t, types.StateSnapshotNID(1), stateNID)
// create a second snapshot
var stateBlockNIDs2 types.StateBlockNIDs
for i := 100; i < 150; i++ {
stateBlockNIDs2 = append(stateBlockNIDs2, types.StateBlockNID(i))
}
stateNID, err = tab.InsertState(ctx, nil, 1, stateBlockNIDs2)
assert.NoError(t, err)
// StateSnapshotNID is now 3, since the DO UPDATE SET statement incremented the sequence
assert.Equal(t, types.StateSnapshotNID(3), stateNID)
nidLists, err := tab.BulkSelectStateBlockNIDs(ctx, nil, []types.StateSnapshotNID{1, 3})
assert.NoError(t, err)
assert.Equal(t, stateBlockNIDs, types.StateBlockNIDs(nidLists[0].StateBlockNIDs))
assert.Equal(t, stateBlockNIDs2, types.StateBlockNIDs(nidLists[1].StateBlockNIDs))
// check we get an error if the state snapshot does not exist
_, err = tab.BulkSelectStateBlockNIDs(ctx, nil, []types.StateSnapshotNID{2})
assert.Error(t, err)
})
}