diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index 8012174a0..1e33c59ce 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -155,12 +155,12 @@ type eventStatements struct { selectRoomNIDsForEventNIDsStmt *sql.Stmt } -func createEventsTable(db *sql.DB) error { +func CreateEventsTable(db *sql.DB) error { _, err := db.Exec(eventsSchema) return err } -func prepareEventsTable(db *sql.DB) (tables.Events, error) { +func PrepareEventsTable(db *sql.DB) (tables.Events, error) { s := &eventStatements{} return s, sqlutil.StatementList{ @@ -380,15 +380,15 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference( defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateAtEventAndReference: rows.close() failed") results := make([]types.StateAtEventAndReference, len(eventNIDs)) i := 0 + var ( + eventTypeNID int64 + eventStateKeyNID int64 + eventNID int64 + stateSnapshotNID int64 + eventID string + eventSHA256 []byte + ) for ; rows.Next(); i++ { - var ( - eventTypeNID int64 - eventStateKeyNID int64 - eventNID int64 - stateSnapshotNID int64 - eventID string - eventSHA256 []byte - ) if err = rows.Scan( &eventTypeNID, &eventStateKeyNID, &eventNID, &stateSnapshotNID, &eventID, &eventSHA256, ); err != nil { @@ -446,9 +446,9 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, ev defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventID: rows.close() failed") results := make(map[types.EventNID]string, len(eventNIDs)) i := 0 + var eventNID int64 + var eventID string!!¹23456789!!"§$%" for ; rows.Next(); i++ { - var eventNID int64 - var eventID string if err = rows.Scan(&eventNID, &eventID); err != nil { return nil, err } @@ -491,9 +491,9 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, e } defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventNID: rows.close() failed") results := make(map[string]types.EventNID, len(eventIDs)) + var eventID string + var eventNID int64 for rows.Next() { - var eventID string - var eventNID int64 if err = rows.Scan(&eventID, &eventNID); err != nil { return nil, err } @@ -522,9 +522,9 @@ func (s *eventStatements) SelectRoomNIDsForEventNIDs( } defer internal.CloseAndLogIfError(ctx, rows, "selectRoomNIDsForEventNIDsStmt: rows.close() failed") result := make(map[types.EventNID]types.RoomNID) + var eventNID types.EventNID + var roomNID types.RoomNID for rows.Next() { - var eventNID types.EventNID - var roomNID types.RoomNID if err = rows.Scan(&eventNID, &roomNID); err != nil { return nil, err } diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 4956767fe..34e891490 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -77,7 +77,7 @@ func (d *Database) create(db *sql.DB) error { if err := CreateEventJSONTable(db); err != nil { return err } - if err := createEventsTable(db); err != nil { + if err := CreateEventsTable(db); err != nil { return err } if err := createRoomsTable(db); err != nil { @@ -124,7 +124,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room if err != nil { return err } - events, err := prepareEventsTable(db) + events, err := PrepareEventsTable(db) if err != nil { return err } diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index 45b49e5cb..feb06150a 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -68,7 +68,8 @@ const bulkSelectStateEventByIDSQL = "" + const bulkSelectStateEventByNIDSQL = "" + "SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" + " WHERE event_nid IN ($1)" - // Rest of query is built by BulkSelectStateEventByNID + +// Rest of query is built by BulkSelectStateEventByNID const bulkSelectStateAtEventByIDSQL = "" + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" + @@ -126,12 +127,12 @@ type eventStatements struct { //selectRoomNIDsForEventNIDsStmt *sql.Stmt } -func createEventsTable(db *sql.DB) error { +func CreateEventsTable(db *sql.DB) error { _, err := db.Exec(eventsSchema) return err } -func prepareEventsTable(db *sql.DB) (tables.Events, error) { +func PrepareEventsTable(db *sql.DB) (tables.Events, error) { s := &eventStatements{ db: db, } @@ -404,15 +405,15 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference( defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateAtEventAndReference: rows.close() failed") results := make([]types.StateAtEventAndReference, len(eventNIDs)) i := 0 + var ( + eventTypeNID int64 + eventStateKeyNID int64 + eventNID int64 + stateSnapshotNID int64 + eventID string + eventSHA256 []byte + ) for ; rows.Next(); i++ { - var ( - eventTypeNID int64 - eventStateKeyNID int64 - eventNID int64 - stateSnapshotNID int64 - eventID string - eventSHA256 []byte - ) if err = rows.Scan( &eventTypeNID, &eventStateKeyNID, &eventNID, &stateSnapshotNID, &eventID, &eventSHA256, ); err != nil { @@ -491,9 +492,9 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, ev defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventID: rows.close() failed") results := make(map[types.EventNID]string, len(eventNIDs)) i := 0 + var eventNID int64 + var eventID string for ; rows.Next(); i++ { - var eventNID int64 - var eventID string if err = rows.Scan(&eventNID, &eventID); err != nil { return nil, err } @@ -545,9 +546,9 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, e } defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventNID: rows.close() failed") results := make(map[string]types.EventNID, len(eventIDs)) + var eventID string + var eventNID int64 for rows.Next() { - var eventID string - var eventNID int64 if err = rows.Scan(&eventID, &eventNID); err != nil { return nil, err } @@ -595,9 +596,9 @@ func (s *eventStatements) SelectRoomNIDsForEventNIDs( } defer internal.CloseAndLogIfError(ctx, rows, "selectRoomNIDsForEventNIDsStmt: rows.close() failed") result := make(map[types.EventNID]types.RoomNID) + var eventNID types.EventNID + var roomNID types.RoomNID for rows.Next() { - var eventNID types.EventNID - var roomNID types.RoomNID if err = rows.Scan(&eventNID, &roomNID); err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 6d69bf862..9522d3058 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -86,7 +86,7 @@ func (d *Database) create(db *sql.DB) error { if err := CreateEventJSONTable(db); err != nil { return err } - if err := createEventsTable(db); err != nil { + if err := CreateEventsTable(db); err != nil { return err } if err := createRoomsTable(db); err != nil { @@ -133,7 +133,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room if err != nil { return err } - events, err := prepareEventsTable(db) + events, err := PrepareEventsTable(db) if err != nil { return err } diff --git a/roomserver/storage/tables/event_json_table_test.go b/roomserver/storage/tables/event_json_table_test.go index 53a168286..cb00aac00 100644 --- a/roomserver/storage/tables/event_json_table_test.go +++ b/roomserver/storage/tables/event_json_table_test.go @@ -3,7 +3,6 @@ package tables_test import ( "context" "fmt" - "reflect" "testing" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -13,6 +12,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" + "github.com/stretchr/testify/assert" ) func mustCreateEventJSONTable(t *testing.T, dbType test.DBType) (tables.EventJSON, func()) { @@ -21,27 +21,19 @@ func mustCreateEventJSONTable(t *testing.T, dbType test.DBType) (tables.EventJSO db, err := sqlutil.Open(&config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, sqlutil.NewExclusiveWriter()) - if err != nil { - t.Fatalf("failed to open db: %s", err) - } + assert.NoError(t, err) var tab tables.EventJSON switch dbType { case test.DBTypePostgres: err = postgres.CreateEventJSONTable(db) - if err != nil { - t.Fatalf("failed to create EventJSON table: %s", err) - } + assert.NoError(t, err) tab, err = postgres.PrepareEventJSONTable(db) case test.DBTypeSQLite: err = sqlite3.CreateEventJSONTable(db) - if err != nil { - t.Fatalf("failed to create EventJSON table: %s", err) - } + assert.NoError(t, err) tab, err = sqlite3.PrepareEventJSONTable(db) } - if err != nil { - t.Fatalf("failed to create table: %s", err) - } + assert.NoError(t, err) return tab, close } @@ -52,29 +44,19 @@ func Test_EventJSONTable(t *testing.T) { defer close() // create some dummy data for i := 0; i < 10; i++ { - if err := tab.InsertEventJSON( + err := tab.InsertEventJSON( context.Background(), nil, types.EventNID(i), []byte(fmt.Sprintf(`{"value":%d"}`, i)), - ); err != nil { - t.Fatalf("unable to insert eventJSON: %s", err) - } + ) + assert.NoError(t, err) } // select a subset of the data values, err := tab.BulkSelectEventJSON(context.Background(), nil, []types.EventNID{1, 2, 3, 4, 5}) - if err != nil { - t.Fatalf("unable to query eventJSON: %s", err) - } - if len(values) != 5 { - t.Fatalf("expected 5 events, got %d", len(values)) - } + assert.NoError(t, err) + assert.Equal(t, 5, len(values)) for i, v := range values { - if v.EventNID != types.EventNID(i+1) { - t.Fatalf("expected eventNID %d, got %d", i+1, v.EventNID) - } - wantValue := []byte(fmt.Sprintf(`{"value":%d"}`, i+1)) - if !reflect.DeepEqual(wantValue, v.EventJSON) { - t.Fatalf("expected JSON to be %s, got %s", string(wantValue), string(v.EventJSON)) - } + assert.Equal(t, v.EventNID, types.EventNID(i+1)) + assert.Equal(t, []byte(fmt.Sprintf(`{"value":%d"}`, i+1)), v.EventJSON) } }) } diff --git a/roomserver/storage/tables/event_state_keys_table_test.go b/roomserver/storage/tables/event_state_keys_table_test.go index c3cea5947..c432ec788 100644 --- a/roomserver/storage/tables/event_state_keys_table_test.go +++ b/roomserver/storage/tables/event_state_keys_table_test.go @@ -12,6 +12,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" + "github.com/stretchr/testify/assert" ) func mustCreateEventStateKeysTable(t *testing.T, dbType test.DBType) (tables.EventStateKeys, func()) { @@ -20,27 +21,19 @@ func mustCreateEventStateKeysTable(t *testing.T, dbType test.DBType) (tables.Eve db, err := sqlutil.Open(&config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, sqlutil.NewExclusiveWriter()) - if err != nil { - t.Fatalf("failed to open db: %s", err) - } + assert.NoError(t, err) var tab tables.EventStateKeys switch dbType { case test.DBTypePostgres: err = postgres.CreateEventStateKeysTable(db) - if err != nil { - t.Fatalf("failed to create EventJSON table: %s", err) - } + assert.NoError(t, err) tab, err = postgres.PrepareEventStateKeysTable(db) case test.DBTypeSQLite: err = sqlite3.CreateEventStateKeysTable(db) - if err != nil { - t.Fatalf("failed to create EventJSON table: %s", err) - } + assert.NoError(t, err) tab, err = sqlite3.PrepareEventStateKeysTable(db) } - if err != nil { - t.Fatalf("failed to create table: %s", err) - } + assert.NoError(t, err) return tab, close } @@ -55,37 +48,26 @@ func Test_EventStateKeysTable(t *testing.T) { // create some dummy data for i := 0; i < 10; i++ { stateKey := fmt.Sprintf("@user%d:localhost", i) - if stateKeyNID, err = tab.InsertEventStateKeyNID( + stateKeyNID, err = tab.InsertEventStateKeyNID( ctx, nil, stateKey, - ); err != nil { - t.Fatalf("unable to insert eventJSON: %s", err) - } + ) + assert.NoError(t, err) gotEventStateKey, err = tab.SelectEventStateKeyNID(ctx, nil, stateKey) - if err != nil { - t.Fatalf("failed to get eventStateKeyNID: %s", err) - } - if stateKeyNID != gotEventStateKey { - t.Fatalf("expected eventStateKey %d, but got %d", stateKeyNID, gotEventStateKey) - } + assert.NoError(t, err) + assert.Equal(t, stateKeyNID, gotEventStateKey) } stateKeyNIDsMap, err := tab.BulkSelectEventStateKeyNID(ctx, nil, []string{"@user0:localhost", "@user1:localhost"}) - if err != nil { - t.Fatalf("failed to get EventStateKeyNIDs: %s", err) - } + assert.NoError(t, err) wantStateKeyNIDs := make([]types.EventStateKeyNID, 0, len(stateKeyNIDsMap)) for _, nid := range stateKeyNIDsMap { wantStateKeyNIDs = append(wantStateKeyNIDs, nid) } stateKeyNIDs, err := tab.BulkSelectEventStateKey(ctx, nil, wantStateKeyNIDs) - if err != nil { - t.Fatalf("failed to get EventStateKeyNIDs: %s", err) - } + assert.NoError(t, err) // verify that BulkSelectEventStateKeyNID and BulkSelectEventStateKey return the same values for userID, nid := range stateKeyNIDsMap { if v, ok := stateKeyNIDs[nid]; ok { - if v != userID { - t.Fatalf("userID does not match: %s != %s", userID, v) - } + assert.Equal(t, v, userID) } else { t.Fatalf("unable to find %d in result set", nid) } diff --git a/roomserver/storage/tables/event_types_table_test.go b/roomserver/storage/tables/event_types_table_test.go index 88cf93d82..8ad41c14e 100644 --- a/roomserver/storage/tables/event_types_table_test.go +++ b/roomserver/storage/tables/event_types_table_test.go @@ -12,6 +12,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" + "github.com/stretchr/testify/assert" ) func mustCreateEventTypesTable(t *testing.T, dbType test.DBType) (tables.EventTypes, func()) { @@ -20,27 +21,19 @@ func mustCreateEventTypesTable(t *testing.T, dbType test.DBType) (tables.EventTy db, err := sqlutil.Open(&config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, sqlutil.NewExclusiveWriter()) - if err != nil { - t.Fatalf("failed to open db: %s", err) - } + assert.NoError(t, err) var tab tables.EventTypes switch dbType { case test.DBTypePostgres: err = postgres.CreateEventTypesTable(db) - if err != nil { - t.Fatalf("failed to create EventJSON table: %s", err) - } + assert.NoError(t, err) tab, err = postgres.PrepareEventTypesTable(db) case test.DBTypeSQLite: err = sqlite3.CreateEventTypesTable(db) - if err != nil { - t.Fatalf("failed to create EventJSON table: %s", err) - } + assert.NoError(t, err) tab, err = sqlite3.PrepareEventTypesTable(db) } - if err != nil { - t.Fatalf("failed to create table: %s", err) - } + assert.NoError(t, err) return tab, close } @@ -63,23 +56,15 @@ func Test_EventTypesTable(t *testing.T) { } eventTypeMap[eventType] = eventTypeNID gotEventTypeNID, err = tab.SelectEventTypeNID(ctx, nil, eventType) - if err != nil { - t.Fatalf("failed to get EventTypeNID: %s", err) - } - if eventTypeNID != gotEventTypeNID { - t.Fatalf("expected eventTypeNID %d, but got %d", eventTypeNID, gotEventTypeNID) - } + assert.NoError(t, err) + assert.Equal(t, eventTypeNID, gotEventTypeNID) } eventTypeNIDs, err := tab.BulkSelectEventTypeNID(ctx, nil, []string{"dummyEventType0", "dummyEventType3"}) - if err != nil { - t.Fatalf("failed to get EventStateKeyNIDs: %s", err) - } + assert.NoError(t, err) // verify that BulkSelectEventTypeNID and InsertEventTypeNID return the same values for eventType, nid := range eventTypeNIDs { if v, ok := eventTypeMap[eventType]; ok { - if v != nid { - t.Fatalf("EventTypeNID does not match: %d != %d", nid, v) - } + assert.Equal(t, v, nid) } else { t.Fatalf("unable to find %d in result set", nid) } diff --git a/roomserver/storage/tables/events_table_test.go b/roomserver/storage/tables/events_table_test.go new file mode 100644 index 000000000..99db39938 --- /dev/null +++ b/roomserver/storage/tables/events_table_test.go @@ -0,0 +1,100 @@ +package tables_test + +import ( + "context" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/postgres" + "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/stretchr/testify/assert" +) + +func mustCreateEventsTable(t *testing.T, dbType test.DBType) (tables.Events, func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + var tab tables.Events + switch dbType { + case test.DBTypePostgres: + err = postgres.CreateEventsTable(db) + assert.NoError(t, err) + tab, err = postgres.PrepareEventsTable(db) + case test.DBTypeSQLite: + err = sqlite3.CreateEventsTable(db) + assert.NoError(t, err) + tab, err = sqlite3.PrepareEventsTable(db) + } + assert.NoError(t, err) + + return tab, close +} + +func Test_EventsTable(t *testing.T) { + alice := test.NewUser() + room := test.NewRoom(t, alice) + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, close := mustCreateEventsTable(t, dbType) + defer close() + // create some dummy data + eventIDs := make([]string, 0, len(room.Events())) + wantStateAtEvent := make([]types.StateAtEvent, 0, len(room.Events())) + for _, ev := range room.Events() { + eventIDs = append(eventIDs, ev.EventID()) + + eventNID, snapNID, err := tab.InsertEvent(ctx, nil, 1, 1, 1, ev.EventID(), []byte(""), nil, 0, false) + assert.NoError(t, err) + gotEventNID, gotSnapNID, err := tab.SelectEvent(ctx, nil, ev.EventID()) + assert.NoError(t, err) + assert.Equal(t, eventNID, gotEventNID) + assert.Equal(t, snapNID, gotSnapNID) + eventID, err := tab.SelectEventID(ctx, nil, eventNID) + assert.NoError(t, err) + assert.Equal(t, eventID, ev.EventID()) + + wantStateAtEvent = append(wantStateAtEvent, types.StateAtEvent{ + Overwrite: false, + BeforeStateSnapshotNID: 0, + IsRejected: false, + StateEntry: types.StateEntry{ + EventNID: eventNID, + StateKeyTuple: types.StateKeyTuple{ + EventTypeNID: 1, + EventStateKeyNID: 1, + }, + }, + }) + } + + stateEvents, err := tab.BulkSelectStateEventByID(ctx, nil, eventIDs) + assert.NoError(t, err) + assert.Equal(t, len(stateEvents), len(eventIDs)) + nids := make([]types.EventNID, 0, len(stateEvents)) + for _, ev := range stateEvents { + nids = append(nids, ev.EventNID) + } + stateEvents2, err := tab.BulkSelectStateEventByNID(ctx, nil, nids, nil) + assert.NoError(t, err) + // somehow SQLite doesn't return the values ordered as requested by the query + assert.ElementsMatch(t, stateEvents, stateEvents2) + + stateAtEvent, err := tab.BulkSelectStateAtEventByID(ctx, nil, eventIDs) + assert.NoError(t, err) + assert.Equal(t, len(eventIDs), len(stateAtEvent)) + + assert.ElementsMatch(t, wantStateAtEvent, stateAtEvent) + + evendNIDMap, err := tab.BulkSelectEventID(ctx, nil, nids) + assert.NoError(t, err) + t.Logf("%+v", evendNIDMap) + assert.Equal(t, len(evendNIDMap), len(nids)) + }) +} diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index b05647626..95609787a 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -35,7 +35,8 @@ type EventStateKeys interface { type Events interface { InsertEvent( - ctx context.Context, txn *sql.Tx, i types.RoomNID, j types.EventTypeNID, k types.EventStateKeyNID, eventID string, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventTypeNID types.EventTypeNID, + eventStateKeyNID types.EventStateKeyNID, eventID string, referenceSHA256 []byte, authEventNIDs []types.EventNID, depth int64, isRejected bool, ) (types.EventNID, types.StateSnapshotNID, error) SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error)