From 2c2b1ef1d62f490fffe0337cb1450df10d780247 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 26 Jul 2022 13:43:38 +0100 Subject: [PATCH] Try optimising checking if server is allowed to see event --- roomserver/internal/helpers/helpers.go | 91 ++++++++++--------- roomserver/state/state.go | 23 +++++ roomserver/storage/interface.go | 2 + .../storage/postgres/state_snapshot_table.go | 51 ++++++++++- roomserver/storage/shared/storage.go | 28 ++++++ .../storage/sqlite3/state_snapshot_table.go | 6 ++ roomserver/storage/tables/interface.go | 1 + 7 files changed, 156 insertions(+), 46 deletions(-) diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index 2653027e9..4a10376bd 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -2,10 +2,7 @@ package helpers import ( "context" - "database/sql" - "errors" "fmt" - "strings" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/auth" @@ -236,55 +233,61 @@ func LoadStateEvents( func CheckServerAllowedToSeeEvent( ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool, ) (bool, error) { - roomState := state.NewStateResolution(db, info) - stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { + /* + roomState := state.NewStateResolution(db, info) + stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return false, nil + } + return false, fmt.Errorf("roomState.LoadStateAtEvent: %w", err) + } + + // Extract all of the event state key NIDs from the room state. + var stateKeyNIDs []types.EventStateKeyNID + for _, entry := range stateEntries { + stateKeyNIDs = append(stateKeyNIDs, entry.EventStateKeyNID) + } + + // Then request those state key NIDs from the database. + stateKeys, err := db.EventStateKeys(ctx, stateKeyNIDs) + if err != nil { + return false, fmt.Errorf("db.EventStateKeys: %w", err) + } + + // If the event state key doesn't match the given servername + // then we'll filter it out. This does preserve state keys that + // are "" since these will contain history visibility etc. + for nid, key := range stateKeys { + if key != "" && !strings.HasSuffix(key, ":"+string(serverName)) { + delete(stateKeys, nid) + } + } + + // Now filter through all of the state events for the room. + // If the state key NID appears in the list of valid state + // keys then we'll add it to the list of filtered entries. + var filteredEntries []types.StateEntry + for _, entry := range stateEntries { + if _, ok := stateKeys[entry.EventStateKeyNID]; ok { + filteredEntries = append(filteredEntries, entry) + } + } + + if len(filteredEntries) == 0 { return false, nil } - return false, fmt.Errorf("roomState.LoadStateAtEvent: %w", err) - } - // Extract all of the event state key NIDs from the room state. - var stateKeyNIDs []types.EventStateKeyNID - for _, entry := range stateEntries { - stateKeyNIDs = append(stateKeyNIDs, entry.EventStateKeyNID) - } - - // Then request those state key NIDs from the database. - stateKeys, err := db.EventStateKeys(ctx, stateKeyNIDs) - if err != nil { - return false, fmt.Errorf("db.EventStateKeys: %w", err) - } - - // If the event state key doesn't match the given servername - // then we'll filter it out. This does preserve state keys that - // are "" since these will contain history visibility etc. - for nid, key := range stateKeys { - if key != "" && !strings.HasSuffix(key, ":"+string(serverName)) { - delete(stateKeys, nid) + stateAtEvent, err := LoadStateEvents(ctx, db, filteredEntries) + if err != nil { + return false, err } - } + */ - // Now filter through all of the state events for the room. - // If the state key NID appears in the list of valid state - // keys then we'll add it to the list of filtered entries. - var filteredEntries []types.StateEntry - for _, entry := range stateEntries { - if _, ok := stateKeys[entry.EventStateKeyNID]; ok { - filteredEntries = append(filteredEntries, entry) - } - } - - if len(filteredEntries) == 0 { - return false, nil - } - - stateAtEvent, err := LoadStateEvents(ctx, db, filteredEntries) + stateAtEvent, err := db.GetHistoryVisibilityState(ctx, info, eventID, string(serverName)) if err != nil { return false, err } - return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil } diff --git a/roomserver/state/state.go b/roomserver/state/state.go index d1d24b099..ca0c69f27 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -124,6 +124,29 @@ func (v *StateResolution) LoadStateAtEvent( return stateEntries, nil } +// LoadStateAtEvent loads the full state of a room before a particular event. +func (v *StateResolution) LoadStateAtEventForHistoryVisibility( + ctx context.Context, eventID string, +) ([]types.StateEntry, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadStateAtEvent") + defer span.Finish() + + snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID) + if err != nil { + return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %w", eventID, err) + } + if snapshotNID == 0 { + return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID(%s) returned 0 NID, was this event stored?", eventID) + } + + stateEntries, err := v.LoadStateAtSnapshot(ctx, snapshotNID) + if err != nil { + return nil, err + } + + return stateEntries, nil +} + // LoadCombinedStateAfterEvents loads a snapshot of the state after each of the events // and combines those snapshots together into a single list. At this point it is // possible to run into duplicate (type, state key) tuples. diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index a98fda073..374776b49 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -166,4 +166,6 @@ type Database interface { GetKnownRooms(ctx context.Context) ([]string, error) // ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error + + lityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]*gomatrixserverlib.Event, error) } diff --git a/roomserver/storage/postgres/state_snapshot_table.go b/roomserver/storage/postgres/state_snapshot_table.go index a24b7f3f0..f865ef0f6 100644 --- a/roomserver/storage/postgres/state_snapshot_table.go +++ b/roomserver/storage/postgres/state_snapshot_table.go @@ -72,9 +72,32 @@ const bulkSelectStateBlockNIDsSQL = "" + "SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" + " WHERE state_snapshot_nid = ANY($1) ORDER BY state_snapshot_nid ASC" +// Looks up both the history visibility event and relevant membership events from +// a given domain name from a given state snapshot. This is used to optimise the +// helpers.CheckServerAllowedToSeeEvent function. +// TODO: There's a sequence scan here because of the hash join strategy, which is +// probably O(n) on state key entries, so there must be a way to avoid that somehow. +const bulkSelectStateForHistoryVisibilitySQL = ` +SELECT event_nid FROM ( + SELECT event_nid, event_type_nid, event_state_key_nid FROM roomserver_events + WHERE (event_type_nid = 5 OR event_type_nid = 7) + AND event_nid = ANY( + SELECT UNNEST(event_nids) FROM roomserver_state_block + WHERE state_block_nid = ANY( + SELECT UNNEST(state_block_nids) FROM roomserver_state_snapshots + WHERE state_snapshot_nid = $1 + ) + ) +) AS roomserver_events +INNER JOIN roomserver_event_state_keys + ON roomserver_events.event_state_key_nid = roomserver_event_state_keys.event_state_key_nid + AND (event_type_nid = 7 OR event_state_key LIKE '%:$2'); +` + type stateSnapshotStatements struct { - insertStateStmt *sql.Stmt - bulkSelectStateBlockNIDsStmt *sql.Stmt + insertStateStmt *sql.Stmt + bulkSelectStateBlockNIDsStmt *sql.Stmt + bulkSelectStateForHistoryVisibilityStmt *sql.Stmt } func CreateStateSnapshotTable(db *sql.DB) error { @@ -88,6 +111,7 @@ func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { return s, sqlutil.StatementList{ {&s.insertStateStmt, insertStateSQL}, {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, + {&s.bulkSelectStateForHistoryVisibilityStmt, bulkSelectStateForHistoryVisibilitySQL}, }.Prepare(db) } @@ -136,3 +160,26 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( } return results, nil } + +func (s *stateSnapshotStatements) BulkSelectStateForHistoryVisibility( + ctx context.Context, txn *sql.Tx, stateSnapshotNID types.StateSnapshotNID, domain string, +) ([]types.EventNID, error) { + stmt := sqlutil.TxStmt(txn, s.bulkSelectStateForHistoryVisibilityStmt) + rows, err := stmt.QueryContext(ctx, stateSnapshotNID, domain) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + results := make([]types.EventNID, 0, 16) + for i := 0; rows.Next(); i++ { + var eventNID types.EventNID + if err = rows.Scan(&eventNID); err != nil { + return nil, err + } + results = append(results, eventNID) + } + if err = rows.Err(); err != nil { + return nil, err + } + return results, nil +} diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 2f9932ff8..e3d09f3a8 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -988,6 +988,34 @@ func (d *Database) loadEvent(ctx context.Context, eventID string) *types.Event { return &evs[0] } +func (d *Database) GetHistoryVisibilityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]*gomatrixserverlib.Event, error) { + eventStates, err := d.EventsTable.BulkSelectStateAtEventByID(ctx, nil, []string{eventID}) + if err != nil { + return nil, err + } + eventNIDs, err := d.StateSnapshotTable.BulkSelectStateForHistoryVisibility(ctx, nil, eventStates[0].BeforeStateSnapshotNID, domain) + if err != nil { + return nil, err + } + eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs) + if err != nil { + eventIDs = map[types.EventNID]string{} + } + events := make([]*gomatrixserverlib.Event, 0, len(eventNIDs)) + for _, eventNID := range eventNIDs { + data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, []types.EventNID{eventNID}) + if err != nil { + return nil, err + } + ev, err := gomatrixserverlib.NewEventFromTrustedJSONWithEventID(eventIDs[eventNID], data[0].EventJSON, false, roomInfo.RoomVersion) + if err != nil { + return nil, err + } + events = append(events, ev) + } + return events, nil +} + // GetStateEvent returns the current state event of a given type for a given room with a given state key // If no event could be found, returns nil // If there was an issue during the retrieval, returns an error diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index b8136b758..d2b2875ae 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -140,3 +140,9 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( } return results, nil } + +func (s *stateSnapshotStatements) BulkSelectStateForHistoryVisibility( + ctx context.Context, txn *sql.Tx, stateSnapshotNID types.StateSnapshotNID, domain string, +) ([]types.EventNID, error) { + return nil, nil +} diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index d228257b7..a726af243 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -80,6 +80,7 @@ type Rooms interface { type StateSnapshot interface { InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs types.StateBlockNIDs) (stateNID types.StateSnapshotNID, err error) BulkSelectStateBlockNIDs(ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) + BulkSelectStateForHistoryVisibility(ctx context.Context, txn *sql.Tx, stateSnapshotNID types.StateSnapshotNID, domain string) ([]types.EventNID, error) } type StateBlock interface {