diff --git a/src/github.com/matrix-org/dendrite/roomserver/input/events.go b/src/github.com/matrix-org/dendrite/roomserver/input/events.go index ccd115d51..5d6cc057d 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/input/events.go +++ b/src/github.com/matrix-org/dendrite/roomserver/input/events.go @@ -62,6 +62,7 @@ func processRoomEvent(db RoomEventDatabase, input api.InputRoomEvent) error { // checkAuthEvents checks that the event passes authentication checks // Returns the numeric IDs for the auth events. func checkAuthEvents(db RoomEventDatabase, event gomatrixserverlib.Event, authEventIDs []string) ([]int64, error) { + // Grab the numeric IDs for the supplied auth state events from the database. authStateEntries, err := db.StateEntriesForEventIDs(authEventIDs) if err != nil { return nil, err @@ -70,17 +71,21 @@ func checkAuthEvents(db RoomEventDatabase, event gomatrixserverlib.Event, authEv return nil, fmt.Errorf("input: Some of the auth event IDs were missing from the database") } + // Work out which of the state events we actaully need. stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event}) + // Load the actual auth events from the database. authEvents, err := loadAuthEvents(db, stateNeeded, authStateEntries) if err != nil { return nil, err } + // Check if the event is allowed. if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil { return nil, err } + // Return the numeric IDs for the auth events. result := make([]int64, len(authStateEntries)) for i := range authStateEntries { result[i] = authStateEntries[i].EventNID @@ -142,6 +147,7 @@ func (ae *authEvents) lookupEvent(typeNID int64, stateKey string) *gomatrixserve return &event.Event } +// loadAuthEvents loads the events needed for authentication from the supplied room state. func loadAuthEvents( db RoomEventDatabase, needed gomatrixserverlib.StateNeeded, @@ -174,6 +180,7 @@ func loadAuthEvents( return } +// stateKeysNeeded works out which numeric state keys we need to authenticate some events. func stateKeysNeeded(stateNIDMap idMap, stateNeeded gomatrixserverlib.StateNeeded) []types.StateKey { var keys []types.StateKey if stateNeeded.Create { @@ -210,18 +217,31 @@ func newIDMap(ids []types.IDPair) idMap { return idMap(result) } +// lookup an entry in the id map. func (m idMap) lookup(id string) (nid int64, ok bool) { + // Use a hash map here. + // We could use binary search here like we do for the maps below as it + // would be faster for small lists. + // However the benefits of binary search aren't as strong here and it's + // possible that we could encounter sets of pathological strings since + // the state keys are ultimately controlled by user input. nid, ok = map[string]int64(m)[id] return } type stateEntryMap []types.StateEntry +// newStateEntryMap creates a map from a sorted list of state entries. func newStateEntryMap(stateEntries []types.StateEntry) stateEntryMap { return stateEntryMap(stateEntries) } +// lookup an entry in the event map. func (m stateEntryMap) lookup(stateKey types.StateKey) (eventNID int64, ok bool) { + // Since the list is sorted we can implement this using binary search. + // This is faster than using a hash map. + // We don't have to worry about pathological cases because the keys are fixed + // size and are controlled by us. list := []types.StateEntry(m) i := sort.Search(len(list), func(i int) bool { return !list[i].StateKey.LessThan(stateKey) @@ -235,7 +255,17 @@ func (m stateEntryMap) lookup(stateKey types.StateKey) (eventNID int64, ok bool) type eventMap []types.Event +// newEventMap creates a map from a sorted list of events. +func newEventMap(events []types.Event) eventMap { + return eventMap(events) +} + +// lookup an entry in the event map. func (m eventMap) lookup(eventNID int64) (event *types.Event, ok bool) { + // Since the list is sorted we can implement this using binary search. + // This is faster than using a hash map. + // We don't have to worry about pathological cases because the keys are fixed + // size are controlled by us. list := []types.Event(m) i := sort.Search(len(list), func(i int) bool { return list[i].EventNID >= eventNID