From 39264cbf4b6e4008ab34a85335871b7e1a7c8631 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 15 Feb 2017 11:05:45 +0000 Subject: [PATCH] Calculate and store the state at each event (#6) * Calculate and store the state at each event * Use type aliases for numeric IDs --- .../matrix-org/dendrite/roomserver/README.md | 59 +++ .../dendrite/roomserver/input/authevents.go | 18 +- .../roomserver/input/authevents_test.go | 17 +- .../dendrite/roomserver/input/events.go | 49 ++- .../dendrite/roomserver/input/state.go | 305 +++++++++++++++ .../dendrite/roomserver/input/state_test.go | 67 ++++ .../dendrite/roomserver/storage/sql.go | 364 ++++++++++++++++-- .../dendrite/roomserver/storage/storage.go | 82 +++- .../dendrite/roomserver/types/types.go | 52 ++- 9 files changed, 935 insertions(+), 78 deletions(-) create mode 100644 src/github.com/matrix-org/dendrite/roomserver/README.md create mode 100644 src/github.com/matrix-org/dendrite/roomserver/input/state.go create mode 100644 src/github.com/matrix-org/dendrite/roomserver/input/state_test.go diff --git a/src/github.com/matrix-org/dendrite/roomserver/README.md b/src/github.com/matrix-org/dendrite/roomserver/README.md new file mode 100644 index 000000000..5a2757603 --- /dev/null +++ b/src/github.com/matrix-org/dendrite/roomserver/README.md @@ -0,0 +1,59 @@ +# RoomServer + + +## RoomServer Internals + +### Numeric IDs + +To save space matrix string identifiers are mapped to local numeric IDs. +The numeric IDs are more efficient to manipulate and use less space to store. +The numeric IDs are never exposed in the API the room server exposes. +The numeric IDs are converted to string IDs before they leave the room server. +The numeric ID for a string ID is never 0 to avoid being confused with go's +default zero value. +Zero is used to indicate that there was no corresponding string ID. +Well-known event types and event state keys are preassigned numeric IDs. + +### State Snapshot Storage + +The room server stores the state of the matrix room at each event. +For efficiency the state is stored as blocks of 3-tuples of numeric IDs for the +event type, event state key and event ID. For further efficiency the state +snapshots are stored as the combination of up to 64 these blocks. This allows +blocks of the room state to be reused in multiple snapshots. + +The resulting database tables look something like this: + + +-------------------------------------------------------------------+ + | Events | + +---------+-------------------+------------------+------------------+ + | EventNID| EventTypeNID | EventStateKeyNID | StateSnapshotNID | + +---------+-------------------+------------------+------------------+ + | 1 | m.room.create 1 | "" 1 | 0 | + | 2 | m.room.member 2 | "@user:foo" 2 | 0 | + | 3 | m.room.member 2 | "@user:bar" 3 | {1,2} 1 | + | 4 | m.room.message 3 | 0 | {1,2,3} 2 | + | 5 | m.room.member 2 | "@user:foo" 2 | {1,2,3} 2 | + | 6 | m.room.message 3 | 0 | {1,3,6} 3 | + +---------+-------------------+------------------+------------------+ + + +----------------------------------------+ + | State Snapshots | + +-----------------------+----------------+ + | EventStateSnapshotNID | StateBlockNIDs | + +-----------------------+----------------| + | 1 | {1} | + | 2 | {1,2} | + | 3 | {1,2,3} | + +-----------------------+----------------+ + + +-----------------------------------------------------------------+ + | State Blocks | + +---------------+-------------------+------------------+----------+ + | StateBlockNID | EventTypeNID | EventStateKeyNID | EventNID | + +---------------+-------------------+------------------+----------+ + | 1 | m.room.create 1 | "" 1 | 1 | + | 1 | m.room.member 2 | "@user:foo" 2 | 2 | + | 2 | m.room.member 2 | "@user:bar" 3 | 3 | + | 3 | m.room.member 2 | "@user:foo" 2 | 6 | + +---------------+-------------------+------------------+----------+ diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/authevents.go b/src/github.com/matrix-org/dendrite/roomserver/input/authevents.go index bb15750b7..7dcaca915 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/input/authevents.go +++ b/src/github.com/matrix-org/dendrite/roomserver/input/authevents.go @@ -8,7 +8,7 @@ import ( // checkAuthEvents checks that the event passes authentication checks // Returns the numeric IDs for the auth events. -func checkAuthEvents(db RoomEventDatabase, event gomatrixserverlib.Event, authEventIDs []string) ([]int64, error) { +func checkAuthEvents(db RoomEventDatabase, event gomatrixserverlib.Event, authEventIDs []string) ([]types.EventNID, error) { // Grab the numeric IDs for the supplied auth state events from the database. authStateEntries, err := db.StateEntriesForEventIDs(authEventIDs) if err != nil { @@ -31,7 +31,7 @@ func checkAuthEvents(db RoomEventDatabase, event gomatrixserverlib.Event, authEv } // Return the numeric IDs for the auth events. - result := make([]int64, len(authStateEntries)) + result := make([]types.EventNID, len(authStateEntries)) for i := range authStateEntries { result[i] = authStateEntries[i].EventNID } @@ -39,7 +39,7 @@ func checkAuthEvents(db RoomEventDatabase, event gomatrixserverlib.Event, authEv } type authEvents struct { - stateKeyNIDMap map[string]int64 + stateKeyNIDMap map[string]types.EventStateKeyNID state stateEntryMap events eventMap } @@ -69,7 +69,7 @@ func (ae *authEvents) ThirdPartyInvite(stateKey string) (*gomatrixserverlib.Even return ae.lookupEvent(types.MRoomThirdPartyInviteNID, stateKey), nil } -func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID int64) *gomatrixserverlib.Event { +func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID types.EventTypeNID) *gomatrixserverlib.Event { eventNID, ok := ae.state.lookup(types.StateKeyTuple{typeNID, types.EmptyStateKeyNID}) if !ok { return nil @@ -81,7 +81,7 @@ func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID int64) *gomatrixserve return &event.Event } -func (ae *authEvents) lookupEvent(typeNID int64, stateKey string) *gomatrixserverlib.Event { +func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) *gomatrixserverlib.Event { stateKeyNID, ok := ae.stateKeyNIDMap[stateKey] if !ok { return nil @@ -113,7 +113,7 @@ func loadAuthEvents( // Load the events we need. result.state = state - var eventNIDs []int64 + var eventNIDs []types.EventNID keyTuplesNeeded := stateKeyTuplesNeeded(result.stateKeyNIDMap, needed) for _, keyTuple := range keyTuplesNeeded { eventNID, ok := result.state.lookup(keyTuple) @@ -128,7 +128,7 @@ func loadAuthEvents( } // stateKeyTuplesNeeded works out which numeric state key tuples we need to authenticate some events. -func stateKeyTuplesNeeded(stateKeyNIDMap map[string]int64, stateNeeded gomatrixserverlib.StateNeeded) []types.StateKeyTuple { +func stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.EventStateKeyNID, stateNeeded gomatrixserverlib.StateNeeded) []types.StateKeyTuple { var keyTuples []types.StateKeyTuple if stateNeeded.Create { keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomCreateNID, types.EmptyStateKeyNID}) @@ -159,7 +159,7 @@ func stateKeyTuplesNeeded(stateKeyNIDMap map[string]int64, stateNeeded gomatrixs type stateEntryMap []types.StateEntry // lookup an entry in the event map. -func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID int64, ok bool) { +func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID types.EventNID, ok bool) { // Since the list is sorted we can implement this using binary search. // This is faster than using a hash map. // We don't have to worry about pathological cases because the keys are fixed @@ -180,7 +180,7 @@ func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID int64, ok type eventMap []types.Event // lookup an entry in the event map. -func (m eventMap) lookup(eventNID int64) (event *types.Event, ok bool) { +func (m eventMap) lookup(eventNID types.EventNID) (event *types.Event, ok bool) { // Since the list is sorted we can implement this using binary search. // This is faster than using a hash map. // We don't have to worry about pathological cases because the keys are fixed diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/authevents_test.go b/src/github.com/matrix-org/dendrite/roomserver/input/authevents_test.go index aba1de092..69be65d78 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/input/authevents_test.go +++ b/src/github.com/matrix-org/dendrite/roomserver/input/authevents_test.go @@ -8,13 +8,18 @@ import ( func benchmarkStateEntryMapLookup(entries, lookups int64, b *testing.B) { var list []types.StateEntry for i := int64(0); i < entries; i++ { - list = append(list, types.StateEntry{types.StateKeyTuple{i, i}, i}) + list = append(list, types.StateEntry{types.StateKeyTuple{ + types.EventTypeNID(i), + types.EventStateKeyNID(i), + }, types.EventNID(i)}) } for i := 0; i < b.N; i++ { entryMap := stateEntryMap(list) for j := int64(0); j < lookups; j++ { - entryMap.lookup(types.StateKeyTuple{j, j}) + entryMap.lookup(types.StateKeyTuple{ + types.EventTypeNID(j), types.EventStateKeyNID(j), + }) } } } @@ -43,10 +48,10 @@ func TestStateEntryMap(t *testing.T) { }) testCases := []struct { - inputTypeNID int64 - inputStateKey int64 + inputTypeNID types.EventTypeNID + inputStateKey types.EventStateKeyNID wantOK bool - wantEventNID int64 + wantEventNID types.EventNID }{ // Check that tuples that in the array are in the map. {1, 1, true, 1}, @@ -80,7 +85,7 @@ func TestEventMap(t *testing.T) { }) testCases := []struct { - inputEventNID int64 + inputEventNID types.EventNID wantOK bool wantEvent *types.Event }{ diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/events.go b/src/github.com/matrix-org/dendrite/roomserver/input/events.go index 8aaf9f841..49328edf4 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/input/events.go +++ b/src/github.com/matrix-org/dendrite/roomserver/input/events.go @@ -9,18 +9,31 @@ import ( // A RoomEventDatabase has the storage APIs needed to store a room event. type RoomEventDatabase interface { // Stores a matrix room event in the database - StoreEvent(event gomatrixserverlib.Event, authEventNIDs []int64) error + StoreEvent(event gomatrixserverlib.Event, authEventNIDs []types.EventNID) (types.RoomNID, types.StateAtEvent, error) // Lookup the state entries for a list of string event IDs - // Returns a sorted list of state entries. - // Returns a error if the there is an error talking to the database + // Returns an error if the there is an error talking to the database // or if the event IDs aren't in the database. StateEntriesForEventIDs(eventIDs []string) ([]types.StateEntry, error) // Lookup the numeric IDs for a list of string event state keys. // Returns a map from string state key to numeric ID for the state key. - EventStateKeyNIDs(eventStateKeys []string) (map[string]int64, error) + EventStateKeyNIDs(eventStateKeys []string) (map[string]types.EventStateKeyNID, error) // Lookup the Events for a list of numeric event IDs. // Returns a sorted list of events. - Events(eventNIDs []int64) ([]types.Event, error) + Events(eventNIDs []types.EventNID) ([]types.Event, error) + // Lookup the state of a room at each event for a list of string event IDs. + // Returns an error if there is an error talking to the database + // or if the room state for the event IDs aren't in the database + StateAtEventIDs(eventIDs []string) ([]types.StateAtEvent, error) + // Lookup the numeric state data IDs for each numeric state snapshot ID + // The returned slice is sorted by numeric state snapshot ID. + StateBlockNIDs(stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) + // Lookup the state data for each numeric state data ID + // The returned slice is sorted by numeric state data ID. + StateEntries(stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) + // Store the room state at an event in the database + AddState(roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) + // Set the state at an event. + SetState(eventNID types.EventNID, stateNID types.StateSnapshotNID) error } func processRoomEvent(db RoomEventDatabase, input api.InputRoomEvent) error { @@ -37,7 +50,8 @@ func processRoomEvent(db RoomEventDatabase, input api.InputRoomEvent) error { } // Store the event - if err := db.StoreEvent(event, authEventNIDs); err != nil { + roomNID, stateAtEvent, err := db.StoreEvent(event, authEventNIDs) + if err != nil { return err } @@ -48,6 +62,29 @@ func processRoomEvent(db RoomEventDatabase, input api.InputRoomEvent) error { return nil } + if stateAtEvent.BeforeStateSnapshotNID == 0 { + // We haven't calculated a state for this event yet. + // Lets calculate one. + if input.StateEventIDs != nil { + // We've been told what the state at the event is so we don't need to calculate it. + // Check that those state events are in the database and store the state. + entries, err := db.StateEntriesForEventIDs(input.StateEventIDs) + if err != nil { + return err + } + + if stateAtEvent.BeforeStateSnapshotNID, err = db.AddState(roomNID, nil, entries); err != nil { + return nil + } + } else { + // We haven't been told what the state at the event is so we need to calculate it from the prev_events + if stateAtEvent.BeforeStateSnapshotNID, err = calculateAndStoreState(db, event, roomNID); err != nil { + return err + } + } + db.SetState(stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID) + } + // TODO: // * Calcuate the state at the event if necessary. // * Store the state at the event. diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/state.go b/src/github.com/matrix-org/dendrite/roomserver/input/state.go new file mode 100644 index 000000000..c46dc6e14 --- /dev/null +++ b/src/github.com/matrix-org/dendrite/roomserver/input/state.go @@ -0,0 +1,305 @@ +package input + +import ( + "fmt" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" + "sort" +) + +// calculateAndStoreState calculates a snapshot of the state of a room before an event. +// Stores the snapshot of the state in the database. +// Returns a numeric ID for that snapshot. +func calculateAndStoreState( + db RoomEventDatabase, event gomatrixserverlib.Event, roomNID types.RoomNID, +) (types.StateSnapshotNID, error) { + // Load the state at the prev events. + prevEventRefs := event.PrevEvents() + prevEventIDs := make([]string, len(prevEventRefs)) + for i := range prevEventRefs { + prevEventIDs[i] = prevEventRefs[i].EventID + } + + prevStates, err := db.StateAtEventIDs(prevEventIDs) + if err != nil { + return 0, err + } + + if len(prevStates) == 0 { + // 2) There weren't any prev_events for this event so the state is + // empty. + return db.AddState(roomNID, nil, nil) + } + + if len(prevStates) == 1 { + prevState := prevStates[0] + if prevState.EventStateKeyNID == 0 { + // 3) None of the previous events were state events and they all + // have the same state, so this event has exactly the same state + // as the previous events. + // This should be the common case. + return prevState.BeforeStateSnapshotNID, nil + } + // The previous event was a state event so we need to store a copy + // of the previous state updated with that event. + stateBlockNIDLists, err := db.StateBlockNIDs([]types.StateSnapshotNID{prevState.BeforeStateSnapshotNID}) + if err != nil { + return 0, err + } + stateBlockNIDs := stateBlockNIDLists[0].StateBlockNIDs + if len(stateBlockNIDs) < maxStateBlockNIDs { + // 4) The number of state data blocks is small enough that we can just + // add the state event as a block of size one to the end of the blocks. + return db.AddState( + roomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry}, + ) + } + // If there are too many deltas then we need to calculate the full state + // So fall through to calculateAndStoreStateMany + } + return calculateAndStoreStateMany(db, roomNID, prevStates) +} + +// maxStateBlockNIDs is the maximum number of state data blocks to use to encode a snapshot of room state. +// Increasing this number means that we can encode more of the state changes as simple deltas which means that +// we need fewer entries in the state data table. However making this number bigger will increase the size of +// the rows in the state table itself and will require more index lookups when retrieving a snapshot. +// TODO: Tune this to get the right balance between size and lookup performance. +const maxStateBlockNIDs = 64 + +// calculateAndStoreStateMany calculates the state of the room before an event +// using the states at each of the event's prev events. +// Stores the resulting state and returns a numeric ID for the snapshot. +func calculateAndStoreStateMany(db RoomEventDatabase, roomNID types.RoomNID, prevStates []types.StateAtEvent) (types.StateSnapshotNID, error) { + // Conflict resolution. + // First stage: load the state after each of the prev events. + combined, err := loadCombinedStateAfterEvents(db, prevStates) + if err != nil { + return 0, err + } + + // Collect all the entries with the same type and key together. + // We don't care about the order here because the conflict resolution + // algorithm doesn't depend on the order of the prev events. + sort.Sort(stateEntrySorter(combined)) + // Remove duplicate entires. + combined = combined[:unique(stateEntrySorter(combined))] + + // Find the conflicts + conflicts := findDuplicateStateKeys(combined) + + var state []types.StateEntry + if len(conflicts) > 0 { + // 5) There are conflicting state events, for each conflict workout + // what the appropriate state event is. + resolved, err := resolveConflicts(db, combined, conflicts) + if err != nil { + return 0, err + } + state = resolved + } else { + // 6) There weren't any conflicts + state = combined + } + + // TODO: Check if we can encode the new state as a delta against the + // previous state. + return db.AddState(roomNID, nil, state) +} + +// loadCombinedStateAfterEvents loads a snapshot of the state after each of the events +// and combines those snapshots together into a single list. +func loadCombinedStateAfterEvents(db RoomEventDatabase, prevStates []types.StateAtEvent) ([]types.StateEntry, error) { + stateNIDs := make([]types.StateSnapshotNID, len(prevStates)) + for i, state := range prevStates { + stateNIDs[i] = state.BeforeStateSnapshotNID + } + // Fetch the state snapshots for the state before the each prev event from the database. + // Deduplicate the IDs before passing them to the database. + // There could be duplicates because the events could be state events where + // the snapshot of the room state before them was the same. + stateBlockNIDLists, err := db.StateBlockNIDs(uniqueStateSnapshotNIDs(stateNIDs)) + if err != nil { + return nil, err + } + + var stateBlockNIDs []types.StateBlockNID + for _, list := range stateBlockNIDLists { + stateBlockNIDs = append(stateBlockNIDs, list.StateBlockNIDs...) + } + // Fetch the state entries that will be combined to create the snapshots. + // Deduplicate the IDs before passing them to the database. + // There could be duplicates because a block of state entries could be reused by + // multiple snapshots. + stateEntryLists, err := db.StateEntries(uniqueStateBlockNIDs(stateBlockNIDs)) + if err != nil { + return nil, err + } + stateBlockNIDsMap := stateBlockNIDListMap(stateBlockNIDLists) + stateEntriesMap := stateEntryListMap(stateEntryLists) + + // Combine the entries from all the snapshots of state after each prev event into a single list. + var combined []types.StateEntry + for _, prevState := range prevStates { + // Grab the list of state data NIDs for this snapshot. + stateBlockNIDs, ok := stateBlockNIDsMap.lookup(prevState.BeforeStateSnapshotNID) + if !ok { + // This should only get hit if the database is corrupt. + // It should be impossible for an event to reference a NID that doesn't exist + panic(fmt.Errorf("Corrupt DB: Missing state numeric ID %d", prevState.BeforeStateSnapshotNID)) + } + + // Combined all the state entries for this snapshot. + // The order of state data NIDs in the list tells us the order to combine them in. + var fullState []types.StateEntry + for _, stateBlockNID := range stateBlockNIDs { + entries, ok := stateEntriesMap.lookup(stateBlockNID) + if !ok { + // This should only get hit if the database is corrupt. + // It should be impossible for an event to reference a NID that doesn't exist + panic(fmt.Errorf("Corrupt DB: Missing state numeric ID %d", prevState.BeforeStateSnapshotNID)) + } + fullState = append(fullState, entries...) + } + if prevState.IsStateEvent() { + // If the prev event was a state event then add an entry for the event itself + // so that we get the state after the event rather than the state before. + fullState = append(fullState, prevState.StateEntry) + } + + // Stable sort so that the most recent entry for each state key stays + // remains later in the list than the older entries for the same state key. + sort.Stable(stateEntryByStateKeySorter(fullState)) + // Unique returns the last entry and hence the most recent entry for each state key. + fullState = fullState[:unique(stateEntryByStateKeySorter(fullState))] + // Add the full state for this StateSnapshotNID. + combined = append(combined, fullState...) + } + return combined, nil +} + +func resolveConflicts(db RoomEventDatabase, combined, conflicted []types.StateEntry) ([]types.StateEntry, error) { + panic(fmt.Errorf("Not implemented")) +} + +// findDuplicateStateKeys finds the state entries where the state key tuple appears more than once in a sorted list. +// Returns a sorted list of those state entries. +func findDuplicateStateKeys(a []types.StateEntry) []types.StateEntry { + var result []types.StateEntry + // j is the starting index of a block of entries with the same state key tuple. + j := 0 + for i := 1; i < len(a); i++ { + // Check if the state key tuple matches the start of the block + if a[j].StateKeyTuple != a[i].StateKeyTuple { + // If the state key tuple is different then we've reached the end of a block of duplicates. + // Check if the size of the block is bigger than one. + // If the size is one then there was only a single entry with that state key tuple so we don't add it to the result + if j+1 != i { + // Add the block to the result. + result = append(result, a[j:i]...) + } + // Start a new block for the next state key tuple. + j = i + } + } + // Check if the last block with the same state key tuple had more than one event in it. + if j+1 != len(a) { + result = append(result, a[j:]...) + } + return result +} + +type stateBlockNIDListMap []types.StateBlockNIDList + +func (m stateBlockNIDListMap) lookup(stateNID types.StateSnapshotNID) (stateBlockNIDs []types.StateBlockNID, ok bool) { + list := []types.StateBlockNIDList(m) + i := sort.Search(len(list), func(i int) bool { + return list[i].StateSnapshotNID >= stateNID + }) + if i < len(list) && list[i].StateSnapshotNID == stateNID { + ok = true + stateBlockNIDs = list[i].StateBlockNIDs + } + return +} + +type stateEntryListMap []types.StateEntryList + +func (m stateEntryListMap) lookup(stateBlockNID types.StateBlockNID) (stateEntries []types.StateEntry, ok bool) { + list := []types.StateEntryList(m) + i := sort.Search(len(list), func(i int) bool { + return list[i].StateBlockNID >= stateBlockNID + }) + if i < len(list) && list[i].StateBlockNID == stateBlockNID { + ok = true + stateEntries = list[i].StateEntries + } + return +} + +type stateEntryByStateKeySorter []types.StateEntry + +func (s stateEntryByStateKeySorter) Len() int { return len(s) } +func (s stateEntryByStateKeySorter) Less(i, j int) bool { + return s[i].StateKeyTuple.LessThan(s[j].StateKeyTuple) +} +func (s stateEntryByStateKeySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +type stateEntrySorter []types.StateEntry + +func (s stateEntrySorter) Len() int { return len(s) } +func (s stateEntrySorter) Less(i, j int) bool { return s[i].LessThan(s[j]) } +func (s stateEntrySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +type stateNIDSorter []types.StateSnapshotNID + +func (s stateNIDSorter) Len() int { return len(s) } +func (s stateNIDSorter) Less(i, j int) bool { return s[i] < s[j] } +func (s stateNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +func uniqueStateSnapshotNIDs(nids []types.StateSnapshotNID) []types.StateSnapshotNID { + sort.Sort(stateNIDSorter(nids)) + return nids[:unique(stateNIDSorter(nids))] +} + +type stateBlockNIDSorter []types.StateBlockNID + +func (s stateBlockNIDSorter) Len() int { return len(s) } +func (s stateBlockNIDSorter) Less(i, j int) bool { return s[i] < s[j] } +func (s stateBlockNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +func uniqueStateBlockNIDs(nids []types.StateBlockNID) []types.StateBlockNID { + sort.Sort(stateBlockNIDSorter(nids)) + return nids[:unique(stateBlockNIDSorter(nids))] +} + +// Remove duplicate items from a sorted list. +// Takes the same interface as sort.Sort +// Returns the length of the data without duplicates +// Uses the last occurance of a duplicate. +// O(n). +func unique(data sort.Interface) int { + if data.Len() == 0 { + return 0 + } + length := data.Len() + // j is the next index to output an element to. + j := 0 + for i := 1; i < length; i++ { + // If the previous element is less than this element then they are + // not equal. Otherwise they must be equal because the list is sorted. + // If they are equal then we move onto the next element. + if data.Less(i-1, i) { + // "Write" the previous element to the output position by swaping + // the elements. + // Note that if the list has no duplicates then i-1 == j so the + // swap does nothing. (This assumes that data.Swap(a,b) nops if a==b) + data.Swap(i-1, j) + // Advance to the next output position in the list. + j++ + } + } + // Output the last element. + data.Swap(length-1, j) + return j + 1 +} diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/state_test.go b/src/github.com/matrix-org/dendrite/roomserver/input/state_test.go new file mode 100644 index 000000000..e5707ff1a --- /dev/null +++ b/src/github.com/matrix-org/dendrite/roomserver/input/state_test.go @@ -0,0 +1,67 @@ +package input + +import ( + "github.com/matrix-org/dendrite/roomserver/types" + "testing" +) + +type sortBytes []byte + +func (s sortBytes) Len() int { return len(s) } +func (s sortBytes) Less(i, j int) bool { return s[i] < s[j] } +func (s sortBytes) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +func TestUnique(t *testing.T) { + testCases := []struct { + Input string + Want string + }{ + {"", ""}, + {"abc", "abc"}, + {"aaabbbccc", "abc"}, + } + + for _, test := range testCases { + input := []byte(test.Input) + want := string(test.Want) + got := string(input[:unique(sortBytes(input))]) + if got != want { + t.Fatal("Wanted ", want, " got ", got) + } + } +} + +func TestFindDuplicateStateKeys(t *testing.T) { + testCases := []struct { + Input []types.StateEntry + Want []types.StateEntry + }{{ + Input: []types.StateEntry{ + {types.StateKeyTuple{1, 1}, 1}, + {types.StateKeyTuple{1, 1}, 2}, + {types.StateKeyTuple{2, 2}, 3}, + }, + Want: []types.StateEntry{ + {types.StateKeyTuple{1, 1}, 1}, + {types.StateKeyTuple{1, 1}, 2}, + }, + }, { + Input: []types.StateEntry{ + {types.StateKeyTuple{1, 1}, 1}, + {types.StateKeyTuple{1, 2}, 2}, + }, + Want: nil, + }} + + for _, test := range testCases { + got := findDuplicateStateKeys(test.Input) + if len(got) != len(test.Want) { + t.Fatalf("Wanted %v, got %v", test.Want, got) + } + for i := range got { + if got[i] != test.Want[i] { + t.Fatalf("Wanted %v, got %v", test.Want, got) + } + } + } +} diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/sql.go b/src/github.com/matrix-org/dendrite/roomserver/storage/sql.go index b373f8309..6b23f1e36 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/sql.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/sql.go @@ -19,8 +19,15 @@ type statements struct { selectRoomNIDStmt *sql.Stmt insertEventStmt *sql.Stmt bulkSelectStateEventByIDStmt *sql.Stmt + bulkSelectStateAtEventByIDStmt *sql.Stmt + updateEventStateStmt *sql.Stmt insertEventJSONStmt *sql.Stmt bulkSelectEventJSONStmt *sql.Stmt + insertStateStmt *sql.Stmt + bulkSelectStateBlockNIDsStmt *sql.Stmt + insertStateDataStmt *sql.Stmt + selectNextStateBlockNIDStmt *sql.Stmt + bulkSelectStateDataEntriesStmt *sql.Stmt } func (s *statements) prepare(db *sql.DB) error { @@ -180,14 +187,16 @@ const insertEventTypeNIDSQL = "" + const selectEventTypeNIDSQL = "" + "SELECT event_type_nid FROM event_types WHERE event_type = $1" -func (s *statements) insertEventTypeNID(eventType string) (eventTypeNID int64, err error) { - err = s.insertEventTypeNIDStmt.QueryRow(eventType).Scan(&eventTypeNID) - return +func (s *statements) insertEventTypeNID(eventType string) (types.EventTypeNID, error) { + var eventTypeNID int64 + err := s.insertEventTypeNIDStmt.QueryRow(eventType).Scan(&eventTypeNID) + return types.EventTypeNID(eventTypeNID), err } -func (s *statements) selectEventTypeNID(eventType string) (eventTypeNID int64, err error) { - err = s.selectEventTypeNIDStmt.QueryRow(eventType).Scan(&eventTypeNID) - return +func (s *statements) selectEventTypeNID(eventType string) (types.EventTypeNID, error) { + var eventTypeNID int64 + err := s.selectEventTypeNIDStmt.QueryRow(eventType).Scan(&eventTypeNID) + return types.EventTypeNID(eventTypeNID), err } func (s *statements) prepareEventStateKeys(db *sql.DB) (err error) { @@ -244,31 +253,33 @@ const bulkSelectEventStateKeyNIDSQL = "" + "SELECT event_state_key, event_state_key_nid FROM event_state_keys" + " WHERE event_state_key = ANY($1)" -func (s *statements) insertEventStateKeyNID(eventStateKey string) (eventStateKeyNID int64, err error) { - err = s.insertEventStateKeyNIDStmt.QueryRow(eventStateKey).Scan(&eventStateKeyNID) - return +func (s *statements) insertEventStateKeyNID(eventStateKey string) (types.EventStateKeyNID, error) { + var eventStateKeyNID int64 + err := s.insertEventStateKeyNIDStmt.QueryRow(eventStateKey).Scan(&eventStateKeyNID) + return types.EventStateKeyNID(eventStateKeyNID), err } -func (s *statements) selectEventStateKeyNID(eventStateKey string) (eventStateKeyNID int64, err error) { - err = s.selectEventStateKeyNIDStmt.QueryRow(eventStateKey).Scan(&eventStateKeyNID) - return +func (s *statements) selectEventStateKeyNID(eventStateKey string) (types.EventStateKeyNID, error) { + var eventStateKeyNID int64 + err := s.selectEventStateKeyNIDStmt.QueryRow(eventStateKey).Scan(&eventStateKeyNID) + return types.EventStateKeyNID(eventStateKeyNID), err } -func (s *statements) bulkSelectEventStateKeyNID(eventStateKeys []string) (map[string]int64, error) { +func (s *statements) bulkSelectEventStateKeyNID(eventStateKeys []string) (map[string]types.EventStateKeyNID, error) { rows, err := s.bulkSelectEventStateKeyNIDStmt.Query(pq.StringArray(eventStateKeys)) if err != nil { return nil, err } defer rows.Close() - result := make(map[string]int64, len(eventStateKeys)) + result := make(map[string]types.EventStateKeyNID, len(eventStateKeys)) for rows.Next() { var stateKey string var stateKeyNID int64 if err := rows.Scan(&stateKey, &stateKeyNID); err != nil { return nil, err } - result[stateKey] = stateKeyNID + result[stateKey] = types.EventStateKeyNID(stateKeyNID) } return result, nil } @@ -307,14 +318,16 @@ const insertRoomNIDSQL = "" + const selectRoomNIDSQL = "" + "SELECT room_nid FROM rooms WHERE room_id = $1" -func (s *statements) insertRoomNID(roomID string) (roomNID int64, err error) { - err = s.insertRoomNIDStmt.QueryRow(roomID).Scan(&roomNID) - return +func (s *statements) insertRoomNID(roomID string) (types.RoomNID, error) { + var roomNID int64 + err := s.insertRoomNIDStmt.QueryRow(roomID).Scan(&roomNID) + return types.RoomNID(roomNID), err } -func (s *statements) selectRoomNID(roomID string) (roomNID int64, err error) { - err = s.selectRoomNIDStmt.QueryRow(roomID).Scan(&roomNID) - return +func (s *statements) selectRoomNID(roomID string) (types.RoomNID, error) { + var roomNID int64 + err := s.selectRoomNIDStmt.QueryRow(roomID).Scan(&roomNID) + return types.RoomNID(roomNID), err } const eventsSchema = ` @@ -333,6 +346,13 @@ CREATE TABLE IF NOT EXISTS events ( -- Local numeric ID for the state_key of the event -- This is 0 if the event is not a state event. event_state_key_nid BIGINT NOT NULL, + -- Local numeric ID for the state at the event. + -- This is 0 if we don't know the state at the event. + -- If the state is not 0 then this event is part of the contiguous + -- part of the event graph + -- Since many different events can have the same state we store the + -- state into a separate state table and refer to it by numeric ID. + state_snapshot_nid bigint NOT NULL DEFAULT 0, -- The textual event id. -- Used to lookup the numeric ID when processing requests. -- Needed for state resolution. @@ -342,7 +362,7 @@ CREATE TABLE IF NOT EXISTS events ( -- Needed for setting reference hashes when sending new events. reference_sha256 BYTEA NOT NULL, -- A list of numeric IDs for events that can authenticate this event. - auth_event_nids BIGINT[] NOT NULL, + auth_event_nids BIGINT[] NOT NULL ); ` @@ -351,7 +371,7 @@ const insertEventSQL = "" + " VALUES ($1, $2, $3, $4, $5, $6)" + " ON CONFLICT ON CONSTRAINT event_id_unique" + " DO UPDATE SET event_id = $1" + - " RETURNING event_nid" + " RETURNING event_nid, state_snapshot_nid" // Bulk lookup of events by string ID. // Sort by the numeric IDs for event type and state key. @@ -361,6 +381,13 @@ const bulkSelectStateEventByIDSQL = "" + " WHERE event_id = ANY($1)" + " ORDER BY event_type_nid, event_state_key_nid ASC" +const bulkSelectStateAtEventByIDSQL = "" + + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid FROM events" + + " WHERE event_id = ANY($1)" + +const updateEventStateSQL = "" + + "UPDATE events SET state_snapshot_nid = $2 WHERE event_nid = $1" + func (s *statements) prepareEvents(db *sql.DB) (err error) { _, err = db.Exec(eventsSchema) if err != nil { @@ -372,20 +399,32 @@ func (s *statements) prepareEvents(db *sql.DB) (err error) { if s.bulkSelectStateEventByIDStmt, err = db.Prepare(bulkSelectStateEventByIDSQL); err != nil { return } + if s.bulkSelectStateAtEventByIDStmt, err = db.Prepare(bulkSelectStateAtEventByIDSQL); err != nil { + return + } + if s.updateEventStateStmt, err = db.Prepare(updateEventStateSQL); err != nil { + return + } return } func (s *statements) insertEvent( - roomNID, eventTypeNID, eventStateKeyNID int64, + roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, eventID string, referenceSHA256 []byte, - authEventNIDs []int64, -) (eventNID int64, err error) { - err = s.insertEventStmt.QueryRow( - roomNID, eventTypeNID, eventStateKeyNID, eventID, referenceSHA256, - pq.Int64Array(authEventNIDs), - ).Scan(&eventNID) - return + authEventNIDs []types.EventNID, +) (types.EventNID, types.StateSnapshotNID, error) { + nids := make([]int64, len(authEventNIDs)) + for i := range authEventNIDs { + nids[i] = int64(authEventNIDs[i]) + } + var eventNID int64 + var stateNID int64 + err := s.insertEventStmt.QueryRow( + int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), eventID, referenceSHA256, + pq.Int64Array(nids), + ).Scan(&eventNID, &stateNID) + return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err } func (s *statements) bulkSelectStateEventByID(eventIDs []string) ([]types.StateEntry, error) { @@ -421,6 +460,39 @@ func (s *statements) bulkSelectStateEventByID(eventIDs []string) ([]types.StateE return results, err } +func (s *statements) bulkSelectStateAtEventByID(eventIDs []string) ([]types.StateAtEvent, error) { + rows, err := s.bulkSelectStateAtEventByIDStmt.Query(pq.StringArray(eventIDs)) + if err != nil { + return nil, err + } + defer rows.Close() + results := make([]types.StateAtEvent, len(eventIDs)) + i := 0 + for ; rows.Next(); i++ { + result := &results[i] + if err = rows.Scan( + &result.EventNID, + &result.EventTypeNID, + &result.EventStateKeyNID, + &result.BeforeStateSnapshotNID, + ); err != nil { + return nil, err + } + if result.BeforeStateSnapshotNID == 0 { + return nil, fmt.Errorf("storage: missing state for event NID %d", result.EventNID) + } + } + if i != len(eventIDs) { + return nil, fmt.Errorf("storage: event IDs missing from the database (%d != %d)", i, len(eventIDs)) + } + return results, err +} + +func (s *statements) updateEventState(eventNID types.EventNID, stateNID types.StateSnapshotNID) error { + _, err := s.updateEventStateStmt.Exec(int64(eventNID), int64(stateNID)) + return err +} + func (s *statements) prepareEventJSON(db *sql.DB) (err error) { _, err = db.Exec(eventJSONSchema) if err != nil { @@ -464,18 +536,22 @@ const bulkSelectEventJSONSQL = "" + " WHERE event_nid = ANY($1)" + " ORDER BY event_nid ASC" -func (s *statements) insertEventJSON(eventNID int64, eventJSON []byte) error { - _, err := s.insertEventJSONStmt.Exec(eventNID, eventJSON) +func (s *statements) insertEventJSON(eventNID types.EventNID, eventJSON []byte) error { + _, err := s.insertEventJSONStmt.Exec(int64(eventNID), eventJSON) return err } type eventJSONPair struct { - EventNID int64 + EventNID types.EventNID EventJSON []byte } -func (s *statements) bulkSelectEventJSON(eventNIDs []int64) ([]eventJSONPair, error) { - rows, err := s.bulkSelectEventJSONStmt.Query(pq.Int64Array(eventNIDs)) +func (s *statements) bulkSelectEventJSON(eventNIDs []types.EventNID) ([]eventJSONPair, error) { + nids := make([]int64, len(eventNIDs)) + for i := range eventNIDs { + nids[i] = int64(eventNIDs[i]) + } + rows, err := s.bulkSelectEventJSONStmt.Query(pq.Int64Array(nids)) if err != nil { return nil, err } @@ -488,9 +564,223 @@ func (s *statements) bulkSelectEventJSON(eventNIDs []int64) ([]eventJSONPair, er results := make([]eventJSONPair, len(eventNIDs)) i := 0 for ; rows.Next(); i++ { - if err := rows.Scan(&results[i].EventNID, &results[i].EventJSON); err != nil { + result := &results[i] + var eventNID int64 + if err := rows.Scan(&eventNID, &result.EventJSON); err != nil { return nil, err } + result.EventNID = types.EventNID(eventNID) } return results[:i], nil } + +const stateSchema = ` +-- The state of a room before an event. +-- Stored as a list of state_block entries stored in a separate table. +-- The actual state is constructed by combining all the state_block entries +-- referenced by state_block_nids together. If the same state key tuple appears +-- multiple times then the entry from the later state_block clobbers the earlier +-- entries. +-- This encoding format allows us to implement a delta encoding which is useful +-- because room state tends to accumulate small changes over time. Although if +-- the list of deltas becomes too long it becomes more efficient to encode +-- the full state under single state_block_nid. +CREATE SEQUENCE IF NOT EXISTS state_snapshot_nid_seq; +CREATE TABLE IF NOT EXISTS state_snapshots ( + -- Local numeric ID for the state. + state_snapshot_nid bigint PRIMARY KEY DEFAULT nextval('state_snapshot_nid_seq'), + -- Local numeric ID of the room this state is for. + -- Unused in normal operation, but useful for background work or ad-hoc debugging. + room_nid bigint NOT NULL, + -- List of state_block_nids, stored sorted by state_block_nid. + state_block_nids bigint[] NOT NULL +); +` + +const insertStateSQL = "" + + "INSERT INTO state_snapshots (room_nid, state_block_nids)" + + " VALUES ($1, $2)" + + " RETURNING state_snapshot_nid" + +// Bulk state data NID lookup. +// Sorting by state_snapshot_nid means we can use binary search over the result +// to lookup the state data NIDs for a state snapshot NID. +const bulkSelectStateBlockNIDsSQL = "" + + "SELECT state_snapshot_nid, state_block_nids FROM state_snapshots" + + " WHERE state_snapshot_nid = ANY($1) ORDER BY state_snapshot_nid ASC" + +func (s *statements) prepareState(db *sql.DB) (err error) { + _, err = db.Exec(stateSchema) + if err != nil { + return + } + if s.insertStateStmt, err = db.Prepare(insertStateSQL); err != nil { + return + } + if s.bulkSelectStateBlockNIDsStmt, err = db.Prepare(bulkSelectStateBlockNIDsSQL); err != nil { + return + } + return +} + +func (s *statements) insertState(roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID) (stateNID types.StateSnapshotNID, err error) { + nids := make([]int64, len(stateBlockNIDs)) + for i := range stateBlockNIDs { + nids[i] = int64(stateBlockNIDs[i]) + } + err = s.insertStateStmt.QueryRow(int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID) + return +} + +func (s *statements) bulkSelectStateBlockNIDs(stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) { + nids := make([]int64, len(stateNIDs)) + for i := range stateNIDs { + nids[i] = int64(stateNIDs[i]) + } + rows, err := s.bulkSelectStateBlockNIDsStmt.Query(pq.Int64Array(nids)) + if err != nil { + return nil, err + } + defer rows.Close() + results := make([]types.StateBlockNIDList, len(stateNIDs)) + i := 0 + for ; rows.Next(); i++ { + result := &results[i] + var stateBlockNIDs pq.Int64Array + if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil { + return nil, err + } + result.StateBlockNIDs = make([]types.StateBlockNID, len(stateBlockNIDs)) + for k := range stateBlockNIDs { + result.StateBlockNIDs[k] = types.StateBlockNID(stateBlockNIDs[k]) + } + } + if i != len(stateNIDs) { + return nil, fmt.Errorf("storage: state NIDs missing from the database (%d != %d)", i, len(stateNIDs)) + } + return results, nil +} + +const stateDataSchema = ` +-- The state data map. +-- Designed to give enough information to run the state resolution algorithm +-- without hitting the database in the common case. +-- TODO: Is it worth replacing the unique btree index with a covering index so +-- that postgres could lookup the state using an index-only scan? +-- The type and state_key are included in the index to make it easier to +-- lookup a specific (type, state_key) pair for an event. It also makes it easy +-- to read the state for a given state_block_nid ordered by (type, state_key) +-- which in turn makes it easier to merge state data blocks. +CREATE SEQUENCE IF NOT EXISTS state_block_nid_seq; +CREATE TABLE IF NOT EXISTS state_block ( + -- Local numeric ID for this state data. + state_block_nid bigint NOT NULL, + event_type_nid bigint NOT NULL, + event_state_key_nid bigint NOT NULL, + event_nid bigint NOT NULL, + UNIQUE (state_block_nid, event_type_nid, event_state_key_nid) +); +` + +const insertStateDataSQL = "" + + "INSERT INTO state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" + + " VALUES ($1, $2, $3, $4)" + +const selectNextStateBlockNIDSQL = "" + + "SELECT nextval('state_block_nid_seq')" + +// Bulk state lookup by numeric event ID. +// Sort by the state_block_nid, event_type_nid, event_state_key_nid +// This means that all the entries for a given state_block_nid will appear +// together in the list and those entries will sorted by event_type_nid +// and event_state_key_nid. This property makes it easier to merge two +// state data blocks together. +const bulkSelectStateDataEntriesSQL = "" + + "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + + " FROM state_block WHERE state_block_nid = ANY($1)" + + " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" + +func (s *statements) prepareStateData(db *sql.DB) (err error) { + _, err = db.Exec(stateDataSchema) + if err != nil { + return + } + if s.insertStateDataStmt, err = db.Prepare(insertStateDataSQL); err != nil { + return + } + if s.selectNextStateBlockNIDStmt, err = db.Prepare(selectNextStateBlockNIDSQL); err != nil { + return + } + + if s.bulkSelectStateDataEntriesStmt, err = db.Prepare(bulkSelectStateDataEntriesSQL); err != nil { + return + } + return +} + +func (s *statements) bulkInsertStateData(stateBlockNID types.StateBlockNID, entries []types.StateEntry) error { + for _, entry := range entries { + _, err := s.insertStateDataStmt.Exec( + int64(stateBlockNID), + int64(entry.EventTypeNID), + int64(entry.EventStateKeyNID), + int64(entry.EventNID), + ) + if err != nil { + return err + } + } + return nil +} + +func (s *statements) selectNextStateBlockNID() (types.StateBlockNID, error) { + var stateBlockNID int64 + err := s.selectNextStateBlockNIDStmt.QueryRow().Scan(&stateBlockNID) + return types.StateBlockNID(stateBlockNID), err +} + +func (s *statements) bulkSelectStateDataEntries(stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) { + nids := make([]int64, len(stateBlockNIDs)) + for i := range stateBlockNIDs { + nids[i] = int64(stateBlockNIDs[i]) + } + rows, err := s.bulkSelectStateDataEntriesStmt.Query(pq.Int64Array(nids)) + if err != nil { + return nil, err + } + defer rows.Close() + + results := make([]types.StateEntryList, len(stateBlockNIDs)) + // current is a pointer to the StateEntryList to append the state entries to. + var current *types.StateEntryList + i := 0 + for rows.Next() { + var ( + stateBlockNID int64 + eventTypeNID int64 + eventStateKeyNID int64 + eventNID int64 + entry types.StateEntry + ) + if err := rows.Scan( + &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID, + ); err != nil { + return nil, err + } + entry.EventTypeNID = types.EventTypeNID(eventTypeNID) + entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) + entry.EventNID = types.EventNID(eventNID) + if current == nil || types.StateBlockNID(stateBlockNID) != current.StateBlockNID { + // The state entry row is for a different state data block to the current one. + // So we start appending to the next entry in the list. + current = &results[i] + current.StateBlockNID = types.StateBlockNID(stateBlockNID) + i++ + } + current.StateEntries = append(current.StateEntries, entry) + } + if i != len(stateBlockNIDs) { + return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(stateBlockNIDs)) + } + return results, nil +} diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go b/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go index d42c1f191..6162fcb7e 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go @@ -38,21 +38,22 @@ func (d *Database) SetPartitionOffset(topic string, partition int32, offset int6 } // StoreEvent implements input.EventDatabase -func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []int64) error { +func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []types.EventNID) (types.RoomNID, types.StateAtEvent, error) { var ( - roomNID int64 - eventTypeNID int64 - eventStateKeyNID int64 - eventNID int64 + roomNID types.RoomNID + eventTypeNID types.EventTypeNID + eventStateKeyNID types.EventStateKeyNID + eventNID types.EventNID + stateNID types.StateSnapshotNID err error ) if roomNID, err = d.assignRoomNID(event.RoomID()); err != nil { - return err + return 0, types.StateAtEvent{}, err } if eventTypeNID, err = d.assignEventTypeNID(event.Type()); err != nil { - return err + return 0, types.StateAtEvent{}, err } eventStateKey := event.StateKey() @@ -60,11 +61,11 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []int // Otherwise set the numeric ID for the state_key to 0. if eventStateKey != nil { if eventStateKeyNID, err = d.assignStateKeyNID(*eventStateKey); err != nil { - return err + return 0, types.StateAtEvent{}, err } } - if eventNID, err = d.statements.insertEvent( + if eventNID, stateNID, err = d.statements.insertEvent( roomNID, eventTypeNID, eventStateKeyNID, @@ -72,13 +73,26 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []int event.EventReference().EventSHA256, authEventNIDs, ); err != nil { - return err + return 0, types.StateAtEvent{}, err } - return d.statements.insertEventJSON(eventNID, event.JSON()) + if err = d.statements.insertEventJSON(eventNID, event.JSON()); err != nil { + return 0, types.StateAtEvent{}, err + } + + return roomNID, types.StateAtEvent{ + BeforeStateSnapshotNID: stateNID, + StateEntry: types.StateEntry{ + StateKeyTuple: types.StateKeyTuple{ + EventTypeNID: eventTypeNID, + EventStateKeyNID: eventStateKeyNID, + }, + EventNID: eventNID, + }, + }, nil } -func (d *Database) assignRoomNID(roomID string) (int64, error) { +func (d *Database) assignRoomNID(roomID string) (types.RoomNID, error) { // Check if we already have a numeric ID in the database. roomNID, err := d.statements.selectRoomNID(roomID) if err == sql.ErrNoRows { @@ -91,7 +105,7 @@ func (d *Database) assignRoomNID(roomID string) (int64, error) { return roomNID, nil } -func (d *Database) assignEventTypeNID(eventType string) (int64, error) { +func (d *Database) assignEventTypeNID(eventType string) (types.EventTypeNID, error) { // Check if we already have a numeric ID in the database. eventTypeNID, err := d.statements.selectEventTypeNID(eventType) if err == sql.ErrNoRows { @@ -104,7 +118,7 @@ func (d *Database) assignEventTypeNID(eventType string) (int64, error) { return eventTypeNID, nil } -func (d *Database) assignStateKeyNID(eventStateKey string) (int64, error) { +func (d *Database) assignStateKeyNID(eventStateKey string) (types.EventStateKeyNID, error) { // Check if we already have a numeric ID in the database. eventStateKeyNID, err := d.statements.selectEventStateKeyNID(eventStateKey) if err == sql.ErrNoRows { @@ -123,12 +137,12 @@ func (d *Database) StateEntriesForEventIDs(eventIDs []string) ([]types.StateEntr } // EventStateKeyNIDs implements input.EventDatabase -func (d *Database) EventStateKeyNIDs(eventStateKeys []string) (map[string]int64, error) { +func (d *Database) EventStateKeyNIDs(eventStateKeys []string) (map[string]types.EventStateKeyNID, error) { return d.statements.bulkSelectEventStateKeyNID(eventStateKeys) } // Events implements input.EventDatabase -func (d *Database) Events(eventNIDs []int64) ([]types.Event, error) { +func (d *Database) Events(eventNIDs []types.EventNID) ([]types.Event, error) { eventJSONs, err := d.statements.bulkSelectEventJSON(eventNIDs) if err != nil { return nil, err @@ -145,3 +159,39 @@ func (d *Database) Events(eventNIDs []int64) ([]types.Event, error) { } return results, nil } + +// AddState implements input.EventDatabase +func (d *Database) AddState(roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) { + if len(state) > 0 { + stateBlockNID, err := d.statements.selectNextStateBlockNID() + if err != nil { + return 0, err + } + if err = d.statements.bulkInsertStateData(stateBlockNID, state); err != nil { + return 0, err + } + stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID) + } + + return d.statements.insertState(roomNID, stateBlockNIDs) +} + +// SetState implements input.EventDatabase +func (d *Database) SetState(eventNID types.EventNID, stateNID types.StateSnapshotNID) error { + return d.statements.updateEventState(eventNID, stateNID) +} + +// StateAtEventIDs implements input.EventDatabase +func (d *Database) StateAtEventIDs(eventIDs []string) ([]types.StateAtEvent, error) { + return d.statements.bulkSelectStateAtEventByID(eventIDs) +} + +// StateBlockNIDs implements input.EventDatabase +func (d *Database) StateBlockNIDs(stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) { + return d.statements.bulkSelectStateBlockNIDs(stateNIDs) +} + +// StateEntries implements input.EventDatabase +func (d *Database) StateEntries(stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) { + return d.statements.bulkSelectStateDataEntries(stateBlockNIDs) +} diff --git a/src/github.com/matrix-org/dendrite/roomserver/types/types.go b/src/github.com/matrix-org/dendrite/roomserver/types/types.go index 0c43baf89..000238495 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/types/types.go +++ b/src/github.com/matrix-org/dendrite/roomserver/types/types.go @@ -13,13 +13,32 @@ type PartitionOffset struct { Offset int64 } +// EventTypeNID is a numeric ID for an event type. +type EventTypeNID int64 + +// EventStateKeyNID is a numeric ID for an event state_key. +type EventStateKeyNID int64 + +// EventNID is a numeric ID for an event. +type EventNID int64 + +// RoomNID is a numeric ID for a room. +type RoomNID int64 + +// StateSnapshotNID is a numeric ID for the state at an event. +type StateSnapshotNID int64 + +// StateBlockNID is a numeric ID for a block of state data. +// These blocks of state data are combined to form the actual state. +type StateBlockNID int64 + // A StateKeyTuple is a pair of a numeric event type and a numeric state key. // It is used to lookup state entries. type StateKeyTuple struct { // The numeric ID for the event type. - EventTypeNID int64 + EventTypeNID EventTypeNID // The numeric ID for the state key. - EventStateKeyNID int64 + EventStateKeyNID EventStateKeyNID } // LessThan returns true if this state key is less than the other state key. @@ -35,7 +54,7 @@ func (a StateKeyTuple) LessThan(b StateKeyTuple) bool { type StateEntry struct { StateKeyTuple // The numeric ID for the event. - EventNID int64 + EventNID EventNID } // LessThan returns true if this state entry is less than the other state entry. @@ -47,10 +66,23 @@ func (a StateEntry) LessThan(b StateEntry) bool { return a.EventNID < b.EventNID } +// StateAtEvent is the state before and after a matrix event. +type StateAtEvent struct { + // The state before the event. + BeforeStateSnapshotNID StateSnapshotNID + // The state entry for the event itself, allows us to calculate the state after the event. + StateEntry +} + +// IsStateEvent returns whether the event the state is at is a state event. +func (s StateAtEvent) IsStateEvent() bool { + return s.EventStateKeyNID != 0 +} + // An Event is a gomatrixserverlib.Event with the numeric event ID attached. // It is when performing bulk event lookup in the database. type Event struct { - EventNID int64 + EventNID EventNID gomatrixserverlib.Event } @@ -75,3 +107,15 @@ const ( // EmptyStateKeyNID is the numeric ID for the empty state key. EmptyStateKeyNID = 1 ) + +// StateBlockNIDList is used to return the result of bulk StateBlockNID lookups from the database. +type StateBlockNIDList struct { + StateSnapshotNID StateSnapshotNID + StateBlockNIDs []StateBlockNID +} + +// StateEntryList is used to return the result of bulk state entry lookups from the database. +type StateEntryList struct { + StateBlockNID StateBlockNID + StateEntries []StateEntry +}