diff --git a/go.mod b/go.mod index 580dc3568..c9a7e09b0 100644 --- a/go.mod +++ b/go.mod @@ -34,7 +34,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 - github.com/matrix-org/gomatrixserverlib v0.0.0-20220530084946-3a4b148706bc + github.com/matrix-org/gomatrixserverlib v0.0.0-20220531163017-35e1cabf12ee github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.13 diff --git a/go.sum b/go.sum index 75c22a7a7..4460b3905 100644 --- a/go.sum +++ b/go.sum @@ -418,8 +418,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220530084946-3a4b148706bc h1:58tT3VznINdRWimb3yYb8QWmTAHX9AAuyOwzdmrp9q4= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220530084946-3a4b148706bc/go.mod h1:jX38yp3SSLJNftBg3PXU1ayd0PCLIiDHQ4xAc9DIixk= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220531163017-35e1cabf12ee h1:56sxEWrwB3eOmwjP2S4JsrQf29uBUaf+8WrbQJmjaGE= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220531163017-35e1cabf12ee/go.mod h1:jX38yp3SSLJNftBg3PXU1ayd0PCLIiDHQ4xAc9DIixk= github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48 h1:W0sjjC6yjskHX4mb0nk3p0fXAlbU5bAFUFeEtlrPASE= github.com/matrix-org/pinecone v0.0.0-20220408153826-2999ea29ed48/go.mod h1:ulJzsVOTssIVp1j/m5eI//4VpAGDkMt5NrRuAVX7wpc= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go index a1b094871..59b3fcb12 100644 --- a/roomserver/internal/input/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -206,7 +206,7 @@ func (u *latestEventsUpdater) latestState() error { // Work out if the state at the extremities has actually changed // or not. If they haven't then we won't bother doing all of the // hard work. - if u.event.StateKey() == nil { + if !u.stateAtEvent.IsStateEvent() { stateChanged := false oldStateNIDs := make([]types.StateSnapshotNID, 0, len(u.oldLatest)) newStateNIDs := make([]types.StateSnapshotNID, 0, len(u.latest)) diff --git a/roomserver/state/state.go b/roomserver/state/state.go index 187b996cd..95abdcb36 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -39,6 +39,7 @@ type StateResolutionStorage interface { StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) + EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) } type StateResolution struct { @@ -659,15 +660,13 @@ func (v *StateResolution) calculateStateAfterManyEvents( } // Collect all the entries with the same type and key together. - // We don't care about the order here because the conflict resolution - // algorithm doesn't depend on the order of the prev events. - // Remove duplicate entires. + // This is done so findDuplicateStateKeys can work in groups. + // We remove duplicates (same type, state key and event NID) too. combined = combined[:util.SortAndUnique(stateEntrySorter(combined))] // Find the conflicts - conflicts := findDuplicateStateKeys(combined) - - if len(conflicts) > 0 { + if conflicts := findDuplicateStateKeys(combined); len(conflicts) > 0 { + conflictMap := stateEntryMap(conflicts) conflictLength = len(conflicts) // 5) There are conflicting state events, for each conflict workout @@ -676,7 +675,7 @@ func (v *StateResolution) calculateStateAfterManyEvents( // Work out which entries aren't conflicted. var notConflicted []types.StateEntry for _, entry := range combined { - if _, ok := stateEntryMap(conflicts).lookup(entry.StateKeyTuple); !ok { + if _, ok := conflictMap.lookup(entry.StateKeyTuple); !ok { notConflicted = append(notConflicted, entry) } } @@ -689,7 +688,7 @@ func (v *StateResolution) calculateStateAfterManyEvents( return } algorithm = "full_state_with_conflicts" - state = resolved[:util.SortAndUnique(stateEntrySorter(resolved))] + state = resolved } else { algorithm = "full_state_no_conflicts" // 6) There weren't any conflicts @@ -818,39 +817,19 @@ func (v *StateResolution) resolveConflictsV2( authDifference := make([]*gomatrixserverlib.Event, 0, estimate) // For each conflicted event, let's try and get the needed auth events. - neededStateKeys := make([]string, 16) - authEntries := make([]types.StateEntry, 16) for _, conflictedEvent := range conflictedEvents { // Work out which auth events we need to load. key := conflictedEvent.EventID() - needed := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{conflictedEvent}) - - // Find the numeric IDs for the necessary state keys. - neededStateKeys = neededStateKeys[:0] - neededStateKeys = append(neededStateKeys, needed.Member...) - neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...) - stateKeyNIDMap, err := v.db.EventStateKeyNIDs(ctx, neededStateKeys) - if err != nil { - return nil, err - } - - // Load the necessary auth events. - tuplesNeeded := v.stateKeyTuplesNeeded(stateKeyNIDMap, needed) - authEntries = authEntries[:0] - for _, tuple := range tuplesNeeded { - if eventNID, ok := stateEntryMap(notConflicted).lookup(tuple); ok { - authEntries = append(authEntries, types.StateEntry{ - StateKeyTuple: tuple, - EventNID: eventNID, - }) - } - } // Store the newly found auth events in the auth set for this event. - authSets[key], _, err = v.loadStateEvents(ctx, authEntries) + var authEventMap map[string]types.StateEntry + authSets[key], authEventMap, err = v.loadAuthEvents(ctx, conflictedEvent) if err != nil { return nil, err } + for k, v := range authEventMap { + eventIDMap[k] = v + } // Only add auth events into the authEvents slice once, otherwise the // check for the auth difference can become expensive and produce @@ -909,7 +888,7 @@ func (v *StateResolution) resolveConflictsV2( for _, resolvedEvent := range resolvedEvents { entry, ok := eventIDMap[resolvedEvent.EventID()] if !ok { - panic(fmt.Errorf("missing state entry for event ID %q", resolvedEvent.EventID())) + return nil, fmt.Errorf("missing state entry for event ID %q", resolvedEvent.EventID()) } notConflicted = append(notConflicted, entry) } @@ -996,6 +975,84 @@ func (v *StateResolution) loadStateEvents( return result, eventIDMap, nil } +// loadAuthEvents loads all of the auth events for a given event recursively, +// along with a map that contains state entries for all of the auth events. +func (v *StateResolution) loadAuthEvents( + ctx context.Context, event *gomatrixserverlib.Event, +) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) { + eventMap := map[string]struct{}{} + var lookup []string + var authEvents []types.Event + queue := event.AuthEventIDs() + for i := 0; i < len(queue); i++ { + lookup = lookup[:0] + for _, authEventID := range queue { + if _, ok := eventMap[authEventID]; ok { + continue + } + lookup = append(lookup, authEventID) + } + if len(lookup) == 0 { + break + } + events, err := v.db.EventsFromIDs(ctx, lookup) + if err != nil { + return nil, nil, fmt.Errorf("v.db.EventsFromIDs: %w", err) + } + add := map[string]struct{}{} + for _, event := range events { + eventMap[event.EventID()] = struct{}{} + authEvents = append(authEvents, event) + for _, authEventID := range event.AuthEventIDs() { + if _, ok := eventMap[authEventID]; ok { + continue + } + add[authEventID] = struct{}{} + } + for authEventID := range add { + queue = append(queue, authEventID) + } + } + } + authEventTypes := map[string]struct{}{} + authEventStateKeys := map[string]struct{}{} + for _, authEvent := range authEvents { + authEventTypes[authEvent.Type()] = struct{}{} + authEventStateKeys[*authEvent.StateKey()] = struct{}{} + } + lookupAuthEventTypes := make([]string, 0, len(authEventTypes)) + lookupAuthEventStateKeys := make([]string, 0, len(authEventStateKeys)) + for eventType := range authEventTypes { + lookupAuthEventTypes = append(lookupAuthEventTypes, eventType) + } + for eventStateKey := range authEventStateKeys { + lookupAuthEventStateKeys = append(lookupAuthEventStateKeys, eventStateKey) + } + eventTypes, err := v.db.EventTypeNIDs(ctx, lookupAuthEventTypes) + if err != nil { + return nil, nil, fmt.Errorf("v.db.EventTypeNIDs: %w", err) + } + eventStateKeys, err := v.db.EventStateKeyNIDs(ctx, lookupAuthEventStateKeys) + if err != nil { + return nil, nil, fmt.Errorf("v.db.EventStateKeyNIDs: %w", err) + } + stateEntryMap := map[string]types.StateEntry{} + for _, authEvent := range authEvents { + stateEntryMap[authEvent.EventID()] = types.StateEntry{ + EventNID: authEvent.EventNID, + StateKeyTuple: types.StateKeyTuple{ + EventTypeNID: eventTypes[authEvent.Type()], + EventStateKeyNID: eventStateKeys[*authEvent.StateKey()], + }, + } + } + nakedEvents := make([]*gomatrixserverlib.Event, 0, len(authEvents)) + for _, authEvent := range authEvents { + nakedEvents = append(nakedEvents, authEvent.Event) + } + return nakedEvents, stateEntryMap, nil +} + // findDuplicateStateKeys finds the state entries where the state key tuple appears more than once in a sorted list. // Returns a sorted list of those state entries. func findDuplicateStateKeys(a []types.StateEntry) []types.StateEntry { diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index d4a2ee3b9..8f4e011bf 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -192,6 +192,10 @@ func (u *RoomUpdater) StateAtEventIDs( return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs) } +func (u *RoomUpdater) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { + return u.d.eventsFromIDs(ctx, u.txn, eventIDs, false) +} + func (u *RoomUpdater) UnsentEventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { return u.d.eventsFromIDs(ctx, u.txn, eventIDs, true) }