diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/state.go b/src/github.com/matrix-org/dendrite/roomserver/input/state.go index 590bc2677..d917180aa 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/input/state.go +++ b/src/github.com/matrix-org/dendrite/roomserver/input/state.go @@ -6,6 +6,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + "sort" ) // calculateAndStoreState calculates a snapshot of the state of a room before an event. @@ -99,7 +100,16 @@ func calculateAndStoreStateAfterManyEvents(db RoomEventDatabase, roomNID types.R if len(conflicts) > 0 { // 5) There are conflicting state events, for each conflict workout // what the appropriate state event is. - resolved, err := resolveConflicts(db, combined, conflicts) + + // Work out which entries aren't conflicted. + var notConflicted []types.StateEntry + for _, entry := range combined { + if _, ok := stateEntryMap(conflicts).lookup(entry.StateKeyTuple); !ok { + notConflicted = append(notConflicted, entry) + } + } + + resolved, err := resolveConflicts(db, notConflicted, conflicts) if err != nil { return 0, err } @@ -114,8 +124,88 @@ func calculateAndStoreStateAfterManyEvents(db RoomEventDatabase, roomNID types.R return db.AddState(roomNID, nil, state) } -func resolveConflicts(db RoomEventDatabase, combined, conflicted []types.StateEntry) ([]types.StateEntry, error) { - panic(fmt.Errorf("Not implemented")) +// loadStateEvents loads the matrix events for a list of state entries. +// Returns a list of state events in no particular order and a map from string event ID back to state entry. +// The map can be used to recover which numeric state entry a given event is for. +// Returns an error if there was a problem talking to the database. +func loadStateEvents(db RoomEventDatabase, entries []types.StateEntry) ([]gomatrixserverlib.Event, map[string]types.StateEntry, error) { + eventNIDs := make([]types.EventNID, len(entries)) + for i := range entries { + eventNIDs[i] = entries[i].EventNID + } + events, err := db.Events(eventNIDs) + if err != nil { + return nil, nil, err + } + eventIDMap := map[string]types.StateEntry{} + result := make([]gomatrixserverlib.Event, len(entries)) + for i := range entries { + event, ok := eventMap(events).lookup(entries[i].EventNID) + if !ok { + panic(fmt.Errorf("Corrupt DB: Missing event numeric ID %d", entries[i].EventNID)) + } + result[i] = event.Event + eventIDMap[event.Event.EventID()] = entries[i] + } + return result, eventIDMap, nil +} + +// resolveConflicts resolves a list of conflicted state entries. It takes two lists. +// The first is a list of all state entries that are not conflicted. +// The second is a list of all state entries that are conflicted +// A state entry is conflicted when there is more than one numeric event ID for the same state key tuple. +// Returns a list that combines the entries without conflicts with the result of state resolution for the entries with conflicts. +// The returned list is sorted by state key tuple. +// Returns an error if there was a problem talking to the database. +func resolveConflicts(db RoomEventDatabase, notConflicted, conflicted []types.StateEntry) ([]types.StateEntry, error) { + + // Load the conflicted events + conflictedEvents, eventIDMap, err := loadStateEvents(db, conflicted) + if err != nil { + return nil, err + } + + // Work out which auth events we need to load. + needed := gomatrixserverlib.StateNeededForAuth(conflictedEvents) + + // Find the numeric IDs for the necessary state keys. + var neededStateKeys []string + neededStateKeys = append(neededStateKeys, needed.Member...) + neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...) + stateKeyNIDMap, err := db.EventStateKeyNIDs(neededStateKeys) + if err != nil { + return nil, err + } + + // Load the necessary auth events. + tuplesNeeded := stateKeyTuplesNeeded(stateKeyNIDMap, needed) + var authEntries []types.StateEntry + for _, tuple := range tuplesNeeded { + if eventNID, ok := stateEntryMap(notConflicted).lookup(tuple); ok { + authEntries = append(authEntries, types.StateEntry{tuple, eventNID}) + } + } + authEvents, _, err := loadStateEvents(db, authEntries) + if err != nil { + return nil, err + } + + // Resolve the conflicts. + fmt.Println("Resolving", len(conflicted), "conflicts with", len(authEvents), "authEvents") + resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents) + + // Map from the full events back to numeric state entries. + for _, resolvedEvent := range resolvedEvents { + entry, ok := eventIDMap[resolvedEvent.EventID()] + if !ok { + panic(fmt.Errorf("Missing state entry for event ID %q", resolvedEvent.EventID())) + } + notConflicted = append(notConflicted, entry) + } + + // Sort the result so it can be searched. + sort.Sort(stateEntrySorter(notConflicted)) + return notConflicted, nil } // findDuplicateStateKeys finds the state entries where the state key tuple appears more than once in a sorted list.