diff --git a/src/github.com/matrix-org/dendrite/roomserver/state/state.go b/src/github.com/matrix-org/dendrite/roomserver/state/state.go index e9657fd91..76092af78 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/state/state.go +++ b/src/github.com/matrix-org/dendrite/roomserver/state/state.go @@ -316,6 +316,76 @@ func loadStateAtSnapshotForNumericTuples( return fullState, nil } +// LoadStateAfterEventsForStringTuples loads the state for a list of event type +// and state key pairs after list of events. +// This is used when we only want to load a subset of the room state after a list of events. +// If there is no entry for a given event type and state key pair then it will be discarded. +// This is typically the state before an event. +// Returns a sorted list of state entries or an error if there was a problem talking to the database. +func LoadStateAfterEventsForStringTuples( + db RoomStateDatabase, prevStates []types.StateAtEvent, stateKeyTuples []gomatrixserverlib.StateKeyTuple, +) ([]types.StateEntry, error) { + numericTuples, err := stringTuplesToNumericTuples(db, stateKeyTuples) + if err != nil { + return nil, err + } + return loadStateAfterEventsForNumericTuples(db, prevStates, numericTuples) +} + +func loadStateAfterEventsForNumericTuples( + db RoomStateDatabase, prevStates []types.StateAtEvent, stateKeyTuples []types.StateKeyTuple, +) ([]types.StateEntry, error) { + if len(prevStates) == 1 { + // Fast path for a single event. + prevState := prevStates[0] + result, err := loadStateAtSnapshotForNumericTuples( + db, prevState.BeforeStateSnapshotNID, stateKeyTuples, + ) + if err != nil { + return nil, err + } + if prevState.IsStateEvent() { + for i := range result { + if result[i].StateKeyTuple == prevState.StateKeyTuple { + result[i] = prevState.StateEntry + } + } + } + return result, nil + } + + // Slow path for more that one event. + // Load the entire state so that we can do conflict resolution if we need to. + // TODO: The are some optimistations we could do here: + // 1) We only need to do conflict resolution if the is a conflict in the + // requested tuples so we might try loading just those tuples and then + // checking for conflicts. + // 2) When there is a conflict we still only need to load the state + // needed to do conflict resolution which would save us having to load + // the full state. + + // TODO: Add metrics for this as it could take a long time for big rooms + // with large conflicts. + fullState, _, _, err := calculateStateAfterManyEvents(db, prevStates) + if err != nil { + return nil, err + } + + // Sort the full state so we can use it as a map. + sort.Sort(stateEntrySorter(fullState)) + + // Filter the full state down to the required tuples. + var result []types.StateEntry + for _, tuple := range stateKeyTuples { + eventNID, ok := stateEntryMap(fullState).lookup(tuple) + if ok { + result = append(result, types.StateEntry{tuple, eventNID}) + } + } + sort.Sort(stateEntrySorter(result)) + return result, nil +} + var calculateStateDurations = prometheus.NewSummaryVec( prometheus.SummaryOpts{ Namespace: "dendrite", @@ -491,12 +561,30 @@ const maxStateBlockNIDs = 64 func calculateAndStoreStateAfterManyEvents( db RoomStateDatabase, roomNID types.RoomNID, prevStates []types.StateAtEvent, metrics calculateStateMetrics, ) (types.StateSnapshotNID, error) { + + state, algorithm, conflictLength, err := calculateStateAfterManyEvents(db, prevStates) + metrics.algorithm = algorithm + if err != nil { + return metrics.stop(0, err) + } + + // TODO: Check if we can encode the new state as a delta against the + // previous state. + metrics.conflictLength = conflictLength + metrics.fullStateLength = len(state) + return metrics.stop(db.AddState(roomNID, nil, state)) +} + +func calculateStateAfterManyEvents( + db RoomStateDatabase, prevStates []types.StateAtEvent, +) (state []types.StateEntry, algorithm string, conflictLength int, err error) { + var combined []types.StateEntry // Conflict resolution. // First stage: load the state after each of the prev events. - combined, err := LoadCombinedStateAfterEvents(db, prevStates) + combined, err = LoadCombinedStateAfterEvents(db, prevStates) if err != nil { - metrics.algorithm = "_load_combined_state" - return metrics.stop(0, err) + algorithm = "_load_combined_state" + return } // Collect all the entries with the same type and key together. @@ -508,9 +596,8 @@ func calculateAndStoreStateAfterManyEvents( // Find the conflicts conflicts := findDuplicateStateKeys(combined) - var state []types.StateEntry if len(conflicts) > 0 { - metrics.conflictLength = len(conflicts) + conflictLength = len(conflicts) // 5) There are conflicting state events, for each conflict workout // what the appropriate state event is. @@ -523,23 +610,20 @@ func calculateAndStoreStateAfterManyEvents( } } - resolved, err := resolveConflicts(db, notConflicted, conflicts) + var resolved []types.StateEntry + resolved, err = resolveConflicts(db, notConflicted, conflicts) if err != nil { - metrics.algorithm = "_resolve_conflicts" - return metrics.stop(0, err) + algorithm = "_resolve_conflicts" + return } - metrics.algorithm = "full_state_with_conflicts" + algorithm = "full_state_with_conflicts" state = resolved } else { - metrics.algorithm = "full_state_no_conflicts" + algorithm = "full_state_no_conflicts" // 6) There weren't any conflicts state = combined } - metrics.fullStateLength = len(state) - - // TODO: Check if we can encode the new state as a delta against the - // previous state. - return metrics.stop(db.AddState(roomNID, nil, state)) + return } // resolveConflicts resolves a list of conflicted state entries. It takes two lists.