diff --git a/roomserver/storage/postgres/previous_events_table.go b/roomserver/storage/postgres/previous_events_table.go index bd4e853eb..26999a290 100644 --- a/roomserver/storage/postgres/previous_events_table.go +++ b/roomserver/storage/postgres/previous_events_table.go @@ -64,12 +64,12 @@ type previousEventStatements struct { selectPreviousEventExistsStmt *sql.Stmt } -func createPrevEventsTable(db *sql.DB) error { +func CreatePrevEventsTable(db *sql.DB) error { _, err := db.Exec(previousEventSchema) return err } -func preparePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) { +func PreparePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) { s := &previousEventStatements{} return s, sqlutil.StatementList{ diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 615a10d39..0aa3f6c19 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -89,7 +89,7 @@ func (d *Database) create(db *sql.DB) error { if err := createStateSnapshotTable(db); err != nil { return err } - if err := createPrevEventsTable(db); err != nil { + if err := CreatePrevEventsTable(db); err != nil { return err } if err := createRoomAliasesTable(db); err != nil { @@ -140,7 +140,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room if err != nil { return err } - prevEvents, err := preparePrevEventsTable(db) + prevEvents, err := PreparePrevEventsTable(db) if err != nil { return err } diff --git a/roomserver/storage/sqlite3/previous_events_table.go b/roomserver/storage/sqlite3/previous_events_table.go index 7304bf0d5..2a146ef64 100644 --- a/roomserver/storage/sqlite3/previous_events_table.go +++ b/roomserver/storage/sqlite3/previous_events_table.go @@ -70,12 +70,12 @@ type previousEventStatements struct { selectPreviousEventExistsStmt *sql.Stmt } -func createPrevEventsTable(db *sql.DB) error { +func CreatePrevEventsTable(db *sql.DB) error { _, err := db.Exec(previousEventSchema) return err } -func preparePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) { +func PreparePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) { s := &previousEventStatements{ db: db, } diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 2ec03e355..2928589f2 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -98,7 +98,7 @@ func (d *Database) create(db *sql.DB) error { if err := createStateSnapshotTable(db); err != nil { return err } - if err := createPrevEventsTable(db); err != nil { + if err := CreatePrevEventsTable(db); err != nil { return err } if err := createRoomAliasesTable(db); err != nil { @@ -149,7 +149,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room if err != nil { return err } - prevEvents, err := preparePrevEventsTable(db) + prevEvents, err := PreparePrevEventsTable(db) if err != nil { return err } diff --git a/roomserver/storage/tables/previous_events_table_test.go b/roomserver/storage/tables/previous_events_table_test.go new file mode 100644 index 000000000..96d7bfed0 --- /dev/null +++ b/roomserver/storage/tables/previous_events_table_test.go @@ -0,0 +1,61 @@ +package tables_test + +import ( + "context" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/postgres" + "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/util" + "github.com/stretchr/testify/assert" +) + +func mustCreatePreviousEventsTable(t *testing.T, dbType test.DBType) (tab tables.PreviousEvents, close func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + switch dbType { + case test.DBTypePostgres: + err = postgres.CreatePrevEventsTable(db) + assert.NoError(t, err) + tab, err = postgres.PreparePrevEventsTable(db) + case test.DBTypeSQLite: + err = sqlite3.CreatePrevEventsTable(db) + assert.NoError(t, err) + tab, err = sqlite3.PreparePrevEventsTable(db) + } + assert.NoError(t, err) + + return tab, close +} + +func TestPreviousEventsTable(t *testing.T) { + ctx := context.Background() + alice := test.NewUser() + room := test.NewRoom(t, alice) + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, close := mustCreatePreviousEventsTable(t, dbType) + defer close() + + for _, x := range room.Events() { + for _, prevEvent := range x.PrevEvents() { + err := tab.InsertPreviousEvent(ctx, nil, prevEvent.EventID, prevEvent.EventSHA256, 1) + assert.NoError(t, err) + + err = tab.SelectPreviousEventExists(ctx, nil, prevEvent.EventID, prevEvent.EventSHA256) + assert.NoError(t, err) + } + } + + // RandomString with a correct EventSHA256 should fail and return sql.ErrNoRows + err := tab.SelectPreviousEventExists(ctx, nil, util.RandomString(16), room.Events()[0].EventReference().EventSHA256) + assert.Error(t, err) + }) +}