From e2751781e7c09faae535e59c8e4249c31dc319de Mon Sep 17 00:00:00 2001
From: Till Faelligen <tfaelligen@gmail.com>
Date: Fri, 6 May 2022 12:36:58 +0200
Subject: [PATCH] Add EventTypesTable tests

---
 .../storage/postgres/event_types_table.go     |  8 +-
 roomserver/storage/postgres/storage.go        |  4 +-
 .../storage/sqlite3/event_types_table.go      |  8 +-
 roomserver/storage/sqlite3/storage.go         |  4 +-
 ...test.go => event_state_keys_table_test.go} |  0
 .../storage/tables/event_types_table_test.go  | 88 +++++++++++++++++++
 6 files changed, 100 insertions(+), 12 deletions(-)
 rename roomserver/storage/tables/{event_state_keys_test.go => event_state_keys_table_test.go} (100%)
 create mode 100644 roomserver/storage/tables/event_types_table_test.go

diff --git a/roomserver/storage/postgres/event_types_table.go b/roomserver/storage/postgres/event_types_table.go
index 1d5de5822..15ab7fd8e 100644
--- a/roomserver/storage/postgres/event_types_table.go
+++ b/roomserver/storage/postgres/event_types_table.go
@@ -99,12 +99,12 @@ type eventTypeStatements struct {
 	bulkSelectEventTypeNIDStmt *sql.Stmt
 }
 
-func createEventTypesTable(db *sql.DB) error {
+func CreateEventTypesTable(db *sql.DB) error {
 	_, err := db.Exec(eventTypesSchema)
 	return err
 }
 
-func prepareEventTypesTable(db *sql.DB) (tables.EventTypes, error) {
+func PrepareEventTypesTable(db *sql.DB) (tables.EventTypes, error) {
 	s := &eventTypeStatements{}
 
 	return s, sqlutil.StatementList{
@@ -143,9 +143,9 @@ func (s *eventTypeStatements) BulkSelectEventTypeNID(
 	defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventTypeNID: rows.close() failed")
 
 	result := make(map[string]types.EventTypeNID, len(eventTypes))
+	var eventType string
+	var eventTypeNID int64
 	for rows.Next() {
-		var eventType string
-		var eventTypeNID int64
 		if err := rows.Scan(&eventType, &eventTypeNID); err != nil {
 			return nil, err
 		}
diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go
index 0d8236171..4956767fe 100644
--- a/roomserver/storage/postgres/storage.go
+++ b/roomserver/storage/postgres/storage.go
@@ -71,7 +71,7 @@ func (d *Database) create(db *sql.DB) error {
 	if err := CreateEventStateKeysTable(db); err != nil {
 		return err
 	}
-	if err := createEventTypesTable(db); err != nil {
+	if err := CreateEventTypesTable(db); err != nil {
 		return err
 	}
 	if err := CreateEventJSONTable(db); err != nil {
@@ -116,7 +116,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
 	if err != nil {
 		return err
 	}
-	eventTypes, err := prepareEventTypesTable(db)
+	eventTypes, err := PrepareEventTypesTable(db)
 	if err != nil {
 		return err
 	}
diff --git a/roomserver/storage/sqlite3/event_types_table.go b/roomserver/storage/sqlite3/event_types_table.go
index c49cc509a..0581ec194 100644
--- a/roomserver/storage/sqlite3/event_types_table.go
+++ b/roomserver/storage/sqlite3/event_types_table.go
@@ -79,12 +79,12 @@ type eventTypeStatements struct {
 	bulkSelectEventTypeNIDStmt *sql.Stmt
 }
 
-func createEventTypesTable(db *sql.DB) error {
+func CreateEventTypesTable(db *sql.DB) error {
 	_, err := db.Exec(eventTypesSchema)
 	return err
 }
 
-func prepareEventTypesTable(db *sql.DB) (tables.EventTypes, error) {
+func PrepareEventTypesTable(db *sql.DB) (tables.EventTypes, error) {
 	s := &eventTypeStatements{
 		db: db,
 	}
@@ -139,9 +139,9 @@ func (s *eventTypeStatements) BulkSelectEventTypeNID(
 	defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventTypeNID: rows.close() failed")
 
 	result := make(map[string]types.EventTypeNID, len(eventTypes))
+	var eventType string
+	var eventTypeNID int64
 	for rows.Next() {
-		var eventType string
-		var eventTypeNID int64
 		if err := rows.Scan(&eventType, &eventTypeNID); err != nil {
 			return nil, err
 		}
diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go
index 17e110056..6d69bf862 100644
--- a/roomserver/storage/sqlite3/storage.go
+++ b/roomserver/storage/sqlite3/storage.go
@@ -80,7 +80,7 @@ func (d *Database) create(db *sql.DB) error {
 	if err := CreateEventStateKeysTable(db); err != nil {
 		return err
 	}
-	if err := createEventTypesTable(db); err != nil {
+	if err := CreateEventTypesTable(db); err != nil {
 		return err
 	}
 	if err := CreateEventJSONTable(db); err != nil {
@@ -125,7 +125,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
 	if err != nil {
 		return err
 	}
-	eventTypes, err := prepareEventTypesTable(db)
+	eventTypes, err := PrepareEventTypesTable(db)
 	if err != nil {
 		return err
 	}
diff --git a/roomserver/storage/tables/event_state_keys_test.go b/roomserver/storage/tables/event_state_keys_table_test.go
similarity index 100%
rename from roomserver/storage/tables/event_state_keys_test.go
rename to roomserver/storage/tables/event_state_keys_table_test.go
diff --git a/roomserver/storage/tables/event_types_table_test.go b/roomserver/storage/tables/event_types_table_test.go
new file mode 100644
index 000000000..88cf93d82
--- /dev/null
+++ b/roomserver/storage/tables/event_types_table_test.go
@@ -0,0 +1,88 @@
+package tables_test
+
+import (
+	"context"
+	"fmt"
+	"testing"
+
+	"github.com/matrix-org/dendrite/internal/sqlutil"
+	"github.com/matrix-org/dendrite/roomserver/storage/postgres"
+	"github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
+	"github.com/matrix-org/dendrite/roomserver/storage/tables"
+	"github.com/matrix-org/dendrite/roomserver/types"
+	"github.com/matrix-org/dendrite/setup/config"
+	"github.com/matrix-org/dendrite/test"
+)
+
+func mustCreateEventTypesTable(t *testing.T, dbType test.DBType) (tables.EventTypes, func()) {
+	t.Helper()
+	connStr, close := test.PrepareDBConnectionString(t, dbType)
+	db, err := sqlutil.Open(&config.DatabaseOptions{
+		ConnectionString: config.DataSource(connStr),
+	}, sqlutil.NewExclusiveWriter())
+	if err != nil {
+		t.Fatalf("failed to open db: %s", err)
+	}
+	var tab tables.EventTypes
+	switch dbType {
+	case test.DBTypePostgres:
+		err = postgres.CreateEventTypesTable(db)
+		if err != nil {
+			t.Fatalf("failed to create EventJSON table: %s", err)
+		}
+		tab, err = postgres.PrepareEventTypesTable(db)
+	case test.DBTypeSQLite:
+		err = sqlite3.CreateEventTypesTable(db)
+		if err != nil {
+			t.Fatalf("failed to create EventJSON table: %s", err)
+		}
+		tab, err = sqlite3.PrepareEventTypesTable(db)
+	}
+	if err != nil {
+		t.Fatalf("failed to create table: %s", err)
+	}
+
+	return tab, close
+}
+
+func Test_EventTypesTable(t *testing.T) {
+	test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+		tab, close := mustCreateEventTypesTable(t, dbType)
+		defer close()
+		ctx := context.Background()
+		var eventTypeNID, gotEventTypeNID types.EventTypeNID
+		var err error
+		// create some dummy data
+		eventTypeMap := make(map[string]types.EventTypeNID)
+		for i := 0; i < 10; i++ {
+			eventType := fmt.Sprintf("dummyEventType%d", i)
+			if eventTypeNID, err = tab.InsertEventTypeNID(
+				ctx, nil, eventType,
+			); err != nil {
+				t.Fatalf("unable to insert eventJSON: %s", err)
+			}
+			eventTypeMap[eventType] = eventTypeNID
+			gotEventTypeNID, err = tab.SelectEventTypeNID(ctx, nil, eventType)
+			if err != nil {
+				t.Fatalf("failed to get EventTypeNID: %s", err)
+			}
+			if eventTypeNID != gotEventTypeNID {
+				t.Fatalf("expected eventTypeNID %d, but got %d", eventTypeNID, gotEventTypeNID)
+			}
+		}
+		eventTypeNIDs, err := tab.BulkSelectEventTypeNID(ctx, nil, []string{"dummyEventType0", "dummyEventType3"})
+		if err != nil {
+			t.Fatalf("failed to get EventStateKeyNIDs: %s", err)
+		}
+		// verify that BulkSelectEventTypeNID and InsertEventTypeNID return the same values
+		for eventType, nid := range eventTypeNIDs {
+			if v, ok := eventTypeMap[eventType]; ok {
+				if v != nid {
+					t.Fatalf("EventTypeNID does not match: %d != %d", nid, v)
+				}
+			} else {
+				t.Fatalf("unable to find %d in result set", nid)
+			}
+		}
+	})
+}