Calculate and store the state at each event (#6)

* Calculate and store the state at each event

* Use type aliases for numeric IDs
This commit is contained in:
Mark Haines 2017-02-15 11:05:45 +00:00 committed by GitHub
parent 832f604b94
commit 39264cbf4b
9 changed files with 935 additions and 78 deletions

View file

@ -0,0 +1,59 @@
# RoomServer
## RoomServer Internals
### Numeric IDs
To save space matrix string identifiers are mapped to local numeric IDs.
The numeric IDs are more efficient to manipulate and use less space to store.
The numeric IDs are never exposed in the API the room server exposes.
The numeric IDs are converted to string IDs before they leave the room server.
The numeric ID for a string ID is never 0 to avoid being confused with go's
default zero value.
Zero is used to indicate that there was no corresponding string ID.
Well-known event types and event state keys are preassigned numeric IDs.
### State Snapshot Storage
The room server stores the state of the matrix room at each event.
For efficiency the state is stored as blocks of 3-tuples of numeric IDs for the
event type, event state key and event ID. For further efficiency the state
snapshots are stored as the combination of up to 64 these blocks. This allows
blocks of the room state to be reused in multiple snapshots.
The resulting database tables look something like this:
+-------------------------------------------------------------------+
| Events |
+---------+-------------------+------------------+------------------+
| EventNID| EventTypeNID | EventStateKeyNID | StateSnapshotNID |
+---------+-------------------+------------------+------------------+
| 1 | m.room.create 1 | "" 1 | <nil> 0 |
| 2 | m.room.member 2 | "@user:foo" 2 | <nil> 0 |
| 3 | m.room.member 2 | "@user:bar" 3 | {1,2} 1 |
| 4 | m.room.message 3 | <nil> 0 | {1,2,3} 2 |
| 5 | m.room.member 2 | "@user:foo" 2 | {1,2,3} 2 |
| 6 | m.room.message 3 | <nil> 0 | {1,3,6} 3 |
+---------+-------------------+------------------+------------------+
+----------------------------------------+
| State Snapshots |
+-----------------------+----------------+
| EventStateSnapshotNID | StateBlockNIDs |
+-----------------------+----------------|
| 1 | {1} |
| 2 | {1,2} |
| 3 | {1,2,3} |
+-----------------------+----------------+
+-----------------------------------------------------------------+
| State Blocks |
+---------------+-------------------+------------------+----------+
| StateBlockNID | EventTypeNID | EventStateKeyNID | EventNID |
+---------------+-------------------+------------------+----------+
| 1 | m.room.create 1 | "" 1 | 1 |
| 1 | m.room.member 2 | "@user:foo" 2 | 2 |
| 2 | m.room.member 2 | "@user:bar" 3 | 3 |
| 3 | m.room.member 2 | "@user:foo" 2 | 6 |
+---------------+-------------------+------------------+----------+

View file

@ -8,7 +8,7 @@ import (
// checkAuthEvents checks that the event passes authentication checks // checkAuthEvents checks that the event passes authentication checks
// Returns the numeric IDs for the auth events. // Returns the numeric IDs for the auth events.
func checkAuthEvents(db RoomEventDatabase, event gomatrixserverlib.Event, authEventIDs []string) ([]int64, error) { func checkAuthEvents(db RoomEventDatabase, event gomatrixserverlib.Event, authEventIDs []string) ([]types.EventNID, error) {
// Grab the numeric IDs for the supplied auth state events from the database. // Grab the numeric IDs for the supplied auth state events from the database.
authStateEntries, err := db.StateEntriesForEventIDs(authEventIDs) authStateEntries, err := db.StateEntriesForEventIDs(authEventIDs)
if err != nil { if err != nil {
@ -31,7 +31,7 @@ func checkAuthEvents(db RoomEventDatabase, event gomatrixserverlib.Event, authEv
} }
// Return the numeric IDs for the auth events. // Return the numeric IDs for the auth events.
result := make([]int64, len(authStateEntries)) result := make([]types.EventNID, len(authStateEntries))
for i := range authStateEntries { for i := range authStateEntries {
result[i] = authStateEntries[i].EventNID result[i] = authStateEntries[i].EventNID
} }
@ -39,7 +39,7 @@ func checkAuthEvents(db RoomEventDatabase, event gomatrixserverlib.Event, authEv
} }
type authEvents struct { type authEvents struct {
stateKeyNIDMap map[string]int64 stateKeyNIDMap map[string]types.EventStateKeyNID
state stateEntryMap state stateEntryMap
events eventMap events eventMap
} }
@ -69,7 +69,7 @@ func (ae *authEvents) ThirdPartyInvite(stateKey string) (*gomatrixserverlib.Even
return ae.lookupEvent(types.MRoomThirdPartyInviteNID, stateKey), nil return ae.lookupEvent(types.MRoomThirdPartyInviteNID, stateKey), nil
} }
func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID int64) *gomatrixserverlib.Event { func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID types.EventTypeNID) *gomatrixserverlib.Event {
eventNID, ok := ae.state.lookup(types.StateKeyTuple{typeNID, types.EmptyStateKeyNID}) eventNID, ok := ae.state.lookup(types.StateKeyTuple{typeNID, types.EmptyStateKeyNID})
if !ok { if !ok {
return nil return nil
@ -81,7 +81,7 @@ func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID int64) *gomatrixserve
return &event.Event return &event.Event
} }
func (ae *authEvents) lookupEvent(typeNID int64, stateKey string) *gomatrixserverlib.Event { func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) *gomatrixserverlib.Event {
stateKeyNID, ok := ae.stateKeyNIDMap[stateKey] stateKeyNID, ok := ae.stateKeyNIDMap[stateKey]
if !ok { if !ok {
return nil return nil
@ -113,7 +113,7 @@ func loadAuthEvents(
// Load the events we need. // Load the events we need.
result.state = state result.state = state
var eventNIDs []int64 var eventNIDs []types.EventNID
keyTuplesNeeded := stateKeyTuplesNeeded(result.stateKeyNIDMap, needed) keyTuplesNeeded := stateKeyTuplesNeeded(result.stateKeyNIDMap, needed)
for _, keyTuple := range keyTuplesNeeded { for _, keyTuple := range keyTuplesNeeded {
eventNID, ok := result.state.lookup(keyTuple) eventNID, ok := result.state.lookup(keyTuple)
@ -128,7 +128,7 @@ func loadAuthEvents(
} }
// stateKeyTuplesNeeded works out which numeric state key tuples we need to authenticate some events. // stateKeyTuplesNeeded works out which numeric state key tuples we need to authenticate some events.
func stateKeyTuplesNeeded(stateKeyNIDMap map[string]int64, stateNeeded gomatrixserverlib.StateNeeded) []types.StateKeyTuple { func stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.EventStateKeyNID, stateNeeded gomatrixserverlib.StateNeeded) []types.StateKeyTuple {
var keyTuples []types.StateKeyTuple var keyTuples []types.StateKeyTuple
if stateNeeded.Create { if stateNeeded.Create {
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomCreateNID, types.EmptyStateKeyNID}) keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomCreateNID, types.EmptyStateKeyNID})
@ -159,7 +159,7 @@ func stateKeyTuplesNeeded(stateKeyNIDMap map[string]int64, stateNeeded gomatrixs
type stateEntryMap []types.StateEntry type stateEntryMap []types.StateEntry
// lookup an entry in the event map. // lookup an entry in the event map.
func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID int64, ok bool) { func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID types.EventNID, ok bool) {
// Since the list is sorted we can implement this using binary search. // Since the list is sorted we can implement this using binary search.
// This is faster than using a hash map. // This is faster than using a hash map.
// We don't have to worry about pathological cases because the keys are fixed // We don't have to worry about pathological cases because the keys are fixed
@ -180,7 +180,7 @@ func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID int64, ok
type eventMap []types.Event type eventMap []types.Event
// lookup an entry in the event map. // lookup an entry in the event map.
func (m eventMap) lookup(eventNID int64) (event *types.Event, ok bool) { func (m eventMap) lookup(eventNID types.EventNID) (event *types.Event, ok bool) {
// Since the list is sorted we can implement this using binary search. // Since the list is sorted we can implement this using binary search.
// This is faster than using a hash map. // This is faster than using a hash map.
// We don't have to worry about pathological cases because the keys are fixed // We don't have to worry about pathological cases because the keys are fixed

View file

@ -8,13 +8,18 @@ import (
func benchmarkStateEntryMapLookup(entries, lookups int64, b *testing.B) { func benchmarkStateEntryMapLookup(entries, lookups int64, b *testing.B) {
var list []types.StateEntry var list []types.StateEntry
for i := int64(0); i < entries; i++ { for i := int64(0); i < entries; i++ {
list = append(list, types.StateEntry{types.StateKeyTuple{i, i}, i}) list = append(list, types.StateEntry{types.StateKeyTuple{
types.EventTypeNID(i),
types.EventStateKeyNID(i),
}, types.EventNID(i)})
} }
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
entryMap := stateEntryMap(list) entryMap := stateEntryMap(list)
for j := int64(0); j < lookups; j++ { for j := int64(0); j < lookups; j++ {
entryMap.lookup(types.StateKeyTuple{j, j}) entryMap.lookup(types.StateKeyTuple{
types.EventTypeNID(j), types.EventStateKeyNID(j),
})
} }
} }
} }
@ -43,10 +48,10 @@ func TestStateEntryMap(t *testing.T) {
}) })
testCases := []struct { testCases := []struct {
inputTypeNID int64 inputTypeNID types.EventTypeNID
inputStateKey int64 inputStateKey types.EventStateKeyNID
wantOK bool wantOK bool
wantEventNID int64 wantEventNID types.EventNID
}{ }{
// Check that tuples that in the array are in the map. // Check that tuples that in the array are in the map.
{1, 1, true, 1}, {1, 1, true, 1},
@ -80,7 +85,7 @@ func TestEventMap(t *testing.T) {
}) })
testCases := []struct { testCases := []struct {
inputEventNID int64 inputEventNID types.EventNID
wantOK bool wantOK bool
wantEvent *types.Event wantEvent *types.Event
}{ }{

View file

@ -9,18 +9,31 @@ import (
// A RoomEventDatabase has the storage APIs needed to store a room event. // A RoomEventDatabase has the storage APIs needed to store a room event.
type RoomEventDatabase interface { type RoomEventDatabase interface {
// Stores a matrix room event in the database // Stores a matrix room event in the database
StoreEvent(event gomatrixserverlib.Event, authEventNIDs []int64) error StoreEvent(event gomatrixserverlib.Event, authEventNIDs []types.EventNID) (types.RoomNID, types.StateAtEvent, error)
// Lookup the state entries for a list of string event IDs // Lookup the state entries for a list of string event IDs
// Returns a sorted list of state entries. // Returns an error if the there is an error talking to the database
// Returns a error if the there is an error talking to the database
// or if the event IDs aren't in the database. // or if the event IDs aren't in the database.
StateEntriesForEventIDs(eventIDs []string) ([]types.StateEntry, error) StateEntriesForEventIDs(eventIDs []string) ([]types.StateEntry, error)
// Lookup the numeric IDs for a list of string event state keys. // Lookup the numeric IDs for a list of string event state keys.
// Returns a map from string state key to numeric ID for the state key. // Returns a map from string state key to numeric ID for the state key.
EventStateKeyNIDs(eventStateKeys []string) (map[string]int64, error) EventStateKeyNIDs(eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
// Lookup the Events for a list of numeric event IDs. // Lookup the Events for a list of numeric event IDs.
// Returns a sorted list of events. // Returns a sorted list of events.
Events(eventNIDs []int64) ([]types.Event, error) Events(eventNIDs []types.EventNID) ([]types.Event, error)
// Lookup the state of a room at each event for a list of string event IDs.
// Returns an error if there is an error talking to the database
// or if the room state for the event IDs aren't in the database
StateAtEventIDs(eventIDs []string) ([]types.StateAtEvent, error)
// Lookup the numeric state data IDs for each numeric state snapshot ID
// The returned slice is sorted by numeric state snapshot ID.
StateBlockNIDs(stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
// Lookup the state data for each numeric state data ID
// The returned slice is sorted by numeric state data ID.
StateEntries(stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error)
// Store the room state at an event in the database
AddState(roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
// Set the state at an event.
SetState(eventNID types.EventNID, stateNID types.StateSnapshotNID) error
} }
func processRoomEvent(db RoomEventDatabase, input api.InputRoomEvent) error { func processRoomEvent(db RoomEventDatabase, input api.InputRoomEvent) error {
@ -37,7 +50,8 @@ func processRoomEvent(db RoomEventDatabase, input api.InputRoomEvent) error {
} }
// Store the event // Store the event
if err := db.StoreEvent(event, authEventNIDs); err != nil { roomNID, stateAtEvent, err := db.StoreEvent(event, authEventNIDs)
if err != nil {
return err return err
} }
@ -48,6 +62,29 @@ func processRoomEvent(db RoomEventDatabase, input api.InputRoomEvent) error {
return nil return nil
} }
if stateAtEvent.BeforeStateSnapshotNID == 0 {
// We haven't calculated a state for this event yet.
// Lets calculate one.
if input.StateEventIDs != nil {
// We've been told what the state at the event is so we don't need to calculate it.
// Check that those state events are in the database and store the state.
entries, err := db.StateEntriesForEventIDs(input.StateEventIDs)
if err != nil {
return err
}
if stateAtEvent.BeforeStateSnapshotNID, err = db.AddState(roomNID, nil, entries); err != nil {
return nil
}
} else {
// We haven't been told what the state at the event is so we need to calculate it from the prev_events
if stateAtEvent.BeforeStateSnapshotNID, err = calculateAndStoreState(db, event, roomNID); err != nil {
return err
}
}
db.SetState(stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID)
}
// TODO: // TODO:
// * Calcuate the state at the event if necessary. // * Calcuate the state at the event if necessary.
// * Store the state at the event. // * Store the state at the event.

View file

@ -0,0 +1,305 @@
package input
import (
"fmt"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"sort"
)
// calculateAndStoreState calculates a snapshot of the state of a room before an event.
// Stores the snapshot of the state in the database.
// Returns a numeric ID for that snapshot.
func calculateAndStoreState(
db RoomEventDatabase, event gomatrixserverlib.Event, roomNID types.RoomNID,
) (types.StateSnapshotNID, error) {
// Load the state at the prev events.
prevEventRefs := event.PrevEvents()
prevEventIDs := make([]string, len(prevEventRefs))
for i := range prevEventRefs {
prevEventIDs[i] = prevEventRefs[i].EventID
}
prevStates, err := db.StateAtEventIDs(prevEventIDs)
if err != nil {
return 0, err
}
if len(prevStates) == 0 {
// 2) There weren't any prev_events for this event so the state is
// empty.
return db.AddState(roomNID, nil, nil)
}
if len(prevStates) == 1 {
prevState := prevStates[0]
if prevState.EventStateKeyNID == 0 {
// 3) None of the previous events were state events and they all
// have the same state, so this event has exactly the same state
// as the previous events.
// This should be the common case.
return prevState.BeforeStateSnapshotNID, nil
}
// The previous event was a state event so we need to store a copy
// of the previous state updated with that event.
stateBlockNIDLists, err := db.StateBlockNIDs([]types.StateSnapshotNID{prevState.BeforeStateSnapshotNID})
if err != nil {
return 0, err
}
stateBlockNIDs := stateBlockNIDLists[0].StateBlockNIDs
if len(stateBlockNIDs) < maxStateBlockNIDs {
// 4) The number of state data blocks is small enough that we can just
// add the state event as a block of size one to the end of the blocks.
return db.AddState(
roomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry},
)
}
// If there are too many deltas then we need to calculate the full state
// So fall through to calculateAndStoreStateMany
}
return calculateAndStoreStateMany(db, roomNID, prevStates)
}
// maxStateBlockNIDs is the maximum number of state data blocks to use to encode a snapshot of room state.
// Increasing this number means that we can encode more of the state changes as simple deltas which means that
// we need fewer entries in the state data table. However making this number bigger will increase the size of
// the rows in the state table itself and will require more index lookups when retrieving a snapshot.
// TODO: Tune this to get the right balance between size and lookup performance.
const maxStateBlockNIDs = 64
// calculateAndStoreStateMany calculates the state of the room before an event
// using the states at each of the event's prev events.
// Stores the resulting state and returns a numeric ID for the snapshot.
func calculateAndStoreStateMany(db RoomEventDatabase, roomNID types.RoomNID, prevStates []types.StateAtEvent) (types.StateSnapshotNID, error) {
// Conflict resolution.
// First stage: load the state after each of the prev events.
combined, err := loadCombinedStateAfterEvents(db, prevStates)
if err != nil {
return 0, err
}
// Collect all the entries with the same type and key together.
// We don't care about the order here because the conflict resolution
// algorithm doesn't depend on the order of the prev events.
sort.Sort(stateEntrySorter(combined))
// Remove duplicate entires.
combined = combined[:unique(stateEntrySorter(combined))]
// Find the conflicts
conflicts := findDuplicateStateKeys(combined)
var state []types.StateEntry
if len(conflicts) > 0 {
// 5) There are conflicting state events, for each conflict workout
// what the appropriate state event is.
resolved, err := resolveConflicts(db, combined, conflicts)
if err != nil {
return 0, err
}
state = resolved
} else {
// 6) There weren't any conflicts
state = combined
}
// TODO: Check if we can encode the new state as a delta against the
// previous state.
return db.AddState(roomNID, nil, state)
}
// loadCombinedStateAfterEvents loads a snapshot of the state after each of the events
// and combines those snapshots together into a single list.
func loadCombinedStateAfterEvents(db RoomEventDatabase, prevStates []types.StateAtEvent) ([]types.StateEntry, error) {
stateNIDs := make([]types.StateSnapshotNID, len(prevStates))
for i, state := range prevStates {
stateNIDs[i] = state.BeforeStateSnapshotNID
}
// Fetch the state snapshots for the state before the each prev event from the database.
// Deduplicate the IDs before passing them to the database.
// There could be duplicates because the events could be state events where
// the snapshot of the room state before them was the same.
stateBlockNIDLists, err := db.StateBlockNIDs(uniqueStateSnapshotNIDs(stateNIDs))
if err != nil {
return nil, err
}
var stateBlockNIDs []types.StateBlockNID
for _, list := range stateBlockNIDLists {
stateBlockNIDs = append(stateBlockNIDs, list.StateBlockNIDs...)
}
// Fetch the state entries that will be combined to create the snapshots.
// Deduplicate the IDs before passing them to the database.
// There could be duplicates because a block of state entries could be reused by
// multiple snapshots.
stateEntryLists, err := db.StateEntries(uniqueStateBlockNIDs(stateBlockNIDs))
if err != nil {
return nil, err
}
stateBlockNIDsMap := stateBlockNIDListMap(stateBlockNIDLists)
stateEntriesMap := stateEntryListMap(stateEntryLists)
// Combine the entries from all the snapshots of state after each prev event into a single list.
var combined []types.StateEntry
for _, prevState := range prevStates {
// Grab the list of state data NIDs for this snapshot.
stateBlockNIDs, ok := stateBlockNIDsMap.lookup(prevState.BeforeStateSnapshotNID)
if !ok {
// This should only get hit if the database is corrupt.
// It should be impossible for an event to reference a NID that doesn't exist
panic(fmt.Errorf("Corrupt DB: Missing state numeric ID %d", prevState.BeforeStateSnapshotNID))
}
// Combined all the state entries for this snapshot.
// The order of state data NIDs in the list tells us the order to combine them in.
var fullState []types.StateEntry
for _, stateBlockNID := range stateBlockNIDs {
entries, ok := stateEntriesMap.lookup(stateBlockNID)
if !ok {
// This should only get hit if the database is corrupt.
// It should be impossible for an event to reference a NID that doesn't exist
panic(fmt.Errorf("Corrupt DB: Missing state numeric ID %d", prevState.BeforeStateSnapshotNID))
}
fullState = append(fullState, entries...)
}
if prevState.IsStateEvent() {
// If the prev event was a state event then add an entry for the event itself
// so that we get the state after the event rather than the state before.
fullState = append(fullState, prevState.StateEntry)
}
// Stable sort so that the most recent entry for each state key stays
// remains later in the list than the older entries for the same state key.
sort.Stable(stateEntryByStateKeySorter(fullState))
// Unique returns the last entry and hence the most recent entry for each state key.
fullState = fullState[:unique(stateEntryByStateKeySorter(fullState))]
// Add the full state for this StateSnapshotNID.
combined = append(combined, fullState...)
}
return combined, nil
}
func resolveConflicts(db RoomEventDatabase, combined, conflicted []types.StateEntry) ([]types.StateEntry, error) {
panic(fmt.Errorf("Not implemented"))
}
// findDuplicateStateKeys finds the state entries where the state key tuple appears more than once in a sorted list.
// Returns a sorted list of those state entries.
func findDuplicateStateKeys(a []types.StateEntry) []types.StateEntry {
var result []types.StateEntry
// j is the starting index of a block of entries with the same state key tuple.
j := 0
for i := 1; i < len(a); i++ {
// Check if the state key tuple matches the start of the block
if a[j].StateKeyTuple != a[i].StateKeyTuple {
// If the state key tuple is different then we've reached the end of a block of duplicates.
// Check if the size of the block is bigger than one.
// If the size is one then there was only a single entry with that state key tuple so we don't add it to the result
if j+1 != i {
// Add the block to the result.
result = append(result, a[j:i]...)
}
// Start a new block for the next state key tuple.
j = i
}
}
// Check if the last block with the same state key tuple had more than one event in it.
if j+1 != len(a) {
result = append(result, a[j:]...)
}
return result
}
type stateBlockNIDListMap []types.StateBlockNIDList
func (m stateBlockNIDListMap) lookup(stateNID types.StateSnapshotNID) (stateBlockNIDs []types.StateBlockNID, ok bool) {
list := []types.StateBlockNIDList(m)
i := sort.Search(len(list), func(i int) bool {
return list[i].StateSnapshotNID >= stateNID
})
if i < len(list) && list[i].StateSnapshotNID == stateNID {
ok = true
stateBlockNIDs = list[i].StateBlockNIDs
}
return
}
type stateEntryListMap []types.StateEntryList
func (m stateEntryListMap) lookup(stateBlockNID types.StateBlockNID) (stateEntries []types.StateEntry, ok bool) {
list := []types.StateEntryList(m)
i := sort.Search(len(list), func(i int) bool {
return list[i].StateBlockNID >= stateBlockNID
})
if i < len(list) && list[i].StateBlockNID == stateBlockNID {
ok = true
stateEntries = list[i].StateEntries
}
return
}
type stateEntryByStateKeySorter []types.StateEntry
func (s stateEntryByStateKeySorter) Len() int { return len(s) }
func (s stateEntryByStateKeySorter) Less(i, j int) bool {
return s[i].StateKeyTuple.LessThan(s[j].StateKeyTuple)
}
func (s stateEntryByStateKeySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
type stateEntrySorter []types.StateEntry
func (s stateEntrySorter) Len() int { return len(s) }
func (s stateEntrySorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
func (s stateEntrySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
type stateNIDSorter []types.StateSnapshotNID
func (s stateNIDSorter) Len() int { return len(s) }
func (s stateNIDSorter) Less(i, j int) bool { return s[i] < s[j] }
func (s stateNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func uniqueStateSnapshotNIDs(nids []types.StateSnapshotNID) []types.StateSnapshotNID {
sort.Sort(stateNIDSorter(nids))
return nids[:unique(stateNIDSorter(nids))]
}
type stateBlockNIDSorter []types.StateBlockNID
func (s stateBlockNIDSorter) Len() int { return len(s) }
func (s stateBlockNIDSorter) Less(i, j int) bool { return s[i] < s[j] }
func (s stateBlockNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func uniqueStateBlockNIDs(nids []types.StateBlockNID) []types.StateBlockNID {
sort.Sort(stateBlockNIDSorter(nids))
return nids[:unique(stateBlockNIDSorter(nids))]
}
// Remove duplicate items from a sorted list.
// Takes the same interface as sort.Sort
// Returns the length of the data without duplicates
// Uses the last occurance of a duplicate.
// O(n).
func unique(data sort.Interface) int {
if data.Len() == 0 {
return 0
}
length := data.Len()
// j is the next index to output an element to.
j := 0
for i := 1; i < length; i++ {
// If the previous element is less than this element then they are
// not equal. Otherwise they must be equal because the list is sorted.
// If they are equal then we move onto the next element.
if data.Less(i-1, i) {
// "Write" the previous element to the output position by swaping
// the elements.
// Note that if the list has no duplicates then i-1 == j so the
// swap does nothing. (This assumes that data.Swap(a,b) nops if a==b)
data.Swap(i-1, j)
// Advance to the next output position in the list.
j++
}
}
// Output the last element.
data.Swap(length-1, j)
return j + 1
}

View file

@ -0,0 +1,67 @@
package input
import (
"github.com/matrix-org/dendrite/roomserver/types"
"testing"
)
type sortBytes []byte
func (s sortBytes) Len() int { return len(s) }
func (s sortBytes) Less(i, j int) bool { return s[i] < s[j] }
func (s sortBytes) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func TestUnique(t *testing.T) {
testCases := []struct {
Input string
Want string
}{
{"", ""},
{"abc", "abc"},
{"aaabbbccc", "abc"},
}
for _, test := range testCases {
input := []byte(test.Input)
want := string(test.Want)
got := string(input[:unique(sortBytes(input))])
if got != want {
t.Fatal("Wanted ", want, " got ", got)
}
}
}
func TestFindDuplicateStateKeys(t *testing.T) {
testCases := []struct {
Input []types.StateEntry
Want []types.StateEntry
}{{
Input: []types.StateEntry{
{types.StateKeyTuple{1, 1}, 1},
{types.StateKeyTuple{1, 1}, 2},
{types.StateKeyTuple{2, 2}, 3},
},
Want: []types.StateEntry{
{types.StateKeyTuple{1, 1}, 1},
{types.StateKeyTuple{1, 1}, 2},
},
}, {
Input: []types.StateEntry{
{types.StateKeyTuple{1, 1}, 1},
{types.StateKeyTuple{1, 2}, 2},
},
Want: nil,
}}
for _, test := range testCases {
got := findDuplicateStateKeys(test.Input)
if len(got) != len(test.Want) {
t.Fatalf("Wanted %v, got %v", test.Want, got)
}
for i := range got {
if got[i] != test.Want[i] {
t.Fatalf("Wanted %v, got %v", test.Want, got)
}
}
}
}

View file

@ -19,8 +19,15 @@ type statements struct {
selectRoomNIDStmt *sql.Stmt selectRoomNIDStmt *sql.Stmt
insertEventStmt *sql.Stmt insertEventStmt *sql.Stmt
bulkSelectStateEventByIDStmt *sql.Stmt bulkSelectStateEventByIDStmt *sql.Stmt
bulkSelectStateAtEventByIDStmt *sql.Stmt
updateEventStateStmt *sql.Stmt
insertEventJSONStmt *sql.Stmt insertEventJSONStmt *sql.Stmt
bulkSelectEventJSONStmt *sql.Stmt bulkSelectEventJSONStmt *sql.Stmt
insertStateStmt *sql.Stmt
bulkSelectStateBlockNIDsStmt *sql.Stmt
insertStateDataStmt *sql.Stmt
selectNextStateBlockNIDStmt *sql.Stmt
bulkSelectStateDataEntriesStmt *sql.Stmt
} }
func (s *statements) prepare(db *sql.DB) error { func (s *statements) prepare(db *sql.DB) error {
@ -180,14 +187,16 @@ const insertEventTypeNIDSQL = "" +
const selectEventTypeNIDSQL = "" + const selectEventTypeNIDSQL = "" +
"SELECT event_type_nid FROM event_types WHERE event_type = $1" "SELECT event_type_nid FROM event_types WHERE event_type = $1"
func (s *statements) insertEventTypeNID(eventType string) (eventTypeNID int64, err error) { func (s *statements) insertEventTypeNID(eventType string) (types.EventTypeNID, error) {
err = s.insertEventTypeNIDStmt.QueryRow(eventType).Scan(&eventTypeNID) var eventTypeNID int64
return err := s.insertEventTypeNIDStmt.QueryRow(eventType).Scan(&eventTypeNID)
return types.EventTypeNID(eventTypeNID), err
} }
func (s *statements) selectEventTypeNID(eventType string) (eventTypeNID int64, err error) { func (s *statements) selectEventTypeNID(eventType string) (types.EventTypeNID, error) {
err = s.selectEventTypeNIDStmt.QueryRow(eventType).Scan(&eventTypeNID) var eventTypeNID int64
return err := s.selectEventTypeNIDStmt.QueryRow(eventType).Scan(&eventTypeNID)
return types.EventTypeNID(eventTypeNID), err
} }
func (s *statements) prepareEventStateKeys(db *sql.DB) (err error) { func (s *statements) prepareEventStateKeys(db *sql.DB) (err error) {
@ -244,31 +253,33 @@ const bulkSelectEventStateKeyNIDSQL = "" +
"SELECT event_state_key, event_state_key_nid FROM event_state_keys" + "SELECT event_state_key, event_state_key_nid FROM event_state_keys" +
" WHERE event_state_key = ANY($1)" " WHERE event_state_key = ANY($1)"
func (s *statements) insertEventStateKeyNID(eventStateKey string) (eventStateKeyNID int64, err error) { func (s *statements) insertEventStateKeyNID(eventStateKey string) (types.EventStateKeyNID, error) {
err = s.insertEventStateKeyNIDStmt.QueryRow(eventStateKey).Scan(&eventStateKeyNID) var eventStateKeyNID int64
return err := s.insertEventStateKeyNIDStmt.QueryRow(eventStateKey).Scan(&eventStateKeyNID)
return types.EventStateKeyNID(eventStateKeyNID), err
} }
func (s *statements) selectEventStateKeyNID(eventStateKey string) (eventStateKeyNID int64, err error) { func (s *statements) selectEventStateKeyNID(eventStateKey string) (types.EventStateKeyNID, error) {
err = s.selectEventStateKeyNIDStmt.QueryRow(eventStateKey).Scan(&eventStateKeyNID) var eventStateKeyNID int64
return err := s.selectEventStateKeyNIDStmt.QueryRow(eventStateKey).Scan(&eventStateKeyNID)
return types.EventStateKeyNID(eventStateKeyNID), err
} }
func (s *statements) bulkSelectEventStateKeyNID(eventStateKeys []string) (map[string]int64, error) { func (s *statements) bulkSelectEventStateKeyNID(eventStateKeys []string) (map[string]types.EventStateKeyNID, error) {
rows, err := s.bulkSelectEventStateKeyNIDStmt.Query(pq.StringArray(eventStateKeys)) rows, err := s.bulkSelectEventStateKeyNIDStmt.Query(pq.StringArray(eventStateKeys))
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close() defer rows.Close()
result := make(map[string]int64, len(eventStateKeys)) result := make(map[string]types.EventStateKeyNID, len(eventStateKeys))
for rows.Next() { for rows.Next() {
var stateKey string var stateKey string
var stateKeyNID int64 var stateKeyNID int64
if err := rows.Scan(&stateKey, &stateKeyNID); err != nil { if err := rows.Scan(&stateKey, &stateKeyNID); err != nil {
return nil, err return nil, err
} }
result[stateKey] = stateKeyNID result[stateKey] = types.EventStateKeyNID(stateKeyNID)
} }
return result, nil return result, nil
} }
@ -307,14 +318,16 @@ const insertRoomNIDSQL = "" +
const selectRoomNIDSQL = "" + const selectRoomNIDSQL = "" +
"SELECT room_nid FROM rooms WHERE room_id = $1" "SELECT room_nid FROM rooms WHERE room_id = $1"
func (s *statements) insertRoomNID(roomID string) (roomNID int64, err error) { func (s *statements) insertRoomNID(roomID string) (types.RoomNID, error) {
err = s.insertRoomNIDStmt.QueryRow(roomID).Scan(&roomNID) var roomNID int64
return err := s.insertRoomNIDStmt.QueryRow(roomID).Scan(&roomNID)
return types.RoomNID(roomNID), err
} }
func (s *statements) selectRoomNID(roomID string) (roomNID int64, err error) { func (s *statements) selectRoomNID(roomID string) (types.RoomNID, error) {
err = s.selectRoomNIDStmt.QueryRow(roomID).Scan(&roomNID) var roomNID int64
return err := s.selectRoomNIDStmt.QueryRow(roomID).Scan(&roomNID)
return types.RoomNID(roomNID), err
} }
const eventsSchema = ` const eventsSchema = `
@ -333,6 +346,13 @@ CREATE TABLE IF NOT EXISTS events (
-- Local numeric ID for the state_key of the event -- Local numeric ID for the state_key of the event
-- This is 0 if the event is not a state event. -- This is 0 if the event is not a state event.
event_state_key_nid BIGINT NOT NULL, event_state_key_nid BIGINT NOT NULL,
-- Local numeric ID for the state at the event.
-- This is 0 if we don't know the state at the event.
-- If the state is not 0 then this event is part of the contiguous
-- part of the event graph
-- Since many different events can have the same state we store the
-- state into a separate state table and refer to it by numeric ID.
state_snapshot_nid bigint NOT NULL DEFAULT 0,
-- The textual event id. -- The textual event id.
-- Used to lookup the numeric ID when processing requests. -- Used to lookup the numeric ID when processing requests.
-- Needed for state resolution. -- Needed for state resolution.
@ -342,7 +362,7 @@ CREATE TABLE IF NOT EXISTS events (
-- Needed for setting reference hashes when sending new events. -- Needed for setting reference hashes when sending new events.
reference_sha256 BYTEA NOT NULL, reference_sha256 BYTEA NOT NULL,
-- A list of numeric IDs for events that can authenticate this event. -- A list of numeric IDs for events that can authenticate this event.
auth_event_nids BIGINT[] NOT NULL, auth_event_nids BIGINT[] NOT NULL
); );
` `
@ -351,7 +371,7 @@ const insertEventSQL = "" +
" VALUES ($1, $2, $3, $4, $5, $6)" + " VALUES ($1, $2, $3, $4, $5, $6)" +
" ON CONFLICT ON CONSTRAINT event_id_unique" + " ON CONFLICT ON CONSTRAINT event_id_unique" +
" DO UPDATE SET event_id = $1" + " DO UPDATE SET event_id = $1" +
" RETURNING event_nid" " RETURNING event_nid, state_snapshot_nid"
// Bulk lookup of events by string ID. // Bulk lookup of events by string ID.
// Sort by the numeric IDs for event type and state key. // Sort by the numeric IDs for event type and state key.
@ -361,6 +381,13 @@ const bulkSelectStateEventByIDSQL = "" +
" WHERE event_id = ANY($1)" + " WHERE event_id = ANY($1)" +
" ORDER BY event_type_nid, event_state_key_nid ASC" " ORDER BY event_type_nid, event_state_key_nid ASC"
const bulkSelectStateAtEventByIDSQL = "" +
"SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid FROM events" +
" WHERE event_id = ANY($1)"
const updateEventStateSQL = "" +
"UPDATE events SET state_snapshot_nid = $2 WHERE event_nid = $1"
func (s *statements) prepareEvents(db *sql.DB) (err error) { func (s *statements) prepareEvents(db *sql.DB) (err error) {
_, err = db.Exec(eventsSchema) _, err = db.Exec(eventsSchema)
if err != nil { if err != nil {
@ -372,20 +399,32 @@ func (s *statements) prepareEvents(db *sql.DB) (err error) {
if s.bulkSelectStateEventByIDStmt, err = db.Prepare(bulkSelectStateEventByIDSQL); err != nil { if s.bulkSelectStateEventByIDStmt, err = db.Prepare(bulkSelectStateEventByIDSQL); err != nil {
return return
} }
if s.bulkSelectStateAtEventByIDStmt, err = db.Prepare(bulkSelectStateAtEventByIDSQL); err != nil {
return
}
if s.updateEventStateStmt, err = db.Prepare(updateEventStateSQL); err != nil {
return
}
return return
} }
func (s *statements) insertEvent( func (s *statements) insertEvent(
roomNID, eventTypeNID, eventStateKeyNID int64, roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID,
eventID string, eventID string,
referenceSHA256 []byte, referenceSHA256 []byte,
authEventNIDs []int64, authEventNIDs []types.EventNID,
) (eventNID int64, err error) { ) (types.EventNID, types.StateSnapshotNID, error) {
err = s.insertEventStmt.QueryRow( nids := make([]int64, len(authEventNIDs))
roomNID, eventTypeNID, eventStateKeyNID, eventID, referenceSHA256, for i := range authEventNIDs {
pq.Int64Array(authEventNIDs), nids[i] = int64(authEventNIDs[i])
).Scan(&eventNID) }
return var eventNID int64
var stateNID int64
err := s.insertEventStmt.QueryRow(
int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), eventID, referenceSHA256,
pq.Int64Array(nids),
).Scan(&eventNID, &stateNID)
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
} }
func (s *statements) bulkSelectStateEventByID(eventIDs []string) ([]types.StateEntry, error) { func (s *statements) bulkSelectStateEventByID(eventIDs []string) ([]types.StateEntry, error) {
@ -421,6 +460,39 @@ func (s *statements) bulkSelectStateEventByID(eventIDs []string) ([]types.StateE
return results, err return results, err
} }
func (s *statements) bulkSelectStateAtEventByID(eventIDs []string) ([]types.StateAtEvent, error) {
rows, err := s.bulkSelectStateAtEventByIDStmt.Query(pq.StringArray(eventIDs))
if err != nil {
return nil, err
}
defer rows.Close()
results := make([]types.StateAtEvent, len(eventIDs))
i := 0
for ; rows.Next(); i++ {
result := &results[i]
if err = rows.Scan(
&result.EventNID,
&result.EventTypeNID,
&result.EventStateKeyNID,
&result.BeforeStateSnapshotNID,
); err != nil {
return nil, err
}
if result.BeforeStateSnapshotNID == 0 {
return nil, fmt.Errorf("storage: missing state for event NID %d", result.EventNID)
}
}
if i != len(eventIDs) {
return nil, fmt.Errorf("storage: event IDs missing from the database (%d != %d)", i, len(eventIDs))
}
return results, err
}
func (s *statements) updateEventState(eventNID types.EventNID, stateNID types.StateSnapshotNID) error {
_, err := s.updateEventStateStmt.Exec(int64(eventNID), int64(stateNID))
return err
}
func (s *statements) prepareEventJSON(db *sql.DB) (err error) { func (s *statements) prepareEventJSON(db *sql.DB) (err error) {
_, err = db.Exec(eventJSONSchema) _, err = db.Exec(eventJSONSchema)
if err != nil { if err != nil {
@ -464,18 +536,22 @@ const bulkSelectEventJSONSQL = "" +
" WHERE event_nid = ANY($1)" + " WHERE event_nid = ANY($1)" +
" ORDER BY event_nid ASC" " ORDER BY event_nid ASC"
func (s *statements) insertEventJSON(eventNID int64, eventJSON []byte) error { func (s *statements) insertEventJSON(eventNID types.EventNID, eventJSON []byte) error {
_, err := s.insertEventJSONStmt.Exec(eventNID, eventJSON) _, err := s.insertEventJSONStmt.Exec(int64(eventNID), eventJSON)
return err return err
} }
type eventJSONPair struct { type eventJSONPair struct {
EventNID int64 EventNID types.EventNID
EventJSON []byte EventJSON []byte
} }
func (s *statements) bulkSelectEventJSON(eventNIDs []int64) ([]eventJSONPair, error) { func (s *statements) bulkSelectEventJSON(eventNIDs []types.EventNID) ([]eventJSONPair, error) {
rows, err := s.bulkSelectEventJSONStmt.Query(pq.Int64Array(eventNIDs)) nids := make([]int64, len(eventNIDs))
for i := range eventNIDs {
nids[i] = int64(eventNIDs[i])
}
rows, err := s.bulkSelectEventJSONStmt.Query(pq.Int64Array(nids))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -488,9 +564,223 @@ func (s *statements) bulkSelectEventJSON(eventNIDs []int64) ([]eventJSONPair, er
results := make([]eventJSONPair, len(eventNIDs)) results := make([]eventJSONPair, len(eventNIDs))
i := 0 i := 0
for ; rows.Next(); i++ { for ; rows.Next(); i++ {
if err := rows.Scan(&results[i].EventNID, &results[i].EventJSON); err != nil { result := &results[i]
var eventNID int64
if err := rows.Scan(&eventNID, &result.EventJSON); err != nil {
return nil, err return nil, err
} }
result.EventNID = types.EventNID(eventNID)
} }
return results[:i], nil return results[:i], nil
} }
const stateSchema = `
-- The state of a room before an event.
-- Stored as a list of state_block entries stored in a separate table.
-- The actual state is constructed by combining all the state_block entries
-- referenced by state_block_nids together. If the same state key tuple appears
-- multiple times then the entry from the later state_block clobbers the earlier
-- entries.
-- This encoding format allows us to implement a delta encoding which is useful
-- because room state tends to accumulate small changes over time. Although if
-- the list of deltas becomes too long it becomes more efficient to encode
-- the full state under single state_block_nid.
CREATE SEQUENCE IF NOT EXISTS state_snapshot_nid_seq;
CREATE TABLE IF NOT EXISTS state_snapshots (
-- Local numeric ID for the state.
state_snapshot_nid bigint PRIMARY KEY DEFAULT nextval('state_snapshot_nid_seq'),
-- Local numeric ID of the room this state is for.
-- Unused in normal operation, but useful for background work or ad-hoc debugging.
room_nid bigint NOT NULL,
-- List of state_block_nids, stored sorted by state_block_nid.
state_block_nids bigint[] NOT NULL
);
`
const insertStateSQL = "" +
"INSERT INTO state_snapshots (room_nid, state_block_nids)" +
" VALUES ($1, $2)" +
" RETURNING state_snapshot_nid"
// Bulk state data NID lookup.
// Sorting by state_snapshot_nid means we can use binary search over the result
// to lookup the state data NIDs for a state snapshot NID.
const bulkSelectStateBlockNIDsSQL = "" +
"SELECT state_snapshot_nid, state_block_nids FROM state_snapshots" +
" WHERE state_snapshot_nid = ANY($1) ORDER BY state_snapshot_nid ASC"
func (s *statements) prepareState(db *sql.DB) (err error) {
_, err = db.Exec(stateSchema)
if err != nil {
return
}
if s.insertStateStmt, err = db.Prepare(insertStateSQL); err != nil {
return
}
if s.bulkSelectStateBlockNIDsStmt, err = db.Prepare(bulkSelectStateBlockNIDsSQL); err != nil {
return
}
return
}
func (s *statements) insertState(roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID) (stateNID types.StateSnapshotNID, err error) {
nids := make([]int64, len(stateBlockNIDs))
for i := range stateBlockNIDs {
nids[i] = int64(stateBlockNIDs[i])
}
err = s.insertStateStmt.QueryRow(int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID)
return
}
func (s *statements) bulkSelectStateBlockNIDs(stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) {
nids := make([]int64, len(stateNIDs))
for i := range stateNIDs {
nids[i] = int64(stateNIDs[i])
}
rows, err := s.bulkSelectStateBlockNIDsStmt.Query(pq.Int64Array(nids))
if err != nil {
return nil, err
}
defer rows.Close()
results := make([]types.StateBlockNIDList, len(stateNIDs))
i := 0
for ; rows.Next(); i++ {
result := &results[i]
var stateBlockNIDs pq.Int64Array
if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil {
return nil, err
}
result.StateBlockNIDs = make([]types.StateBlockNID, len(stateBlockNIDs))
for k := range stateBlockNIDs {
result.StateBlockNIDs[k] = types.StateBlockNID(stateBlockNIDs[k])
}
}
if i != len(stateNIDs) {
return nil, fmt.Errorf("storage: state NIDs missing from the database (%d != %d)", i, len(stateNIDs))
}
return results, nil
}
const stateDataSchema = `
-- The state data map.
-- Designed to give enough information to run the state resolution algorithm
-- without hitting the database in the common case.
-- TODO: Is it worth replacing the unique btree index with a covering index so
-- that postgres could lookup the state using an index-only scan?
-- The type and state_key are included in the index to make it easier to
-- lookup a specific (type, state_key) pair for an event. It also makes it easy
-- to read the state for a given state_block_nid ordered by (type, state_key)
-- which in turn makes it easier to merge state data blocks.
CREATE SEQUENCE IF NOT EXISTS state_block_nid_seq;
CREATE TABLE IF NOT EXISTS state_block (
-- Local numeric ID for this state data.
state_block_nid bigint NOT NULL,
event_type_nid bigint NOT NULL,
event_state_key_nid bigint NOT NULL,
event_nid bigint NOT NULL,
UNIQUE (state_block_nid, event_type_nid, event_state_key_nid)
);
`
const insertStateDataSQL = "" +
"INSERT INTO state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" +
" VALUES ($1, $2, $3, $4)"
const selectNextStateBlockNIDSQL = "" +
"SELECT nextval('state_block_nid_seq')"
// Bulk state lookup by numeric event ID.
// Sort by the state_block_nid, event_type_nid, event_state_key_nid
// This means that all the entries for a given state_block_nid will appear
// together in the list and those entries will sorted by event_type_nid
// and event_state_key_nid. This property makes it easier to merge two
// state data blocks together.
const bulkSelectStateDataEntriesSQL = "" +
"SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
" FROM state_block WHERE state_block_nid = ANY($1)" +
" ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
func (s *statements) prepareStateData(db *sql.DB) (err error) {
_, err = db.Exec(stateDataSchema)
if err != nil {
return
}
if s.insertStateDataStmt, err = db.Prepare(insertStateDataSQL); err != nil {
return
}
if s.selectNextStateBlockNIDStmt, err = db.Prepare(selectNextStateBlockNIDSQL); err != nil {
return
}
if s.bulkSelectStateDataEntriesStmt, err = db.Prepare(bulkSelectStateDataEntriesSQL); err != nil {
return
}
return
}
func (s *statements) bulkInsertStateData(stateBlockNID types.StateBlockNID, entries []types.StateEntry) error {
for _, entry := range entries {
_, err := s.insertStateDataStmt.Exec(
int64(stateBlockNID),
int64(entry.EventTypeNID),
int64(entry.EventStateKeyNID),
int64(entry.EventNID),
)
if err != nil {
return err
}
}
return nil
}
func (s *statements) selectNextStateBlockNID() (types.StateBlockNID, error) {
var stateBlockNID int64
err := s.selectNextStateBlockNIDStmt.QueryRow().Scan(&stateBlockNID)
return types.StateBlockNID(stateBlockNID), err
}
func (s *statements) bulkSelectStateDataEntries(stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) {
nids := make([]int64, len(stateBlockNIDs))
for i := range stateBlockNIDs {
nids[i] = int64(stateBlockNIDs[i])
}
rows, err := s.bulkSelectStateDataEntriesStmt.Query(pq.Int64Array(nids))
if err != nil {
return nil, err
}
defer rows.Close()
results := make([]types.StateEntryList, len(stateBlockNIDs))
// current is a pointer to the StateEntryList to append the state entries to.
var current *types.StateEntryList
i := 0
for rows.Next() {
var (
stateBlockNID int64
eventTypeNID int64
eventStateKeyNID int64
eventNID int64
entry types.StateEntry
)
if err := rows.Scan(
&stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
); err != nil {
return nil, err
}
entry.EventTypeNID = types.EventTypeNID(eventTypeNID)
entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
entry.EventNID = types.EventNID(eventNID)
if current == nil || types.StateBlockNID(stateBlockNID) != current.StateBlockNID {
// The state entry row is for a different state data block to the current one.
// So we start appending to the next entry in the list.
current = &results[i]
current.StateBlockNID = types.StateBlockNID(stateBlockNID)
i++
}
current.StateEntries = append(current.StateEntries, entry)
}
if i != len(stateBlockNIDs) {
return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(stateBlockNIDs))
}
return results, nil
}

View file

@ -38,21 +38,22 @@ func (d *Database) SetPartitionOffset(topic string, partition int32, offset int6
} }
// StoreEvent implements input.EventDatabase // StoreEvent implements input.EventDatabase
func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []int64) error { func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []types.EventNID) (types.RoomNID, types.StateAtEvent, error) {
var ( var (
roomNID int64 roomNID types.RoomNID
eventTypeNID int64 eventTypeNID types.EventTypeNID
eventStateKeyNID int64 eventStateKeyNID types.EventStateKeyNID
eventNID int64 eventNID types.EventNID
stateNID types.StateSnapshotNID
err error err error
) )
if roomNID, err = d.assignRoomNID(event.RoomID()); err != nil { if roomNID, err = d.assignRoomNID(event.RoomID()); err != nil {
return err return 0, types.StateAtEvent{}, err
} }
if eventTypeNID, err = d.assignEventTypeNID(event.Type()); err != nil { if eventTypeNID, err = d.assignEventTypeNID(event.Type()); err != nil {
return err return 0, types.StateAtEvent{}, err
} }
eventStateKey := event.StateKey() eventStateKey := event.StateKey()
@ -60,11 +61,11 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []int
// Otherwise set the numeric ID for the state_key to 0. // Otherwise set the numeric ID for the state_key to 0.
if eventStateKey != nil { if eventStateKey != nil {
if eventStateKeyNID, err = d.assignStateKeyNID(*eventStateKey); err != nil { if eventStateKeyNID, err = d.assignStateKeyNID(*eventStateKey); err != nil {
return err return 0, types.StateAtEvent{}, err
} }
} }
if eventNID, err = d.statements.insertEvent( if eventNID, stateNID, err = d.statements.insertEvent(
roomNID, roomNID,
eventTypeNID, eventTypeNID,
eventStateKeyNID, eventStateKeyNID,
@ -72,13 +73,26 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []int
event.EventReference().EventSHA256, event.EventReference().EventSHA256,
authEventNIDs, authEventNIDs,
); err != nil { ); err != nil {
return err return 0, types.StateAtEvent{}, err
} }
return d.statements.insertEventJSON(eventNID, event.JSON()) if err = d.statements.insertEventJSON(eventNID, event.JSON()); err != nil {
return 0, types.StateAtEvent{}, err
} }
func (d *Database) assignRoomNID(roomID string) (int64, error) { return roomNID, types.StateAtEvent{
BeforeStateSnapshotNID: stateNID,
StateEntry: types.StateEntry{
StateKeyTuple: types.StateKeyTuple{
EventTypeNID: eventTypeNID,
EventStateKeyNID: eventStateKeyNID,
},
EventNID: eventNID,
},
}, nil
}
func (d *Database) assignRoomNID(roomID string) (types.RoomNID, error) {
// Check if we already have a numeric ID in the database. // Check if we already have a numeric ID in the database.
roomNID, err := d.statements.selectRoomNID(roomID) roomNID, err := d.statements.selectRoomNID(roomID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -91,7 +105,7 @@ func (d *Database) assignRoomNID(roomID string) (int64, error) {
return roomNID, nil return roomNID, nil
} }
func (d *Database) assignEventTypeNID(eventType string) (int64, error) { func (d *Database) assignEventTypeNID(eventType string) (types.EventTypeNID, error) {
// Check if we already have a numeric ID in the database. // Check if we already have a numeric ID in the database.
eventTypeNID, err := d.statements.selectEventTypeNID(eventType) eventTypeNID, err := d.statements.selectEventTypeNID(eventType)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -104,7 +118,7 @@ func (d *Database) assignEventTypeNID(eventType string) (int64, error) {
return eventTypeNID, nil return eventTypeNID, nil
} }
func (d *Database) assignStateKeyNID(eventStateKey string) (int64, error) { func (d *Database) assignStateKeyNID(eventStateKey string) (types.EventStateKeyNID, error) {
// Check if we already have a numeric ID in the database. // Check if we already have a numeric ID in the database.
eventStateKeyNID, err := d.statements.selectEventStateKeyNID(eventStateKey) eventStateKeyNID, err := d.statements.selectEventStateKeyNID(eventStateKey)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -123,12 +137,12 @@ func (d *Database) StateEntriesForEventIDs(eventIDs []string) ([]types.StateEntr
} }
// EventStateKeyNIDs implements input.EventDatabase // EventStateKeyNIDs implements input.EventDatabase
func (d *Database) EventStateKeyNIDs(eventStateKeys []string) (map[string]int64, error) { func (d *Database) EventStateKeyNIDs(eventStateKeys []string) (map[string]types.EventStateKeyNID, error) {
return d.statements.bulkSelectEventStateKeyNID(eventStateKeys) return d.statements.bulkSelectEventStateKeyNID(eventStateKeys)
} }
// Events implements input.EventDatabase // Events implements input.EventDatabase
func (d *Database) Events(eventNIDs []int64) ([]types.Event, error) { func (d *Database) Events(eventNIDs []types.EventNID) ([]types.Event, error) {
eventJSONs, err := d.statements.bulkSelectEventJSON(eventNIDs) eventJSONs, err := d.statements.bulkSelectEventJSON(eventNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
@ -145,3 +159,39 @@ func (d *Database) Events(eventNIDs []int64) ([]types.Event, error) {
} }
return results, nil return results, nil
} }
// AddState implements input.EventDatabase
func (d *Database) AddState(roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) {
if len(state) > 0 {
stateBlockNID, err := d.statements.selectNextStateBlockNID()
if err != nil {
return 0, err
}
if err = d.statements.bulkInsertStateData(stateBlockNID, state); err != nil {
return 0, err
}
stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID)
}
return d.statements.insertState(roomNID, stateBlockNIDs)
}
// SetState implements input.EventDatabase
func (d *Database) SetState(eventNID types.EventNID, stateNID types.StateSnapshotNID) error {
return d.statements.updateEventState(eventNID, stateNID)
}
// StateAtEventIDs implements input.EventDatabase
func (d *Database) StateAtEventIDs(eventIDs []string) ([]types.StateAtEvent, error) {
return d.statements.bulkSelectStateAtEventByID(eventIDs)
}
// StateBlockNIDs implements input.EventDatabase
func (d *Database) StateBlockNIDs(stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) {
return d.statements.bulkSelectStateBlockNIDs(stateNIDs)
}
// StateEntries implements input.EventDatabase
func (d *Database) StateEntries(stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) {
return d.statements.bulkSelectStateDataEntries(stateBlockNIDs)
}

View file

@ -13,13 +13,32 @@ type PartitionOffset struct {
Offset int64 Offset int64
} }
// EventTypeNID is a numeric ID for an event type.
type EventTypeNID int64
// EventStateKeyNID is a numeric ID for an event state_key.
type EventStateKeyNID int64
// EventNID is a numeric ID for an event.
type EventNID int64
// RoomNID is a numeric ID for a room.
type RoomNID int64
// StateSnapshotNID is a numeric ID for the state at an event.
type StateSnapshotNID int64
// StateBlockNID is a numeric ID for a block of state data.
// These blocks of state data are combined to form the actual state.
type StateBlockNID int64
// A StateKeyTuple is a pair of a numeric event type and a numeric state key. // A StateKeyTuple is a pair of a numeric event type and a numeric state key.
// It is used to lookup state entries. // It is used to lookup state entries.
type StateKeyTuple struct { type StateKeyTuple struct {
// The numeric ID for the event type. // The numeric ID for the event type.
EventTypeNID int64 EventTypeNID EventTypeNID
// The numeric ID for the state key. // The numeric ID for the state key.
EventStateKeyNID int64 EventStateKeyNID EventStateKeyNID
} }
// LessThan returns true if this state key is less than the other state key. // LessThan returns true if this state key is less than the other state key.
@ -35,7 +54,7 @@ func (a StateKeyTuple) LessThan(b StateKeyTuple) bool {
type StateEntry struct { type StateEntry struct {
StateKeyTuple StateKeyTuple
// The numeric ID for the event. // The numeric ID for the event.
EventNID int64 EventNID EventNID
} }
// LessThan returns true if this state entry is less than the other state entry. // LessThan returns true if this state entry is less than the other state entry.
@ -47,10 +66,23 @@ func (a StateEntry) LessThan(b StateEntry) bool {
return a.EventNID < b.EventNID return a.EventNID < b.EventNID
} }
// StateAtEvent is the state before and after a matrix event.
type StateAtEvent struct {
// The state before the event.
BeforeStateSnapshotNID StateSnapshotNID
// The state entry for the event itself, allows us to calculate the state after the event.
StateEntry
}
// IsStateEvent returns whether the event the state is at is a state event.
func (s StateAtEvent) IsStateEvent() bool {
return s.EventStateKeyNID != 0
}
// An Event is a gomatrixserverlib.Event with the numeric event ID attached. // An Event is a gomatrixserverlib.Event with the numeric event ID attached.
// It is when performing bulk event lookup in the database. // It is when performing bulk event lookup in the database.
type Event struct { type Event struct {
EventNID int64 EventNID EventNID
gomatrixserverlib.Event gomatrixserverlib.Event
} }
@ -75,3 +107,15 @@ const (
// EmptyStateKeyNID is the numeric ID for the empty state key. // EmptyStateKeyNID is the numeric ID for the empty state key.
EmptyStateKeyNID = 1 EmptyStateKeyNID = 1
) )
// StateBlockNIDList is used to return the result of bulk StateBlockNID lookups from the database.
type StateBlockNIDList struct {
StateSnapshotNID StateSnapshotNID
StateBlockNIDs []StateBlockNID
}
// StateEntryList is used to return the result of bulk state entry lookups from the database.
type StateEntryList struct {
StateBlockNID StateBlockNID
StateEntries []StateEntry
}