diff --git a/roomserver/storage/postgres/published_table.go b/roomserver/storage/postgres/published_table.go index 15985fcd6..56fa02f7b 100644 --- a/roomserver/storage/postgres/published_table.go +++ b/roomserver/storage/postgres/published_table.go @@ -49,12 +49,12 @@ type publishedStatements struct { selectPublishedStmt *sql.Stmt } -func createPublishedTable(db *sql.DB) error { +func CreatePublishedTable(db *sql.DB) error { _, err := db.Exec(publishedSchema) return err } -func preparePublishedTable(db *sql.DB) (tables.Published, error) { +func PreparePublishedTable(db *sql.DB) (tables.Published, error) { s := &publishedStatements{} return s, sqlutil.StatementList{ @@ -94,8 +94,8 @@ func (s *publishedStatements) SelectAllPublishedRooms( defer internal.CloseAndLogIfError(ctx, rows, "selectAllPublishedStmt: rows.close() failed") var roomIDs []string + var roomID string for rows.Next() { - var roomID string if err = rows.Scan(&roomID); err != nil { return nil, err } diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 0aa3f6c19..76ee6b489 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -101,7 +101,7 @@ func (d *Database) create(db *sql.DB) error { if err := CreateMembershipTable(db); err != nil { return err } - if err := createPublishedTable(db); err != nil { + if err := CreatePublishedTable(db); err != nil { return err } if err := createRedactionsTable(db); err != nil { @@ -156,7 +156,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room if err != nil { return err } - published, err := preparePublishedTable(db) + published, err := PreparePublishedTable(db) if err != nil { return err } diff --git a/roomserver/storage/sqlite3/published_table.go b/roomserver/storage/sqlite3/published_table.go index 9e416ace3..50dfa5492 100644 --- a/roomserver/storage/sqlite3/published_table.go +++ b/roomserver/storage/sqlite3/published_table.go @@ -49,12 +49,12 @@ type publishedStatements struct { selectPublishedStmt *sql.Stmt } -func createPublishedTable(db *sql.DB) error { +func CreatePublishedTable(db *sql.DB) error { _, err := db.Exec(publishedSchema) return err } -func preparePublishedTable(db *sql.DB) (tables.Published, error) { +func PreparePublishedTable(db *sql.DB) (tables.Published, error) { s := &publishedStatements{ db: db, } @@ -96,8 +96,8 @@ func (s *publishedStatements) SelectAllPublishedRooms( defer internal.CloseAndLogIfError(ctx, rows, "selectAllPublishedStmt: rows.close() failed") var roomIDs []string + var roomID string for rows.Next() { - var roomID string if err = rows.Scan(&roomID); err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 2928589f2..e041921b6 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -110,7 +110,7 @@ func (d *Database) create(db *sql.DB) error { if err := CreateMembershipTable(db); err != nil { return err } - if err := createPublishedTable(db); err != nil { + if err := CreatePublishedTable(db); err != nil { return err } if err := createRedactionsTable(db); err != nil { @@ -165,7 +165,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room if err != nil { return err } - published, err := preparePublishedTable(db) + published, err := PreparePublishedTable(db) if err != nil { return err } diff --git a/roomserver/storage/tables/published_table_test.go b/roomserver/storage/tables/published_table_test.go new file mode 100644 index 000000000..7d53242dc --- /dev/null +++ b/roomserver/storage/tables/published_table_test.go @@ -0,0 +1,68 @@ +package tables_test + +import ( + "context" + "sort" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/postgres" + "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/stretchr/testify/assert" +) + +func mustCreatePublishedTable(t *testing.T, dbType test.DBType) (tab tables.Published, close func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + switch dbType { + case test.DBTypePostgres: + err = postgres.CreatePublishedTable(db) + assert.NoError(t, err) + tab, err = postgres.PreparePublishedTable(db) + case test.DBTypeSQLite: + err = sqlite3.CreatePublishedTable(db) + assert.NoError(t, err) + tab, err = sqlite3.PreparePublishedTable(db) + } + assert.NoError(t, err) + + return tab, close +} + +func TestPublishedTable(t *testing.T) { + ctx := context.Background() + alice := test.NewUser() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, close := mustCreatePublishedTable(t, dbType) + defer close() + + // Publish some rooms + publishedRooms := []string{} + for i := 0; i < 10; i++ { + room := test.NewRoom(t, alice) + published := i%2 == 0 + err := tab.UpsertRoomPublished(ctx, nil, room.ID, published) + assert.NoError(t, err) + if published { + publishedRooms = append(publishedRooms, room.ID) + } + publishedRes, err := tab.SelectPublishedFromRoomID(ctx, nil, room.ID) + assert.NoError(t, err) + assert.Equal(t, published, publishedRes) + } + sort.Strings(publishedRooms) + + // check that we get the expected published rooms + roomIDs, err := tab.SelectAllPublishedRooms(ctx, nil, true) + assert.NoError(t, err) + assert.Equal(t, publishedRooms, roomIDs) + }) +}