From 84682b33c968f469d201bbe08b90b4ff6e120dd4 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 8 Mar 2017 13:27:21 +0000 Subject: [PATCH 1/2] Use Unique from github.com/matrix-org/util (#28) * Update github.com/matrix-org/util * Use Unique from github.com/matrix-org/util --- .../dendrite/roomserver/input/state.go | 45 ++------- .../dendrite/roomserver/input/state_test.go | 26 ----- vendor/manifest | 2 +- vendor/src/github.com/matrix-org/util/json.go | 6 +- .../github.com/matrix-org/util/json_test.go | 22 +++++ .../src/github.com/matrix-org/util/unique.go | 57 +++++++++++ .../github.com/matrix-org/util/unique_test.go | 96 +++++++++++++++++++ 7 files changed, 185 insertions(+), 69 deletions(-) create mode 100644 vendor/src/github.com/matrix-org/util/unique.go create mode 100644 vendor/src/github.com/matrix-org/util/unique_test.go diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/state.go b/src/github.com/matrix-org/dendrite/roomserver/input/state.go index 36ab43b1c..a7e24701e 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/input/state.go +++ b/src/github.com/matrix-org/dendrite/roomserver/input/state.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" "sort" ) @@ -88,9 +89,8 @@ func calculateAndStoreStateAfterManyEvents(db RoomEventDatabase, roomNID types.R // Collect all the entries with the same type and key together. // We don't care about the order here because the conflict resolution // algorithm doesn't depend on the order of the prev events. - sort.Sort(stateEntrySorter(combined)) // Remove duplicate entires. - combined = combined[:unique(stateEntrySorter(combined))] + combined = combined[:util.SortAndUnique(stateEntrySorter(combined))] // Find the conflicts conflicts := findDuplicateStateKeys(combined) @@ -202,7 +202,7 @@ func loadStateAtSnapshot(db RoomEventDatabase, stateNID types.StateSnapshotNID) // remains later in the list than the older entries for the same state key. sort.Stable(stateEntryByStateKeySorter(fullState)) // Unique returns the last entry and hence the most recent entry for each state key. - fullState = fullState[:unique(stateEntryByStateKeySorter(fullState))] + fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))] return fullState, nil } @@ -270,7 +270,7 @@ func loadCombinedStateAfterEvents(db RoomEventDatabase, prevStates []types.State // remains later in the list than the older entries for the same state key. sort.Stable(stateEntryByStateKeySorter(fullState)) // Unique returns the last entry and hence the most recent entry for each state key. - fullState = fullState[:unique(stateEntryByStateKeySorter(fullState))] + fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))] // Add the full state for this StateSnapshotNID. combined = append(combined, fullState...) } @@ -357,8 +357,7 @@ func (s stateNIDSorter) Less(i, j int) bool { return s[i] < s[j] } func (s stateNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } func uniqueStateSnapshotNIDs(nids []types.StateSnapshotNID) []types.StateSnapshotNID { - sort.Sort(stateNIDSorter(nids)) - return nids[:unique(stateNIDSorter(nids))] + return nids[:util.SortAndUnique(stateNIDSorter(nids))] } type stateBlockNIDSorter []types.StateBlockNID @@ -368,37 +367,5 @@ func (s stateBlockNIDSorter) Less(i, j int) bool { return s[i] < s[j] } func (s stateBlockNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } func uniqueStateBlockNIDs(nids []types.StateBlockNID) []types.StateBlockNID { - sort.Sort(stateBlockNIDSorter(nids)) - return nids[:unique(stateBlockNIDSorter(nids))] -} - -// Remove duplicate items from a sorted list. -// Takes the same interface as sort.Sort -// Returns the length of the data without duplicates -// Uses the last occurance of a duplicate. -// O(n). -func unique(data sort.Interface) int { - if data.Len() == 0 { - return 0 - } - length := data.Len() - // j is the next index to output an element to. - j := 0 - for i := 1; i < length; i++ { - // If the previous element is less than this element then they are - // not equal. Otherwise they must be equal because the list is sorted. - // If they are equal then we move onto the next element. - if data.Less(i-1, i) { - // "Write" the previous element to the output position by swaping - // the elements. - // Note that if the list has no duplicates then i-1 == j so the - // swap does nothing. (This assumes that data.Swap(a,b) nops if a==b) - data.Swap(i-1, j) - // Advance to the next output position in the list. - j++ - } - } - // Output the last element. - data.Swap(length-1, j) - return j + 1 + return nids[:util.SortAndUnique(stateBlockNIDSorter(nids))] } diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/state_test.go b/src/github.com/matrix-org/dendrite/roomserver/input/state_test.go index e5707ff1a..c89576881 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/input/state_test.go +++ b/src/github.com/matrix-org/dendrite/roomserver/input/state_test.go @@ -5,32 +5,6 @@ import ( "testing" ) -type sortBytes []byte - -func (s sortBytes) Len() int { return len(s) } -func (s sortBytes) Less(i, j int) bool { return s[i] < s[j] } -func (s sortBytes) Swap(i, j int) { s[i], s[j] = s[j], s[i] } - -func TestUnique(t *testing.T) { - testCases := []struct { - Input string - Want string - }{ - {"", ""}, - {"abc", "abc"}, - {"aaabbbccc", "abc"}, - } - - for _, test := range testCases { - input := []byte(test.Input) - want := string(test.Want) - got := string(input[:unique(sortBytes(input))]) - if got != want { - t.Fatal("Wanted ", want, " got ", got) - } - } -} - func TestFindDuplicateStateKeys(t *testing.T) { testCases := []struct { Input []types.StateEntry diff --git a/vendor/manifest b/vendor/manifest index 35065c88a..26f8a5571 100644 --- a/vendor/manifest +++ b/vendor/manifest @@ -98,7 +98,7 @@ { "importpath": "github.com/matrix-org/util", "repository": "https://github.com/matrix-org/util", - "revision": "28bd7491c8aafbf346ca23821664f0f9911ef52b", + "revision": "ec8896cd7d9ba6de6143c5f123d1e45413657e7d", "branch": "master" }, { diff --git a/vendor/src/github.com/matrix-org/util/json.go b/vendor/src/github.com/matrix-org/util/json.go index 46c5396f5..3323b526b 100644 --- a/vendor/src/github.com/matrix-org/util/json.go +++ b/vendor/src/github.com/matrix-org/util/json.go @@ -80,7 +80,7 @@ func Protect(handler http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, req *http.Request) { defer func() { if r := recover(); r != nil { - logger := req.Context().Value(ctxValueLogger).(*log.Entry) + logger := GetLogger(req.Context()) logger.WithFields(log.Fields{ "panic": r, }).Errorf( @@ -108,7 +108,7 @@ func MakeJSONAPI(handler JSONRequestHandler) http.HandlerFunc { ctx = context.WithValue(ctx, ctxValueRequestID, reqID) req = req.WithContext(ctx) - logger := req.Context().Value(ctxValueLogger).(*log.Entry) + logger := GetLogger(req.Context()) logger.Print("Incoming request") res := handler.OnIncomingRequest(req) @@ -122,7 +122,7 @@ func MakeJSONAPI(handler JSONRequestHandler) http.HandlerFunc { } func respond(w http.ResponseWriter, req *http.Request, res JSONResponse) { - logger := req.Context().Value(ctxValueLogger).(*log.Entry) + logger := GetLogger(req.Context()) // Set custom headers if res.Headers != nil { diff --git a/vendor/src/github.com/matrix-org/util/json_test.go b/vendor/src/github.com/matrix-org/util/json_test.go index 3ce03a883..aeb5a9e55 100644 --- a/vendor/src/github.com/matrix-org/util/json_test.go +++ b/vendor/src/github.com/matrix-org/util/json_test.go @@ -194,6 +194,28 @@ func TestProtect(t *testing.T) { } } +func TestProtectWithoutLogger(t *testing.T) { + log.SetLevel(log.PanicLevel) // suppress logs in test output + mockWriter := httptest.NewRecorder() + mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) + h := Protect(func(w http.ResponseWriter, req *http.Request) { + panic("oh noes!") + }) + + h(mockWriter, mockReq) + + expectCode := 500 + if mockWriter.Code != expectCode { + t.Errorf("TestProtect wanted HTTP status %d, got %d", expectCode, mockWriter.Code) + } + + expectBody := `{"message":"Internal Server Error"}` + actualBody := mockWriter.Body.String() + if actualBody != expectBody { + t.Errorf("TestProtect wanted body %s, got %s", expectBody, actualBody) + } +} + func TestWithCORSOptions(t *testing.T) { log.SetLevel(log.PanicLevel) // suppress logs in test output mockWriter := httptest.NewRecorder() diff --git a/vendor/src/github.com/matrix-org/util/unique.go b/vendor/src/github.com/matrix-org/util/unique.go new file mode 100644 index 000000000..401c55609 --- /dev/null +++ b/vendor/src/github.com/matrix-org/util/unique.go @@ -0,0 +1,57 @@ +package util + +import ( + "fmt" + "sort" +) + +// Unique removes duplicate items from a sorted list in place. +// Takes the same interface as sort.Sort +// Returns the length of the data without duplicates +// Uses the last occurrence of a duplicate. +// O(n). +func Unique(data sort.Interface) int { + if !sort.IsSorted(data) { + panic(fmt.Errorf("util: the input to Unique() must be sorted")) + } + + if data.Len() == 0 { + return 0 + } + length := data.Len() + // j is the next index to output an element to. + j := 0 + for i := 1; i < length; i++ { + // If the previous element is less than this element then they are + // not equal. Otherwise they must be equal because the list is sorted. + // If they are equal then we move onto the next element. + if data.Less(i-1, i) { + // "Write" the previous element to the output position by swapping + // the elements. + // Note that if the list has no duplicates then i-1 == j so the + // swap does nothing. (This assumes that data.Swap(a,b) nops if a==b) + data.Swap(i-1, j) + // Advance to the next output position in the list. + j++ + } + } + // Output the last element. + data.Swap(length-1, j) + return j + 1 +} + +// SortAndUnique sorts a list and removes duplicate entries in place. +// Takes the same interface as sort.Sort +// Returns the length of the data without duplicates +// Uses the last occurrence of a duplicate. +// O(nlog(n)) +func SortAndUnique(data sort.Interface) int { + sort.Sort(data) + return Unique(data) +} + +// UniqueStrings turns a list of strings into a sorted list of unique strings. +// O(nlog(n)) +func UniqueStrings(strings []string) []string { + return strings[:SortAndUnique(sort.StringSlice(strings))] +} diff --git a/vendor/src/github.com/matrix-org/util/unique_test.go b/vendor/src/github.com/matrix-org/util/unique_test.go new file mode 100644 index 000000000..721624d37 --- /dev/null +++ b/vendor/src/github.com/matrix-org/util/unique_test.go @@ -0,0 +1,96 @@ +package util + +import ( + "sort" + "testing" +) + +type sortBytes []byte + +func (s sortBytes) Len() int { return len(s) } +func (s sortBytes) Less(i, j int) bool { return s[i] < s[j] } +func (s sortBytes) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +func TestUnique(t *testing.T) { + testCases := []struct { + Input string + Want string + }{ + {"", ""}, + {"abc", "abc"}, + {"aaabbbccc", "abc"}, + } + + for _, test := range testCases { + input := []byte(test.Input) + want := string(test.Want) + got := string(input[:Unique(sortBytes(input))]) + if got != want { + t.Fatal("Wanted ", want, " got ", got) + } + } +} + +type sortByFirstByte []string + +func (s sortByFirstByte) Len() int { return len(s) } +func (s sortByFirstByte) Less(i, j int) bool { return s[i][0] < s[j][0] } +func (s sortByFirstByte) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +func TestUniquePicksLastDuplicate(t *testing.T) { + input := []string{ + "aardvark", + "avacado", + "cat", + "cucumber", + } + want := []string{ + "avacado", + "cucumber", + } + got := input[:Unique(sortByFirstByte(input))] + + if len(want) != len(got) { + t.Errorf("Wanted %#v got %#v", want, got) + } + for i := range want { + if want[i] != got[i] { + t.Errorf("Wanted %#v got %#v", want, got) + } + } +} + +func TestUniquePanicsIfNotSorted(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("Expected Unique() to panic on unsorted input but it didn't") + } + }() + Unique(sort.StringSlice{"out", "of", "order"}) +} + +func TestUniqueStrings(t *testing.T) { + input := []string{ + "badger", "badger", "badger", "badger", + "badger", "badger", "badger", "badger", + "badger", "badger", "badger", "badger", + "mushroom", "mushroom", + "badger", "badger", "badger", "badger", + "badger", "badger", "badger", "badger", + "badger", "badger", "badger", "badger", + "snake", "snake", + } + + want := []string{"badger", "mushroom", "snake"} + + got := UniqueStrings(input) + + if len(want) != len(got) { + t.Errorf("Wanted %#v got %#v", want, got) + } + for i := range want { + if want[i] != got[i] { + t.Errorf("Wanted %#v got %#v", want, got) + } + } +} From 1d18da11894dd98b41817d3ed289f32b6a8cb522 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 8 Mar 2017 15:10:26 +0000 Subject: [PATCH 2/2] Move the functions for reading room state to a separate package. (#29) This should: 1) Make the input package a bit cleaner. 2) Allow use to reuse the state reading code from the query package. --- .../dendrite/roomserver/input/events.go | 8 +- .../roomserver/input/latest_events.go | 3 +- .../dendrite/roomserver/input/state.go | 223 +--------------- .../dendrite/roomserver/state/state.go | 239 ++++++++++++++++++ 4 files changed, 245 insertions(+), 228 deletions(-) create mode 100644 src/github.com/matrix-org/dendrite/roomserver/state/state.go diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/events.go b/src/github.com/matrix-org/dendrite/roomserver/input/events.go index adc25661d..600ec6ae9 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/input/events.go +++ b/src/github.com/matrix-org/dendrite/roomserver/input/events.go @@ -2,12 +2,14 @@ package input import ( "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) // A RoomEventDatabase has the storage APIs needed to store a room event. type RoomEventDatabase interface { + state.RoomStateDatabase // Stores a matrix room event in the database StoreEvent(event gomatrixserverlib.Event, authEventNIDs []types.EventNID) (types.RoomNID, types.StateAtEvent, error) // Lookup the state entries for a list of string event IDs @@ -24,12 +26,6 @@ type RoomEventDatabase interface { // Returns an error if there is an error talking to the database // or if the room state for the event IDs aren't in the database StateAtEventIDs(eventIDs []string) ([]types.StateAtEvent, error) - // Lookup the numeric state data IDs for each numeric state snapshot ID - // The returned slice is sorted by numeric state snapshot ID. - StateBlockNIDs(stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) - // Lookup the state data for each numeric state data ID - // The returned slice is sorted by numeric state data ID. - StateEntries(stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) // Store the room state at an event in the database AddState(roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) // Set the state at an event. diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/latest_events.go b/src/github.com/matrix-org/dendrite/roomserver/input/latest_events.go index feaeccdb0..19479d44d 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/input/latest_events.go +++ b/src/github.com/matrix-org/dendrite/roomserver/input/latest_events.go @@ -3,6 +3,7 @@ package input import ( "bytes" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -90,7 +91,7 @@ func doUpdateLatestEvents( return err } - removed, added, err := differenceBetweeenStateSnapshots(db, oldStateNID, newStateNID) + removed, added, err := state.DifferenceBetweeenStateSnapshots(db, oldStateNID, newStateNID) if err != nil { return err } diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/state.go b/src/github.com/matrix-org/dendrite/roomserver/input/state.go index a7e24701e..590bc2677 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/input/state.go +++ b/src/github.com/matrix-org/dendrite/roomserver/input/state.go @@ -2,10 +2,10 @@ package input import ( "fmt" + "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" - "sort" ) // calculateAndStoreState calculates a snapshot of the state of a room before an event. @@ -81,7 +81,7 @@ const maxStateBlockNIDs = 64 func calculateAndStoreStateAfterManyEvents(db RoomEventDatabase, roomNID types.RoomNID, prevStates []types.StateAtEvent) (types.StateSnapshotNID, error) { // Conflict resolution. // First stage: load the state after each of the prev events. - combined, err := loadCombinedStateAfterEvents(db, prevStates) + combined, err := state.LoadCombinedStateAfterEvents(db, prevStates) if err != nil { return 0, err } @@ -114,169 +114,6 @@ func calculateAndStoreStateAfterManyEvents(db RoomEventDatabase, roomNID types.R return db.AddState(roomNID, nil, state) } -// differenceBetweeenStateSnapshots works out which state entries have been added and removed between two snapshots. -func differenceBetweeenStateSnapshots(db RoomEventDatabase, oldStateNID, newStateNID types.StateSnapshotNID) ( - removed, added []types.StateEntry, err error, -) { - if oldStateNID == newStateNID { - // If the snapshot NIDs are the same then nothing has changed - return nil, nil, nil - } - - var oldEntries []types.StateEntry - var newEntries []types.StateEntry - if oldStateNID != 0 { - oldEntries, err = loadStateAtSnapshot(db, oldStateNID) - if err != nil { - return nil, nil, err - } - } - if newStateNID != 0 { - newEntries, err = loadStateAtSnapshot(db, newStateNID) - if err != nil { - return nil, nil, err - } - } - - var oldI int - var newI int - for { - switch { - case oldI == len(oldEntries): - // We've reached the end of the old entries. - // The rest of the new list must have been newly added. - added = append(added, newEntries[newI:]...) - return - case newI == len(newEntries): - // We've reached the end of the new entries. - // The rest of the old list must be have been removed. - removed = append(removed, oldEntries[oldI:]...) - return - case oldEntries[oldI] == newEntries[newI]: - // The entry is in both lists so skip over it. - oldI++ - newI++ - case oldEntries[oldI].LessThan(newEntries[newI]): - // The lists are sorted so the old entry being less than the new entry means that it only appears in the old list. - removed = append(removed, oldEntries[oldI]) - oldI++ - default: - // Reaching the default case implies that the new entry is less than the old entry. - // Since the lists are sorted this means that it only appears in the new list. - added = append(added, newEntries[newI]) - newI++ - } - } -} - -// loadStateAtSnapshot loads the full state of a room at a particular snapshot. -// This is typically the state before an event or the current state of a room. -// Returns a sorted list of state entries or an error if there was a problem talking to the database. -func loadStateAtSnapshot(db RoomEventDatabase, stateNID types.StateSnapshotNID) ([]types.StateEntry, error) { - stateBlockNIDLists, err := db.StateBlockNIDs([]types.StateSnapshotNID{stateNID}) - if err != nil { - return nil, err - } - stateBlockNIDList := stateBlockNIDLists[0] - - stateEntryLists, err := db.StateEntries(stateBlockNIDList.StateBlockNIDs) - if err != nil { - return nil, err - } - stateEntriesMap := stateEntryListMap(stateEntryLists) - - // Combined all the state entries for this snapshot. - // The order of state block NIDs in the list tells us the order to combine them in. - var fullState []types.StateEntry - for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs { - entries, ok := stateEntriesMap.lookup(stateBlockNID) - if !ok { - // This should only get hit if the database is corrupt. - // It should be impossible for an event to reference a NID that doesn't exist - panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID)) - } - fullState = append(fullState, entries...) - } - - // Stable sort so that the most recent entry for each state key stays - // remains later in the list than the older entries for the same state key. - sort.Stable(stateEntryByStateKeySorter(fullState)) - // Unique returns the last entry and hence the most recent entry for each state key. - fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))] - return fullState, nil -} - -// loadCombinedStateAfterEvents loads a snapshot of the state after each of the events -// and combines those snapshots together into a single list. -func loadCombinedStateAfterEvents(db RoomEventDatabase, prevStates []types.StateAtEvent) ([]types.StateEntry, error) { - stateNIDs := make([]types.StateSnapshotNID, len(prevStates)) - for i, state := range prevStates { - stateNIDs[i] = state.BeforeStateSnapshotNID - } - // Fetch the state snapshots for the state before the each prev event from the database. - // Deduplicate the IDs before passing them to the database. - // There could be duplicates because the events could be state events where - // the snapshot of the room state before them was the same. - stateBlockNIDLists, err := db.StateBlockNIDs(uniqueStateSnapshotNIDs(stateNIDs)) - if err != nil { - return nil, err - } - - var stateBlockNIDs []types.StateBlockNID - for _, list := range stateBlockNIDLists { - stateBlockNIDs = append(stateBlockNIDs, list.StateBlockNIDs...) - } - // Fetch the state entries that will be combined to create the snapshots. - // Deduplicate the IDs before passing them to the database. - // There could be duplicates because a block of state entries could be reused by - // multiple snapshots. - stateEntryLists, err := db.StateEntries(uniqueStateBlockNIDs(stateBlockNIDs)) - if err != nil { - return nil, err - } - stateBlockNIDsMap := stateBlockNIDListMap(stateBlockNIDLists) - stateEntriesMap := stateEntryListMap(stateEntryLists) - - // Combine the entries from all the snapshots of state after each prev event into a single list. - var combined []types.StateEntry - for _, prevState := range prevStates { - // Grab the list of state data NIDs for this snapshot. - stateBlockNIDs, ok := stateBlockNIDsMap.lookup(prevState.BeforeStateSnapshotNID) - if !ok { - // This should only get hit if the database is corrupt. - // It should be impossible for an event to reference a NID that doesn't exist - panic(fmt.Errorf("Corrupt DB: Missing state snapshot numeric ID %d", prevState.BeforeStateSnapshotNID)) - } - - // Combined all the state entries for this snapshot. - // The order of state block NIDs in the list tells us the order to combine them in. - var fullState []types.StateEntry - for _, stateBlockNID := range stateBlockNIDs { - entries, ok := stateEntriesMap.lookup(stateBlockNID) - if !ok { - // This should only get hit if the database is corrupt. - // It should be impossible for an event to reference a NID that doesn't exist - panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID)) - } - fullState = append(fullState, entries...) - } - if prevState.IsStateEvent() { - // If the prev event was a state event then add an entry for the event itself - // so that we get the state after the event rather than the state before. - fullState = append(fullState, prevState.StateEntry) - } - - // Stable sort so that the most recent entry for each state key stays - // remains later in the list than the older entries for the same state key. - sort.Stable(stateEntryByStateKeySorter(fullState)) - // Unique returns the last entry and hence the most recent entry for each state key. - fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))] - // Add the full state for this StateSnapshotNID. - combined = append(combined, fullState...) - } - return combined, nil -} - func resolveConflicts(db RoomEventDatabase, combined, conflicted []types.StateEntry) ([]types.StateEntry, error) { panic(fmt.Errorf("Not implemented")) } @@ -308,64 +145,8 @@ func findDuplicateStateKeys(a []types.StateEntry) []types.StateEntry { return result } -type stateBlockNIDListMap []types.StateBlockNIDList - -func (m stateBlockNIDListMap) lookup(stateNID types.StateSnapshotNID) (stateBlockNIDs []types.StateBlockNID, ok bool) { - list := []types.StateBlockNIDList(m) - i := sort.Search(len(list), func(i int) bool { - return list[i].StateSnapshotNID >= stateNID - }) - if i < len(list) && list[i].StateSnapshotNID == stateNID { - ok = true - stateBlockNIDs = list[i].StateBlockNIDs - } - return -} - -type stateEntryListMap []types.StateEntryList - -func (m stateEntryListMap) lookup(stateBlockNID types.StateBlockNID) (stateEntries []types.StateEntry, ok bool) { - list := []types.StateEntryList(m) - i := sort.Search(len(list), func(i int) bool { - return list[i].StateBlockNID >= stateBlockNID - }) - if i < len(list) && list[i].StateBlockNID == stateBlockNID { - ok = true - stateEntries = list[i].StateEntries - } - return -} - -type stateEntryByStateKeySorter []types.StateEntry - -func (s stateEntryByStateKeySorter) Len() int { return len(s) } -func (s stateEntryByStateKeySorter) Less(i, j int) bool { - return s[i].StateKeyTuple.LessThan(s[j].StateKeyTuple) -} -func (s stateEntryByStateKeySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } - type stateEntrySorter []types.StateEntry func (s stateEntrySorter) Len() int { return len(s) } func (s stateEntrySorter) Less(i, j int) bool { return s[i].LessThan(s[j]) } func (s stateEntrySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } - -type stateNIDSorter []types.StateSnapshotNID - -func (s stateNIDSorter) Len() int { return len(s) } -func (s stateNIDSorter) Less(i, j int) bool { return s[i] < s[j] } -func (s stateNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } - -func uniqueStateSnapshotNIDs(nids []types.StateSnapshotNID) []types.StateSnapshotNID { - return nids[:util.SortAndUnique(stateNIDSorter(nids))] -} - -type stateBlockNIDSorter []types.StateBlockNID - -func (s stateBlockNIDSorter) Len() int { return len(s) } -func (s stateBlockNIDSorter) Less(i, j int) bool { return s[i] < s[j] } -func (s stateBlockNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } - -func uniqueStateBlockNIDs(nids []types.StateBlockNID) []types.StateBlockNID { - return nids[:util.SortAndUnique(stateBlockNIDSorter(nids))] -} diff --git a/src/github.com/matrix-org/dendrite/roomserver/state/state.go b/src/github.com/matrix-org/dendrite/roomserver/state/state.go new file mode 100644 index 000000000..aadc5550f --- /dev/null +++ b/src/github.com/matrix-org/dendrite/roomserver/state/state.go @@ -0,0 +1,239 @@ +// Package state provides functions for reading state from the database. +// The functions for writing state to the database are the input package. +package state + +import ( + "fmt" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/util" + "sort" +) + +// A RoomStateDatabase has the storage APIs needed to load state from the database +type RoomStateDatabase interface { + // Lookup the numeric state data IDs for each numeric state snapshot ID + // The returned slice is sorted by numeric state snapshot ID. + StateBlockNIDs(stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) + // Lookup the state data for each numeric state data ID + // The returned slice is sorted by numeric state data ID. + StateEntries(stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) +} + +// LoadStateAtSnapshot loads the full state of a room at a particular snapshot. +// This is typically the state before an event or the current state of a room. +// Returns a sorted list of state entries or an error if there was a problem talking to the database. +func LoadStateAtSnapshot(db RoomStateDatabase, stateNID types.StateSnapshotNID) ([]types.StateEntry, error) { + stateBlockNIDLists, err := db.StateBlockNIDs([]types.StateSnapshotNID{stateNID}) + if err != nil { + return nil, err + } + stateBlockNIDList := stateBlockNIDLists[0] + + stateEntryLists, err := db.StateEntries(stateBlockNIDList.StateBlockNIDs) + if err != nil { + return nil, err + } + stateEntriesMap := stateEntryListMap(stateEntryLists) + + // Combined all the state entries for this snapshot. + // The order of state block NIDs in the list tells us the order to combine them in. + var fullState []types.StateEntry + for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs { + entries, ok := stateEntriesMap.lookup(stateBlockNID) + if !ok { + // This should only get hit if the database is corrupt. + // It should be impossible for an event to reference a NID that doesn't exist + panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID)) + } + fullState = append(fullState, entries...) + } + + // Stable sort so that the most recent entry for each state key stays + // remains later in the list than the older entries for the same state key. + sort.Stable(stateEntryByStateKeySorter(fullState)) + // Unique returns the last entry and hence the most recent entry for each state key. + fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))] + return fullState, nil +} + +// LoadCombinedStateAfterEvents loads a snapshot of the state after each of the events +// and combines those snapshots together into a single list. +func LoadCombinedStateAfterEvents(db RoomStateDatabase, prevStates []types.StateAtEvent) ([]types.StateEntry, error) { + stateNIDs := make([]types.StateSnapshotNID, len(prevStates)) + for i, state := range prevStates { + stateNIDs[i] = state.BeforeStateSnapshotNID + } + // Fetch the state snapshots for the state before the each prev event from the database. + // Deduplicate the IDs before passing them to the database. + // There could be duplicates because the events could be state events where + // the snapshot of the room state before them was the same. + stateBlockNIDLists, err := db.StateBlockNIDs(uniqueStateSnapshotNIDs(stateNIDs)) + if err != nil { + return nil, err + } + + var stateBlockNIDs []types.StateBlockNID + for _, list := range stateBlockNIDLists { + stateBlockNIDs = append(stateBlockNIDs, list.StateBlockNIDs...) + } + // Fetch the state entries that will be combined to create the snapshots. + // Deduplicate the IDs before passing them to the database. + // There could be duplicates because a block of state entries could be reused by + // multiple snapshots. + stateEntryLists, err := db.StateEntries(uniqueStateBlockNIDs(stateBlockNIDs)) + if err != nil { + return nil, err + } + stateBlockNIDsMap := stateBlockNIDListMap(stateBlockNIDLists) + stateEntriesMap := stateEntryListMap(stateEntryLists) + + // Combine the entries from all the snapshots of state after each prev event into a single list. + var combined []types.StateEntry + for _, prevState := range prevStates { + // Grab the list of state data NIDs for this snapshot. + stateBlockNIDs, ok := stateBlockNIDsMap.lookup(prevState.BeforeStateSnapshotNID) + if !ok { + // This should only get hit if the database is corrupt. + // It should be impossible for an event to reference a NID that doesn't exist + panic(fmt.Errorf("Corrupt DB: Missing state snapshot numeric ID %d", prevState.BeforeStateSnapshotNID)) + } + + // Combined all the state entries for this snapshot. + // The order of state block NIDs in the list tells us the order to combine them in. + var fullState []types.StateEntry + for _, stateBlockNID := range stateBlockNIDs { + entries, ok := stateEntriesMap.lookup(stateBlockNID) + if !ok { + // This should only get hit if the database is corrupt. + // It should be impossible for an event to reference a NID that doesn't exist + panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID)) + } + fullState = append(fullState, entries...) + } + if prevState.IsStateEvent() { + // If the prev event was a state event then add an entry for the event itself + // so that we get the state after the event rather than the state before. + fullState = append(fullState, prevState.StateEntry) + } + + // Stable sort so that the most recent entry for each state key stays + // remains later in the list than the older entries for the same state key. + sort.Stable(stateEntryByStateKeySorter(fullState)) + // Unique returns the last entry and hence the most recent entry for each state key. + fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))] + // Add the full state for this StateSnapshotNID. + combined = append(combined, fullState...) + } + return combined, nil +} + +// DifferenceBetweeenStateSnapshots works out which state entries have been added and removed between two snapshots. +func DifferenceBetweeenStateSnapshots(db RoomStateDatabase, oldStateNID, newStateNID types.StateSnapshotNID) ( + removed, added []types.StateEntry, err error, +) { + if oldStateNID == newStateNID { + // If the snapshot NIDs are the same then nothing has changed + return nil, nil, nil + } + + var oldEntries []types.StateEntry + var newEntries []types.StateEntry + if oldStateNID != 0 { + oldEntries, err = LoadStateAtSnapshot(db, oldStateNID) + if err != nil { + return nil, nil, err + } + } + if newStateNID != 0 { + newEntries, err = LoadStateAtSnapshot(db, newStateNID) + if err != nil { + return nil, nil, err + } + } + + var oldI int + var newI int + for { + switch { + case oldI == len(oldEntries): + // We've reached the end of the old entries. + // The rest of the new list must have been newly added. + added = append(added, newEntries[newI:]...) + return + case newI == len(newEntries): + // We've reached the end of the new entries. + // The rest of the old list must be have been removed. + removed = append(removed, oldEntries[oldI:]...) + return + case oldEntries[oldI] == newEntries[newI]: + // The entry is in both lists so skip over it. + oldI++ + newI++ + case oldEntries[oldI].LessThan(newEntries[newI]): + // The lists are sorted so the old entry being less than the new entry means that it only appears in the old list. + removed = append(removed, oldEntries[oldI]) + oldI++ + default: + // Reaching the default case implies that the new entry is less than the old entry. + // Since the lists are sorted this means that it only appears in the new list. + added = append(added, newEntries[newI]) + newI++ + } + } +} + +type stateBlockNIDListMap []types.StateBlockNIDList + +func (m stateBlockNIDListMap) lookup(stateNID types.StateSnapshotNID) (stateBlockNIDs []types.StateBlockNID, ok bool) { + list := []types.StateBlockNIDList(m) + i := sort.Search(len(list), func(i int) bool { + return list[i].StateSnapshotNID >= stateNID + }) + if i < len(list) && list[i].StateSnapshotNID == stateNID { + ok = true + stateBlockNIDs = list[i].StateBlockNIDs + } + return +} + +type stateEntryListMap []types.StateEntryList + +func (m stateEntryListMap) lookup(stateBlockNID types.StateBlockNID) (stateEntries []types.StateEntry, ok bool) { + list := []types.StateEntryList(m) + i := sort.Search(len(list), func(i int) bool { + return list[i].StateBlockNID >= stateBlockNID + }) + if i < len(list) && list[i].StateBlockNID == stateBlockNID { + ok = true + stateEntries = list[i].StateEntries + } + return +} + +type stateEntryByStateKeySorter []types.StateEntry + +func (s stateEntryByStateKeySorter) Len() int { return len(s) } +func (s stateEntryByStateKeySorter) Less(i, j int) bool { + return s[i].StateKeyTuple.LessThan(s[j].StateKeyTuple) +} +func (s stateEntryByStateKeySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +type stateNIDSorter []types.StateSnapshotNID + +func (s stateNIDSorter) Len() int { return len(s) } +func (s stateNIDSorter) Less(i, j int) bool { return s[i] < s[j] } +func (s stateNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +func uniqueStateSnapshotNIDs(nids []types.StateSnapshotNID) []types.StateSnapshotNID { + return nids[:util.SortAndUnique(stateNIDSorter(nids))] +} + +type stateBlockNIDSorter []types.StateBlockNID + +func (s stateBlockNIDSorter) Len() int { return len(s) } +func (s stateBlockNIDSorter) Less(i, j int) bool { return s[i] < s[j] } +func (s stateBlockNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +func uniqueStateBlockNIDs(nids []types.StateBlockNID) []types.StateBlockNID { + return nids[:util.SortAndUnique(stateBlockNIDSorter(nids))] +}