mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-01 11:13:12 -06:00
Add State Snapshot tests
Some optimization
This commit is contained in:
parent
34d1888c09
commit
454dc43e56
|
|
@ -77,12 +77,12 @@ type stateSnapshotStatements struct {
|
|||
bulkSelectStateBlockNIDsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func createStateSnapshotTable(db *sql.DB) error {
|
||||
func CreateStateSnapshotTable(db *sql.DB) error {
|
||||
_, err := db.Exec(stateSnapshotSchema)
|
||||
return err
|
||||
}
|
||||
|
||||
func prepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
|
||||
func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
|
||||
s := &stateSnapshotStatements{}
|
||||
|
||||
return s, sqlutil.StatementList{
|
||||
|
|
@ -95,12 +95,10 @@ func (s *stateSnapshotStatements) InsertState(
|
|||
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, nids types.StateBlockNIDs,
|
||||
) (stateNID types.StateSnapshotNID, err error) {
|
||||
nids = nids[:util.SortAndUnique(nids)]
|
||||
var id int64
|
||||
err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, nids.Hash(), int64(roomNID), stateBlockNIDsAsArray(nids)).Scan(&id)
|
||||
err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, nids.Hash(), int64(roomNID), stateBlockNIDsAsArray(nids)).Scan(&stateNID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
stateNID = types.StateSnapshotNID(id)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -119,9 +117,9 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
|
|||
defer rows.Close() // nolint: errcheck
|
||||
results := make([]types.StateBlockNIDList, len(stateNIDs))
|
||||
i := 0
|
||||
var stateBlockNIDs pq.Int64Array
|
||||
for ; rows.Next(); i++ {
|
||||
result := &results[i]
|
||||
var stateBlockNIDs pq.Int64Array
|
||||
if err = rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ func (d *Database) create(db *sql.DB) error {
|
|||
if err := CreateStateBlockTable(db); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := createStateSnapshotTable(db); err != nil {
|
||||
if err := CreateStateSnapshotTable(db); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := createPrevEventsTable(db); err != nil {
|
||||
|
|
@ -136,7 +136,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stateSnapshot, err := prepareStateSnapshotTable(db)
|
||||
stateSnapshot, err := PrepareStateSnapshotTable(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -68,12 +68,12 @@ type stateSnapshotStatements struct {
|
|||
bulkSelectStateBlockNIDsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func createStateSnapshotTable(db *sql.DB) error {
|
||||
func CreateStateSnapshotTable(db *sql.DB) error {
|
||||
_, err := db.Exec(stateSnapshotSchema)
|
||||
return err
|
||||
}
|
||||
|
||||
func prepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
|
||||
func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
|
||||
s := &stateSnapshotStatements{
|
||||
db: db,
|
||||
}
|
||||
|
|
@ -96,12 +96,10 @@ func (s *stateSnapshotStatements) InsertState(
|
|||
return
|
||||
}
|
||||
insertStmt := sqlutil.TxStmt(txn, s.insertStateStmt)
|
||||
var id int64
|
||||
err = insertStmt.QueryRowContext(ctx, stateBlockNIDs.Hash(), int64(roomNID), string(stateBlockNIDsJSON)).Scan(&id)
|
||||
err = insertStmt.QueryRowContext(ctx, stateBlockNIDs.Hash(), int64(roomNID), string(stateBlockNIDsJSON)).Scan(&stateNID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
stateNID = types.StateSnapshotNID(id)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -127,9 +125,9 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
|
|||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockNIDs: rows.close() failed")
|
||||
results := make([]types.StateBlockNIDList, len(stateNIDs))
|
||||
i := 0
|
||||
var stateBlockNIDsJSON string
|
||||
for ; rows.Next(); i++ {
|
||||
result := &results[i]
|
||||
var stateBlockNIDsJSON string
|
||||
if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDsJSON); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -95,7 +95,7 @@ func (d *Database) create(db *sql.DB) error {
|
|||
if err := CreateStateBlockTable(db); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := createStateSnapshotTable(db); err != nil {
|
||||
if err := CreateStateSnapshotTable(db); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := createPrevEventsTable(db); err != nil {
|
||||
|
|
@ -145,7 +145,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stateSnapshot, err := prepareStateSnapshotTable(db)
|
||||
stateSnapshot, err := PrepareStateSnapshotTable(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
79
roomserver/storage/tables/state_snapshot_table_test.go
Normal file
79
roomserver/storage/tables/state_snapshot_table_test.go
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
package tables_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/postgres"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func mustCreateStateSnapshotTable(t *testing.T, dbType test.DBType) (tab tables.StateSnapshot, close func()) {
|
||||
t.Helper()
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
}, sqlutil.NewExclusiveWriter())
|
||||
assert.NoError(t, err)
|
||||
switch dbType {
|
||||
case test.DBTypePostgres:
|
||||
err = postgres.CreateStateSnapshotTable(db)
|
||||
assert.NoError(t, err)
|
||||
tab, err = postgres.PrepareStateSnapshotTable(db)
|
||||
case test.DBTypeSQLite:
|
||||
err = sqlite3.CreateStateSnapshotTable(db)
|
||||
assert.NoError(t, err)
|
||||
tab, err = sqlite3.PrepareStateSnapshotTable(db)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
return tab, close
|
||||
}
|
||||
|
||||
func TestStateSnapshotTable(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
tab, close := mustCreateStateSnapshotTable(t, dbType)
|
||||
defer close()
|
||||
|
||||
// generate some dummy data
|
||||
var stateBlockNIDs types.StateBlockNIDs
|
||||
for i := 0; i < 100; i++ {
|
||||
stateBlockNIDs = append(stateBlockNIDs, types.StateBlockNID(i))
|
||||
}
|
||||
stateNID, err := tab.InsertState(ctx, nil, 1, stateBlockNIDs)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, types.StateSnapshotNID(1), stateNID)
|
||||
|
||||
// verify ON CONFLICT; Note: this updates the sequence!
|
||||
stateNID, err = tab.InsertState(ctx, nil, 1, stateBlockNIDs)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, types.StateSnapshotNID(1), stateNID)
|
||||
|
||||
// create a second snapshot
|
||||
var stateBlockNIDs2 types.StateBlockNIDs
|
||||
for i := 100; i < 150; i++ {
|
||||
stateBlockNIDs2 = append(stateBlockNIDs2, types.StateBlockNID(i))
|
||||
}
|
||||
|
||||
stateNID, err = tab.InsertState(ctx, nil, 1, stateBlockNIDs2)
|
||||
assert.NoError(t, err)
|
||||
// StateSnapshotNID is now 3, since the DO UPDATE SET statement incremented the sequence
|
||||
assert.Equal(t, types.StateSnapshotNID(3), stateNID)
|
||||
|
||||
nidLists, err := tab.BulkSelectStateBlockNIDs(ctx, nil, []types.StateSnapshotNID{1, 3})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, stateBlockNIDs, types.StateBlockNIDs(nidLists[0].StateBlockNIDs))
|
||||
assert.Equal(t, stateBlockNIDs2, types.StateBlockNIDs(nidLists[1].StateBlockNIDs))
|
||||
|
||||
// check we get an error if the state snapshot does not exist
|
||||
_, err = tab.BulkSelectStateBlockNIDs(ctx, nil, []types.StateSnapshotNID{2})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
Loading…
Reference in a new issue