diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go index 5173d3ab2..ae28ebefa 100644 --- a/roomserver/internal/input/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -405,7 +405,7 @@ func (u *latestEventsUpdater) extraEventsForIDs(roomVersion gomatrixserverlib.Ro if len(extraEventIDs) == 0 { return nil, nil } - extraEvents, err := u.updater.EventsFromIDs(u.ctx, extraEventIDs) + extraEvents, err := u.updater.UnsentEventsFromIDs(u.ctx, extraEventIDs) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index c136f039a..b3c23a3b7 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -108,6 +108,9 @@ const updateEventStateSQL = "" + const selectEventSentToOutputSQL = "" + "SELECT sent_to_output FROM roomserver_events WHERE event_nid = $1" +const bulkSelectEventFilteredBySentToOutputSQL = "" + + "SELECT event_nid FROM roomserver_events WHERE event_nid = ANY($1) AND sent_to_output = $2" + const updateEventSentToOutputSQL = "" + "UPDATE roomserver_events SET sent_to_output = TRUE WHERE event_nid = $1" @@ -134,21 +137,22 @@ const selectRoomNIDsForEventNIDsSQL = "" + "SELECT event_nid, room_nid FROM roomserver_events WHERE event_nid = ANY($1)" type eventStatements struct { - insertEventStmt *sql.Stmt - selectEventStmt *sql.Stmt - bulkSelectStateEventByIDStmt *sql.Stmt - bulkSelectStateEventByNIDStmt *sql.Stmt - bulkSelectStateAtEventByIDStmt *sql.Stmt - updateEventStateStmt *sql.Stmt - selectEventSentToOutputStmt *sql.Stmt - updateEventSentToOutputStmt *sql.Stmt - selectEventIDStmt *sql.Stmt - bulkSelectStateAtEventAndReferenceStmt *sql.Stmt - bulkSelectEventReferenceStmt *sql.Stmt - bulkSelectEventIDStmt *sql.Stmt - bulkSelectEventNIDStmt *sql.Stmt - selectMaxEventDepthStmt *sql.Stmt - selectRoomNIDsForEventNIDsStmt *sql.Stmt + insertEventStmt *sql.Stmt + selectEventStmt *sql.Stmt + bulkSelectStateEventByIDStmt *sql.Stmt + bulkSelectStateEventByNIDStmt *sql.Stmt + bulkSelectStateAtEventByIDStmt *sql.Stmt + updateEventStateStmt *sql.Stmt + selectEventSentToOutputStmt *sql.Stmt + bulkSelectEventFilteredBySentToOutputStmt *sql.Stmt + updateEventSentToOutputStmt *sql.Stmt + selectEventIDStmt *sql.Stmt + bulkSelectStateAtEventAndReferenceStmt *sql.Stmt + bulkSelectEventReferenceStmt *sql.Stmt + bulkSelectEventIDStmt *sql.Stmt + bulkSelectEventNIDStmt *sql.Stmt + selectMaxEventDepthStmt *sql.Stmt + selectRoomNIDsForEventNIDsStmt *sql.Stmt } func createEventsTable(db *sql.DB) error { @@ -168,6 +172,7 @@ func prepareEventsTable(db *sql.DB) (tables.Events, error) { {&s.updateEventStateStmt, updateEventStateSQL}, {&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL}, {&s.selectEventSentToOutputStmt, selectEventSentToOutputSQL}, + {&s.bulkSelectEventFilteredBySentToOutputStmt, bulkSelectEventFilteredBySentToOutputSQL}, {&s.selectEventIDStmt, selectEventIDSQL}, {&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL}, {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL}, @@ -342,6 +347,26 @@ func (s *eventStatements) UpdateEventState( return err } +func (s *eventStatements) BulkSelectEventsFilteredBySentToOutput( + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, sent bool, +) (results []types.EventNID, err error) { + stmt := sqlutil.TxStmt(txn, s.bulkSelectEventFilteredBySentToOutputStmt) + rows, err := stmt.QueryContext(ctx, pq.Array(eventNIDs), sent) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventFilteredBySentToOutputStmt: rows.close() failed") + results = make([]types.EventNID, 0, len(eventNIDs)) + for i := 0; rows.Next(); i++ { + var eventNID types.EventNID + if err = rows.Scan(&eventNID); err != nil { + return nil, err + } + results = append(results, eventNID) + } + return +} + func (s *eventStatements) SelectEventSentToOutput( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, ) (sentToOutput bool, err error) { diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index 89b878b9d..d4484382e 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -137,7 +137,7 @@ func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEvent func (u *RoomUpdater) Events( ctx context.Context, eventNIDs []types.EventNID, ) ([]types.Event, error) { - return u.d.events(ctx, u.txn, eventNIDs) + return u.d.events(ctx, u.txn, eventNIDs, false) } func (u *RoomUpdater) SnapshotNIDFromEventID( @@ -231,7 +231,11 @@ func (u *RoomUpdater) StateEntriesForEventIDs( } func (u *RoomUpdater) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { - return u.d.eventsFromIDs(ctx, u.txn, eventIDs) + return u.d.eventsFromIDs(ctx, u.txn, eventIDs, false) +} + +func (u *RoomUpdater) UnsentEventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { + return u.d.eventsFromIDs(ctx, u.txn, eventIDs, true) } func (u *RoomUpdater) GetMembershipEventNIDsForRoom( diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index e96c77afa..aea38e3ab 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -281,10 +281,10 @@ func (d *Database) EventIDs( } func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { - return d.eventsFromIDs(ctx, nil, eventIDs) + return d.eventsFromIDs(ctx, nil, eventIDs, false) } -func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.Event, error) { +func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool) ([]types.Event, error) { nidMap, err := d.eventNIDs(ctx, txn, eventIDs) if err != nil { return nil, err @@ -295,7 +295,7 @@ func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []st nids = append(nids, nid) } - return d.events(ctx, txn, nids) + return d.events(ctx, txn, nids, onlyUnsent) } func (d *Database) LatestEventIDs( @@ -437,12 +437,21 @@ func (d *Database) GetInvitesForUser( func (d *Database) Events( ctx context.Context, eventNIDs []types.EventNID, ) ([]types.Event, error) { - return d.events(ctx, nil, eventNIDs) + return d.events(ctx, nil, eventNIDs, false) } func (d *Database) events( - ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, onlyUnsent bool, ) ([]types.Event, error) { + if onlyUnsent { + // Reduce the list down to event NIDs that haven't already been sent to + // output before, so that we don't send duplicates again. + var err error + eventNIDs, err = d.EventsTable.BulkSelectEventsFilteredBySentToOutput(ctx, txn, eventNIDs, false) + if err != nil { + return nil, err + } + } eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, txn, eventNIDs) if err != nil { return nil, err diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index cef09fe60..b1f9304a2 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -80,6 +80,9 @@ const updateEventStateSQL = "" + const selectEventSentToOutputSQL = "" + "SELECT sent_to_output FROM roomserver_events WHERE event_nid = $1" +const bulkSelectEventFilteredBySentToOutputSQL = "" + + "SELECT event_nid FROM roomserver_events WHERE sent_to_output = $1 AND event_nid IN ($2)" + const updateEventSentToOutputSQL = "" + "UPDATE roomserver_events SET sent_to_output = TRUE WHERE event_nid = $1" @@ -119,7 +122,8 @@ type eventStatements struct { bulkSelectEventReferenceStmt *sql.Stmt bulkSelectEventIDStmt *sql.Stmt bulkSelectEventNIDStmt *sql.Stmt - //selectRoomNIDsForEventNIDsStmt *sql.Stmt + //selectRoomNIDsForEventNIDsStmt *sql.Stmt + //bulkSelectEventFilteredBySentToOutputStmt *sql.Stmt } func createEventsTable(db *sql.DB) error { @@ -146,6 +150,7 @@ func prepareEventsTable(db *sql.DB) (tables.Events, error) { {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, //{&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL}, + //{&s.bulkSelectEventFilteredBySentToOutputStmt, bulkSelectEventFilteredBySentToOutputSQL}, }.Prepare(db) } @@ -358,6 +363,36 @@ func (s *eventStatements) SelectEventSentToOutput( return } +func (s *eventStatements) BulkSelectEventsFilteredBySentToOutput( + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, sent bool, +) (results []types.EventNID, err error) { + params := make([]interface{}, 0, 1+len(eventNIDs)) + params = append(params, sent) + for _, v := range eventNIDs { + params = append(params, v) + } + selectOrig := strings.Replace(bulkSelectEventFilteredBySentToOutputSQL, "($2)", sqlutil.QueryVariadic(len(eventNIDs)), 1) + selectStmt, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, err + } + stmt := sqlutil.TxStmt(txn, selectStmt) + rows, err := stmt.QueryContext(ctx, params...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventFilteredBySentToOutputStmt: rows.close() failed") + results = make([]types.EventNID, 0, len(eventNIDs)) + for i := 0; rows.Next(); i++ { + var eventNID types.EventNID + if err = rows.Scan(&eventNID); err != nil { + return nil, err + } + results = append(results, eventNID) + } + return +} + func (s *eventStatements) UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error { updateStmt := sqlutil.TxStmt(txn, s.updateEventSentToOutputStmt) _, err := updateStmt.ExecContext(ctx, int64(eventNID)) diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index fed39b944..9f2673179 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -50,6 +50,7 @@ type Events interface { BulkSelectStateAtEventByID(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StateAtEvent, error) UpdateEventState(ctx context.Context, txn *sql.Tx, eventNID types.EventNID, stateNID types.StateSnapshotNID) error SelectEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) (sentToOutput bool, err error) + BulkSelectEventsFilteredBySentToOutput(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, sent bool) (results []types.EventNID, err error) UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error SelectEventID(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) (eventID string, err error) BulkSelectStateAtEventAndReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error)