Optimise checking other servers allowed to see events (#2596)

* Try optimising checking if server is allowed to see event

* Fix error

* Handle case where snapshot NID is 0

* Fix query

* Update SQL

* Clean up `CheckServerAllowedToSeeEvent`

* Not supported on SQLite

* Maybe placate the unit tests

* Review comments
This commit is contained in:
Neil Alexander 2022-08-01 14:11:00 +01:00 committed by GitHub
parent c7f7aec4d0
commit 05c83923e3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 154 additions and 12 deletions

View file

@ -236,13 +236,34 @@ func LoadStateEvents(
func CheckServerAllowedToSeeEvent(
ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool,
) (bool, error) {
stateAtEvent, err := db.GetHistoryVisibilityState(ctx, info, eventID, string(serverName))
switch err {
case nil:
// No error, so continue normally
case tables.OptimisationNotSupportedError:
// The database engine didn't support this optimisation, so fall back to using
// the old and slow method
stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, eventID, serverName)
if err != nil {
return false, err
}
default:
// Something else went wrong
return false, err
}
return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil
}
func slowGetHistoryVisibilityState(
ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName,
) ([]*gomatrixserverlib.Event, error) {
roomState := state.NewStateResolution(db, info)
stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return false, nil
return nil, nil
}
return false, fmt.Errorf("roomState.LoadStateAtEvent: %w", err)
return nil, fmt.Errorf("roomState.LoadStateAtEvent: %w", err)
}
// Extract all of the event state key NIDs from the room state.
@ -254,7 +275,7 @@ func CheckServerAllowedToSeeEvent(
// Then request those state key NIDs from the database.
stateKeys, err := db.EventStateKeys(ctx, stateKeyNIDs)
if err != nil {
return false, fmt.Errorf("db.EventStateKeys: %w", err)
return nil, fmt.Errorf("db.EventStateKeys: %w", err)
}
// If the event state key doesn't match the given servername
@ -277,15 +298,10 @@ func CheckServerAllowedToSeeEvent(
}
if len(filteredEntries) == 0 {
return false, nil
return nil, nil
}
stateAtEvent, err := LoadStateEvents(ctx, db, filteredEntries)
if err != nil {
return false, err
}
return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil
return LoadStateEvents(ctx, db, filteredEntries)
}
// TODO: Remove this when we have tests to assert correctness of this function

View file

@ -124,6 +124,29 @@ func (v *StateResolution) LoadStateAtEvent(
return stateEntries, nil
}
// LoadStateAtEvent loads the full state of a room before a particular event.
func (v *StateResolution) LoadStateAtEventForHistoryVisibility(
ctx context.Context, eventID string,
) ([]types.StateEntry, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadStateAtEvent")
defer span.Finish()
snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID)
if err != nil {
return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %w", eventID, err)
}
if snapshotNID == 0 {
return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID(%s) returned 0 NID, was this event stored?", eventID)
}
stateEntries, err := v.LoadStateAtSnapshot(ctx, snapshotNID)
if err != nil {
return nil, err
}
return stateEntries, nil
}
// LoadCombinedStateAfterEvents loads a snapshot of the state after each of the events
// and combines those snapshots together into a single list. At this point it is
// possible to run into duplicate (type, state key) tuples.

View file

@ -166,4 +166,6 @@ type Database interface {
GetKnownRooms(ctx context.Context) ([]string, error)
// ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room
ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error
GetHistoryVisibilityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]*gomatrixserverlib.Event, error)
}

View file

@ -72,9 +72,35 @@ const bulkSelectStateBlockNIDsSQL = "" +
"SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" +
" WHERE state_snapshot_nid = ANY($1) ORDER BY state_snapshot_nid ASC"
// Looks up both the history visibility event and relevant membership events from
// a given domain name from a given state snapshot. This is used to optimise the
// helpers.CheckServerAllowedToSeeEvent function.
// TODO: There's a sequence scan here because of the hash join strategy, which is
// probably O(n) on state key entries, so there must be a way to avoid that somehow.
// Event type NIDs are:
// - 5: m.room.member as per https://github.com/matrix-org/dendrite/blob/c7f7aec4d07d59120d37d5b16a900f6d608a75c4/roomserver/storage/postgres/event_types_table.go#L40
// - 7: m.room.history_visibility as per https://github.com/matrix-org/dendrite/blob/c7f7aec4d07d59120d37d5b16a900f6d608a75c4/roomserver/storage/postgres/event_types_table.go#L42
const bulkSelectStateForHistoryVisibilitySQL = `
SELECT event_nid FROM (
SELECT event_nid, event_type_nid, event_state_key_nid FROM roomserver_events
WHERE (event_type_nid = 5 OR event_type_nid = 7)
AND event_nid = ANY(
SELECT UNNEST(event_nids) FROM roomserver_state_block
WHERE state_block_nid = ANY(
SELECT UNNEST(state_block_nids) FROM roomserver_state_snapshots
WHERE state_snapshot_nid = $1
)
)
) AS roomserver_events
INNER JOIN roomserver_event_state_keys
ON roomserver_events.event_state_key_nid = roomserver_event_state_keys.event_state_key_nid
AND (event_type_nid = 7 OR event_state_key LIKE '%:' || $2);
`
type stateSnapshotStatements struct {
insertStateStmt *sql.Stmt
bulkSelectStateBlockNIDsStmt *sql.Stmt
insertStateStmt *sql.Stmt
bulkSelectStateBlockNIDsStmt *sql.Stmt
bulkSelectStateForHistoryVisibilityStmt *sql.Stmt
}
func CreateStateSnapshotTable(db *sql.DB) error {
@ -88,6 +114,7 @@ func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
return s, sqlutil.StatementList{
{&s.insertStateStmt, insertStateSQL},
{&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL},
{&s.bulkSelectStateForHistoryVisibilityStmt, bulkSelectStateForHistoryVisibilitySQL},
}.Prepare(db)
}
@ -136,3 +163,23 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
}
return results, nil
}
func (s *stateSnapshotStatements) BulkSelectStateForHistoryVisibility(
ctx context.Context, txn *sql.Tx, stateSnapshotNID types.StateSnapshotNID, domain string,
) ([]types.EventNID, error) {
stmt := sqlutil.TxStmt(txn, s.bulkSelectStateForHistoryVisibilityStmt)
rows, err := stmt.QueryContext(ctx, stateSnapshotNID, domain)
if err != nil {
return nil, err
}
defer rows.Close() // nolint: errcheck
results := make([]types.EventNID, 0, 16)
for rows.Next() {
var eventNID types.EventNID
if err = rows.Scan(&eventNID); err != nil {
return nil, err
}
results = append(results, eventNID)
}
return results, rows.Err()
}

View file

@ -988,6 +988,38 @@ func (d *Database) loadEvent(ctx context.Context, eventID string) *types.Event {
return &evs[0]
}
func (d *Database) GetHistoryVisibilityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]*gomatrixserverlib.Event, error) {
eventStates, err := d.EventsTable.BulkSelectStateAtEventByID(ctx, nil, []string{eventID})
if err != nil {
return nil, err
}
stateSnapshotNID := eventStates[0].BeforeStateSnapshotNID
if stateSnapshotNID == 0 {
return nil, nil
}
eventNIDs, err := d.StateSnapshotTable.BulkSelectStateForHistoryVisibility(ctx, nil, stateSnapshotNID, domain)
if err != nil {
return nil, err
}
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
if err != nil {
eventIDs = map[types.EventNID]string{}
}
events := make([]*gomatrixserverlib.Event, 0, len(eventNIDs))
for _, eventNID := range eventNIDs {
data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, []types.EventNID{eventNID})
if err != nil {
return nil, err
}
ev, err := gomatrixserverlib.NewEventFromTrustedJSONWithEventID(eventIDs[eventNID], data[0].EventJSON, false, roomInfo.RoomVersion)
if err != nil {
return nil, err
}
events = append(events, ev)
}
return events, nil
}
// GetStateEvent returns the current state event of a given type for a given room with a given state key
// If no event could be found, returns nil
// If there was an issue during the retrieval, returns an error

View file

@ -140,3 +140,9 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
}
return results, nil
}
func (s *stateSnapshotStatements) BulkSelectStateForHistoryVisibility(
ctx context.Context, txn *sql.Tx, stateSnapshotNID types.StateSnapshotNID, domain string,
) ([]types.EventNID, error) {
return nil, tables.OptimisationNotSupportedError
}

View file

@ -3,12 +3,15 @@ package tables
import (
"context"
"database/sql"
"errors"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/tidwall/gjson"
)
var OptimisationNotSupportedError = errors.New("optimisation not supported")
type EventJSONPair struct {
EventNID types.EventNID
EventJSON []byte
@ -80,6 +83,10 @@ type Rooms interface {
type StateSnapshot interface {
InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs types.StateBlockNIDs) (stateNID types.StateSnapshotNID, err error)
BulkSelectStateBlockNIDs(ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
// BulkSelectStateForHistoryVisibility is a PostgreSQL-only optimisation for finding
// which users are in a room faster than having to load the entire room state. In the
// case of SQLite, this will return tables.OptimisationNotSupportedError.
BulkSelectStateForHistoryVisibility(ctx context.Context, txn *sql.Tx, stateSnapshotNID types.StateSnapshotNID, domain string) ([]types.EventNID, error)
}
type StateBlock interface {

View file

@ -23,6 +23,15 @@ func mustCreateStateSnapshotTable(t *testing.T, dbType test.DBType) (tab tables.
assert.NoError(t, err)
switch dbType {
case test.DBTypePostgres:
// for the PostgreSQL history visibility optimisation to work,
// we also need some other tables to exist
err = postgres.CreateEventStateKeysTable(db)
assert.NoError(t, err)
err = postgres.CreateEventsTable(db)
assert.NoError(t, err)
err = postgres.CreateStateBlockTable(db)
assert.NoError(t, err)
// ... and then the snapshot table itself
err = postgres.CreateStateSnapshotTable(db)
assert.NoError(t, err)
tab, err = postgres.PrepareStateSnapshotTable(db)