diff --git a/roomserver/storage/postgres/event_types_table.go b/roomserver/storage/postgres/event_types_table.go index 1d5de5822..15ab7fd8e 100644 --- a/roomserver/storage/postgres/event_types_table.go +++ b/roomserver/storage/postgres/event_types_table.go @@ -99,12 +99,12 @@ type eventTypeStatements struct { bulkSelectEventTypeNIDStmt *sql.Stmt } -func createEventTypesTable(db *sql.DB) error { +func CreateEventTypesTable(db *sql.DB) error { _, err := db.Exec(eventTypesSchema) return err } -func prepareEventTypesTable(db *sql.DB) (tables.EventTypes, error) { +func PrepareEventTypesTable(db *sql.DB) (tables.EventTypes, error) { s := &eventTypeStatements{} return s, sqlutil.StatementList{ @@ -143,9 +143,9 @@ func (s *eventTypeStatements) BulkSelectEventTypeNID( defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventTypeNID: rows.close() failed") result := make(map[string]types.EventTypeNID, len(eventTypes)) + var eventType string + var eventTypeNID int64 for rows.Next() { - var eventType string - var eventTypeNID int64 if err := rows.Scan(&eventType, &eventTypeNID); err != nil { return nil, err } diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 0d8236171..4956767fe 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -71,7 +71,7 @@ func (d *Database) create(db *sql.DB) error { if err := CreateEventStateKeysTable(db); err != nil { return err } - if err := createEventTypesTable(db); err != nil { + if err := CreateEventTypesTable(db); err != nil { return err } if err := CreateEventJSONTable(db); err != nil { @@ -116,7 +116,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room if err != nil { return err } - eventTypes, err := prepareEventTypesTable(db) + eventTypes, err := PrepareEventTypesTable(db) if err != nil { return err } diff --git a/roomserver/storage/sqlite3/event_types_table.go b/roomserver/storage/sqlite3/event_types_table.go index c49cc509a..0581ec194 100644 --- a/roomserver/storage/sqlite3/event_types_table.go +++ b/roomserver/storage/sqlite3/event_types_table.go @@ -79,12 +79,12 @@ type eventTypeStatements struct { bulkSelectEventTypeNIDStmt *sql.Stmt } -func createEventTypesTable(db *sql.DB) error { +func CreateEventTypesTable(db *sql.DB) error { _, err := db.Exec(eventTypesSchema) return err } -func prepareEventTypesTable(db *sql.DB) (tables.EventTypes, error) { +func PrepareEventTypesTable(db *sql.DB) (tables.EventTypes, error) { s := &eventTypeStatements{ db: db, } @@ -139,9 +139,9 @@ func (s *eventTypeStatements) BulkSelectEventTypeNID( defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventTypeNID: rows.close() failed") result := make(map[string]types.EventTypeNID, len(eventTypes)) + var eventType string + var eventTypeNID int64 for rows.Next() { - var eventType string - var eventTypeNID int64 if err := rows.Scan(&eventType, &eventTypeNID); err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 17e110056..6d69bf862 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -80,7 +80,7 @@ func (d *Database) create(db *sql.DB) error { if err := CreateEventStateKeysTable(db); err != nil { return err } - if err := createEventTypesTable(db); err != nil { + if err := CreateEventTypesTable(db); err != nil { return err } if err := CreateEventJSONTable(db); err != nil { @@ -125,7 +125,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room if err != nil { return err } - eventTypes, err := prepareEventTypesTable(db) + eventTypes, err := PrepareEventTypesTable(db) if err != nil { return err } diff --git a/roomserver/storage/tables/event_state_keys_test.go b/roomserver/storage/tables/event_state_keys_table_test.go similarity index 100% rename from roomserver/storage/tables/event_state_keys_test.go rename to roomserver/storage/tables/event_state_keys_table_test.go diff --git a/roomserver/storage/tables/event_types_table_test.go b/roomserver/storage/tables/event_types_table_test.go new file mode 100644 index 000000000..88cf93d82 --- /dev/null +++ b/roomserver/storage/tables/event_types_table_test.go @@ -0,0 +1,88 @@ +package tables_test + +import ( + "context" + "fmt" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/postgres" + "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" +) + +func mustCreateEventTypesTable(t *testing.T, dbType test.DBType) (tables.EventTypes, func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + if err != nil { + t.Fatalf("failed to open db: %s", err) + } + var tab tables.EventTypes + switch dbType { + case test.DBTypePostgres: + err = postgres.CreateEventTypesTable(db) + if err != nil { + t.Fatalf("failed to create EventJSON table: %s", err) + } + tab, err = postgres.PrepareEventTypesTable(db) + case test.DBTypeSQLite: + err = sqlite3.CreateEventTypesTable(db) + if err != nil { + t.Fatalf("failed to create EventJSON table: %s", err) + } + tab, err = sqlite3.PrepareEventTypesTable(db) + } + if err != nil { + t.Fatalf("failed to create table: %s", err) + } + + return tab, close +} + +func Test_EventTypesTable(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, close := mustCreateEventTypesTable(t, dbType) + defer close() + ctx := context.Background() + var eventTypeNID, gotEventTypeNID types.EventTypeNID + var err error + // create some dummy data + eventTypeMap := make(map[string]types.EventTypeNID) + for i := 0; i < 10; i++ { + eventType := fmt.Sprintf("dummyEventType%d", i) + if eventTypeNID, err = tab.InsertEventTypeNID( + ctx, nil, eventType, + ); err != nil { + t.Fatalf("unable to insert eventJSON: %s", err) + } + eventTypeMap[eventType] = eventTypeNID + gotEventTypeNID, err = tab.SelectEventTypeNID(ctx, nil, eventType) + if err != nil { + t.Fatalf("failed to get EventTypeNID: %s", err) + } + if eventTypeNID != gotEventTypeNID { + t.Fatalf("expected eventTypeNID %d, but got %d", eventTypeNID, gotEventTypeNID) + } + } + eventTypeNIDs, err := tab.BulkSelectEventTypeNID(ctx, nil, []string{"dummyEventType0", "dummyEventType3"}) + if err != nil { + t.Fatalf("failed to get EventStateKeyNIDs: %s", err) + } + // verify that BulkSelectEventTypeNID and InsertEventTypeNID return the same values + for eventType, nid := range eventTypeNIDs { + if v, ok := eventTypeMap[eventType]; ok { + if v != nid { + t.Fatalf("EventTypeNID does not match: %d != %d", nid, v) + } + } else { + t.Fatalf("unable to find %d in result set", nid) + } + } + }) +}