From bab3ca5f5fe6551033c4980f6f455a5c3c13963e Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Mon, 13 Feb 2017 16:07:02 +0000 Subject: [PATCH] Calculate and store the state at each event --- .../dendrite/roomserver/input/events.go | 31 ++- .../dendrite/roomserver/input/state.go | 247 ++++++++++++++++++ .../dendrite/roomserver/storage/sql.go | 246 ++++++++++++++++- .../dendrite/roomserver/storage/storage.go | 64 ++++- .../dendrite/roomserver/types/types.go | 20 ++ 5 files changed, 595 insertions(+), 13 deletions(-) create mode 100644 src/github.com/matrix-org/dendrite/roomserver/input/state.go diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/events.go b/src/github.com/matrix-org/dendrite/roomserver/input/events.go index 8aaf9f841..56369150d 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/input/events.go +++ b/src/github.com/matrix-org/dendrite/roomserver/input/events.go @@ -9,10 +9,10 @@ import ( // A RoomEventDatabase has the storage APIs needed to store a room event. type RoomEventDatabase interface { // Stores a matrix room event in the database - StoreEvent(event gomatrixserverlib.Event, authEventNIDs []int64) error + StoreEvent(event gomatrixserverlib.Event, authEventNIDs []int64) (roomNID int64, stateAtEvent types.StateAtEvent, err error) // Lookup the state entries for a list of string event IDs // Returns a sorted list of state entries. - // Returns a error if the there is an error talking to the database + // Returns an error if the there is an error talking to the database // or if the event IDs aren't in the database. StateEntriesForEventIDs(eventIDs []string) ([]types.StateEntry, error) // Lookup the numeric IDs for a list of string event state keys. @@ -21,6 +21,21 @@ type RoomEventDatabase interface { // Lookup the Events for a list of numeric event IDs. // Returns a sorted list of events. Events(eventNIDs []int64) ([]types.Event, error) + // Lookup the state of a room at each event for a list of string event IDs. + // Returns a sorted list of state at each event. + // Returns an error if there is an error talking to the database + // or if the room state for the event IDs aren't in the database + StateAtEventIDs(eventIDs []string) ([]types.StateAtEvent, error) + // Lookup the numeric state data IDs for the each numeric state ID + // The returned slice is sorted by numeric state ID. + StateDataNIDs(stateNIDs []int64) ([]types.StateDataNIDList, error) + // Lookup the state data for each numeric state data ID + // The returned slice is sorted by numeric state data ID. + StateEntries(stateDataNIDs []int64) ([]types.StateEntryList, error) + // Store the room state at an event in the database + AddState(roomNID int64, stateDataNIDs []int64, state []types.StateEntry) (stateNID int64, err error) + // Set the state at an event. + SetState(eventNID, stateNID int64) error } func processRoomEvent(db RoomEventDatabase, input api.InputRoomEvent) error { @@ -37,7 +52,8 @@ func processRoomEvent(db RoomEventDatabase, input api.InputRoomEvent) error { } // Store the event - if err := db.StoreEvent(event, authEventNIDs); err != nil { + roomNID, stateAtEvent, err := db.StoreEvent(event, authEventNIDs) + if err != nil { return err } @@ -48,6 +64,15 @@ func processRoomEvent(db RoomEventDatabase, input api.InputRoomEvent) error { return nil } + if stateAtEvent.BeforeStateNID == 0 { + // We haven't calculated a state for this event yet. + // Lets calculate one. + if stateAtEvent.BeforeStateNID, err = calculateAndStoreState(db, event, roomNID, input.StateEventIDs); err != nil { + return err + } + db.SetState(stateAtEvent.EventNID, stateAtEvent.BeforeStateNID) + } + // TODO: // * Calcuate the state at the event if necessary. // * Store the state at the event. diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/state.go b/src/github.com/matrix-org/dendrite/roomserver/input/state.go new file mode 100644 index 000000000..b8cf80550 --- /dev/null +++ b/src/github.com/matrix-org/dendrite/roomserver/input/state.go @@ -0,0 +1,247 @@ +package input + +import ( + "fmt" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" + "sort" +) + +func calculateAndStoreState( + db RoomEventDatabase, event gomatrixserverlib.Event, roomNID int64, stateEventIDs []string, +) (int64, error) { + if stateEventIDs != nil { + // 1) We've been told what the state at the event is. + // Check that those state events are in the database and store the state. + entries, err := db.StateEntriesForEventIDs(stateEventIDs) + if err != nil { + return 0, err + } + + return db.AddState(roomNID, nil, entries) + } + + // Load the state at the prev events. + prevEventRefs := event.PrevEvents() + prevEventIDs := make([]string, len(prevEventRefs)) + for i := range prevEventRefs { + prevEventIDs[i] = prevEventRefs[i].EventID + } + + prevStates, err := db.StateAtEventIDs(prevEventIDs) + if err != nil { + return 0, err + } + + if len(prevStates) == 0 { + // 2) There weren't any prev_events for this event so the state is + // empty. + return db.AddState(roomNID, nil, nil) + } + + if len(prevStates) == 1 { + prevState := prevStates[0] + if prevState.EventStateKeyNID == 0 { + // 3) None of the previous events were state events and they all + // have the same state, so this event has exactly the same state + // as the previous events. + // This should be the common case. + return prevState.BeforeStateNID, nil + } + // The previous event was a state event so we need to store a copy + // of the previous state updated with that event. + stateDataNIDLists, err := db.StateDataNIDs([]int64{prevState.BeforeStateNID}) + if err != nil { + return 0, err + } + stateDataNIDs := stateDataNIDLists[0].StateDataNIDs + if len(stateDataNIDs) < maxStateDataNIDs { + // 4) The number of state data blocks is small enough that we can just + // add the state event as a block of size one to the end of the blocks. + return db.AddState( + roomNID, stateDataNIDs, []types.StateEntry{prevState.StateEntry}, + ) + } + // If there are too many deltas then we need to calculate the full state + // So fall through to calculateAndStoreStateMany + } + return calculateAndStoreStateMany(db, roomNID, prevStates) +} + +const maxStateDataNIDs = 64 + +func calculateAndStoreStateMany(db RoomEventDatabase, roomNID int64, prevStates []types.StateAtEvent) (int64, error) { + // Conflict resolution. + // First stage: load the state datablocks for the prev events. + stateNIDs := make([]int64, len(prevStates)) + for i, state := range prevStates { + stateNIDs[i] = state.BeforeStateNID + } + stateDataNIDLists, err := db.StateDataNIDs(uniqueNIDs(stateNIDs)) + if err != nil { + return 0, err + } + + var stateDataNIDs []int64 + for _, list := range stateDataNIDLists { + stateDataNIDs = append(stateDataNIDs, list.StateDataNIDs...) + } + stateEntryLists, err := db.StateEntries(uniqueNIDs(stateDataNIDs)) + if err != nil { + return 0, err + } + stateDataNIDsMap := stateDataNIDListMap(stateDataNIDLists) + stateEntriesMap := stateEntryListMap(stateEntryLists) + + var combined []types.StateEntry + for _, prevState := range prevStates { + list, ok := stateDataNIDsMap.lookup(prevState.BeforeStateNID) + if !ok { + // This should only get hit if the database is corrupt. + // It should be impossible for an event to reference a NID that doesn't exist + panic(fmt.Errorf("Corrupt DB: Missing state numeric ID %d", prevState.BeforeStateNID)) + } + + var fullState []types.StateEntry + for _, stateDataNID := range list { + entries, ok := stateEntriesMap.lookup(stateDataNID) + if !ok { + // This should only get hit if the database is corrupt. + // It should be impossible for an event to reference a NID that doesn't exist + panic(fmt.Errorf("Corrupt DB: Missing state numeric ID %d", prevState.BeforeStateNID)) + } + fullState = append(fullState, entries...) + } + if prevState.EventStateKeyNID != 0 { + fullState = append(fullState, prevState.StateEntry) + } + + // Stable sort so that the most recent entry for each state key stays + // remains later in the list than the older entries for the same state key. + sort.Stable(stateEntryByStateKeySorter(fullState)) + // Unique returns the last entry for each state key. + fullState = fullState[:unique(stateEntryByStateKeySorter(fullState))] + // Add the full state for this StateNID. + combined = append(combined, fullState...) + } + + // Collect all the entries with the same type and key together. + // We don't care about the order here. + sort.Sort(stateEntrySorter(combined)) + // Remove duplicate entires. + combined = combined[:unique(stateEntrySorter(combined))] + + // Find the conflicts + conflicts := duplicateStateKeys(combined) + + var state []types.StateEntry + if len(conflicts) > 0 { + // 5) There are conflicting state events, for each conflict workout + // what the appropriate state event is. + resolved, err := resolveConflicts(db, combined, conflicts) + if err != nil { + return 0, err + } + state = resolved + } else { + // 6) There weren't any conflicts + state = combined + } + + // TODO: Check if we can encode the new state as a delta against the + // previous state. + return db.AddState(roomNID, nil, state) +} + +func resolveConflicts(db RoomEventDatabase, combinded, conflicted []types.StateEntry) ([]types.StateEntry, error) { + panic(fmt.Errorf("Not implemented")) +} + +func duplicateStateKeys(a []types.StateEntry) []types.StateEntry { + var result []types.StateEntry + j := 0 + for i := 1; i < len(a); i++ { + if a[j].StateKeyTuple != a[i].StateKeyTuple { + result = append(result, a[j:i]...) + j = i + } + } + if j != len(a)-1 { + result = append(result, a[j:]...) + } + return result +} + +func uniqueNIDs(nids []int64) []int64 { + sort.Sort(int64Sorter(nids)) + return nids[:unique(int64Sorter(nids))] +} + +type stateDataNIDListMap []types.StateDataNIDList + +func (m stateDataNIDListMap) lookup(stateNID int64) (stateDataNIDs []int64, ok bool) { + list := []types.StateDataNIDList(m) + i := sort.Search(len(list), func(i int) bool { + return list[i].StateNID >= stateNID + }) + if i < len(list) && list[i].StateNID == stateNID { + ok = true + stateDataNIDs = list[i].StateDataNIDs + } + return +} + +type stateEntryListMap []types.StateEntryList + +func (m stateEntryListMap) lookup(stateDataNID int64) (stateEntries []types.StateEntry, ok bool) { + list := []types.StateEntryList(m) + i := sort.Search(len(list), func(i int) bool { + return list[i].StateDataNID >= stateDataNID + }) + if i < len(list) && list[i].StateDataNID == stateDataNID { + ok = true + stateEntries = list[i].StateEntries + } + return +} + +type stateEntryByStateKeySorter []types.StateEntry + +func (s stateEntryByStateKeySorter) Len() int { return len(s) } +func (s stateEntryByStateKeySorter) Less(i, j int) bool { + return s[i].StateKeyTuple.LessThan(s[j].StateKeyTuple) +} +func (s stateEntryByStateKeySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +type stateEntrySorter []types.StateEntry + +func (s stateEntrySorter) Len() int { return len(s) } +func (s stateEntrySorter) Less(i, j int) bool { return s[i].LessThan(s[j]) } +func (s stateEntrySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +type int64Sorter []int64 + +func (s int64Sorter) Len() int { return len(s) } +func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] } +func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +// Remove duplicate items from a sorted list. +// Takes the same interface as sort.Sort +// Returns the length of the date without duplicates +// Uses the last occurance of a duplicate. +// O(n). +func unique(data sort.Interface) int { + if data.Len() == 0 { + return 0 + } + length := data.Len() + j := 0 + for i := 1; i < length; i++ { + if data.Less(i-1, i) { + data.Swap(i-1, j) + j++ + } + } + data.Swap(length-1, j) + return j + 1 +} diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/sql.go b/src/github.com/matrix-org/dendrite/roomserver/storage/sql.go index b373f8309..c03793ecb 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/sql.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/sql.go @@ -19,8 +19,15 @@ type statements struct { selectRoomNIDStmt *sql.Stmt insertEventStmt *sql.Stmt bulkSelectStateEventByIDStmt *sql.Stmt + bulkSelectStateAtEventByIDStmt *sql.Stmt + updateEventStateStmt *sql.Stmt insertEventJSONStmt *sql.Stmt bulkSelectEventJSONStmt *sql.Stmt + insertStateStmt *sql.Stmt + bulkSelectStateDataNIDsStmt *sql.Stmt + insertStateDataStmt *sql.Stmt + selectNextStateDataNIDStmt *sql.Stmt + bulkSelectStateDataEntriesStmt *sql.Stmt } func (s *statements) prepare(db *sql.DB) error { @@ -333,6 +340,13 @@ CREATE TABLE IF NOT EXISTS events ( -- Local numeric ID for the state_key of the event -- This is 0 if the event is not a state event. event_state_key_nid BIGINT NOT NULL, + -- Local numeric ID for the state at the event. + -- This is 0 if we don't know the state at the event. + -- If the state is not 0 this this event is part of the contiguous + -- part of the event graph + -- Since many different events can have the same state we store the + -- state into a separate state table and refer to it by numeric ID. + state_nid bigint NOT NULL DEFAULT 0 -- The textual event id. -- Used to lookup the numeric ID when processing requests. -- Needed for state resolution. @@ -351,7 +365,7 @@ const insertEventSQL = "" + " VALUES ($1, $2, $3, $4, $5, $6)" + " ON CONFLICT ON CONSTRAINT event_id_unique" + " DO UPDATE SET event_id = $1" + - " RETURNING event_nid" + " RETURNING event_nid, state_nid" // Bulk lookup of events by string ID. // Sort by the numeric IDs for event type and state key. @@ -361,6 +375,14 @@ const bulkSelectStateEventByIDSQL = "" + " WHERE event_id = ANY($1)" + " ORDER BY event_type_nid, event_state_key_nid ASC" +const bulkSelectStateAtEventByIDSQL = "" + + "SELECT event_type_nid, event_state_key_nid, event_nid, state_nid FROM events" + + " WHERE event_id = ANY($1)" + + " ORDER BY event_type_nid, event_state_key_nid ASC" + +const updateEventStateSQL = "" + + "UPDATE events SET state_nid = $2 WHERE event_nid = $1" + func (s *statements) prepareEvents(db *sql.DB) (err error) { _, err = db.Exec(eventsSchema) if err != nil { @@ -372,6 +394,12 @@ func (s *statements) prepareEvents(db *sql.DB) (err error) { if s.bulkSelectStateEventByIDStmt, err = db.Prepare(bulkSelectStateEventByIDSQL); err != nil { return } + if s.bulkSelectStateAtEventByIDStmt, err = db.Prepare(bulkSelectStateAtEventByIDSQL); err != nil { + return + } + if s.updateEventStateStmt, err = db.Prepare(updateEventStateSQL); err != nil { + return + } return } @@ -380,11 +408,11 @@ func (s *statements) insertEvent( eventID string, referenceSHA256 []byte, authEventNIDs []int64, -) (eventNID int64, err error) { +) (eventNID, stateNID int64, err error) { err = s.insertEventStmt.QueryRow( roomNID, eventTypeNID, eventStateKeyNID, eventID, referenceSHA256, pq.Int64Array(authEventNIDs), - ).Scan(&eventNID) + ).Scan(&eventNID, &stateNID) return } @@ -421,6 +449,39 @@ func (s *statements) bulkSelectStateEventByID(eventIDs []string) ([]types.StateE return results, err } +func (s *statements) bulkSelectStateAtEventByID(eventIDs []string) ([]types.StateAtEvent, error) { + rows, err := s.bulkSelectStateAtEventByIDStmt.Query(pq.StringArray(eventIDs)) + if err != nil { + return nil, err + } + defer rows.Close() + results := make([]types.StateAtEvent, len(eventIDs)) + i := 0 + for ; rows.Next(); i++ { + result := &results[i] + if err = rows.Scan( + &result.EventNID, + &result.EventTypeNID, + &result.EventStateKeyNID, + &result.BeforeStateNID, + ); err != nil { + return nil, err + } + if result.BeforeStateNID == 0 { + return nil, fmt.Errorf("storage: missing state for event NID %d", result.EventNID) + } + } + if i != len(eventIDs) { + return nil, fmt.Errorf("storage: event IDs missing from the database (%d != %d)", i, len(eventIDs)) + } + return results, err +} + +func (s *statements) updateEventState(eventNID, stateNID int64) error { + _, err := s.updateEventStateStmt.Exec(eventNID, stateNID) + return err +} + func (s *statements) prepareEventJSON(db *sql.DB) (err error) { _, err = db.Exec(eventJSONSchema) if err != nil { @@ -494,3 +555,182 @@ func (s *statements) bulkSelectEventJSON(eventNIDs []int64) ([]eventJSONPair, er } return results[:i], nil } + +const stateSchema = ` +-- The state of a room before an event. +-- Stored as a list of state_data entries stored in a separate table. +-- The actual state is constructed by combining all the state_data entries +-- referenced by state_data_nids together. If the same state key tuple appears +-- multiple times then the entry from the later state_data clobbers the earlier +-- entries. +-- This encoding format allows us to implement a delta encoding which is useful +-- because room state tends to accumulate small changes over time. Although if +-- the list of deltas becomes too long it becomes more efficient to encode +-- the full state under single state_data_nid. +CREATE SEQUENCE IF NOT EXISTS state_nid_seq; +CREATE TABLE IF NOT EXISTS state ( + -- Local numeric ID for the state. + state_nid bigint PRIMARY KEY DEFAULT nextval('state_nid_seq'), + -- Local numeric ID of the room this state is for. + -- Unused in normal operation, but useful for background work or ad-hoc debugging. + room_nid bigint NOT NULL, + -- List of state_data_nids, stored sorted by state_data_nid. + state_data_nids bigint[] NOT NULL +); +` + +const insertStateSQL = "" + + "INSERT INTO state (room_nid, state_data_nids)" + + " VALUES ($1, $2)" + + " RETURNING state_nid" + +const bulkSelectStateDataNIDsSQL = "" + + "SELECT state_nid, state_data_nids FROM state" + + " WHERE state_nid = ANY($1) ORDER BY state_nid" + +func (s *statements) prepareState(db *sql.DB) (err error) { + _, err = db.Exec(stateSchema) + if err != nil { + return + } + if s.insertStateStmt, err = db.Prepare(insertStateSQL); err != nil { + return + } + if s.bulkSelectStateDataNIDsStmt, err = db.Prepare(bulkSelectStateDataNIDsSQL); err != nil { + return + } + return +} + +func (s *statements) insertState(roomNID int64, stateDataNIDs []int64) (stateNID int64, err error) { + err = s.insertStateStmt.QueryRow(roomNID, pq.Int64Array(stateDataNIDs)).Scan(&stateNID) + return +} + +func (s *statements) bulkSelectStateDataNIDs(stateNIDs []int64) ([]types.StateDataNIDList, error) { + rows, err := s.bulkSelectStateDataNIDsStmt.Query(pq.Int64Array(stateNIDs)) + if err != nil { + return nil, err + } + defer rows.Close() + results := make([]types.StateDataNIDList, len(stateNIDs)) + i := 0 + for ; rows.Next(); i++ { + result := &results[i] + var stateDataNids pq.Int64Array + if err := rows.Scan(&result.StateNID, &stateDataNids); err != nil { + return nil, err + } + result.StateDataNIDs = stateDataNids + } + if i != len(stateNIDs) { + return nil, fmt.Errorf("storage: state NIDs missing from the database (%d != %d)", i, len(stateNIDs)) + } + return results, nil +} + +const stateDataSchema = ` +-- The state data map. +-- Designed to give enough information to run the state resolution algorithm +-- without hitting the database in the common case. +-- TODO: Is it worth replacing the unique btree index with a covering index so +-- that postgres could lookup the state using an index-only scan? +-- The type and state_key are included in the index to make it easier to +-- lookup a specific (type, state_key) pair for an event. It also makes it easy +-- to read the state for a given state_data_nid ordered by (type, state_key) +-- which in turn makes it easier to merge state data blocks. +CREATE SEQUENCE IF NOT EXISTS state_data_nid_seq; +CREATE TABLE IF NOT EXISTS state_data ( + -- Local numeric ID for this state data. + state_data_nid bigint NOT NULL, + event_type_nid bigint NOT NULL, + event_state_key_nid bigint NOT NULL, + event_nid bigint NOT NULL, + UNIQUE (state_data_nid, event_type_nid, event_state_key_nid) +); +` + +const insertStateDataSQL = "" + + "INSERT INTO state_data (state_data_nid, event_type_nid, event_state_key_nid, event_nid)" + + " VALUES ($1, $2, $3, $4)" + +const selectNextStateDataNIDSQL = "" + + "SELECT nextval('state_data_nid_seq')" + +const bulkSelectStateDataEntriesSQL = "" + + "SELECT state_data_nid, event_type_nid, event_state_key_nid, event_nid" + + " FROM state_data WHERE state_data_nid = ANY($1)" + + " ORDER BY state_data_nid, event_type_nid, event_state_key_nid" + +func (s *statements) prepareStateData(db *sql.DB) (err error) { + _, err = db.Exec(stateDataSchema) + if err != nil { + return + } + if s.insertStateDataStmt, err = db.Prepare(insertStateDataSQL); err != nil { + return + } + if s.selectNextStateDataNIDStmt, err = db.Prepare(selectNextStateDataNIDSQL); err != nil { + return + } + + if s.bulkSelectStateDataEntriesStmt, err = db.Prepare(bulkSelectStateDataEntriesSQL); err != nil { + return + } + return +} + +func (s *statements) bulkInsertStateData(stateDataNID int64, entries []types.StateEntry) error { + for _, entry := range entries { + _, err := s.insertStateDataStmt.Exec( + stateDataNID, + entry.EventTypeNID, + entry.EventStateKeyNID, + entry.EventNID, + ) + if err != nil { + return err + } + } + return nil +} + +func (s *statements) selectNextStateDataNID() (stateDataNID int64, err error) { + err = s.selectNextStateDataNIDStmt.QueryRow().Scan(&stateDataNID) + return +} + +func (s *statements) bulkSelectStateDataEntries(stateDataNIDs []int64) ([]types.StateEntryList, error) { + rows, err := s.bulkSelectStateDataEntriesStmt.Query(pq.Int64Array(stateDataNIDs)) + if err != nil { + return nil, err + } + defer rows.Close() + + results := make([]types.StateEntryList, len(stateDataNIDs)) + // current is a pointer to the StateEntryList to append the state entries to. + var current *types.StateEntryList + i := 0 + for rows.Next() { + var stateDataNID int64 + var entry types.StateEntry + if err := rows.Scan( + &stateDataNID, + &entry.EventTypeNID, &entry.EventStateKeyNID, &entry.EventNID, + ); err != nil { + return nil, err + } + if current == nil || stateDataNID != current.StateDataNID { + // The state entry row is for a different state data block to the current one. + // So we start appending to the next entry in the list. + current = &results[i] + current.StateDataNID = stateDataNID + i++ + } + current.StateEntries = append(current.StateEntries, entry) + } + if i != len(stateDataNIDs) { + return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(stateDataNIDs)) + } + return results, nil +} diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go b/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go index d42c1f191..7d46a9698 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go @@ -38,21 +38,22 @@ func (d *Database) SetPartitionOffset(topic string, partition int32, offset int6 } // StoreEvent implements input.EventDatabase -func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []int64) error { +func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []int64) (int64, types.StateAtEvent, error) { var ( roomNID int64 eventTypeNID int64 eventStateKeyNID int64 eventNID int64 + stateNID int64 err error ) if roomNID, err = d.assignRoomNID(event.RoomID()); err != nil { - return err + return 0, types.StateAtEvent{}, err } if eventTypeNID, err = d.assignEventTypeNID(event.Type()); err != nil { - return err + return 0, types.StateAtEvent{}, err } eventStateKey := event.StateKey() @@ -60,11 +61,11 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []int // Otherwise set the numeric ID for the state_key to 0. if eventStateKey != nil { if eventStateKeyNID, err = d.assignStateKeyNID(*eventStateKey); err != nil { - return err + return 0, types.StateAtEvent{}, err } } - if eventNID, err = d.statements.insertEvent( + if eventNID, stateNID, err = d.statements.insertEvent( roomNID, eventTypeNID, eventStateKeyNID, @@ -72,10 +73,23 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []int event.EventReference().EventSHA256, authEventNIDs, ); err != nil { - return err + return 0, types.StateAtEvent{}, err } - return d.statements.insertEventJSON(eventNID, event.JSON()) + if err = d.statements.insertEventJSON(eventNID, event.JSON()); err != nil { + return 0, types.StateAtEvent{}, err + } + + return roomNID, types.StateAtEvent{ + BeforeStateNID: stateNID, + StateEntry: types.StateEntry{ + StateKeyTuple: types.StateKeyTuple{ + EventTypeNID: eventTypeNID, + EventStateKeyNID: eventStateKeyNID, + }, + EventNID: eventNID, + }, + }, nil } func (d *Database) assignRoomNID(roomID string) (int64, error) { @@ -145,3 +159,39 @@ func (d *Database) Events(eventNIDs []int64) ([]types.Event, error) { } return results, nil } + +// AddState implements input.EventDatabase +func (d *Database) AddState(roomNID int64, stateDataNIDs []int64, state []types.StateEntry) (stateNID int64, err error) { + if len(state) > 0 { + var stateDataNID int64 + if stateDataNID, err = d.statements.selectNextStateDataNID(); err != nil { + return + } + if err = d.statements.bulkInsertStateData(stateDataNID, state); err != nil { + return + } + stateDataNIDs = append(stateDataNIDs[:len(stateDataNIDs):len(stateDataNIDs)], stateDataNID) + } + + return d.statements.insertState(roomNID, stateDataNIDs) +} + +// SetState implements input.EventDatabase +func (d *Database) SetState(eventNID, stateNID int64) error { + return d.statements.updateEventState(eventNID, stateNID) +} + +// StateAtEventIDs implements input.EventDatabase +func (d *Database) StateAtEventIDs(eventIDs []string) ([]types.StateAtEvent, error) { + return d.statements.bulkSelectStateAtEventByID(eventIDs) +} + +// StateDataNIDs implements input.EventDatabase +func (d *Database) StateDataNIDs(stateNIDs []int64) ([]types.StateDataNIDList, error) { + return d.statements.bulkSelectStateDataNIDs(stateNIDs) +} + +// StateEntries implements input.EventDatabase +func (d *Database) StateEntries(stateDataNIDs []int64) ([]types.StateEntryList, error) { + return d.statements.bulkSelectStateDataEntries(stateDataNIDs) +} diff --git a/src/github.com/matrix-org/dendrite/roomserver/types/types.go b/src/github.com/matrix-org/dendrite/roomserver/types/types.go index 0c43baf89..096547a6f 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/types/types.go +++ b/src/github.com/matrix-org/dendrite/roomserver/types/types.go @@ -47,6 +47,14 @@ func (a StateEntry) LessThan(b StateEntry) bool { return a.EventNID < b.EventNID } +// StateAtEvent is the state before and after a matrix event. +type StateAtEvent struct { + // The state before the event. + BeforeStateNID int64 + // The state entry for the event itself, allows us to calculate the state after the event. + StateEntry +} + // An Event is a gomatrixserverlib.Event with the numeric event ID attached. // It is when performing bulk event lookup in the database. type Event struct { @@ -75,3 +83,15 @@ const ( // EmptyStateKeyNID is the numeric ID for the empty state key. EmptyStateKeyNID = 1 ) + +// StateDataNIDList is used to return the result of bulk StateDataNID lookups from the database. +type StateDataNIDList struct { + StateNID int64 + StateDataNIDs []int64 +} + +// StateEntryList is used to return the result of bulk state entry lookups from the database. +type StateEntryList struct { + StateDataNID int64 + StateEntries []StateEntry +}