From c2afa5ca6b77907325e224ff8e9f880c766225b0 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 1 Aug 2022 12:21:33 +0100 Subject: [PATCH] Not supported on SQLite --- roomserver/internal/helpers/helpers.go | 66 ++++++++++++++++++- .../storage/sqlite3/state_snapshot_table.go | 48 ++------------ roomserver/storage/tables/interface.go | 3 + 3 files changed, 72 insertions(+), 45 deletions(-) diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index 511ac513d..16a6f615d 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -2,7 +2,10 @@ package helpers import ( "context" + "database/sql" + "errors" "fmt" + "strings" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/auth" @@ -234,12 +237,73 @@ func CheckServerAllowedToSeeEvent( ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool, ) (bool, error) { stateAtEvent, err := db.GetHistoryVisibilityState(ctx, info, eventID, string(serverName)) - if err != nil { + switch err { + case nil: + // No error, so continue normally + case tables.OptimisationNotSupportedError: + // The database engine didn't support this optimisation, so fall back to using + // the old and slow method + stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, eventID, serverName) + if err != nil { + return false, err + } + default: + // Something else went wrong return false, err } return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil } +func slowGetHistoryVisibilityState( + ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, +) ([]*gomatrixserverlib.Event, error) { + roomState := state.NewStateResolution(db, info) + stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, fmt.Errorf("roomState.LoadStateAtEvent: %w", err) + } + + // Extract all of the event state key NIDs from the room state. + var stateKeyNIDs []types.EventStateKeyNID + for _, entry := range stateEntries { + stateKeyNIDs = append(stateKeyNIDs, entry.EventStateKeyNID) + } + + // Then request those state key NIDs from the database. + stateKeys, err := db.EventStateKeys(ctx, stateKeyNIDs) + if err != nil { + return nil, fmt.Errorf("db.EventStateKeys: %w", err) + } + + // If the event state key doesn't match the given servername + // then we'll filter it out. This does preserve state keys that + // are "" since these will contain history visibility etc. + for nid, key := range stateKeys { + if key != "" && !strings.HasSuffix(key, ":"+string(serverName)) { + delete(stateKeys, nid) + } + } + + // Now filter through all of the state events for the room. + // If the state key NID appears in the list of valid state + // keys then we'll add it to the list of filtered entries. + var filteredEntries []types.StateEntry + for _, entry := range stateEntries { + if _, ok := stateKeys[entry.EventStateKeyNID]; ok { + filteredEntries = append(filteredEntries, entry) + } + } + + if len(filteredEntries) == 0 { + return nil, nil + } + + return LoadStateEvents(ctx, db, filteredEntries) +} + // TODO: Remove this when we have tests to assert correctness of this function func ScanEventTree( ctx context.Context, db storage.Database, info *types.RoomInfo, front []string, visited map[string]bool, limit int, diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index 6cc783bfc..73827522c 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -62,32 +62,10 @@ const bulkSelectStateBlockNIDsSQL = "" + "SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" + " WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC" -// Looks up both the history visibility event and relevant membership events from -// a given domain name from a given state snapshot. This is used to optimise the -// helpers.CheckServerAllowedToSeeEvent function. -// TODO: There's a sequence scan here because of the hash join strategy, which is -// probably O(n) on state key entries, so there must be a way to avoid that somehow. -const bulkSelectStateForHistoryVisibilitySQL = "" + - "SELECT event_nid FROM (" + - " SELECT event_nid, event_type_nid, event_state_key_nid FROM roomserver_events" + - " WHERE (event_type_nid = 5 OR event_type_nid = 7)" + - " AND event_nid = ANY(" + - " SELECT UNNEST(event_nids) FROM roomserver_state_block" + - " WHERE state_block_nid = ANY(" + - " SELECT UNNEST(state_block_nids) FROM roomserver_state_snapshots" + - " WHERE state_snapshot_nid = $1" + - " )" + - " )" + - ") AS roomserver_events" + - " INNER JOIN roomserver_event_state_keys" + - " ON roomserver_events.event_state_key_nid = roomserver_event_state_keys.event_state_key_nid" + - " AND (event_type_nid = 7 OR event_state_key LIKE '%:' || $2);" - type stateSnapshotStatements struct { - db *sql.DB - insertStateStmt *sql.Stmt - bulkSelectStateBlockNIDsStmt *sql.Stmt - bulkSelectStateForHistoryVisibilityStmt *sql.Stmt + db *sql.DB + insertStateStmt *sql.Stmt + bulkSelectStateBlockNIDsStmt *sql.Stmt } func CreateStateSnapshotTable(db *sql.DB) error { @@ -103,7 +81,6 @@ func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { return s, sqlutil.StatementList{ {&s.insertStateStmt, insertStateSQL}, {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, - {&s.bulkSelectStateForHistoryVisibilityStmt, bulkSelectStateForHistoryVisibilitySQL}, }.Prepare(db) } @@ -167,22 +144,5 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( func (s *stateSnapshotStatements) BulkSelectStateForHistoryVisibility( ctx context.Context, txn *sql.Tx, stateSnapshotNID types.StateSnapshotNID, domain string, ) ([]types.EventNID, error) { - stmt := sqlutil.TxStmt(txn, s.bulkSelectStateForHistoryVisibilityStmt) - rows, err := stmt.QueryContext(ctx, stateSnapshotNID, domain) - if err != nil { - return nil, err - } - defer rows.Close() // nolint: errcheck - results := make([]types.EventNID, 0, 16) - for i := 0; rows.Next(); i++ { - var eventNID types.EventNID - if err = rows.Scan(&eventNID); err != nil { - return nil, err - } - results = append(results, eventNID) - } - if err = rows.Err(); err != nil { - return nil, err - } - return results, nil + return nil, tables.OptimisationNotSupportedError } diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index a726af243..cf64ca28a 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -3,12 +3,15 @@ package tables import ( "context" "database/sql" + "errors" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" "github.com/tidwall/gjson" ) +var OptimisationNotSupportedError = errors.New("optimisation not supported") + type EventJSONPair struct { EventNID types.EventNID EventJSON []byte