diff --git a/roomserver/storage/postgres/state_snapshot_table.go b/roomserver/storage/postgres/state_snapshot_table.go index 527b714a6..26ebcf856 100644 --- a/roomserver/storage/postgres/state_snapshot_table.go +++ b/roomserver/storage/postgres/state_snapshot_table.go @@ -77,22 +77,21 @@ const bulkSelectStateBlockNIDsSQL = "" + // helpers.CheckServerAllowedToSeeEvent function. // TODO: There's a sequence scan here because of the hash join strategy, which is // probably O(n) on state key entries, so there must be a way to avoid that somehow. -const bulkSelectStateForHistoryVisibilitySQL = ` -SELECT event_nid FROM ( - SELECT event_nid, event_type_nid, event_state_key_nid FROM roomserver_events - WHERE (event_type_nid = 5 OR event_type_nid = 7) - AND event_nid = ANY( - SELECT UNNEST(event_nids) FROM roomserver_state_block - WHERE state_block_nid = ANY( - SELECT UNNEST(state_block_nids) FROM roomserver_state_snapshots - WHERE state_snapshot_nid = $1 - ) - ) -) AS roomserver_events -INNER JOIN roomserver_event_state_keys - ON roomserver_events.event_state_key_nid = roomserver_event_state_keys.event_state_key_nid - AND (event_type_nid = 7 OR event_state_key LIKE '%:' || $2); -` +const bulkSelectStateForHistoryVisibilitySQL = "" + + "SELECT event_nid FROM (" + + " SELECT event_nid, event_type_nid, event_state_key_nid FROM roomserver_events" + + " WHERE (event_type_nid = 5 OR event_type_nid = 7)" + + " AND event_nid = ANY(" + + " SELECT UNNEST(event_nids) FROM roomserver_state_block" + + " WHERE state_block_nid = ANY(" + + " SELECT UNNEST(state_block_nids) FROM roomserver_state_snapshots" + + " WHERE state_snapshot_nid = $1" + + " )" + + " )" + + ") AS roomserver_events" + + " INNER JOIN roomserver_event_state_keys" + + " ON roomserver_events.event_state_key_nid = roomserver_event_state_keys.event_state_key_nid" + + " AND (event_type_nid = 7 OR event_state_key LIKE '%:' || $2);" type stateSnapshotStatements struct { insertStateStmt *sql.Stmt diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index d2b2875ae..6cc783bfc 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -62,10 +62,32 @@ const bulkSelectStateBlockNIDsSQL = "" + "SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" + " WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC" +// Looks up both the history visibility event and relevant membership events from +// a given domain name from a given state snapshot. This is used to optimise the +// helpers.CheckServerAllowedToSeeEvent function. +// TODO: There's a sequence scan here because of the hash join strategy, which is +// probably O(n) on state key entries, so there must be a way to avoid that somehow. +const bulkSelectStateForHistoryVisibilitySQL = "" + + "SELECT event_nid FROM (" + + " SELECT event_nid, event_type_nid, event_state_key_nid FROM roomserver_events" + + " WHERE (event_type_nid = 5 OR event_type_nid = 7)" + + " AND event_nid = ANY(" + + " SELECT UNNEST(event_nids) FROM roomserver_state_block" + + " WHERE state_block_nid = ANY(" + + " SELECT UNNEST(state_block_nids) FROM roomserver_state_snapshots" + + " WHERE state_snapshot_nid = $1" + + " )" + + " )" + + ") AS roomserver_events" + + " INNER JOIN roomserver_event_state_keys" + + " ON roomserver_events.event_state_key_nid = roomserver_event_state_keys.event_state_key_nid" + + " AND (event_type_nid = 7 OR event_state_key LIKE '%:' || $2);" + type stateSnapshotStatements struct { - db *sql.DB - insertStateStmt *sql.Stmt - bulkSelectStateBlockNIDsStmt *sql.Stmt + db *sql.DB + insertStateStmt *sql.Stmt + bulkSelectStateBlockNIDsStmt *sql.Stmt + bulkSelectStateForHistoryVisibilityStmt *sql.Stmt } func CreateStateSnapshotTable(db *sql.DB) error { @@ -81,6 +103,7 @@ func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { return s, sqlutil.StatementList{ {&s.insertStateStmt, insertStateSQL}, {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, + {&s.bulkSelectStateForHistoryVisibilityStmt, bulkSelectStateForHistoryVisibilitySQL}, }.Prepare(db) } @@ -144,5 +167,22 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( func (s *stateSnapshotStatements) BulkSelectStateForHistoryVisibility( ctx context.Context, txn *sql.Tx, stateSnapshotNID types.StateSnapshotNID, domain string, ) ([]types.EventNID, error) { - return nil, nil + stmt := sqlutil.TxStmt(txn, s.bulkSelectStateForHistoryVisibilityStmt) + rows, err := stmt.QueryContext(ctx, stateSnapshotNID, domain) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + results := make([]types.EventNID, 0, 16) + for i := 0; rows.Next(); i++ { + var eventNID types.EventNID + if err = rows.Scan(&eventNID); err != nil { + return nil, err + } + results = append(results, eventNID) + } + if err = rows.Err(); err != nil { + return nil, err + } + return results, nil }