From e5d407d948ba23f473333d738a7903394a233edf Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 17 Feb 2022 12:46:35 +0000 Subject: [PATCH] Filter on event NIDs instead, hopefully bring joy to SQLite --- roomserver/storage/postgres/events_table.go | 69 ++++++++----------- .../storage/shared/membership_updater.go | 6 +- roomserver/storage/shared/room_updater.go | 10 ++- roomserver/storage/shared/storage.go | 23 ++----- roomserver/storage/sqlite3/events_table.go | 55 ++++----------- roomserver/storage/tables/interface.go | 3 +- 6 files changed, 61 insertions(+), 105 deletions(-) diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index b3c23a3b7..662dbbf00 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -108,9 +108,6 @@ const updateEventStateSQL = "" + const selectEventSentToOutputSQL = "" + "SELECT sent_to_output FROM roomserver_events WHERE event_nid = $1" -const bulkSelectEventFilteredBySentToOutputSQL = "" + - "SELECT event_nid FROM roomserver_events WHERE event_nid = ANY($1) AND sent_to_output = $2" - const updateEventSentToOutputSQL = "" + "UPDATE roomserver_events SET sent_to_output = TRUE WHERE event_nid = $1" @@ -130,6 +127,9 @@ const bulkSelectEventIDSQL = "" + const bulkSelectEventNIDSQL = "" + "SELECT event_id, event_nid FROM roomserver_events WHERE event_id = ANY($1)" +const bulkSelectUnsentEventNIDSQL = "" + + "SELECT event_id, event_nid FROM roomserver_events WHERE event_id = ANY($1) AND sent_to_output = FALSE" + const selectMaxEventDepthSQL = "" + "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid = ANY($1)" @@ -137,22 +137,22 @@ const selectRoomNIDsForEventNIDsSQL = "" + "SELECT event_nid, room_nid FROM roomserver_events WHERE event_nid = ANY($1)" type eventStatements struct { - insertEventStmt *sql.Stmt - selectEventStmt *sql.Stmt - bulkSelectStateEventByIDStmt *sql.Stmt - bulkSelectStateEventByNIDStmt *sql.Stmt - bulkSelectStateAtEventByIDStmt *sql.Stmt - updateEventStateStmt *sql.Stmt - selectEventSentToOutputStmt *sql.Stmt - bulkSelectEventFilteredBySentToOutputStmt *sql.Stmt - updateEventSentToOutputStmt *sql.Stmt - selectEventIDStmt *sql.Stmt - bulkSelectStateAtEventAndReferenceStmt *sql.Stmt - bulkSelectEventReferenceStmt *sql.Stmt - bulkSelectEventIDStmt *sql.Stmt - bulkSelectEventNIDStmt *sql.Stmt - selectMaxEventDepthStmt *sql.Stmt - selectRoomNIDsForEventNIDsStmt *sql.Stmt + insertEventStmt *sql.Stmt + selectEventStmt *sql.Stmt + bulkSelectStateEventByIDStmt *sql.Stmt + bulkSelectStateEventByNIDStmt *sql.Stmt + bulkSelectStateAtEventByIDStmt *sql.Stmt + updateEventStateStmt *sql.Stmt + selectEventSentToOutputStmt *sql.Stmt + updateEventSentToOutputStmt *sql.Stmt + selectEventIDStmt *sql.Stmt + bulkSelectStateAtEventAndReferenceStmt *sql.Stmt + bulkSelectEventReferenceStmt *sql.Stmt + bulkSelectEventIDStmt *sql.Stmt + bulkSelectEventNIDStmt *sql.Stmt + bulkSelectUnsentEventNIDStmt *sql.Stmt + selectMaxEventDepthStmt *sql.Stmt + selectRoomNIDsForEventNIDsStmt *sql.Stmt } func createEventsTable(db *sql.DB) error { @@ -172,12 +172,12 @@ func prepareEventsTable(db *sql.DB) (tables.Events, error) { {&s.updateEventStateStmt, updateEventStateSQL}, {&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL}, {&s.selectEventSentToOutputStmt, selectEventSentToOutputSQL}, - {&s.bulkSelectEventFilteredBySentToOutputStmt, bulkSelectEventFilteredBySentToOutputSQL}, {&s.selectEventIDStmt, selectEventIDSQL}, {&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL}, {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL}, {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, + {&s.bulkSelectUnsentEventNIDStmt, bulkSelectUnsentEventNIDSQL}, {&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL}, {&s.selectRoomNIDsForEventNIDsStmt, selectRoomNIDsForEventNIDsSQL}, }.Prepare(db) @@ -347,26 +347,6 @@ func (s *eventStatements) UpdateEventState( return err } -func (s *eventStatements) BulkSelectEventsFilteredBySentToOutput( - ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, sent bool, -) (results []types.EventNID, err error) { - stmt := sqlutil.TxStmt(txn, s.bulkSelectEventFilteredBySentToOutputStmt) - rows, err := stmt.QueryContext(ctx, pq.Array(eventNIDs), sent) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventFilteredBySentToOutputStmt: rows.close() failed") - results = make([]types.EventNID, 0, len(eventNIDs)) - for i := 0; rows.Next(); i++ { - var eventNID types.EventNID - if err = rows.Scan(&eventNID); err != nil { - return nil, err - } - results = append(results, eventNID) - } - return -} - func (s *eventStatements) SelectEventSentToOutput( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, ) (sentToOutput bool, err error) { @@ -485,8 +465,13 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, ev // bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. -func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { - stmt := sqlutil.TxStmt(txn, s.bulkSelectEventNIDStmt) +func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool) (map[string]types.EventNID, error) { + var stmt *sql.Stmt + if onlyUnsent { + stmt = sqlutil.TxStmt(txn, s.bulkSelectUnsentEventNIDStmt) + } else { + stmt = sqlutil.TxStmt(txn, s.bulkSelectEventNIDStmt) + } rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { return nil, err diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go index 66ac2f5b6..8f3f3d631 100644 --- a/roomserver/storage/shared/membership_updater.go +++ b/roomserver/storage/shared/membership_updater.go @@ -136,7 +136,7 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd } // Look up the NID of the new join event - nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}) + nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}, false) if err != nil { return fmt.Errorf("u.d.EventNIDs: %w", err) } @@ -170,7 +170,7 @@ func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s } // Look up the NID of the new leave event - nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}) + nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}, false) if err != nil { return fmt.Errorf("u.d.EventNIDs: %w", err) } @@ -196,7 +196,7 @@ func (u *MembershipUpdater) SetToKnock(event *gomatrixserverlib.Event) (bool, er } if u.membership != tables.MembershipStateKnock { // Look up the NID of the new knock event - nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{event.EventID()}) + nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{event.EventID()}, false) if err != nil { return fmt.Errorf("u.d.EventNIDs: %w", err) } diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index d4484382e..b148f8acd 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -137,7 +137,7 @@ func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEvent func (u *RoomUpdater) Events( ctx context.Context, eventNIDs []types.EventNID, ) ([]types.Event, error) { - return u.d.events(ctx, u.txn, eventNIDs, false) + return u.d.events(ctx, u.txn, eventNIDs) } func (u *RoomUpdater) SnapshotNIDFromEventID( @@ -215,7 +215,13 @@ func (u *RoomUpdater) EventIDs( func (u *RoomUpdater) EventNIDs( ctx context.Context, eventIDs []string, ) (map[string]types.EventNID, error) { - return u.d.eventNIDs(ctx, u.txn, eventIDs) + return u.d.eventNIDs(ctx, u.txn, eventIDs, false) +} + +func (u *RoomUpdater) UnsentEventNIDs( + ctx context.Context, eventIDs []string, +) (map[string]types.EventNID, error) { + return u.d.eventNIDs(ctx, u.txn, eventIDs, true) } func (u *RoomUpdater) StateAtEventIDs( diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index aea38e3ab..256aa47f0 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -238,13 +238,13 @@ func (d *Database) addState( func (d *Database) EventNIDs( ctx context.Context, eventIDs []string, ) (map[string]types.EventNID, error) { - return d.eventNIDs(ctx, nil, eventIDs) + return d.eventNIDs(ctx, nil, eventIDs, false) } func (d *Database) eventNIDs( - ctx context.Context, txn *sql.Tx, eventIDs []string, + ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool, ) (map[string]types.EventNID, error) { - return d.EventsTable.BulkSelectEventNID(ctx, txn, eventIDs) + return d.EventsTable.BulkSelectEventNID(ctx, txn, eventIDs, onlyUnsent) } func (d *Database) SetState( @@ -285,7 +285,7 @@ func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]type } func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool) ([]types.Event, error) { - nidMap, err := d.eventNIDs(ctx, txn, eventIDs) + nidMap, err := d.eventNIDs(ctx, txn, eventIDs, onlyUnsent) if err != nil { return nil, err } @@ -295,7 +295,7 @@ func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []st nids = append(nids, nid) } - return d.events(ctx, txn, nids, onlyUnsent) + return d.events(ctx, txn, nids) } func (d *Database) LatestEventIDs( @@ -437,21 +437,12 @@ func (d *Database) GetInvitesForUser( func (d *Database) Events( ctx context.Context, eventNIDs []types.EventNID, ) ([]types.Event, error) { - return d.events(ctx, nil, eventNIDs, false) + return d.events(ctx, nil, eventNIDs) } func (d *Database) events( - ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, onlyUnsent bool, + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) ([]types.Event, error) { - if onlyUnsent { - // Reduce the list down to event NIDs that haven't already been sent to - // output before, so that we don't send duplicates again. - var err error - eventNIDs, err = d.EventsTable.BulkSelectEventsFilteredBySentToOutput(ctx, txn, eventNIDs, false) - if err != nil { - return nil, err - } - } eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, txn, eventNIDs) if err != nil { return nil, err diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index b1f9304a2..e49b23739 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -80,9 +80,6 @@ const updateEventStateSQL = "" + const selectEventSentToOutputSQL = "" + "SELECT sent_to_output FROM roomserver_events WHERE event_nid = $1" -const bulkSelectEventFilteredBySentToOutputSQL = "" + - "SELECT event_nid FROM roomserver_events WHERE sent_to_output = $1 AND event_nid IN ($2)" - const updateEventSentToOutputSQL = "" + "UPDATE roomserver_events SET sent_to_output = TRUE WHERE event_nid = $1" @@ -102,6 +99,9 @@ const bulkSelectEventIDSQL = "" + const bulkSelectEventNIDSQL = "" + "SELECT event_id, event_nid FROM roomserver_events WHERE event_id IN ($1)" +const bulkSelectUnsentEventNIDSQL = "" + + "SELECT event_id, event_nid FROM roomserver_events WHERE sent_to_output = 0 AND event_id IN ($1)" + const selectMaxEventDepthSQL = "" + "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)" @@ -121,9 +121,9 @@ type eventStatements struct { bulkSelectStateAtEventAndReferenceStmt *sql.Stmt bulkSelectEventReferenceStmt *sql.Stmt bulkSelectEventIDStmt *sql.Stmt - bulkSelectEventNIDStmt *sql.Stmt - //selectRoomNIDsForEventNIDsStmt *sql.Stmt - //bulkSelectEventFilteredBySentToOutputStmt *sql.Stmt + //bulkSelectEventNIDStmt *sql.Stmt + //bulkSelectUnsentEventNIDStmt *sql.Stmt + //selectRoomNIDsForEventNIDsStmt *sql.Stmt } func createEventsTable(db *sql.DB) error { @@ -148,9 +148,9 @@ func prepareEventsTable(db *sql.DB) (tables.Events, error) { {&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL}, {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL}, {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, - {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, + //{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, + //{&s.bulkSelectUnsentEventNIDStmt, bulkSelectUnsentEventNIDSQL}, //{&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL}, - //{&s.bulkSelectEventFilteredBySentToOutputStmt, bulkSelectEventFilteredBySentToOutputSQL}, }.Prepare(db) } @@ -363,36 +363,6 @@ func (s *eventStatements) SelectEventSentToOutput( return } -func (s *eventStatements) BulkSelectEventsFilteredBySentToOutput( - ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, sent bool, -) (results []types.EventNID, err error) { - params := make([]interface{}, 0, 1+len(eventNIDs)) - params = append(params, sent) - for _, v := range eventNIDs { - params = append(params, v) - } - selectOrig := strings.Replace(bulkSelectEventFilteredBySentToOutputSQL, "($2)", sqlutil.QueryVariadic(len(eventNIDs)), 1) - selectStmt, err := s.db.Prepare(selectOrig) - if err != nil { - return nil, err - } - stmt := sqlutil.TxStmt(txn, selectStmt) - rows, err := stmt.QueryContext(ctx, params...) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventFilteredBySentToOutputStmt: rows.close() failed") - results = make([]types.EventNID, 0, len(eventNIDs)) - for i := 0; rows.Next(); i++ { - var eventNID types.EventNID - if err = rows.Scan(&eventNID); err != nil { - return nil, err - } - results = append(results, eventNID) - } - return -} - func (s *eventStatements) UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error { updateStmt := sqlutil.TxStmt(txn, s.updateEventSentToOutputStmt) _, err := updateStmt.ExecContext(ctx, int64(eventNID)) @@ -531,13 +501,18 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, ev // bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. -func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { +func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool) (map[string]types.EventNID, error) { /////////////// iEventIDs := make([]interface{}, len(eventIDs)) for k, v := range eventIDs { iEventIDs[k] = v } - selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) + var selectOrig string + if onlyUnsent { + selectOrig = strings.Replace(bulkSelectUnsentEventNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) + } else { + selectOrig = strings.Replace(bulkSelectEventNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) + } selectStmt, err := s.db.Prepare(selectOrig) if err != nil { return nil, err diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 9f2673179..24fef2036 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -50,7 +50,6 @@ type Events interface { BulkSelectStateAtEventByID(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StateAtEvent, error) UpdateEventState(ctx context.Context, txn *sql.Tx, eventNID types.EventNID, stateNID types.StateSnapshotNID) error SelectEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) (sentToOutput bool, err error) - BulkSelectEventsFilteredBySentToOutput(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, sent bool) (results []types.EventNID, err error) UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error SelectEventID(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) (eventID string, err error) BulkSelectStateAtEventAndReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error) @@ -59,7 +58,7 @@ type Events interface { BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) // BulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. - BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) + BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool) (map[string]types.EventNID, error) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) SelectRoomNIDsForEventNIDs(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error) }