diff --git a/roomserver/storage/postgres/previous_events_table.go b/roomserver/storage/postgres/previous_events_table.go index 26999a290..1b4b94a1d 100644 --- a/roomserver/storage/postgres/previous_events_table.go +++ b/roomserver/storage/postgres/previous_events_table.go @@ -57,7 +57,7 @@ const insertPreviousEventSQL = "" + // This should only be done while holding a "FOR UPDATE" lock on the row in the rooms table for this room. const selectPreviousEventExistsSQL = "" + "SELECT 1 FROM roomserver_previous_events" + - " WHERE previous_event_id = $1 AND previous_reference_sha256 = $2" + " WHERE previous_event_id = $1" type previousEventStatements struct { insertPreviousEventStmt *sql.Stmt @@ -82,12 +82,11 @@ func (s *previousEventStatements) InsertPreviousEvent( ctx context.Context, txn *sql.Tx, previousEventID string, - previousEventReferenceSHA256 []byte, eventNID types.EventNID, ) error { stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt) _, err := stmt.ExecContext( - ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID), + ctx, previousEventID, []byte(""), int64(eventNID), ) return err } @@ -95,9 +94,9 @@ func (s *previousEventStatements) InsertPreviousEvent( // Check if the event reference exists // Returns sql.ErrNoRows if the event reference doesn't exist. func (s *previousEventStatements) SelectPreviousEventExists( - ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte, + ctx context.Context, txn *sql.Tx, eventID string, ) error { var ok int64 stmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt) - return stmt.QueryRowContext(ctx, eventID, eventReferenceSHA256).Scan(&ok) + return stmt.QueryRowContext(ctx, eventID).Scan(&ok) } diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index 5a20c67b3..8669943cb 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -108,7 +108,7 @@ func (u *RoomUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { for _, ref := range previousEventReferences { - if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { + if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, txn, ref.EventID, eventNID); err != nil { return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err) } } @@ -204,7 +204,7 @@ func (u *RoomUpdater) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInf // IsReferenced implements types.RoomRecentEventsUpdater func (u *RoomUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) { - err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256) + err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID) if err == nil { return true, nil } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 60e46c478..2d10a0412 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -762,7 +762,7 @@ func (d *EventDatabase) StoreEvent( return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err) } - if prevEvents := event.PrevEvents(); len(prevEvents) > 0 { + if prevEvents := event.PrevEventIDs(); len(prevEvents) > 0 { // Create an updater - NB: on sqlite this WILL create a txn as we are directly calling the shared DB form of // GetLatestEventsForUpdate - not via the SQLiteDatabase form which has `nil` txns. This // function only does SELECTs though so the created txn (at this point) is just a read txn like @@ -770,8 +770,8 @@ func (d *EventDatabase) StoreEvent( // to do writes however then this will need to go inside `Writer.Do`. // The following is a copy of RoomUpdater.StorePreviousEvents - for _, ref := range prevEvents { - if err = d.PrevEventsTable.InsertPreviousEvent(ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { + for _, eventID := range prevEvents { + if err = d.PrevEventsTable.InsertPreviousEvent(ctx, txn, eventID, eventNID); err != nil { return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err) } } diff --git a/roomserver/storage/sqlite3/previous_events_table.go b/roomserver/storage/sqlite3/previous_events_table.go index 2a146ef64..f8e11257e 100644 --- a/roomserver/storage/sqlite3/previous_events_table.go +++ b/roomserver/storage/sqlite3/previous_events_table.go @@ -47,20 +47,20 @@ const previousEventSchema = ` // The lock is necessary to avoid data races when checking whether an event is already referenced by another event. const insertPreviousEventSQL = ` INSERT OR REPLACE INTO roomserver_previous_events - (previous_event_id, previous_reference_sha256, event_nids) - VALUES ($1, $2, $3) + (previous_event_id, event_nids) + VALUES ($1, $2) ` const selectPreviousEventNIDsSQL = ` SELECT event_nids FROM roomserver_previous_events - WHERE previous_event_id = $1 AND previous_reference_sha256 = $2 + WHERE previous_event_id = $1 ` // Check if the event is referenced by another event in the table. // This should only be done while holding a "FOR UPDATE" lock on the row in the rooms table for this room. const selectPreviousEventExistsSQL = ` SELECT 1 FROM roomserver_previous_events - WHERE previous_event_id = $1 AND previous_reference_sha256 = $2 + WHERE previous_event_id = $1 ` type previousEventStatements struct { @@ -91,13 +91,12 @@ func (s *previousEventStatements) InsertPreviousEvent( ctx context.Context, txn *sql.Tx, previousEventID string, - previousEventReferenceSHA256 []byte, eventNID types.EventNID, ) error { var eventNIDs string eventNIDAsString := fmt.Sprintf("%d", eventNID) selectStmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt) - err := selectStmt.QueryRowContext(ctx, previousEventID, previousEventReferenceSHA256).Scan(&eventNIDs) + err := selectStmt.QueryRowContext(ctx, previousEventID).Scan(&eventNIDs) if err != nil && err != sql.ErrNoRows { return fmt.Errorf("selectStmt.QueryRowContext.Scan: %w", err) } @@ -115,7 +114,7 @@ func (s *previousEventStatements) InsertPreviousEvent( } insertStmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt) _, err = insertStmt.ExecContext( - ctx, previousEventID, previousEventReferenceSHA256, eventNIDs, + ctx, previousEventID, eventNIDs, ) return err } @@ -123,9 +122,9 @@ func (s *previousEventStatements) InsertPreviousEvent( // Check if the event reference exists // Returns sql.ErrNoRows if the event reference doesn't exist. func (s *previousEventStatements) SelectPreviousEventExists( - ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte, + ctx context.Context, txn *sql.Tx, eventID string, ) error { var ok int64 stmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt) - return stmt.QueryRowContext(ctx, eventID, eventReferenceSHA256).Scan(&ok) + return stmt.QueryRowContext(ctx, eventID).Scan(&ok) } diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index a9c5f8b11..f0f553918 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -113,10 +113,10 @@ type RoomAliases interface { } type PreviousEvents interface { - InsertPreviousEvent(ctx context.Context, txn *sql.Tx, previousEventID string, previousEventReferenceSHA256 []byte, eventNID types.EventNID) error + InsertPreviousEvent(ctx context.Context, txn *sql.Tx, previousEventID string, eventNID types.EventNID) error // Check if the event reference exists // Returns sql.ErrNoRows if the event reference doesn't exist. - SelectPreviousEventExists(ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte) error + SelectPreviousEventExists(ctx context.Context, txn *sql.Tx, eventID string) error } type Invites interface { diff --git a/roomserver/storage/tables/previous_events_table_test.go b/roomserver/storage/tables/previous_events_table_test.go index 63d540696..9d41e90be 100644 --- a/roomserver/storage/tables/previous_events_table_test.go +++ b/roomserver/storage/tables/previous_events_table_test.go @@ -45,17 +45,17 @@ func TestPreviousEventsTable(t *testing.T) { defer close() for _, x := range room.Events() { - for _, prevEvent := range x.PrevEvents() { - err := tab.InsertPreviousEvent(ctx, nil, prevEvent.EventID, prevEvent.EventSHA256, 1) + for _, eventID := range x.PrevEventIDs() { + err := tab.InsertPreviousEvent(ctx, nil, eventID, 1) assert.NoError(t, err) - err = tab.SelectPreviousEventExists(ctx, nil, prevEvent.EventID, prevEvent.EventSHA256) + err = tab.SelectPreviousEventExists(ctx, nil, eventID) assert.NoError(t, err) } } - // RandomString with a correct EventSHA256 should fail and return sql.ErrNoRows - err := tab.SelectPreviousEventExists(ctx, nil, util.RandomString(16), room.Events()[0].EventReference().EventSHA256) + // RandomString should fail and return sql.ErrNoRows + err := tab.SelectPreviousEventExists(ctx, nil, util.RandomString(16)) assert.Error(t, err) }) }