Break out some shared functions

This commit is contained in:
Neil Alexander 2020-02-14 14:54:09 +00:00
parent 3dabf4d4ed
commit 9be7134727
3 changed files with 173 additions and 188 deletions

View file

@ -0,0 +1,139 @@
package shared
import (
"sort"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/util"
)
// findDuplicateStateKeys finds the state entries where the state key tuple appears more than once in a sorted list.
// Returns a sorted list of those state entries.
func FindDuplicateStateKeys(a []types.StateEntry) []types.StateEntry {
var result []types.StateEntry
// j is the starting index of a block of entries with the same state key tuple.
j := 0
for i := 1; i < len(a); i++ {
// Check if the state key tuple matches the start of the block
if a[j].StateKeyTuple != a[i].StateKeyTuple {
// If the state key tuple is different then we've reached the end of a block of duplicates.
// Check if the size of the block is bigger than one.
// If the size is one then there was only a single entry with that state key tuple so we don't add it to the result
if j+1 != i {
// Add the block to the result.
result = append(result, a[j:i]...)
}
// Start a new block for the next state key tuple.
j = i
}
}
// Check if the last block with the same state key tuple had more than one event in it.
if j+1 != len(a) {
result = append(result, a[j:]...)
}
return result
}
type StateEntrySorter []types.StateEntry
func (s StateEntrySorter) Len() int { return len(s) }
func (s StateEntrySorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
func (s StateEntrySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
type StateBlockNIDListMap []types.StateBlockNIDList
func (m StateBlockNIDListMap) Lookup(stateNID types.StateSnapshotNID) (stateBlockNIDs []types.StateBlockNID, ok bool) {
list := []types.StateBlockNIDList(m)
i := sort.Search(len(list), func(i int) bool {
return list[i].StateSnapshotNID >= stateNID
})
if i < len(list) && list[i].StateSnapshotNID == stateNID {
ok = true
stateBlockNIDs = list[i].StateBlockNIDs
}
return
}
type StateEntryListMap []types.StateEntryList
func (m StateEntryListMap) Lookup(stateBlockNID types.StateBlockNID) (stateEntries []types.StateEntry, ok bool) {
list := []types.StateEntryList(m)
i := sort.Search(len(list), func(i int) bool {
return list[i].StateBlockNID >= stateBlockNID
})
if i < len(list) && list[i].StateBlockNID == stateBlockNID {
ok = true
stateEntries = list[i].StateEntries
}
return
}
type StateEntryByStateKeySorter []types.StateEntry
func (s StateEntryByStateKeySorter) Len() int { return len(s) }
func (s StateEntryByStateKeySorter) Less(i, j int) bool {
return s[i].StateKeyTuple.LessThan(s[j].StateKeyTuple)
}
func (s StateEntryByStateKeySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
type StateNIDSorter []types.StateSnapshotNID
func (s StateNIDSorter) Len() int { return len(s) }
func (s StateNIDSorter) Less(i, j int) bool { return s[i] < s[j] }
func (s StateNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func UniqueStateSnapshotNIDs(nids []types.StateSnapshotNID) []types.StateSnapshotNID {
return nids[:util.SortAndUnique(StateNIDSorter(nids))]
}
type StateBlockNIDSorter []types.StateBlockNID
func (s StateBlockNIDSorter) Len() int { return len(s) }
func (s StateBlockNIDSorter) Less(i, j int) bool { return s[i] < s[j] }
func (s StateBlockNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func UniqueStateBlockNIDs(nids []types.StateBlockNID) []types.StateBlockNID {
return nids[:util.SortAndUnique(StateBlockNIDSorter(nids))]
}
// Map from event type, state key tuple to numeric event ID.
// Implemented using binary search on a sorted array.
type StateEntryMap []types.StateEntry
// lookup an entry in the event map.
func (m StateEntryMap) Lookup(stateKey types.StateKeyTuple) (eventNID types.EventNID, ok bool) {
// Since the list is sorted we can implement this using binary search.
// This is faster than using a hash map.
// We don't have to worry about pathological cases because the keys are fixed
// size and are controlled by us.
list := []types.StateEntry(m)
i := sort.Search(len(list), func(i int) bool {
return !list[i].StateKeyTuple.LessThan(stateKey)
})
if i < len(list) && list[i].StateKeyTuple == stateKey {
ok = true
eventNID = list[i].EventNID
}
return
}
// Map from numeric event ID to event.
// Implemented using binary search on a sorted array.
type EventMap []types.Event
// lookup an entry in the event map.
func (m EventMap) Lookup(eventNID types.EventNID) (event *types.Event, ok bool) {
// Since the list is sorted we can implement this using binary search.
// This is faster than using a hash map.
// We don't have to worry about pathological cases because the keys are fixed
// size are controlled by us.
list := []types.Event(m)
i := sort.Search(len(list), func(i int) bool {
return list[i].EventNID >= eventNID
})
if i < len(list) && list[i].EventNID == eventNID {
ok = true
event = &list[i]
}
return
}

View file

@ -46,36 +46,12 @@ func GetStateResolutionAlgorithm(
}
type StateResolutionImpl interface {
LoadStateAtSnapshot(
ctx context.Context, stateNID types.StateSnapshotNID,
) ([]types.StateEntry, error)
LoadStateAtEvent(
ctx context.Context, eventID string,
) ([]types.StateEntry, error)
LoadCombinedStateAfterEvents(
ctx context.Context, prevStates []types.StateAtEvent,
) ([]types.StateEntry, error)
DifferenceBetweeenStateSnapshots(
ctx context.Context, oldStateNID, newStateNID types.StateSnapshotNID,
) (removed, added []types.StateEntry, err error)
LoadStateAtSnapshotForStringTuples(
ctx context.Context,
stateNID types.StateSnapshotNID,
stateKeyTuples []gomatrixserverlib.StateKeyTuple,
) ([]types.StateEntry, error)
LoadStateAfterEventsForStringTuples(
ctx context.Context,
prevStates []types.StateAtEvent,
stateKeyTuples []gomatrixserverlib.StateKeyTuple,
) ([]types.StateEntry, error)
CalculateAndStoreStateBeforeEvent(
ctx context.Context,
event gomatrixserverlib.Event,
roomNID types.RoomNID,
) (types.StateSnapshotNID, error)
CalculateAndStoreStateAfterEvents(
ctx context.Context,
roomNID types.RoomNID,
prevStates []types.StateAtEvent,
) (types.StateSnapshotNID, error)
LoadStateAtSnapshot(ctx context.Context, stateNID types.StateSnapshotNID) ([]types.StateEntry, error)
LoadStateAtEvent(ctx context.Context, eventID string) ([]types.StateEntry, error)
LoadCombinedStateAfterEvents(ctx context.Context, prevStates []types.StateAtEvent) ([]types.StateEntry, error)
DifferenceBetweeenStateSnapshots(ctx context.Context, oldStateNID, newStateNID types.StateSnapshotNID) (removed, added []types.StateEntry, err error)
LoadStateAtSnapshotForStringTuples(ctx context.Context, stateNID types.StateSnapshotNID, stateKeyTuples []gomatrixserverlib.StateKeyTuple) ([]types.StateEntry, error)
LoadStateAfterEventsForStringTuples(ctx context.Context, prevStates []types.StateAtEvent, stateKeyTuples []gomatrixserverlib.StateKeyTuple) ([]types.StateEntry, error)
CalculateAndStoreStateBeforeEvent(ctx context.Context, event gomatrixserverlib.Event, roomNID types.RoomNID) (types.StateSnapshotNID, error)
CalculateAndStoreStateAfterEvents(ctx context.Context, roomNID types.RoomNID, prevStates []types.StateAtEvent) (types.StateSnapshotNID, error)
}

View file

@ -23,6 +23,7 @@ import (
"time"
"github.com/matrix-org/dendrite/roomserver/state/database"
"github.com/matrix-org/dendrite/roomserver/state/shared"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
@ -56,13 +57,13 @@ func (v StateResolutionV1) LoadStateAtSnapshot(
if err != nil {
return nil, err
}
stateEntriesMap := stateEntryListMap(stateEntryLists)
stateEntriesMap := shared.StateEntryListMap(stateEntryLists)
// Combine all the state entries for this snapshot.
// The order of state block NIDs in the list tells us the order to combine them in.
var fullState []types.StateEntry
for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs {
entries, ok := stateEntriesMap.lookup(stateBlockNID)
entries, ok := stateEntriesMap.Lookup(stateBlockNID)
if !ok {
// This should only get hit if the database is corrupt.
// It should be impossible for an event to reference a NID that doesn't exist
@ -73,9 +74,9 @@ func (v StateResolutionV1) LoadStateAtSnapshot(
// Stable sort so that the most recent entry for each state key stays
// remains later in the list than the older entries for the same state key.
sort.Stable(stateEntryByStateKeySorter(fullState))
sort.Stable(shared.StateEntryByStateKeySorter(fullState))
// Unique returns the last entry and hence the most recent entry for each state key.
fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))]
fullState = fullState[:util.Unique(shared.StateEntryByStateKeySorter(fullState))]
return fullState, nil
}
@ -109,7 +110,7 @@ func (v StateResolutionV1) LoadCombinedStateAfterEvents(
// Deduplicate the IDs before passing them to the database.
// There could be duplicates because the events could be state events where
// the snapshot of the room state before them was the same.
stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, uniqueStateSnapshotNIDs(stateNIDs))
stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, shared.UniqueStateSnapshotNIDs(stateNIDs))
if err != nil {
return nil, err
}
@ -122,18 +123,18 @@ func (v StateResolutionV1) LoadCombinedStateAfterEvents(
// Deduplicate the IDs before passing them to the database.
// There could be duplicates because a block of state entries could be reused by
// multiple snapshots.
stateEntryLists, err := v.db.StateEntries(ctx, uniqueStateBlockNIDs(stateBlockNIDs))
stateEntryLists, err := v.db.StateEntries(ctx, shared.UniqueStateBlockNIDs(stateBlockNIDs))
if err != nil {
return nil, err
}
stateBlockNIDsMap := stateBlockNIDListMap(stateBlockNIDLists)
stateEntriesMap := stateEntryListMap(stateEntryLists)
stateBlockNIDsMap := shared.StateBlockNIDListMap(stateBlockNIDLists)
stateEntriesMap := shared.StateEntryListMap(stateEntryLists)
// Combine the entries from all the snapshots of state after each prev event into a single list.
var combined []types.StateEntry
for _, prevState := range prevStates {
// Grab the list of state data NIDs for this snapshot.
stateBlockNIDs, ok := stateBlockNIDsMap.lookup(prevState.BeforeStateSnapshotNID)
stateBlockNIDs, ok := stateBlockNIDsMap.Lookup(prevState.BeforeStateSnapshotNID)
if !ok {
// This should only get hit if the database is corrupt.
// It should be impossible for an event to reference a NID that doesn't exist
@ -144,7 +145,7 @@ func (v StateResolutionV1) LoadCombinedStateAfterEvents(
// The order of state block NIDs in the list tells us the order to combine them in.
var fullState []types.StateEntry
for _, stateBlockNID := range stateBlockNIDs {
entries, ok := stateEntriesMap.lookup(stateBlockNID)
entries, ok := stateEntriesMap.Lookup(stateBlockNID)
if !ok {
// This should only get hit if the database is corrupt.
// It should be impossible for an event to reference a NID that doesn't exist
@ -160,9 +161,9 @@ func (v StateResolutionV1) LoadCombinedStateAfterEvents(
// Stable sort so that the most recent entry for each state key stays
// remains later in the list than the older entries for the same state key.
sort.Stable(stateEntryByStateKeySorter(fullState))
sort.Stable(shared.StateEntryByStateKeySorter(fullState))
// Unique returns the last entry and hence the most recent entry for each state key.
fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))]
fullState = fullState[:util.Unique(shared.StateEntryByStateKeySorter(fullState))]
// Add the full state for this StateSnapshotNID.
combined = append(combined, fullState...)
}
@ -303,13 +304,13 @@ func (v StateResolutionV1) loadStateAtSnapshotForNumericTuples(
if err != nil {
return nil, err
}
stateEntriesMap := stateEntryListMap(stateEntryLists)
stateEntriesMap := shared.StateEntryListMap(stateEntryLists)
// Combine all the state entries for this snapshot.
// The order of state block NIDs in the list tells us the order to combine them in.
var fullState []types.StateEntry
for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs {
entries, ok := stateEntriesMap.lookup(stateBlockNID)
entries, ok := stateEntriesMap.Lookup(stateBlockNID)
if !ok {
// If the block is missing from the map it means that none of its entries matched a requested tuple.
// This can happen if the block doesn't contain an update for one of the requested tuples.
@ -321,9 +322,9 @@ func (v StateResolutionV1) loadStateAtSnapshotForNumericTuples(
// Stable sort so that the most recent entry for each state key stays
// remains later in the list than the older entries for the same state key.
sort.Stable(stateEntryByStateKeySorter(fullState))
sort.Stable(shared.StateEntryByStateKeySorter(fullState))
// Unique returns the last entry and hence the most recent entry for each state key.
fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))]
fullState = fullState[:util.Unique(shared.StateEntryByStateKeySorter(fullState))]
return fullState, nil
}
@ -393,12 +394,12 @@ func (v StateResolutionV1) loadStateAfterEventsForNumericTuples(
}
// Sort the full state so we can use it as a map.
sort.Sort(stateEntrySorter(fullState))
sort.Sort(shared.StateEntrySorter(fullState))
// Filter the full state down to the required tuples.
var result []types.StateEntry
for _, tuple := range stateKeyTuples {
eventNID, ok := stateEntryMap(fullState).lookup(tuple)
eventNID, ok := shared.StateEntryMap(fullState).Lookup(tuple)
if ok {
result = append(result, types.StateEntry{
StateKeyTuple: tuple,
@ -406,7 +407,7 @@ func (v StateResolutionV1) loadStateAfterEventsForNumericTuples(
})
}
}
sort.Sort(stateEntrySorter(result))
sort.Sort(shared.StateEntrySorter(result))
return result, nil
}
@ -627,10 +628,10 @@ func (v StateResolutionV1) calculateStateAfterManyEvents(
// We don't care about the order here because the conflict resolution
// algorithm doesn't depend on the order of the prev events.
// Remove duplicate entires.
combined = combined[:util.SortAndUnique(stateEntrySorter(combined))]
combined = combined[:util.SortAndUnique(shared.StateEntrySorter(combined))]
// Find the conflicts
conflicts := findDuplicateStateKeys(combined)
conflicts := shared.FindDuplicateStateKeys(combined)
if len(conflicts) > 0 {
conflictLength = len(conflicts)
@ -641,7 +642,7 @@ func (v StateResolutionV1) calculateStateAfterManyEvents(
// Work out which entries aren't conflicted.
var notConflicted []types.StateEntry
for _, entry := range combined {
if _, ok := stateEntryMap(conflicts).lookup(entry.StateKeyTuple); !ok {
if _, ok := shared.StateEntryMap(conflicts).Lookup(entry.StateKeyTuple); !ok {
notConflicted = append(notConflicted, entry)
}
}
@ -696,7 +697,7 @@ func (v StateResolutionV1) resolveConflicts(
tuplesNeeded := v.stateKeyTuplesNeeded(stateKeyNIDMap, needed)
var authEntries []types.StateEntry
for _, tuple := range tuplesNeeded {
if eventNID, ok := stateEntryMap(notConflicted).lookup(tuple); ok {
if eventNID, ok := shared.StateEntryMap(notConflicted).Lookup(tuple); ok {
authEntries = append(authEntries, types.StateEntry{
StateKeyTuple: tuple,
EventNID: eventNID,
@ -721,7 +722,7 @@ func (v StateResolutionV1) resolveConflicts(
}
// Sort the result so it can be searched.
sort.Sort(stateEntrySorter(notConflicted))
sort.Sort(shared.StateEntrySorter(notConflicted))
return notConflicted, nil
}
@ -785,7 +786,7 @@ func (v StateResolutionV1) loadStateEvents(
eventIDMap := map[string]types.StateEntry{}
result := make([]gomatrixserverlib.Event, len(entries))
for i := range entries {
event, ok := eventMap(events).lookup(entries[i].EventNID)
event, ok := shared.EventMap(events).Lookup(entries[i].EventNID)
if !ok {
panic(fmt.Errorf("Corrupt DB: Missing event numeric ID %d", entries[i].EventNID))
}
@ -794,134 +795,3 @@ func (v StateResolutionV1) loadStateEvents(
}
return result, eventIDMap, nil
}
// findDuplicateStateKeys finds the state entries where the state key tuple appears more than once in a sorted list.
// Returns a sorted list of those state entries.
func findDuplicateStateKeys(a []types.StateEntry) []types.StateEntry {
var result []types.StateEntry
// j is the starting index of a block of entries with the same state key tuple.
j := 0
for i := 1; i < len(a); i++ {
// Check if the state key tuple matches the start of the block
if a[j].StateKeyTuple != a[i].StateKeyTuple {
// If the state key tuple is different then we've reached the end of a block of duplicates.
// Check if the size of the block is bigger than one.
// If the size is one then there was only a single entry with that state key tuple so we don't add it to the result
if j+1 != i {
// Add the block to the result.
result = append(result, a[j:i]...)
}
// Start a new block for the next state key tuple.
j = i
}
}
// Check if the last block with the same state key tuple had more than one event in it.
if j+1 != len(a) {
result = append(result, a[j:]...)
}
return result
}
type stateEntrySorter []types.StateEntry
func (s stateEntrySorter) Len() int { return len(s) }
func (s stateEntrySorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
func (s stateEntrySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
type stateBlockNIDListMap []types.StateBlockNIDList
func (m stateBlockNIDListMap) lookup(stateNID types.StateSnapshotNID) (stateBlockNIDs []types.StateBlockNID, ok bool) {
list := []types.StateBlockNIDList(m)
i := sort.Search(len(list), func(i int) bool {
return list[i].StateSnapshotNID >= stateNID
})
if i < len(list) && list[i].StateSnapshotNID == stateNID {
ok = true
stateBlockNIDs = list[i].StateBlockNIDs
}
return
}
type stateEntryListMap []types.StateEntryList
func (m stateEntryListMap) lookup(stateBlockNID types.StateBlockNID) (stateEntries []types.StateEntry, ok bool) {
list := []types.StateEntryList(m)
i := sort.Search(len(list), func(i int) bool {
return list[i].StateBlockNID >= stateBlockNID
})
if i < len(list) && list[i].StateBlockNID == stateBlockNID {
ok = true
stateEntries = list[i].StateEntries
}
return
}
type stateEntryByStateKeySorter []types.StateEntry
func (s stateEntryByStateKeySorter) Len() int { return len(s) }
func (s stateEntryByStateKeySorter) Less(i, j int) bool {
return s[i].StateKeyTuple.LessThan(s[j].StateKeyTuple)
}
func (s stateEntryByStateKeySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
type stateNIDSorter []types.StateSnapshotNID
func (s stateNIDSorter) Len() int { return len(s) }
func (s stateNIDSorter) Less(i, j int) bool { return s[i] < s[j] }
func (s stateNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func uniqueStateSnapshotNIDs(nids []types.StateSnapshotNID) []types.StateSnapshotNID {
return nids[:util.SortAndUnique(stateNIDSorter(nids))]
}
type stateBlockNIDSorter []types.StateBlockNID
func (s stateBlockNIDSorter) Len() int { return len(s) }
func (s stateBlockNIDSorter) Less(i, j int) bool { return s[i] < s[j] }
func (s stateBlockNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func uniqueStateBlockNIDs(nids []types.StateBlockNID) []types.StateBlockNID {
return nids[:util.SortAndUnique(stateBlockNIDSorter(nids))]
}
// Map from event type, state key tuple to numeric event ID.
// Implemented using binary search on a sorted array.
type stateEntryMap []types.StateEntry
// lookup an entry in the event map.
func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID types.EventNID, ok bool) {
// Since the list is sorted we can implement this using binary search.
// This is faster than using a hash map.
// We don't have to worry about pathological cases because the keys are fixed
// size and are controlled by us.
list := []types.StateEntry(m)
i := sort.Search(len(list), func(i int) bool {
return !list[i].StateKeyTuple.LessThan(stateKey)
})
if i < len(list) && list[i].StateKeyTuple == stateKey {
ok = true
eventNID = list[i].EventNID
}
return
}
// Map from numeric event ID to event.
// Implemented using binary search on a sorted array.
type eventMap []types.Event
// lookup an entry in the event map.
func (m eventMap) lookup(eventNID types.EventNID) (event *types.Event, ok bool) {
// Since the list is sorted we can implement this using binary search.
// This is faster than using a hash map.
// We don't have to worry about pathological cases because the keys are fixed
// size are controlled by us.
list := []types.Event(m)
i := sort.Search(len(list), func(i int) bool {
return list[i].EventNID >= eventNID
})
if i < len(list) && list[i].EventNID == eventNID {
ok = true
event = &list[i]
}
return
}