From 9be71347275fc1e6b047831ca65d934329b6c04c Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 14 Feb 2020 14:54:09 +0000 Subject: [PATCH] Break out some shared functions --- roomserver/state/shared/shared.go | 139 +++++++++++++++++++++++ roomserver/state/state.go | 40 ++----- roomserver/state/v1/state.go | 182 +++++------------------------- 3 files changed, 173 insertions(+), 188 deletions(-) create mode 100644 roomserver/state/shared/shared.go diff --git a/roomserver/state/shared/shared.go b/roomserver/state/shared/shared.go new file mode 100644 index 000000000..f04f233f9 --- /dev/null +++ b/roomserver/state/shared/shared.go @@ -0,0 +1,139 @@ +package shared + +import ( + "sort" + + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/util" +) + +// findDuplicateStateKeys finds the state entries where the state key tuple appears more than once in a sorted list. +// Returns a sorted list of those state entries. +func FindDuplicateStateKeys(a []types.StateEntry) []types.StateEntry { + var result []types.StateEntry + // j is the starting index of a block of entries with the same state key tuple. + j := 0 + for i := 1; i < len(a); i++ { + // Check if the state key tuple matches the start of the block + if a[j].StateKeyTuple != a[i].StateKeyTuple { + // If the state key tuple is different then we've reached the end of a block of duplicates. + // Check if the size of the block is bigger than one. + // If the size is one then there was only a single entry with that state key tuple so we don't add it to the result + if j+1 != i { + // Add the block to the result. + result = append(result, a[j:i]...) + } + // Start a new block for the next state key tuple. + j = i + } + } + // Check if the last block with the same state key tuple had more than one event in it. + if j+1 != len(a) { + result = append(result, a[j:]...) + } + return result +} + +type StateEntrySorter []types.StateEntry + +func (s StateEntrySorter) Len() int { return len(s) } +func (s StateEntrySorter) Less(i, j int) bool { return s[i].LessThan(s[j]) } +func (s StateEntrySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +type StateBlockNIDListMap []types.StateBlockNIDList + +func (m StateBlockNIDListMap) Lookup(stateNID types.StateSnapshotNID) (stateBlockNIDs []types.StateBlockNID, ok bool) { + list := []types.StateBlockNIDList(m) + i := sort.Search(len(list), func(i int) bool { + return list[i].StateSnapshotNID >= stateNID + }) + if i < len(list) && list[i].StateSnapshotNID == stateNID { + ok = true + stateBlockNIDs = list[i].StateBlockNIDs + } + return +} + +type StateEntryListMap []types.StateEntryList + +func (m StateEntryListMap) Lookup(stateBlockNID types.StateBlockNID) (stateEntries []types.StateEntry, ok bool) { + list := []types.StateEntryList(m) + i := sort.Search(len(list), func(i int) bool { + return list[i].StateBlockNID >= stateBlockNID + }) + if i < len(list) && list[i].StateBlockNID == stateBlockNID { + ok = true + stateEntries = list[i].StateEntries + } + return +} + +type StateEntryByStateKeySorter []types.StateEntry + +func (s StateEntryByStateKeySorter) Len() int { return len(s) } +func (s StateEntryByStateKeySorter) Less(i, j int) bool { + return s[i].StateKeyTuple.LessThan(s[j].StateKeyTuple) +} +func (s StateEntryByStateKeySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +type StateNIDSorter []types.StateSnapshotNID + +func (s StateNIDSorter) Len() int { return len(s) } +func (s StateNIDSorter) Less(i, j int) bool { return s[i] < s[j] } +func (s StateNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +func UniqueStateSnapshotNIDs(nids []types.StateSnapshotNID) []types.StateSnapshotNID { + return nids[:util.SortAndUnique(StateNIDSorter(nids))] +} + +type StateBlockNIDSorter []types.StateBlockNID + +func (s StateBlockNIDSorter) Len() int { return len(s) } +func (s StateBlockNIDSorter) Less(i, j int) bool { return s[i] < s[j] } +func (s StateBlockNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +func UniqueStateBlockNIDs(nids []types.StateBlockNID) []types.StateBlockNID { + return nids[:util.SortAndUnique(StateBlockNIDSorter(nids))] +} + +// Map from event type, state key tuple to numeric event ID. +// Implemented using binary search on a sorted array. +type StateEntryMap []types.StateEntry + +// lookup an entry in the event map. +func (m StateEntryMap) Lookup(stateKey types.StateKeyTuple) (eventNID types.EventNID, ok bool) { + // Since the list is sorted we can implement this using binary search. + // This is faster than using a hash map. + // We don't have to worry about pathological cases because the keys are fixed + // size and are controlled by us. + list := []types.StateEntry(m) + i := sort.Search(len(list), func(i int) bool { + return !list[i].StateKeyTuple.LessThan(stateKey) + }) + if i < len(list) && list[i].StateKeyTuple == stateKey { + ok = true + eventNID = list[i].EventNID + } + return +} + +// Map from numeric event ID to event. +// Implemented using binary search on a sorted array. +type EventMap []types.Event + +// lookup an entry in the event map. +func (m EventMap) Lookup(eventNID types.EventNID) (event *types.Event, ok bool) { + // Since the list is sorted we can implement this using binary search. + // This is faster than using a hash map. + // We don't have to worry about pathological cases because the keys are fixed + // size are controlled by us. + list := []types.Event(m) + i := sort.Search(len(list), func(i int) bool { + return list[i].EventNID >= eventNID + }) + if i < len(list) && list[i].EventNID == eventNID { + ok = true + event = &list[i] + } + return +} diff --git a/roomserver/state/state.go b/roomserver/state/state.go index 687a120e3..05e8b4c26 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -46,36 +46,12 @@ func GetStateResolutionAlgorithm( } type StateResolutionImpl interface { - LoadStateAtSnapshot( - ctx context.Context, stateNID types.StateSnapshotNID, - ) ([]types.StateEntry, error) - LoadStateAtEvent( - ctx context.Context, eventID string, - ) ([]types.StateEntry, error) - LoadCombinedStateAfterEvents( - ctx context.Context, prevStates []types.StateAtEvent, - ) ([]types.StateEntry, error) - DifferenceBetweeenStateSnapshots( - ctx context.Context, oldStateNID, newStateNID types.StateSnapshotNID, - ) (removed, added []types.StateEntry, err error) - LoadStateAtSnapshotForStringTuples( - ctx context.Context, - stateNID types.StateSnapshotNID, - stateKeyTuples []gomatrixserverlib.StateKeyTuple, - ) ([]types.StateEntry, error) - LoadStateAfterEventsForStringTuples( - ctx context.Context, - prevStates []types.StateAtEvent, - stateKeyTuples []gomatrixserverlib.StateKeyTuple, - ) ([]types.StateEntry, error) - CalculateAndStoreStateBeforeEvent( - ctx context.Context, - event gomatrixserverlib.Event, - roomNID types.RoomNID, - ) (types.StateSnapshotNID, error) - CalculateAndStoreStateAfterEvents( - ctx context.Context, - roomNID types.RoomNID, - prevStates []types.StateAtEvent, - ) (types.StateSnapshotNID, error) + LoadStateAtSnapshot(ctx context.Context, stateNID types.StateSnapshotNID) ([]types.StateEntry, error) + LoadStateAtEvent(ctx context.Context, eventID string) ([]types.StateEntry, error) + LoadCombinedStateAfterEvents(ctx context.Context, prevStates []types.StateAtEvent) ([]types.StateEntry, error) + DifferenceBetweeenStateSnapshots(ctx context.Context, oldStateNID, newStateNID types.StateSnapshotNID) (removed, added []types.StateEntry, err error) + LoadStateAtSnapshotForStringTuples(ctx context.Context, stateNID types.StateSnapshotNID, stateKeyTuples []gomatrixserverlib.StateKeyTuple) ([]types.StateEntry, error) + LoadStateAfterEventsForStringTuples(ctx context.Context, prevStates []types.StateAtEvent, stateKeyTuples []gomatrixserverlib.StateKeyTuple) ([]types.StateEntry, error) + CalculateAndStoreStateBeforeEvent(ctx context.Context, event gomatrixserverlib.Event, roomNID types.RoomNID) (types.StateSnapshotNID, error) + CalculateAndStoreStateAfterEvents(ctx context.Context, roomNID types.RoomNID, prevStates []types.StateAtEvent) (types.StateSnapshotNID, error) } diff --git a/roomserver/state/v1/state.go b/roomserver/state/v1/state.go index 5683745bf..a2fa9bb6b 100644 --- a/roomserver/state/v1/state.go +++ b/roomserver/state/v1/state.go @@ -23,6 +23,7 @@ import ( "time" "github.com/matrix-org/dendrite/roomserver/state/database" + "github.com/matrix-org/dendrite/roomserver/state/shared" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -56,13 +57,13 @@ func (v StateResolutionV1) LoadStateAtSnapshot( if err != nil { return nil, err } - stateEntriesMap := stateEntryListMap(stateEntryLists) + stateEntriesMap := shared.StateEntryListMap(stateEntryLists) // Combine all the state entries for this snapshot. // The order of state block NIDs in the list tells us the order to combine them in. var fullState []types.StateEntry for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs { - entries, ok := stateEntriesMap.lookup(stateBlockNID) + entries, ok := stateEntriesMap.Lookup(stateBlockNID) if !ok { // This should only get hit if the database is corrupt. // It should be impossible for an event to reference a NID that doesn't exist @@ -73,9 +74,9 @@ func (v StateResolutionV1) LoadStateAtSnapshot( // Stable sort so that the most recent entry for each state key stays // remains later in the list than the older entries for the same state key. - sort.Stable(stateEntryByStateKeySorter(fullState)) + sort.Stable(shared.StateEntryByStateKeySorter(fullState)) // Unique returns the last entry and hence the most recent entry for each state key. - fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))] + fullState = fullState[:util.Unique(shared.StateEntryByStateKeySorter(fullState))] return fullState, nil } @@ -109,7 +110,7 @@ func (v StateResolutionV1) LoadCombinedStateAfterEvents( // Deduplicate the IDs before passing them to the database. // There could be duplicates because the events could be state events where // the snapshot of the room state before them was the same. - stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, uniqueStateSnapshotNIDs(stateNIDs)) + stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, shared.UniqueStateSnapshotNIDs(stateNIDs)) if err != nil { return nil, err } @@ -122,18 +123,18 @@ func (v StateResolutionV1) LoadCombinedStateAfterEvents( // Deduplicate the IDs before passing them to the database. // There could be duplicates because a block of state entries could be reused by // multiple snapshots. - stateEntryLists, err := v.db.StateEntries(ctx, uniqueStateBlockNIDs(stateBlockNIDs)) + stateEntryLists, err := v.db.StateEntries(ctx, shared.UniqueStateBlockNIDs(stateBlockNIDs)) if err != nil { return nil, err } - stateBlockNIDsMap := stateBlockNIDListMap(stateBlockNIDLists) - stateEntriesMap := stateEntryListMap(stateEntryLists) + stateBlockNIDsMap := shared.StateBlockNIDListMap(stateBlockNIDLists) + stateEntriesMap := shared.StateEntryListMap(stateEntryLists) // Combine the entries from all the snapshots of state after each prev event into a single list. var combined []types.StateEntry for _, prevState := range prevStates { // Grab the list of state data NIDs for this snapshot. - stateBlockNIDs, ok := stateBlockNIDsMap.lookup(prevState.BeforeStateSnapshotNID) + stateBlockNIDs, ok := stateBlockNIDsMap.Lookup(prevState.BeforeStateSnapshotNID) if !ok { // This should only get hit if the database is corrupt. // It should be impossible for an event to reference a NID that doesn't exist @@ -144,7 +145,7 @@ func (v StateResolutionV1) LoadCombinedStateAfterEvents( // The order of state block NIDs in the list tells us the order to combine them in. var fullState []types.StateEntry for _, stateBlockNID := range stateBlockNIDs { - entries, ok := stateEntriesMap.lookup(stateBlockNID) + entries, ok := stateEntriesMap.Lookup(stateBlockNID) if !ok { // This should only get hit if the database is corrupt. // It should be impossible for an event to reference a NID that doesn't exist @@ -160,9 +161,9 @@ func (v StateResolutionV1) LoadCombinedStateAfterEvents( // Stable sort so that the most recent entry for each state key stays // remains later in the list than the older entries for the same state key. - sort.Stable(stateEntryByStateKeySorter(fullState)) + sort.Stable(shared.StateEntryByStateKeySorter(fullState)) // Unique returns the last entry and hence the most recent entry for each state key. - fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))] + fullState = fullState[:util.Unique(shared.StateEntryByStateKeySorter(fullState))] // Add the full state for this StateSnapshotNID. combined = append(combined, fullState...) } @@ -303,13 +304,13 @@ func (v StateResolutionV1) loadStateAtSnapshotForNumericTuples( if err != nil { return nil, err } - stateEntriesMap := stateEntryListMap(stateEntryLists) + stateEntriesMap := shared.StateEntryListMap(stateEntryLists) // Combine all the state entries for this snapshot. // The order of state block NIDs in the list tells us the order to combine them in. var fullState []types.StateEntry for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs { - entries, ok := stateEntriesMap.lookup(stateBlockNID) + entries, ok := stateEntriesMap.Lookup(stateBlockNID) if !ok { // If the block is missing from the map it means that none of its entries matched a requested tuple. // This can happen if the block doesn't contain an update for one of the requested tuples. @@ -321,9 +322,9 @@ func (v StateResolutionV1) loadStateAtSnapshotForNumericTuples( // Stable sort so that the most recent entry for each state key stays // remains later in the list than the older entries for the same state key. - sort.Stable(stateEntryByStateKeySorter(fullState)) + sort.Stable(shared.StateEntryByStateKeySorter(fullState)) // Unique returns the last entry and hence the most recent entry for each state key. - fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))] + fullState = fullState[:util.Unique(shared.StateEntryByStateKeySorter(fullState))] return fullState, nil } @@ -393,12 +394,12 @@ func (v StateResolutionV1) loadStateAfterEventsForNumericTuples( } // Sort the full state so we can use it as a map. - sort.Sort(stateEntrySorter(fullState)) + sort.Sort(shared.StateEntrySorter(fullState)) // Filter the full state down to the required tuples. var result []types.StateEntry for _, tuple := range stateKeyTuples { - eventNID, ok := stateEntryMap(fullState).lookup(tuple) + eventNID, ok := shared.StateEntryMap(fullState).Lookup(tuple) if ok { result = append(result, types.StateEntry{ StateKeyTuple: tuple, @@ -406,7 +407,7 @@ func (v StateResolutionV1) loadStateAfterEventsForNumericTuples( }) } } - sort.Sort(stateEntrySorter(result)) + sort.Sort(shared.StateEntrySorter(result)) return result, nil } @@ -627,10 +628,10 @@ func (v StateResolutionV1) calculateStateAfterManyEvents( // We don't care about the order here because the conflict resolution // algorithm doesn't depend on the order of the prev events. // Remove duplicate entires. - combined = combined[:util.SortAndUnique(stateEntrySorter(combined))] + combined = combined[:util.SortAndUnique(shared.StateEntrySorter(combined))] // Find the conflicts - conflicts := findDuplicateStateKeys(combined) + conflicts := shared.FindDuplicateStateKeys(combined) if len(conflicts) > 0 { conflictLength = len(conflicts) @@ -641,7 +642,7 @@ func (v StateResolutionV1) calculateStateAfterManyEvents( // Work out which entries aren't conflicted. var notConflicted []types.StateEntry for _, entry := range combined { - if _, ok := stateEntryMap(conflicts).lookup(entry.StateKeyTuple); !ok { + if _, ok := shared.StateEntryMap(conflicts).Lookup(entry.StateKeyTuple); !ok { notConflicted = append(notConflicted, entry) } } @@ -696,7 +697,7 @@ func (v StateResolutionV1) resolveConflicts( tuplesNeeded := v.stateKeyTuplesNeeded(stateKeyNIDMap, needed) var authEntries []types.StateEntry for _, tuple := range tuplesNeeded { - if eventNID, ok := stateEntryMap(notConflicted).lookup(tuple); ok { + if eventNID, ok := shared.StateEntryMap(notConflicted).Lookup(tuple); ok { authEntries = append(authEntries, types.StateEntry{ StateKeyTuple: tuple, EventNID: eventNID, @@ -721,7 +722,7 @@ func (v StateResolutionV1) resolveConflicts( } // Sort the result so it can be searched. - sort.Sort(stateEntrySorter(notConflicted)) + sort.Sort(shared.StateEntrySorter(notConflicted)) return notConflicted, nil } @@ -785,7 +786,7 @@ func (v StateResolutionV1) loadStateEvents( eventIDMap := map[string]types.StateEntry{} result := make([]gomatrixserverlib.Event, len(entries)) for i := range entries { - event, ok := eventMap(events).lookup(entries[i].EventNID) + event, ok := shared.EventMap(events).Lookup(entries[i].EventNID) if !ok { panic(fmt.Errorf("Corrupt DB: Missing event numeric ID %d", entries[i].EventNID)) } @@ -794,134 +795,3 @@ func (v StateResolutionV1) loadStateEvents( } return result, eventIDMap, nil } - -// findDuplicateStateKeys finds the state entries where the state key tuple appears more than once in a sorted list. -// Returns a sorted list of those state entries. -func findDuplicateStateKeys(a []types.StateEntry) []types.StateEntry { - var result []types.StateEntry - // j is the starting index of a block of entries with the same state key tuple. - j := 0 - for i := 1; i < len(a); i++ { - // Check if the state key tuple matches the start of the block - if a[j].StateKeyTuple != a[i].StateKeyTuple { - // If the state key tuple is different then we've reached the end of a block of duplicates. - // Check if the size of the block is bigger than one. - // If the size is one then there was only a single entry with that state key tuple so we don't add it to the result - if j+1 != i { - // Add the block to the result. - result = append(result, a[j:i]...) - } - // Start a new block for the next state key tuple. - j = i - } - } - // Check if the last block with the same state key tuple had more than one event in it. - if j+1 != len(a) { - result = append(result, a[j:]...) - } - return result -} - -type stateEntrySorter []types.StateEntry - -func (s stateEntrySorter) Len() int { return len(s) } -func (s stateEntrySorter) Less(i, j int) bool { return s[i].LessThan(s[j]) } -func (s stateEntrySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } - -type stateBlockNIDListMap []types.StateBlockNIDList - -func (m stateBlockNIDListMap) lookup(stateNID types.StateSnapshotNID) (stateBlockNIDs []types.StateBlockNID, ok bool) { - list := []types.StateBlockNIDList(m) - i := sort.Search(len(list), func(i int) bool { - return list[i].StateSnapshotNID >= stateNID - }) - if i < len(list) && list[i].StateSnapshotNID == stateNID { - ok = true - stateBlockNIDs = list[i].StateBlockNIDs - } - return -} - -type stateEntryListMap []types.StateEntryList - -func (m stateEntryListMap) lookup(stateBlockNID types.StateBlockNID) (stateEntries []types.StateEntry, ok bool) { - list := []types.StateEntryList(m) - i := sort.Search(len(list), func(i int) bool { - return list[i].StateBlockNID >= stateBlockNID - }) - if i < len(list) && list[i].StateBlockNID == stateBlockNID { - ok = true - stateEntries = list[i].StateEntries - } - return -} - -type stateEntryByStateKeySorter []types.StateEntry - -func (s stateEntryByStateKeySorter) Len() int { return len(s) } -func (s stateEntryByStateKeySorter) Less(i, j int) bool { - return s[i].StateKeyTuple.LessThan(s[j].StateKeyTuple) -} -func (s stateEntryByStateKeySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } - -type stateNIDSorter []types.StateSnapshotNID - -func (s stateNIDSorter) Len() int { return len(s) } -func (s stateNIDSorter) Less(i, j int) bool { return s[i] < s[j] } -func (s stateNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } - -func uniqueStateSnapshotNIDs(nids []types.StateSnapshotNID) []types.StateSnapshotNID { - return nids[:util.SortAndUnique(stateNIDSorter(nids))] -} - -type stateBlockNIDSorter []types.StateBlockNID - -func (s stateBlockNIDSorter) Len() int { return len(s) } -func (s stateBlockNIDSorter) Less(i, j int) bool { return s[i] < s[j] } -func (s stateBlockNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } - -func uniqueStateBlockNIDs(nids []types.StateBlockNID) []types.StateBlockNID { - return nids[:util.SortAndUnique(stateBlockNIDSorter(nids))] -} - -// Map from event type, state key tuple to numeric event ID. -// Implemented using binary search on a sorted array. -type stateEntryMap []types.StateEntry - -// lookup an entry in the event map. -func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID types.EventNID, ok bool) { - // Since the list is sorted we can implement this using binary search. - // This is faster than using a hash map. - // We don't have to worry about pathological cases because the keys are fixed - // size and are controlled by us. - list := []types.StateEntry(m) - i := sort.Search(len(list), func(i int) bool { - return !list[i].StateKeyTuple.LessThan(stateKey) - }) - if i < len(list) && list[i].StateKeyTuple == stateKey { - ok = true - eventNID = list[i].EventNID - } - return -} - -// Map from numeric event ID to event. -// Implemented using binary search on a sorted array. -type eventMap []types.Event - -// lookup an entry in the event map. -func (m eventMap) lookup(eventNID types.EventNID) (event *types.Event, ok bool) { - // Since the list is sorted we can implement this using binary search. - // This is faster than using a hash map. - // We don't have to worry about pathological cases because the keys are fixed - // size are controlled by us. - list := []types.Event(m) - i := sort.Search(len(list), func(i int) bool { - return list[i].EventNID >= eventNID - }) - if i < len(list) && list[i].EventNID == eventNID { - ok = true - event = &list[i] - } - return -}