From fc4eb85379e4251ee77393cebf1a611004df6888 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 9 Feb 2017 16:48:14 +0000 Subject: [PATCH] Check that events pass authentication checks. (#4) * Check that events pass authentication checks. Record the list of events that the event passes authentication checks against. --- .../dendrite/roomserver/api/input.go | 8 +- .../dendrite/roomserver/input/events.go | 216 +++++++++++++++++- .../dendrite/roomserver/input/events_test.go | 112 +++++++++ .../dendrite/roomserver/storage/sql.go | 144 ++++++++++-- .../dendrite/roomserver/storage/storage.go | 32 ++- .../dendrite/roomserver/types/types.go | 67 ++++++ 6 files changed, 560 insertions(+), 19 deletions(-) create mode 100644 src/github.com/matrix-org/dendrite/roomserver/input/events_test.go diff --git a/src/github.com/matrix-org/dendrite/roomserver/api/input.go b/src/github.com/matrix-org/dendrite/roomserver/api/input.go index bfa4a9d58..fd177b19d 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/api/input.go +++ b/src/github.com/matrix-org/dendrite/roomserver/api/input.go @@ -30,7 +30,13 @@ type InputRoomEvent struct { Kind int // The event JSON for the event to add. Event []byte + // List of state event IDs that authenticate this event. + // These are likely derived from the "auth_events" JSON key of the event. + // But can be different because the "auth_events" key can be incomplete or wrong. + // For example many matrix events forget to reference the m.room.create event even though it is needed for auth. + // (since synapse allows this to happen we have to allow it as well.) + AuthEventIDs []string // Optional list of state event IDs forming the state before this event. // These state events must have already been persisted. - State []string + StateEventIDs []string } diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/events.go b/src/github.com/matrix-org/dendrite/roomserver/input/events.go index dfeff9536..8943578db 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/input/events.go +++ b/src/github.com/matrix-org/dendrite/roomserver/input/events.go @@ -2,12 +2,26 @@ package input import ( "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" + "sort" ) // A RoomEventDatabase has the storage APIs needed to store a room event. type RoomEventDatabase interface { - StoreEvent(event gomatrixserverlib.Event) error + // Stores a matrix room event in the database + StoreEvent(event gomatrixserverlib.Event, authEventNIDs []int64) error + // Lookup the state entries for a list of string event IDs + // Returns a sorted list of state entries. + // Returns a error if the there is an error talking to the database + // or if the event IDs aren't in the database. + StateEntriesForEventIDs(eventIDs []string) ([]types.StateEntry, error) + // Lookup the numeric IDs for a list of string event state keys. + // Returns a map from string state key to numeric ID for the state key. + EventStateKeyNIDs(eventStateKeys []string) (map[string]int64, error) + // Lookup the Events for a list of numeric event IDs. + // Returns a sorted list of events. + Events(eventNIDs []int64) ([]types.Event, error) } func processRoomEvent(db RoomEventDatabase, input api.InputRoomEvent) error { @@ -17,12 +31,16 @@ func processRoomEvent(db RoomEventDatabase, input api.InputRoomEvent) error { return err } - if err := db.StoreEvent(event); err != nil { + // Check that the event passes authentication checks and work out the numeric IDs for the auth events. + authEventNIDs, err := checkAuthEvents(db, event, input.AuthEventIDs) + if err != nil { return err } - // TODO: - // * Check that the event passes authentication checks. + // Store the event + if err := db.StoreEvent(event, authEventNIDs); err != nil { + return err + } if input.Kind == api.KindOutlier { // For outliers we can stop after we've stored the event itself as it @@ -44,3 +62,193 @@ func processRoomEvent(db RoomEventDatabase, input api.InputRoomEvent) error { // - The changes to the current state of the room. panic("Not implemented") } + +// checkAuthEvents checks that the event passes authentication checks +// Returns the numeric IDs for the auth events. +func checkAuthEvents(db RoomEventDatabase, event gomatrixserverlib.Event, authEventIDs []string) ([]int64, error) { + // Grab the numeric IDs for the supplied auth state events from the database. + authStateEntries, err := db.StateEntriesForEventIDs(authEventIDs) + if err != nil { + return nil, err + } + // TODO: check for duplicate state keys here. + + // Work out which of the state events we actually need. + stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event}) + + // Load the actual auth events from the database. + authEvents, err := loadAuthEvents(db, stateNeeded, authStateEntries) + if err != nil { + return nil, err + } + + // Check if the event is allowed. + if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil { + return nil, err + } + + // Return the numeric IDs for the auth events. + result := make([]int64, len(authStateEntries)) + for i := range authStateEntries { + result[i] = authStateEntries[i].EventNID + } + return result, nil +} + +type authEvents struct { + stateKeyNIDMap map[string]int64 + state stateEntryMap + events eventMap +} + +// Create implements gomatrixserverlib.AuthEvents +func (ae *authEvents) Create() (*gomatrixserverlib.Event, error) { + return ae.lookupEventWithEmptyStateKey(types.MRoomCreateNID), nil +} + +// PowerLevels implements gomatrixserverlib.AuthEvents +func (ae *authEvents) PowerLevels() (*gomatrixserverlib.Event, error) { + return ae.lookupEventWithEmptyStateKey(types.MRoomPowerLevelsNID), nil +} + +// JoinRules implements gomatrixserverlib.AuthEvents +func (ae *authEvents) JoinRules() (*gomatrixserverlib.Event, error) { + return ae.lookupEventWithEmptyStateKey(types.MRoomJoinRulesNID), nil +} + +// Memmber implements gomatrixserverlib.AuthEvents +func (ae *authEvents) Member(stateKey string) (*gomatrixserverlib.Event, error) { + return ae.lookupEvent(types.MRoomMemberNID, stateKey), nil +} + +// ThirdPartyInvite implements gomatrixserverlib.AuthEvents +func (ae *authEvents) ThirdPartyInvite(stateKey string) (*gomatrixserverlib.Event, error) { + return ae.lookupEvent(types.MRoomThirdPartyInviteNID, stateKey), nil +} + +func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID int64) *gomatrixserverlib.Event { + eventNID, ok := ae.state.lookup(types.StateKeyTuple{typeNID, types.EmptyStateKeyNID}) + if !ok { + return nil + } + event, ok := ae.events.lookup(eventNID) + if !ok { + return nil + } + return &event.Event +} + +func (ae *authEvents) lookupEvent(typeNID int64, stateKey string) *gomatrixserverlib.Event { + stateKeyNID, ok := ae.stateKeyNIDMap[stateKey] + if !ok { + return nil + } + eventNID, ok := ae.state.lookup(types.StateKeyTuple{typeNID, stateKeyNID}) + if !ok { + return nil + } + event, ok := ae.events.lookup(eventNID) + if !ok { + return nil + } + return &event.Event +} + +// loadAuthEvents loads the events needed for authentication from the supplied room state. +func loadAuthEvents( + db RoomEventDatabase, + needed gomatrixserverlib.StateNeeded, + state []types.StateEntry, +) (result authEvents, err error) { + // Lookup the numeric IDs for the state keys needed for auth. + var neededStateKeys []string + neededStateKeys = append(neededStateKeys, needed.Member...) + neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...) + if result.stateKeyNIDMap, err = db.EventStateKeyNIDs(neededStateKeys); err != nil { + return + } + + // Load the events we need. + result.state = state + var eventNIDs []int64 + keyTuplesNeeded := stateKeyTuplesNeeded(result.stateKeyNIDMap, needed) + for _, keyTuple := range keyTuplesNeeded { + eventNID, ok := result.state.lookup(keyTuple) + if ok { + eventNIDs = append(eventNIDs, eventNID) + } + } + if result.events, err = db.Events(eventNIDs); err != nil { + return + } + return +} + +// stateKeyTuplesNeeded works out which numeric state key tuples we need to authenticate some events. +func stateKeyTuplesNeeded(stateKeyNIDMap map[string]int64, stateNeeded gomatrixserverlib.StateNeeded) []types.StateKeyTuple { + var keyTuples []types.StateKeyTuple + if stateNeeded.Create { + keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomCreateNID, types.EmptyStateKeyNID}) + } + if stateNeeded.PowerLevels { + keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomPowerLevelsNID, types.EmptyStateKeyNID}) + } + if stateNeeded.JoinRules { + keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomJoinRulesNID, types.EmptyStateKeyNID}) + } + for _, member := range stateNeeded.Member { + stateKeyNID, ok := stateKeyNIDMap[member] + if ok { + keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomMemberNID, stateKeyNID}) + } + } + for _, token := range stateNeeded.ThirdPartyInvite { + stateKeyNID, ok := stateKeyNIDMap[token] + if ok { + keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomThirdPartyInviteNID, stateKeyNID}) + } + } + return keyTuples +} + +// Map from event type, state key tuple to numeric event ID. +// Implemented using binary search on a sorted array. +type stateEntryMap []types.StateEntry + +// lookup an entry in the event map. +func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID int64, ok bool) { + // Since the list is sorted we can implement this using binary search. + // This is faster than using a hash map. + // We don't have to worry about pathological cases because the keys are fixed + // size and are controlled by us. + list := []types.StateEntry(m) + i := sort.Search(len(list), func(i int) bool { + return !list[i].StateKeyTuple.LessThan(stateKey) + }) + if i < len(list) && list[i].StateKeyTuple == stateKey { + ok = true + eventNID = list[i].EventNID + } + return +} + +// Map from numeric event ID to event. +// Implemented using binary search on a sorted array. +type eventMap []types.Event + +// lookup an entry in the event map. +func (m eventMap) lookup(eventNID int64) (event *types.Event, ok bool) { + // Since the list is sorted we can implement this using binary search. + // This is faster than using a hash map. + // We don't have to worry about pathological cases because the keys are fixed + // size are controlled by us. + list := []types.Event(m) + i := sort.Search(len(list), func(i int) bool { + return list[i].EventNID >= eventNID + }) + if i < len(list) && list[i].EventNID == eventNID { + ok = true + event = &list[i] + } + return +} diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/events_test.go b/src/github.com/matrix-org/dendrite/roomserver/input/events_test.go new file mode 100644 index 000000000..aba1de092 --- /dev/null +++ b/src/github.com/matrix-org/dendrite/roomserver/input/events_test.go @@ -0,0 +1,112 @@ +package input + +import ( + "github.com/matrix-org/dendrite/roomserver/types" + "testing" +) + +func benchmarkStateEntryMapLookup(entries, lookups int64, b *testing.B) { + var list []types.StateEntry + for i := int64(0); i < entries; i++ { + list = append(list, types.StateEntry{types.StateKeyTuple{i, i}, i}) + } + + for i := 0; i < b.N; i++ { + entryMap := stateEntryMap(list) + for j := int64(0); j < lookups; j++ { + entryMap.lookup(types.StateKeyTuple{j, j}) + } + } +} + +func BenchmarkStateEntryMap100Lookup10(b *testing.B) { + benchmarkStateEntryMapLookup(100, 10, b) +} + +func BenchmarkStateEntryMap1000Lookup100(b *testing.B) { + benchmarkStateEntryMapLookup(1000, 100, b) +} + +func BenchmarkStateEntryMap100Lookup100(b *testing.B) { + benchmarkStateEntryMapLookup(100, 100, b) +} + +func BenchmarkStateEntryMap1000Lookup10000(b *testing.B) { + benchmarkStateEntryMapLookup(1000, 10000, b) +} + +func TestStateEntryMap(t *testing.T) { + entryMap := stateEntryMap([]types.StateEntry{ + {types.StateKeyTuple{1, 1}, 1}, + {types.StateKeyTuple{1, 3}, 2}, + {types.StateKeyTuple{2, 1}, 3}, + }) + + testCases := []struct { + inputTypeNID int64 + inputStateKey int64 + wantOK bool + wantEventNID int64 + }{ + // Check that tuples that in the array are in the map. + {1, 1, true, 1}, + {1, 3, true, 2}, + {2, 1, true, 3}, + // Check that tuples that aren't in the array aren't in the map. + {0, 0, false, 0}, + {1, 2, false, 0}, + {3, 1, false, 0}, + } + + for _, testCase := range testCases { + keyTuple := types.StateKeyTuple{testCase.inputTypeNID, testCase.inputStateKey} + gotEventNID, gotOK := entryMap.lookup(keyTuple) + if testCase.wantOK != gotOK { + t.Fatalf("stateEntryMap lookup(%v): want ok to be %v, got %v", keyTuple, testCase.wantOK, gotOK) + } + if testCase.wantEventNID != gotEventNID { + t.Fatalf("stateEntryMap lookup(%v): want eventNID to be %v, got %v", keyTuple, testCase.wantEventNID, gotEventNID) + } + } +} + +func TestEventMap(t *testing.T) { + events := eventMap([]types.Event{ + {EventNID: 1}, + {EventNID: 2}, + {EventNID: 3}, + {EventNID: 5}, + {EventNID: 8}, + }) + + testCases := []struct { + inputEventNID int64 + wantOK bool + wantEvent *types.Event + }{ + // Check that the IDs that are in the array are in the map. + {1, true, &events[0]}, + {2, true, &events[1]}, + {3, true, &events[2]}, + {5, true, &events[3]}, + {8, true, &events[4]}, + // Check that tuples that aren't in the array aren't in the map. + {0, false, nil}, + {4, false, nil}, + {6, false, nil}, + {7, false, nil}, + {9, false, nil}, + } + + for _, testCase := range testCases { + gotEvent, gotOK := events.lookup(testCase.inputEventNID) + if testCase.wantOK != gotOK { + t.Fatalf("eventMap lookup(%v): want ok to be %v, got %v", testCase.inputEventNID, testCase.wantOK, gotOK) + } + + if testCase.wantEvent != gotEvent { + t.Fatalf("eventMap lookup(%v): want event to be %v, got %v", testCase.inputEventNID, testCase.wantEvent, gotEvent) + } + } + +} diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/sql.go b/src/github.com/matrix-org/dendrite/roomserver/storage/sql.go index 0a3d6b573..b373f8309 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/sql.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/sql.go @@ -2,20 +2,25 @@ package storage import ( "database/sql" + "fmt" + "github.com/lib/pq" "github.com/matrix-org/dendrite/roomserver/types" ) type statements struct { - selectPartitionOffsetsStmt *sql.Stmt - upsertPartitionOffsetStmt *sql.Stmt - insertEventTypeNIDStmt *sql.Stmt - selectEventTypeNIDStmt *sql.Stmt - insertEventStateKeyNIDStmt *sql.Stmt - selectEventStateKeyNIDStmt *sql.Stmt - insertRoomNIDStmt *sql.Stmt - selectRoomNIDStmt *sql.Stmt - insertEventStmt *sql.Stmt - insertEventJSONStmt *sql.Stmt + selectPartitionOffsetsStmt *sql.Stmt + upsertPartitionOffsetStmt *sql.Stmt + insertEventTypeNIDStmt *sql.Stmt + selectEventTypeNIDStmt *sql.Stmt + insertEventStateKeyNIDStmt *sql.Stmt + selectEventStateKeyNIDStmt *sql.Stmt + bulkSelectEventStateKeyNIDStmt *sql.Stmt + insertRoomNIDStmt *sql.Stmt + selectRoomNIDStmt *sql.Stmt + insertEventStmt *sql.Stmt + bulkSelectStateEventByIDStmt *sql.Stmt + insertEventJSONStmt *sql.Stmt + bulkSelectEventJSONStmt *sql.Stmt } func (s *statements) prepare(db *sql.DB) error { @@ -196,6 +201,9 @@ func (s *statements) prepareEventStateKeys(db *sql.DB) (err error) { if s.selectEventStateKeyNIDStmt, err = db.Prepare(selectEventStateKeyNIDSQL); err != nil { return } + if s.bulkSelectEventStateKeyNIDStmt, err = db.Prepare(bulkSelectEventStateKeyNIDSQL); err != nil { + return + } return } @@ -230,6 +238,12 @@ const insertEventStateKeyNIDSQL = "" + const selectEventStateKeyNIDSQL = "" + "SELECT event_state_key_nid FROM event_state_keys WHERE event_state_key = $1" +// Bulk lookup from string state key to numeric ID for that state key. +// Takes an array of strings as the query parameter. +const bulkSelectEventStateKeyNIDSQL = "" + + "SELECT event_state_key, event_state_key_nid FROM event_state_keys" + + " WHERE event_state_key = ANY($1)" + func (s *statements) insertEventStateKeyNID(eventStateKey string) (eventStateKeyNID int64, err error) { err = s.insertEventStateKeyNIDStmt.QueryRow(eventStateKey).Scan(&eventStateKeyNID) return @@ -240,6 +254,25 @@ func (s *statements) selectEventStateKeyNID(eventStateKey string) (eventStateKey return } +func (s *statements) bulkSelectEventStateKeyNID(eventStateKeys []string) (map[string]int64, error) { + rows, err := s.bulkSelectEventStateKeyNIDStmt.Query(pq.StringArray(eventStateKeys)) + if err != nil { + return nil, err + } + defer rows.Close() + + result := make(map[string]int64, len(eventStateKeys)) + for rows.Next() { + var stateKey string + var stateKeyNID int64 + if err := rows.Scan(&stateKey, &stateKeyNID); err != nil { + return nil, err + } + result[stateKey] = stateKeyNID + } + return result, nil +} + func (s *statements) prepareRooms(db *sql.DB) (err error) { _, err = db.Exec(roomsSchema) if err != nil { @@ -307,17 +340,27 @@ CREATE TABLE IF NOT EXISTS events ( event_id TEXT NOT NULL CONSTRAINT event_id_unique UNIQUE, -- The sha256 reference hash for the event. -- Needed for setting reference hashes when sending new events. - reference_sha256 BYTEA NOT NULL + reference_sha256 BYTEA NOT NULL, + -- A list of numeric IDs for events that can authenticate this event. + auth_event_nids BIGINT[] NOT NULL, ); ` const insertEventSQL = "" + - "INSERT INTO events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256)" + - " VALUES ($1, $2, $3, $4, $5)" + + "INSERT INTO events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids)" + + " VALUES ($1, $2, $3, $4, $5, $6)" + " ON CONFLICT ON CONSTRAINT event_id_unique" + " DO UPDATE SET event_id = $1" + " RETURNING event_nid" +// Bulk lookup of events by string ID. +// Sort by the numeric IDs for event type and state key. +// This means we can use binary search to lookup entries by type and state key. +const bulkSelectStateEventByIDSQL = "" + + "SELECT event_type_nid, event_state_key_nid, event_nid FROM events" + + " WHERE event_id = ANY($1)" + + " ORDER BY event_type_nid, event_state_key_nid ASC" + func (s *statements) prepareEvents(db *sql.DB) (err error) { _, err = db.Exec(eventsSchema) if err != nil { @@ -326,6 +369,9 @@ func (s *statements) prepareEvents(db *sql.DB) (err error) { if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil { return } + if s.bulkSelectStateEventByIDStmt, err = db.Prepare(bulkSelectStateEventByIDSQL); err != nil { + return + } return } @@ -333,13 +379,48 @@ func (s *statements) insertEvent( roomNID, eventTypeNID, eventStateKeyNID int64, eventID string, referenceSHA256 []byte, + authEventNIDs []int64, ) (eventNID int64, err error) { err = s.insertEventStmt.QueryRow( roomNID, eventTypeNID, eventStateKeyNID, eventID, referenceSHA256, + pq.Int64Array(authEventNIDs), ).Scan(&eventNID) return } +func (s *statements) bulkSelectStateEventByID(eventIDs []string) ([]types.StateEntry, error) { + rows, err := s.bulkSelectStateEventByIDStmt.Query(pq.StringArray(eventIDs)) + if err != nil { + return nil, err + } + defer rows.Close() + // We know that we will only get as many results as event IDs + // because of the unique constraint on event IDs. + // So we can allocate an array of the correct size now. + // We might get fewer results than IDs so we adjust the length of the slice before returning it. + results := make([]types.StateEntry, len(eventIDs)) + i := 0 + for ; rows.Next(); i++ { + result := &results[i] + if err = rows.Scan( + &result.EventNID, + &result.EventTypeNID, + &result.EventStateKeyNID, + ); err != nil { + return nil, err + } + } + if i != len(eventIDs) { + // If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have. + // We don't know which ones were missing because we don't return the string IDs in the query. + // However it should be possible debug this by replaying queries or entries from the input kafka logs. + // If this turns out to be impossible and we do need the debug information here, it would be better + // to do it as a separate query rather than slowing down/complicating the common case. + return nil, fmt.Errorf("storage: state event IDs missing from the database (%d != %d)", i, len(eventIDs)) + } + return results, err +} + func (s *statements) prepareEventJSON(db *sql.DB) (err error) { _, err = db.Exec(eventJSONSchema) if err != nil { @@ -348,6 +429,9 @@ func (s *statements) prepareEventJSON(db *sql.DB) (err error) { if s.insertEventJSONStmt, err = db.Prepare(insertEventJSONSQL); err != nil { return } + if s.bulkSelectEventJSONStmt, err = db.Prepare(bulkSelectEventJSONSQL); err != nil { + return + } return } @@ -372,7 +456,41 @@ const insertEventJSONSQL = "" + "INSERT INTO event_json (event_nid, event_json) VALUES ($1, $2)" + " ON CONFLICT DO NOTHING" +// Bulk event JSON lookup by numeric event ID. +// Sort by the numeric event ID. +// This means that we can use binary search to lookup by numeric event ID. +const bulkSelectEventJSONSQL = "" + + "SELECT event_nid, event_json FROM event_json" + + " WHERE event_nid = ANY($1)" + + " ORDER BY event_nid ASC" + func (s *statements) insertEventJSON(eventNID int64, eventJSON []byte) error { _, err := s.insertEventJSONStmt.Exec(eventNID, eventJSON) return err } + +type eventJSONPair struct { + EventNID int64 + EventJSON []byte +} + +func (s *statements) bulkSelectEventJSON(eventNIDs []int64) ([]eventJSONPair, error) { + rows, err := s.bulkSelectEventJSONStmt.Query(pq.Int64Array(eventNIDs)) + if err != nil { + return nil, err + } + defer rows.Close() + + // We know that we will only get as many results as event NIDs + // because of the unique constraint on event NIDs. + // So we can allocate an array of the correct size now. + // We might get fewer results than NIDs so we adjust the length of the slice before returning it. + results := make([]eventJSONPair, len(eventNIDs)) + i := 0 + for ; rows.Next(); i++ { + if err := rows.Scan(&results[i].EventNID, &results[i].EventJSON); err != nil { + return nil, err + } + } + return results[:i], nil +} diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go b/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go index 1d6d7a327..d42c1f191 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go @@ -38,7 +38,7 @@ func (d *Database) SetPartitionOffset(topic string, partition int32, offset int6 } // StoreEvent implements input.EventDatabase -func (d *Database) StoreEvent(event gomatrixserverlib.Event) error { +func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []int64) error { var ( roomNID int64 eventTypeNID int64 @@ -70,6 +70,7 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event) error { eventStateKeyNID, event.EventID(), event.EventReference().EventSHA256, + authEventNIDs, ); err != nil { return err } @@ -115,3 +116,32 @@ func (d *Database) assignStateKeyNID(eventStateKey string) (int64, error) { } return eventStateKeyNID, nil } + +// StateEntriesForEventIDs implements input.EventDatabase +func (d *Database) StateEntriesForEventIDs(eventIDs []string) ([]types.StateEntry, error) { + return d.statements.bulkSelectStateEventByID(eventIDs) +} + +// EventStateKeyNIDs implements input.EventDatabase +func (d *Database) EventStateKeyNIDs(eventStateKeys []string) (map[string]int64, error) { + return d.statements.bulkSelectEventStateKeyNID(eventStateKeys) +} + +// Events implements input.EventDatabase +func (d *Database) Events(eventNIDs []int64) ([]types.Event, error) { + eventJSONs, err := d.statements.bulkSelectEventJSON(eventNIDs) + if err != nil { + return nil, err + } + results := make([]types.Event, len(eventJSONs)) + for i, eventJSON := range eventJSONs { + result := &results[i] + result.EventNID = eventJSON.EventNID + // TODO: Use NewEventFromTrustedJSON for efficiency + result.Event, err = gomatrixserverlib.NewEventFromUntrustedJSON(eventJSON.EventJSON) + if err != nil { + return nil, err + } + } + return results, nil +} diff --git a/src/github.com/matrix-org/dendrite/roomserver/types/types.go b/src/github.com/matrix-org/dendrite/roomserver/types/types.go index 3c6dd5bb3..0c43baf89 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/types/types.go +++ b/src/github.com/matrix-org/dendrite/roomserver/types/types.go @@ -1,6 +1,10 @@ // Package types provides the types that are used internally within the roomserver. package types +import ( + "github.com/matrix-org/gomatrixserverlib" +) + // A PartitionOffset is the offset into a partition of the input log. type PartitionOffset struct { // The ID of the partition. @@ -8,3 +12,66 @@ type PartitionOffset struct { // The offset into the partition. Offset int64 } + +// A StateKeyTuple is a pair of a numeric event type and a numeric state key. +// It is used to lookup state entries. +type StateKeyTuple struct { + // The numeric ID for the event type. + EventTypeNID int64 + // The numeric ID for the state key. + EventStateKeyNID int64 +} + +// LessThan returns true if this state key is less than the other state key. +// The ordering is arbitrary and is used to implement binary search and to efficiently deduplicate entries. +func (a StateKeyTuple) LessThan(b StateKeyTuple) bool { + if a.EventTypeNID != b.EventTypeNID { + return a.EventTypeNID < b.EventTypeNID + } + return a.EventStateKeyNID < b.EventStateKeyNID +} + +// A StateEntry is an entry in the room state of a matrix room. +type StateEntry struct { + StateKeyTuple + // The numeric ID for the event. + EventNID int64 +} + +// LessThan returns true if this state entry is less than the other state entry. +// The ordering is arbitrary and is used to implement binary search and to efficiently deduplicate entries. +func (a StateEntry) LessThan(b StateEntry) bool { + if a.StateKeyTuple != b.StateKeyTuple { + return a.StateKeyTuple.LessThan(b.StateKeyTuple) + } + return a.EventNID < b.EventNID +} + +// An Event is a gomatrixserverlib.Event with the numeric event ID attached. +// It is when performing bulk event lookup in the database. +type Event struct { + EventNID int64 + gomatrixserverlib.Event +} + +const ( + // MRoomCreateNID is the numeric ID for the "m.room.create" event type. + MRoomCreateNID = 1 + // MRoomPowerLevelsNID is the numeric ID for the "m.room.power_levels" event type. + MRoomPowerLevelsNID = 2 + // MRoomJoinRulesNID is the numeric ID for the "m.room.join_rules" event type. + MRoomJoinRulesNID = 3 + // MRoomThirdPartyInviteNID is the numeric ID for the "m.room.third_party_invite" event type. + MRoomThirdPartyInviteNID = 4 + // MRoomMemberNID is the numeric ID for the "m.room.member" event type. + MRoomMemberNID = 5 + // MRoomRedactionNID is the numeric ID for the "m.room.redaction" event type. + MRoomRedactionNID = 6 + // MRoomHistoryVisibilityNID is the numeric ID for the "m.room.history_visibility" event type. + MRoomHistoryVisibilityNID = 7 +) + +const ( + // EmptyStateKeyNID is the numeric ID for the empty state key. + EmptyStateKeyNID = 1 +)