From 6ccfee63af057a4ee7fe6c2420a127a60a3fa704 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 5 Oct 2020 16:46:28 +0100 Subject: [PATCH] Resolve state after event against current room state when determining latest state changes --- .../internal/input/input_latest_events.go | 23 ++++++++++++++++--- roomserver/state/state.go | 8 ++++--- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go index 5c2a1de6a..2e9f3b4e4 100644 --- a/roomserver/internal/input/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -215,10 +215,27 @@ func (u *latestEventsUpdater) latestState() error { var err error roomState := state.NewStateResolution(u.api.DB, *u.roomInfo) - // Get a list of the current latest events. - latestStateAtEvents := make([]types.StateAtEvent, len(u.latest)) + // Get a list of the current room state events if available. + var currentState []types.StateEntry + if u.roomInfo.StateSnapshotNID != 0 { + currentState, _ = roomState.LoadStateAtSnapshot(u.ctx, u.roomInfo.StateSnapshotNID) + } + + // Get a list of the current latest events. This will include both + // the current room state and the latest events after the input event. + // The idea is that we will perform state resolution on this set and + // any conflicting events will be resolved properly. + latestStateAtEvents := make([]types.StateAtEvent, len(u.latest)+len(currentState)) + offset := 0 + for i := range currentState { + latestStateAtEvents[i] = types.StateAtEvent{ + BeforeStateSnapshotNID: u.roomInfo.StateSnapshotNID, + StateEntry: currentState[i], + } + offset++ + } for i := range u.latest { - latestStateAtEvents[i] = u.latest[i].StateAtEvent + latestStateAtEvents[offset+i] = u.latest[i].StateAtEvent } // Takes the NIDs of the latest events and creates a state snapshot diff --git a/roomserver/state/state.go b/roomserver/state/state.go index 0663499e7..2944f71c1 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -118,7 +118,7 @@ func (v StateResolution) LoadCombinedStateAfterEvents( // the snapshot of the room state before them was the same. stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, uniqueStateSnapshotNIDs(stateNIDs)) if err != nil { - return nil, err + return nil, fmt.Errorf("v.db.StateBlockNIDs: %w", err) } var stateBlockNIDs []types.StateBlockNID @@ -131,7 +131,7 @@ func (v StateResolution) LoadCombinedStateAfterEvents( // multiple snapshots. stateEntryLists, err := v.db.StateEntries(ctx, uniqueStateBlockNIDs(stateBlockNIDs)) if err != nil { - return nil, err + return nil, fmt.Errorf("v.db.StateEntries: %w", err) } stateBlockNIDsMap := stateBlockNIDListMap(stateBlockNIDLists) stateEntriesMap := stateEntryListMap(stateEntryLists) @@ -623,7 +623,7 @@ func (v StateResolution) calculateAndStoreStateAfterManyEvents( v.calculateStateAfterManyEvents(ctx, v.roomInfo.RoomVersion, prevStates) metrics.algorithm = algorithm if err != nil { - return metrics.stop(0, err) + return metrics.stop(0, fmt.Errorf("v.calculateStateAfterManyEvents: %w", err)) } // TODO: Check if we can encode the new state as a delta against the @@ -642,6 +642,7 @@ func (v StateResolution) calculateStateAfterManyEvents( // First stage: load the state after each of the prev events. combined, err = v.LoadCombinedStateAfterEvents(ctx, prevStates) if err != nil { + err = fmt.Errorf("v.LoadCombinedStateAfterEvents: %w", err) algorithm = "_load_combined_state" return } @@ -672,6 +673,7 @@ func (v StateResolution) calculateStateAfterManyEvents( var resolved []types.StateEntry resolved, err = v.resolveConflicts(ctx, roomVersion, notConflicted, conflicts) if err != nil { + err = fmt.Errorf("v.resolveConflits: %w", err) algorithm = "_resolve_conflicts" return }