mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-09 07:03:10 -06:00
Try optimising checking if server is allowed to see event
This commit is contained in:
parent
962b76da44
commit
2c2b1ef1d6
|
|
@ -2,10 +2,7 @@ package helpers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/roomserver/auth"
|
"github.com/matrix-org/dendrite/roomserver/auth"
|
||||||
|
|
@ -236,55 +233,61 @@ func LoadStateEvents(
|
||||||
func CheckServerAllowedToSeeEvent(
|
func CheckServerAllowedToSeeEvent(
|
||||||
ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool,
|
ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool,
|
||||||
) (bool, error) {
|
) (bool, error) {
|
||||||
roomState := state.NewStateResolution(db, info)
|
/*
|
||||||
stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID)
|
roomState := state.NewStateResolution(db, info)
|
||||||
if err != nil {
|
stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID)
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, fmt.Errorf("roomState.LoadStateAtEvent: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract all of the event state key NIDs from the room state.
|
||||||
|
var stateKeyNIDs []types.EventStateKeyNID
|
||||||
|
for _, entry := range stateEntries {
|
||||||
|
stateKeyNIDs = append(stateKeyNIDs, entry.EventStateKeyNID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then request those state key NIDs from the database.
|
||||||
|
stateKeys, err := db.EventStateKeys(ctx, stateKeyNIDs)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("db.EventStateKeys: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the event state key doesn't match the given servername
|
||||||
|
// then we'll filter it out. This does preserve state keys that
|
||||||
|
// are "" since these will contain history visibility etc.
|
||||||
|
for nid, key := range stateKeys {
|
||||||
|
if key != "" && !strings.HasSuffix(key, ":"+string(serverName)) {
|
||||||
|
delete(stateKeys, nid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now filter through all of the state events for the room.
|
||||||
|
// If the state key NID appears in the list of valid state
|
||||||
|
// keys then we'll add it to the list of filtered entries.
|
||||||
|
var filteredEntries []types.StateEntry
|
||||||
|
for _, entry := range stateEntries {
|
||||||
|
if _, ok := stateKeys[entry.EventStateKeyNID]; ok {
|
||||||
|
filteredEntries = append(filteredEntries, entry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(filteredEntries) == 0 {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
return false, fmt.Errorf("roomState.LoadStateAtEvent: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract all of the event state key NIDs from the room state.
|
stateAtEvent, err := LoadStateEvents(ctx, db, filteredEntries)
|
||||||
var stateKeyNIDs []types.EventStateKeyNID
|
if err != nil {
|
||||||
for _, entry := range stateEntries {
|
return false, err
|
||||||
stateKeyNIDs = append(stateKeyNIDs, entry.EventStateKeyNID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Then request those state key NIDs from the database.
|
|
||||||
stateKeys, err := db.EventStateKeys(ctx, stateKeyNIDs)
|
|
||||||
if err != nil {
|
|
||||||
return false, fmt.Errorf("db.EventStateKeys: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the event state key doesn't match the given servername
|
|
||||||
// then we'll filter it out. This does preserve state keys that
|
|
||||||
// are "" since these will contain history visibility etc.
|
|
||||||
for nid, key := range stateKeys {
|
|
||||||
if key != "" && !strings.HasSuffix(key, ":"+string(serverName)) {
|
|
||||||
delete(stateKeys, nid)
|
|
||||||
}
|
}
|
||||||
}
|
*/
|
||||||
|
|
||||||
// Now filter through all of the state events for the room.
|
stateAtEvent, err := db.GetHistoryVisibilityState(ctx, info, eventID, string(serverName))
|
||||||
// If the state key NID appears in the list of valid state
|
|
||||||
// keys then we'll add it to the list of filtered entries.
|
|
||||||
var filteredEntries []types.StateEntry
|
|
||||||
for _, entry := range stateEntries {
|
|
||||||
if _, ok := stateKeys[entry.EventStateKeyNID]; ok {
|
|
||||||
filteredEntries = append(filteredEntries, entry)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(filteredEntries) == 0 {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
stateAtEvent, err := LoadStateEvents(ctx, db, filteredEntries)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil
|
return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -124,6 +124,29 @@ func (v *StateResolution) LoadStateAtEvent(
|
||||||
return stateEntries, nil
|
return stateEntries, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LoadStateAtEvent loads the full state of a room before a particular event.
|
||||||
|
func (v *StateResolution) LoadStateAtEventForHistoryVisibility(
|
||||||
|
ctx context.Context, eventID string,
|
||||||
|
) ([]types.StateEntry, error) {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadStateAtEvent")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %w", eventID, err)
|
||||||
|
}
|
||||||
|
if snapshotNID == 0 {
|
||||||
|
return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID(%s) returned 0 NID, was this event stored?", eventID)
|
||||||
|
}
|
||||||
|
|
||||||
|
stateEntries, err := v.LoadStateAtSnapshot(ctx, snapshotNID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return stateEntries, nil
|
||||||
|
}
|
||||||
|
|
||||||
// LoadCombinedStateAfterEvents loads a snapshot of the state after each of the events
|
// LoadCombinedStateAfterEvents loads a snapshot of the state after each of the events
|
||||||
// and combines those snapshots together into a single list. At this point it is
|
// and combines those snapshots together into a single list. At this point it is
|
||||||
// possible to run into duplicate (type, state key) tuples.
|
// possible to run into duplicate (type, state key) tuples.
|
||||||
|
|
|
||||||
|
|
@ -166,4 +166,6 @@ type Database interface {
|
||||||
GetKnownRooms(ctx context.Context) ([]string, error)
|
GetKnownRooms(ctx context.Context) ([]string, error)
|
||||||
// ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room
|
// ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room
|
||||||
ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error
|
ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error
|
||||||
|
|
||||||
|
lityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]*gomatrixserverlib.Event, error)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -72,9 +72,32 @@ const bulkSelectStateBlockNIDsSQL = "" +
|
||||||
"SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" +
|
"SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" +
|
||||||
" WHERE state_snapshot_nid = ANY($1) ORDER BY state_snapshot_nid ASC"
|
" WHERE state_snapshot_nid = ANY($1) ORDER BY state_snapshot_nid ASC"
|
||||||
|
|
||||||
|
// Looks up both the history visibility event and relevant membership events from
|
||||||
|
// a given domain name from a given state snapshot. This is used to optimise the
|
||||||
|
// helpers.CheckServerAllowedToSeeEvent function.
|
||||||
|
// TODO: There's a sequence scan here because of the hash join strategy, which is
|
||||||
|
// probably O(n) on state key entries, so there must be a way to avoid that somehow.
|
||||||
|
const bulkSelectStateForHistoryVisibilitySQL = `
|
||||||
|
SELECT event_nid FROM (
|
||||||
|
SELECT event_nid, event_type_nid, event_state_key_nid FROM roomserver_events
|
||||||
|
WHERE (event_type_nid = 5 OR event_type_nid = 7)
|
||||||
|
AND event_nid = ANY(
|
||||||
|
SELECT UNNEST(event_nids) FROM roomserver_state_block
|
||||||
|
WHERE state_block_nid = ANY(
|
||||||
|
SELECT UNNEST(state_block_nids) FROM roomserver_state_snapshots
|
||||||
|
WHERE state_snapshot_nid = $1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
) AS roomserver_events
|
||||||
|
INNER JOIN roomserver_event_state_keys
|
||||||
|
ON roomserver_events.event_state_key_nid = roomserver_event_state_keys.event_state_key_nid
|
||||||
|
AND (event_type_nid = 7 OR event_state_key LIKE '%:$2');
|
||||||
|
`
|
||||||
|
|
||||||
type stateSnapshotStatements struct {
|
type stateSnapshotStatements struct {
|
||||||
insertStateStmt *sql.Stmt
|
insertStateStmt *sql.Stmt
|
||||||
bulkSelectStateBlockNIDsStmt *sql.Stmt
|
bulkSelectStateBlockNIDsStmt *sql.Stmt
|
||||||
|
bulkSelectStateForHistoryVisibilityStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateStateSnapshotTable(db *sql.DB) error {
|
func CreateStateSnapshotTable(db *sql.DB) error {
|
||||||
|
|
@ -88,6 +111,7 @@ func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
|
||||||
return s, sqlutil.StatementList{
|
return s, sqlutil.StatementList{
|
||||||
{&s.insertStateStmt, insertStateSQL},
|
{&s.insertStateStmt, insertStateSQL},
|
||||||
{&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL},
|
{&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL},
|
||||||
|
{&s.bulkSelectStateForHistoryVisibilityStmt, bulkSelectStateForHistoryVisibilitySQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -136,3 +160,26 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
|
||||||
}
|
}
|
||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *stateSnapshotStatements) BulkSelectStateForHistoryVisibility(
|
||||||
|
ctx context.Context, txn *sql.Tx, stateSnapshotNID types.StateSnapshotNID, domain string,
|
||||||
|
) ([]types.EventNID, error) {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.bulkSelectStateForHistoryVisibilityStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, stateSnapshotNID, domain)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close() // nolint: errcheck
|
||||||
|
results := make([]types.EventNID, 0, 16)
|
||||||
|
for i := 0; rows.Next(); i++ {
|
||||||
|
var eventNID types.EventNID
|
||||||
|
if err = rows.Scan(&eventNID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
results = append(results, eventNID)
|
||||||
|
}
|
||||||
|
if err = rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -988,6 +988,34 @@ func (d *Database) loadEvent(ctx context.Context, eventID string) *types.Event {
|
||||||
return &evs[0]
|
return &evs[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Database) GetHistoryVisibilityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]*gomatrixserverlib.Event, error) {
|
||||||
|
eventStates, err := d.EventsTable.BulkSelectStateAtEventByID(ctx, nil, []string{eventID})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
eventNIDs, err := d.StateSnapshotTable.BulkSelectStateForHistoryVisibility(ctx, nil, eventStates[0].BeforeStateSnapshotNID, domain)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
|
||||||
|
if err != nil {
|
||||||
|
eventIDs = map[types.EventNID]string{}
|
||||||
|
}
|
||||||
|
events := make([]*gomatrixserverlib.Event, 0, len(eventNIDs))
|
||||||
|
for _, eventNID := range eventNIDs {
|
||||||
|
data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, []types.EventNID{eventNID})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ev, err := gomatrixserverlib.NewEventFromTrustedJSONWithEventID(eventIDs[eventNID], data[0].EventJSON, false, roomInfo.RoomVersion)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
events = append(events, ev)
|
||||||
|
}
|
||||||
|
return events, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetStateEvent returns the current state event of a given type for a given room with a given state key
|
// GetStateEvent returns the current state event of a given type for a given room with a given state key
|
||||||
// If no event could be found, returns nil
|
// If no event could be found, returns nil
|
||||||
// If there was an issue during the retrieval, returns an error
|
// If there was an issue during the retrieval, returns an error
|
||||||
|
|
|
||||||
|
|
@ -140,3 +140,9 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
|
||||||
}
|
}
|
||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *stateSnapshotStatements) BulkSelectStateForHistoryVisibility(
|
||||||
|
ctx context.Context, txn *sql.Tx, stateSnapshotNID types.StateSnapshotNID, domain string,
|
||||||
|
) ([]types.EventNID, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -80,6 +80,7 @@ type Rooms interface {
|
||||||
type StateSnapshot interface {
|
type StateSnapshot interface {
|
||||||
InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs types.StateBlockNIDs) (stateNID types.StateSnapshotNID, err error)
|
InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs types.StateBlockNIDs) (stateNID types.StateSnapshotNID, err error)
|
||||||
BulkSelectStateBlockNIDs(ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
|
BulkSelectStateBlockNIDs(ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
|
||||||
|
BulkSelectStateForHistoryVisibility(ctx context.Context, txn *sql.Tx, stateSnapshotNID types.StateSnapshotNID, domain string) ([]types.EventNID, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type StateBlock interface {
|
type StateBlock interface {
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue