mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-06 05:33:10 -06:00
Add State Snapshot tests
Some optimization
This commit is contained in:
parent
34d1888c09
commit
454dc43e56
|
|
@ -77,12 +77,12 @@ type stateSnapshotStatements struct {
|
||||||
bulkSelectStateBlockNIDsStmt *sql.Stmt
|
bulkSelectStateBlockNIDsStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func createStateSnapshotTable(db *sql.DB) error {
|
func CreateStateSnapshotTable(db *sql.DB) error {
|
||||||
_, err := db.Exec(stateSnapshotSchema)
|
_, err := db.Exec(stateSnapshotSchema)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func prepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
|
func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
|
||||||
s := &stateSnapshotStatements{}
|
s := &stateSnapshotStatements{}
|
||||||
|
|
||||||
return s, sqlutil.StatementList{
|
return s, sqlutil.StatementList{
|
||||||
|
|
@ -95,12 +95,10 @@ func (s *stateSnapshotStatements) InsertState(
|
||||||
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, nids types.StateBlockNIDs,
|
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, nids types.StateBlockNIDs,
|
||||||
) (stateNID types.StateSnapshotNID, err error) {
|
) (stateNID types.StateSnapshotNID, err error) {
|
||||||
nids = nids[:util.SortAndUnique(nids)]
|
nids = nids[:util.SortAndUnique(nids)]
|
||||||
var id int64
|
err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, nids.Hash(), int64(roomNID), stateBlockNIDsAsArray(nids)).Scan(&stateNID)
|
||||||
err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, nids.Hash(), int64(roomNID), stateBlockNIDsAsArray(nids)).Scan(&id)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
stateNID = types.StateSnapshotNID(id)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -119,9 +117,9 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
|
||||||
defer rows.Close() // nolint: errcheck
|
defer rows.Close() // nolint: errcheck
|
||||||
results := make([]types.StateBlockNIDList, len(stateNIDs))
|
results := make([]types.StateBlockNIDList, len(stateNIDs))
|
||||||
i := 0
|
i := 0
|
||||||
|
var stateBlockNIDs pq.Int64Array
|
||||||
for ; rows.Next(); i++ {
|
for ; rows.Next(); i++ {
|
||||||
result := &results[i]
|
result := &results[i]
|
||||||
var stateBlockNIDs pq.Int64Array
|
|
||||||
if err = rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil {
|
if err = rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -86,7 +86,7 @@ func (d *Database) create(db *sql.DB) error {
|
||||||
if err := CreateStateBlockTable(db); err != nil {
|
if err := CreateStateBlockTable(db); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := createStateSnapshotTable(db); err != nil {
|
if err := CreateStateSnapshotTable(db); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := createPrevEventsTable(db); err != nil {
|
if err := createPrevEventsTable(db); err != nil {
|
||||||
|
|
@ -136,7 +136,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
stateSnapshot, err := prepareStateSnapshotTable(db)
|
stateSnapshot, err := PrepareStateSnapshotTable(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -68,12 +68,12 @@ type stateSnapshotStatements struct {
|
||||||
bulkSelectStateBlockNIDsStmt *sql.Stmt
|
bulkSelectStateBlockNIDsStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func createStateSnapshotTable(db *sql.DB) error {
|
func CreateStateSnapshotTable(db *sql.DB) error {
|
||||||
_, err := db.Exec(stateSnapshotSchema)
|
_, err := db.Exec(stateSnapshotSchema)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func prepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
|
func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
|
||||||
s := &stateSnapshotStatements{
|
s := &stateSnapshotStatements{
|
||||||
db: db,
|
db: db,
|
||||||
}
|
}
|
||||||
|
|
@ -96,12 +96,10 @@ func (s *stateSnapshotStatements) InsertState(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
insertStmt := sqlutil.TxStmt(txn, s.insertStateStmt)
|
insertStmt := sqlutil.TxStmt(txn, s.insertStateStmt)
|
||||||
var id int64
|
err = insertStmt.QueryRowContext(ctx, stateBlockNIDs.Hash(), int64(roomNID), string(stateBlockNIDsJSON)).Scan(&stateNID)
|
||||||
err = insertStmt.QueryRowContext(ctx, stateBlockNIDs.Hash(), int64(roomNID), string(stateBlockNIDsJSON)).Scan(&id)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
stateNID = types.StateSnapshotNID(id)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -127,9 +125,9 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockNIDs: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockNIDs: rows.close() failed")
|
||||||
results := make([]types.StateBlockNIDList, len(stateNIDs))
|
results := make([]types.StateBlockNIDList, len(stateNIDs))
|
||||||
i := 0
|
i := 0
|
||||||
|
var stateBlockNIDsJSON string
|
||||||
for ; rows.Next(); i++ {
|
for ; rows.Next(); i++ {
|
||||||
result := &results[i]
|
result := &results[i]
|
||||||
var stateBlockNIDsJSON string
|
|
||||||
if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDsJSON); err != nil {
|
if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDsJSON); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -95,7 +95,7 @@ func (d *Database) create(db *sql.DB) error {
|
||||||
if err := CreateStateBlockTable(db); err != nil {
|
if err := CreateStateBlockTable(db); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := createStateSnapshotTable(db); err != nil {
|
if err := CreateStateSnapshotTable(db); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := createPrevEventsTable(db); err != nil {
|
if err := createPrevEventsTable(db); err != nil {
|
||||||
|
|
@ -145,7 +145,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
stateSnapshot, err := prepareStateSnapshotTable(db)
|
stateSnapshot, err := PrepareStateSnapshotTable(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
79
roomserver/storage/tables/state_snapshot_table_test.go
Normal file
79
roomserver/storage/tables/state_snapshot_table_test.go
Normal file
|
|
@ -0,0 +1,79 @@
|
||||||
|
package tables_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage/postgres"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/test"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mustCreateStateSnapshotTable(t *testing.T, dbType test.DBType) (tab tables.StateSnapshot, close func()) {
|
||||||
|
t.Helper()
|
||||||
|
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||||
|
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||||
|
ConnectionString: config.DataSource(connStr),
|
||||||
|
}, sqlutil.NewExclusiveWriter())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
switch dbType {
|
||||||
|
case test.DBTypePostgres:
|
||||||
|
err = postgres.CreateStateSnapshotTable(db)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
tab, err = postgres.PrepareStateSnapshotTable(db)
|
||||||
|
case test.DBTypeSQLite:
|
||||||
|
err = sqlite3.CreateStateSnapshotTable(db)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
tab, err = sqlite3.PrepareStateSnapshotTable(db)
|
||||||
|
}
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
return tab, close
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStateSnapshotTable(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
tab, close := mustCreateStateSnapshotTable(t, dbType)
|
||||||
|
defer close()
|
||||||
|
|
||||||
|
// generate some dummy data
|
||||||
|
var stateBlockNIDs types.StateBlockNIDs
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
stateBlockNIDs = append(stateBlockNIDs, types.StateBlockNID(i))
|
||||||
|
}
|
||||||
|
stateNID, err := tab.InsertState(ctx, nil, 1, stateBlockNIDs)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, types.StateSnapshotNID(1), stateNID)
|
||||||
|
|
||||||
|
// verify ON CONFLICT; Note: this updates the sequence!
|
||||||
|
stateNID, err = tab.InsertState(ctx, nil, 1, stateBlockNIDs)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, types.StateSnapshotNID(1), stateNID)
|
||||||
|
|
||||||
|
// create a second snapshot
|
||||||
|
var stateBlockNIDs2 types.StateBlockNIDs
|
||||||
|
for i := 100; i < 150; i++ {
|
||||||
|
stateBlockNIDs2 = append(stateBlockNIDs2, types.StateBlockNID(i))
|
||||||
|
}
|
||||||
|
|
||||||
|
stateNID, err = tab.InsertState(ctx, nil, 1, stateBlockNIDs2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
// StateSnapshotNID is now 3, since the DO UPDATE SET statement incremented the sequence
|
||||||
|
assert.Equal(t, types.StateSnapshotNID(3), stateNID)
|
||||||
|
|
||||||
|
nidLists, err := tab.BulkSelectStateBlockNIDs(ctx, nil, []types.StateSnapshotNID{1, 3})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, stateBlockNIDs, types.StateBlockNIDs(nidLists[0].StateBlockNIDs))
|
||||||
|
assert.Equal(t, stateBlockNIDs2, types.StateBlockNIDs(nidLists[1].StateBlockNIDs))
|
||||||
|
|
||||||
|
// check we get an error if the state snapshot does not exist
|
||||||
|
_, err = tab.BulkSelectStateBlockNIDs(ctx, nil, []types.StateSnapshotNID{2})
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
Loading…
Reference in a new issue