diff --git a/roomserver/input/events.go b/roomserver/input/events.go index 034b06c11..ab2bbe1cd 100644 --- a/roomserver/input/events.go +++ b/roomserver/input/events.go @@ -156,11 +156,8 @@ func calculateAndSetState( stateAtEvent *types.StateAtEvent, event gomatrixserverlib.Event, ) error { - // TODO: get the correct room version - roomState, err := state.GetStateResolutionAlgorithm(state.StateResolutionAlgorithmV1, db) - if err != nil { - return err - } + var err error + roomState := state.NewStateResolution(db) if input.HasState { // We've been told what the state at the event is so we don't need to calculate it. diff --git a/roomserver/input/latest_events.go b/roomserver/input/latest_events.go index 4d75daae9..525a6f518 100644 --- a/roomserver/input/latest_events.go +++ b/roomserver/input/latest_events.go @@ -178,11 +178,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { func (u *latestEventsUpdater) latestState() error { var err error - // TODO: get the correct room version - roomState, err := state.GetStateResolutionAlgorithm(state.StateResolutionAlgorithmV1, u.db) - if err != nil { - return err - } + roomState := state.NewStateResolution(u.db) latestStateAtEvents := make([]types.StateAtEvent, len(u.latest)) for i := range u.latest { diff --git a/roomserver/query/query.go b/roomserver/query/query.go index e0f385883..2638919ad 100644 --- a/roomserver/query/query.go +++ b/roomserver/query/query.go @@ -106,11 +106,14 @@ func (r *RoomserverQueryAPI) QueryLatestEventsAndState( request *api.QueryLatestEventsAndStateRequest, response *api.QueryLatestEventsAndStateResponse, ) error { - // TODO: get the correct room version - roomState, err := state.GetStateResolutionAlgorithm(state.StateResolutionAlgorithmV1, r.DB) + roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, request.RoomID) if err != nil { - return err + response.RoomExists = false + return nil } + + roomState := state.NewStateResolution(r.DB) + response.QueryLatestEventsAndStateRequest = *request roomNID, err := r.DB.RoomNID(ctx, request.RoomID) if err != nil { @@ -121,11 +124,6 @@ func (r *RoomserverQueryAPI) QueryLatestEventsAndState( } response.RoomExists = true - roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, request.RoomID) - if err != nil { - return err - } - var currentStateSnapshotNID types.StateSnapshotNID response.LatestEvents, currentStateSnapshotNID, response.Depth, err = r.DB.LatestEventIDs(ctx, roomNID) @@ -159,11 +157,14 @@ func (r *RoomserverQueryAPI) QueryStateAfterEvents( request *api.QueryStateAfterEventsRequest, response *api.QueryStateAfterEventsResponse, ) error { - // TODO: get the correct room version - roomState, err := state.GetStateResolutionAlgorithm(state.StateResolutionAlgorithmV1, r.DB) + roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, request.RoomID) if err != nil { - return err + response.RoomExists = false + return nil } + + roomState := state.NewStateResolution(r.DB) + response.QueryStateAfterEventsRequest = *request roomNID, err := r.DB.RoomNID(ctx, request.RoomID) if err != nil { @@ -174,11 +175,6 @@ func (r *RoomserverQueryAPI) QueryStateAfterEvents( } response.RoomExists = true - roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, request.RoomID) - if err != nil { - return err - } - prevStates, err := r.DB.StateAtEventIDs(ctx, request.PrevEventIDs) if err != nil { switch err.(type) { @@ -192,7 +188,7 @@ func (r *RoomserverQueryAPI) QueryStateAfterEvents( // Look up the currrent state for the requested tuples. stateEntries, err := roomState.LoadStateAfterEventsForStringTuples( - ctx, prevStates, request.StateToFetch, + ctx, roomNID, prevStates, request.StateToFetch, ) if err != nil { return err @@ -358,11 +354,7 @@ func (r *RoomserverQueryAPI) QueryMembershipsForRoom( func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID( ctx context.Context, eventNID types.EventNID, joinedOnly bool, ) ([]types.Event, error) { - // TODO: get the correct room version - roomState, err := state.GetStateResolutionAlgorithm(state.StateResolutionAlgorithmV1, r.DB) - if err != nil { - return []types.Event{}, err - } + roomState := state.NewStateResolution(r.DB) events := []types.Event{} // Lookup the event NID eIDs, err := r.DB.EventIDs(ctx, []types.EventNID{eventNID}) @@ -464,12 +456,7 @@ func (r *RoomserverQueryAPI) QueryServerAllowedToSeeEvent( func (r *RoomserverQueryAPI) checkServerAllowedToSeeEvent( ctx context.Context, eventID string, serverName gomatrixserverlib.ServerName, ) (bool, error) { - // TODO: get the correct room version - roomState, err := state.GetStateResolutionAlgorithm(state.StateResolutionAlgorithmV1, r.DB) - if err != nil { - return false, err - } - + roomState := state.NewStateResolution(r.DB) stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID) if err != nil { return false, err @@ -689,12 +676,7 @@ func (r *RoomserverQueryAPI) QueryStateAndAuthChain( } func (r *RoomserverQueryAPI) loadStateAtEventIDs(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) { - // TODO: get the correct room version - roomState, err := state.GetStateResolutionAlgorithm(state.StateResolutionAlgorithmV1, r.DB) - if err != nil { - return nil, err - } - + roomState := state.NewStateResolution(r.DB) prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs) if err != nil { switch err.(type) { diff --git a/roomserver/state/database/database.go b/roomserver/state/database/database.go index ede6c5ec3..80f1b14f4 100644 --- a/roomserver/state/database/database.go +++ b/roomserver/state/database/database.go @@ -20,6 +20,7 @@ import ( "context" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" ) // A RoomStateDatabase has the storage APIs needed to load state from the database @@ -61,4 +62,6 @@ type RoomStateDatabase interface { Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) // Look up snapshot NID for an event ID string SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) + // Look up a room version from the room NID. + GetRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error) } diff --git a/roomserver/state/shared/shared.go b/roomserver/state/shared/shared.go new file mode 100644 index 000000000..a29b5e403 --- /dev/null +++ b/roomserver/state/shared/shared.go @@ -0,0 +1 @@ +package shared diff --git a/roomserver/state/state.go b/roomserver/state/state.go index 687a120e3..b8e3e18a1 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -19,63 +19,1069 @@ package state import ( "context" "errors" + "fmt" + "sort" + "time" "github.com/matrix-org/dendrite/roomserver/state/database" - v1 "github.com/matrix-org/dendrite/roomserver/state/v1" + "github.com/matrix-org/util" + "github.com/prometheus/client_golang/prometheus" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) -type StateResolutionVersion int +type StateResolution struct { + db database.RoomStateDatabase +} -const ( - StateResolutionAlgorithmV1 StateResolutionVersion = iota + 1 - StateResolutionAlgorithmV2 -) - -func GetStateResolutionAlgorithm( - version StateResolutionVersion, db database.RoomStateDatabase, -) (StateResolutionImpl, error) { - switch version { - case StateResolutionAlgorithmV1: - return v1.Prepare(db), nil - default: - return nil, errors.New("unsupported room version") +func NewStateResolution(db database.RoomStateDatabase) StateResolution { + return StateResolution{ + db: db, } } -type StateResolutionImpl interface { - LoadStateAtSnapshot( - ctx context.Context, stateNID types.StateSnapshotNID, - ) ([]types.StateEntry, error) - LoadStateAtEvent( - ctx context.Context, eventID string, - ) ([]types.StateEntry, error) - LoadCombinedStateAfterEvents( - ctx context.Context, prevStates []types.StateAtEvent, - ) ([]types.StateEntry, error) - DifferenceBetweeenStateSnapshots( - ctx context.Context, oldStateNID, newStateNID types.StateSnapshotNID, - ) (removed, added []types.StateEntry, err error) - LoadStateAtSnapshotForStringTuples( - ctx context.Context, - stateNID types.StateSnapshotNID, - stateKeyTuples []gomatrixserverlib.StateKeyTuple, - ) ([]types.StateEntry, error) - LoadStateAfterEventsForStringTuples( - ctx context.Context, - prevStates []types.StateAtEvent, - stateKeyTuples []gomatrixserverlib.StateKeyTuple, - ) ([]types.StateEntry, error) - CalculateAndStoreStateBeforeEvent( - ctx context.Context, - event gomatrixserverlib.Event, - roomNID types.RoomNID, - ) (types.StateSnapshotNID, error) - CalculateAndStoreStateAfterEvents( - ctx context.Context, - roomNID types.RoomNID, - prevStates []types.StateAtEvent, - ) (types.StateSnapshotNID, error) +// LoadStateAtSnapshot loads the full state of a room at a particular snapshot. +// This is typically the state before an event or the current state of a room. +// Returns a sorted list of state entries or an error if there was a problem talking to the database. +func (v StateResolution) LoadStateAtSnapshot( + ctx context.Context, stateNID types.StateSnapshotNID, +) ([]types.StateEntry, error) { + stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID}) + if err != nil { + return nil, err + } + // We've asked for exactly one snapshot from the db so we should have exactly one entry in the result. + stateBlockNIDList := stateBlockNIDLists[0] + + stateEntryLists, err := v.db.StateEntries(ctx, stateBlockNIDList.StateBlockNIDs) + if err != nil { + return nil, err + } + stateEntriesMap := stateEntryListMap(stateEntryLists) + + // Combine all the state entries for this snapshot. + // The order of state block NIDs in the list tells us the order to combine them in. + var fullState []types.StateEntry + for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs { + entries, ok := stateEntriesMap.lookup(stateBlockNID) + if !ok { + // This should only get hit if the database is corrupt. + // It should be impossible for an event to reference a NID that doesn't exist + panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID)) + } + fullState = append(fullState, entries...) + } + + // Stable sort so that the most recent entry for each state key stays + // remains later in the list than the older entries for the same state key. + sort.Stable(stateEntryByStateKeySorter(fullState)) + // Unique returns the last entry and hence the most recent entry for each state key. + fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))] + return fullState, nil +} + +// LoadStateAtEvent loads the full state of a room before a particular event. +func (v StateResolution) LoadStateAtEvent( + ctx context.Context, eventID string, +) ([]types.StateEntry, error) { + snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID) + if err != nil { + return nil, err + } + + stateEntries, err := v.LoadStateAtSnapshot(ctx, snapshotNID) + if err != nil { + return nil, err + } + + return stateEntries, nil +} + +// LoadCombinedStateAfterEvents loads a snapshot of the state after each of the events +// and combines those snapshots together into a single list. At this point it is +// possible to run into duplicate (type, state key) tuples. +func (v StateResolution) LoadCombinedStateAfterEvents( + ctx context.Context, prevStates []types.StateAtEvent, +) ([]types.StateEntry, error) { + stateNIDs := make([]types.StateSnapshotNID, len(prevStates)) + for i, state := range prevStates { + stateNIDs[i] = state.BeforeStateSnapshotNID + } + // Fetch the state snapshots for the state before the each prev event from the database. + // Deduplicate the IDs before passing them to the database. + // There could be duplicates because the events could be state events where + // the snapshot of the room state before them was the same. + stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, uniqueStateSnapshotNIDs(stateNIDs)) + if err != nil { + return nil, err + } + + var stateBlockNIDs []types.StateBlockNID + for _, list := range stateBlockNIDLists { + stateBlockNIDs = append(stateBlockNIDs, list.StateBlockNIDs...) + } + // Fetch the state entries that will be combined to create the snapshots. + // Deduplicate the IDs before passing them to the database. + // There could be duplicates because a block of state entries could be reused by + // multiple snapshots. + stateEntryLists, err := v.db.StateEntries(ctx, uniqueStateBlockNIDs(stateBlockNIDs)) + if err != nil { + return nil, err + } + stateBlockNIDsMap := stateBlockNIDListMap(stateBlockNIDLists) + stateEntriesMap := stateEntryListMap(stateEntryLists) + + // Combine the entries from all the snapshots of state after each prev event into a single list. + var combined []types.StateEntry + for _, prevState := range prevStates { + // Grab the list of state data NIDs for this snapshot. + stateBlockNIDs, ok := stateBlockNIDsMap.lookup(prevState.BeforeStateSnapshotNID) + if !ok { + // This should only get hit if the database is corrupt. + // It should be impossible for an event to reference a NID that doesn't exist + panic(fmt.Errorf("Corrupt DB: Missing state snapshot numeric ID %d", prevState.BeforeStateSnapshotNID)) + } + + // Combine all the state entries for this snapshot. + // The order of state block NIDs in the list tells us the order to combine them in. + var fullState []types.StateEntry + for _, stateBlockNID := range stateBlockNIDs { + entries, ok := stateEntriesMap.lookup(stateBlockNID) + if !ok { + // This should only get hit if the database is corrupt. + // It should be impossible for an event to reference a NID that doesn't exist + panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID)) + } + fullState = append(fullState, entries...) + } + if prevState.IsStateEvent() { + // If the prev event was a state event then add an entry for the event itself + // so that we get the state after the event rather than the state before. + fullState = append(fullState, prevState.StateEntry) + } + + // Stable sort so that the most recent entry for each state key stays + // remains later in the list than the older entries for the same state key. + sort.Stable(stateEntryByStateKeySorter(fullState)) + // Unique returns the last entry and hence the most recent entry for each state key. + fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))] + // Add the full state for this StateSnapshotNID. + combined = append(combined, fullState...) + } + return combined, nil +} + +// DifferenceBetweeenStateSnapshots works out which state entries have been added and removed between two snapshots. +func (v StateResolution) DifferenceBetweeenStateSnapshots( + ctx context.Context, oldStateNID, newStateNID types.StateSnapshotNID, +) (removed, added []types.StateEntry, err error) { + if oldStateNID == newStateNID { + // If the snapshot NIDs are the same then nothing has changed + return nil, nil, nil + } + + var oldEntries []types.StateEntry + var newEntries []types.StateEntry + if oldStateNID != 0 { + oldEntries, err = v.LoadStateAtSnapshot(ctx, oldStateNID) + if err != nil { + return nil, nil, err + } + } + if newStateNID != 0 { + newEntries, err = v.LoadStateAtSnapshot(ctx, newStateNID) + if err != nil { + return nil, nil, err + } + } + + var oldI int + var newI int + for { + switch { + case oldI == len(oldEntries): + // We've reached the end of the old entries. + // The rest of the new list must have been newly added. + added = append(added, newEntries[newI:]...) + return + case newI == len(newEntries): + // We've reached the end of the new entries. + // The rest of the old list must be have been removed. + removed = append(removed, oldEntries[oldI:]...) + return + case oldEntries[oldI] == newEntries[newI]: + // The entry is in both lists so skip over it. + oldI++ + newI++ + case oldEntries[oldI].LessThan(newEntries[newI]): + // The lists are sorted so the old entry being less than the new entry means that it only appears in the old list. + removed = append(removed, oldEntries[oldI]) + oldI++ + default: + // Reaching the default case implies that the new entry is less than the old entry. + // Since the lists are sorted this means that it only appears in the new list. + added = append(added, newEntries[newI]) + newI++ + } + } +} + +// LoadStateAtSnapshotForStringTuples loads the state for a list of event type and state key pairs at a snapshot. +// This is used when we only want to load a subset of the room state at a snapshot. +// If there is no entry for a given event type and state key pair then it will be discarded. +// This is typically the state before an event or the current state of a room. +// Returns a sorted list of state entries or an error if there was a problem talking to the database. +func (v StateResolution) LoadStateAtSnapshotForStringTuples( + ctx context.Context, + stateNID types.StateSnapshotNID, + stateKeyTuples []gomatrixserverlib.StateKeyTuple, +) ([]types.StateEntry, error) { + numericTuples, err := v.stringTuplesToNumericTuples(ctx, stateKeyTuples) + if err != nil { + return nil, err + } + return v.loadStateAtSnapshotForNumericTuples(ctx, stateNID, numericTuples) +} + +// stringTuplesToNumericTuples converts the string state key tuples into numeric IDs +// If there isn't a numeric ID for either the event type or the event state key then the tuple is discarded. +// Returns an error if there was a problem talking to the database. +func (v StateResolution) stringTuplesToNumericTuples( + ctx context.Context, + stringTuples []gomatrixserverlib.StateKeyTuple, +) ([]types.StateKeyTuple, error) { + eventTypes := make([]string, len(stringTuples)) + stateKeys := make([]string, len(stringTuples)) + for i := range stringTuples { + eventTypes[i] = stringTuples[i].EventType + stateKeys[i] = stringTuples[i].StateKey + } + eventTypes = util.UniqueStrings(eventTypes) + eventTypeMap, err := v.db.EventTypeNIDs(ctx, eventTypes) + if err != nil { + return nil, err + } + stateKeys = util.UniqueStrings(stateKeys) + stateKeyMap, err := v.db.EventStateKeyNIDs(ctx, stateKeys) + if err != nil { + return nil, err + } + + var result []types.StateKeyTuple + for _, stringTuple := range stringTuples { + var numericTuple types.StateKeyTuple + var ok1, ok2 bool + numericTuple.EventTypeNID, ok1 = eventTypeMap[stringTuple.EventType] + numericTuple.EventStateKeyNID, ok2 = stateKeyMap[stringTuple.StateKey] + // Discard the tuple if there wasn't a numeric ID for either the event type or the state key. + if ok1 && ok2 { + result = append(result, numericTuple) + } + } + + return result, nil +} + +// loadStateAtSnapshotForNumericTuples loads the state for a list of event type and state key pairs at a snapshot. +// This is used when we only want to load a subset of the room state at a snapshot. +// If there is no entry for a given event type and state key pair then it will be discarded. +// This is typically the state before an event or the current state of a room. +// Returns a sorted list of state entries or an error if there was a problem talking to the database. +func (v StateResolution) loadStateAtSnapshotForNumericTuples( + ctx context.Context, + stateNID types.StateSnapshotNID, + stateKeyTuples []types.StateKeyTuple, +) ([]types.StateEntry, error) { + stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID}) + if err != nil { + return nil, err + } + // We've asked for exactly one snapshot from the db so we should have exactly one entry in the result. + stateBlockNIDList := stateBlockNIDLists[0] + + stateEntryLists, err := v.db.StateEntriesForTuples( + ctx, stateBlockNIDList.StateBlockNIDs, stateKeyTuples, + ) + if err != nil { + return nil, err + } + stateEntriesMap := stateEntryListMap(stateEntryLists) + + // Combine all the state entries for this snapshot. + // The order of state block NIDs in the list tells us the order to combine them in. + var fullState []types.StateEntry + for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs { + entries, ok := stateEntriesMap.lookup(stateBlockNID) + if !ok { + // If the block is missing from the map it means that none of its entries matched a requested tuple. + // This can happen if the block doesn't contain an update for one of the requested tuples. + // If none of the requested tuples are in the block then it can be safely skipped. + continue + } + fullState = append(fullState, entries...) + } + + // Stable sort so that the most recent entry for each state key stays + // remains later in the list than the older entries for the same state key. + sort.Stable(stateEntryByStateKeySorter(fullState)) + // Unique returns the last entry and hence the most recent entry for each state key. + fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))] + return fullState, nil +} + +// LoadStateAfterEventsForStringTuples loads the state for a list of event type +// and state key pairs after list of events. +// This is used when we only want to load a subset of the room state after a list of events. +// If there is no entry for a given event type and state key pair then it will be discarded. +// This is typically the state before an event. +// Returns a sorted list of state entries or an error if there was a problem talking to the database. +func (v StateResolution) LoadStateAfterEventsForStringTuples( + ctx context.Context, roomNID types.RoomNID, + prevStates []types.StateAtEvent, + stateKeyTuples []gomatrixserverlib.StateKeyTuple, +) ([]types.StateEntry, error) { + numericTuples, err := v.stringTuplesToNumericTuples(ctx, stateKeyTuples) + if err != nil { + return nil, err + } + return v.loadStateAfterEventsForNumericTuples(ctx, roomNID, prevStates, numericTuples) +} + +func (v StateResolution) loadStateAfterEventsForNumericTuples( + ctx context.Context, roomNID types.RoomNID, + prevStates []types.StateAtEvent, + stateKeyTuples []types.StateKeyTuple, +) ([]types.StateEntry, error) { + roomVersion, err := v.db.GetRoomVersionForRoomNID(ctx, roomNID) + if err != nil { + return nil, err + } + + if len(prevStates) == 1 { + // Fast path for a single event. + prevState := prevStates[0] + var result []types.StateEntry + result, err = v.loadStateAtSnapshotForNumericTuples( + ctx, prevState.BeforeStateSnapshotNID, stateKeyTuples, + ) + if err != nil { + return nil, err + } + if prevState.IsStateEvent() { + // The result is current the state before the requested event. + // We want the state after the requested event. + // If the requested event was a state event then we need to + // update that key in the result. + // If the requested event wasn't a state event then the state after + // it is the same as the state before it. + set := false + for i := range result { + if result[i].StateKeyTuple == prevState.StateKeyTuple { + result[i] = prevState.StateEntry + set = true + } + } + if !set { // no previous state exists for this event: add new state + result = append(result, prevState.StateEntry) + } + } + return result, nil + } + + // Slow path for more that one event. + // Load the entire state so that we can do conflict resolution if we need to. + // TODO: The are some optimistations we could do here: + // 1) We only need to do conflict resolution if there is a conflict in the + // requested tuples so we might try loading just those tuples and then + // checking for conflicts. + // 2) When there is a conflict we still only need to load the state + // needed to do conflict resolution which would save us having to load + // the full state. + + // TODO: Add metrics for this as it could take a long time for big rooms + // with large conflicts. + fullState, _, _, err := v.calculateStateAfterManyEvents(ctx, roomVersion, prevStates) + if err != nil { + return nil, err + } + + // Sort the full state so we can use it as a map. + sort.Sort(stateEntrySorter(fullState)) + + // Filter the full state down to the required tuples. + var result []types.StateEntry + for _, tuple := range stateKeyTuples { + eventNID, ok := stateEntryMap(fullState).lookup(tuple) + if ok { + result = append(result, types.StateEntry{ + StateKeyTuple: tuple, + EventNID: eventNID, + }) + } + } + sort.Sort(stateEntrySorter(result)) + return result, nil +} + +var calculateStateDurations = prometheus.NewSummaryVec( + prometheus.SummaryOpts{ + Namespace: "dendrite", + Subsystem: "roomserver", + Name: "calculate_state_duration_microseconds", + Help: "How long it takes to calculate the state after a list of events", + }, + // Takes two labels: + // algorithm: + // The algorithm used to calculate the state or the step it failed on if it failed. + // Labels starting with "_" are used to indicate when the algorithm fails halfway. + // outcome: + // Whether the state was successfully calculated. + // + // The possible values for algorithm are: + // empty_state -> The list of events was empty so the state is empty. + // no_change -> The state hasn't changed. + // single_delta -> There was a single event added to the state in a way that can be encoded as a single delta + // full_state_no_conflicts -> We created a new copy of the full room state, but didn't enounter any conflicts + // while doing so. + // full_state_with_conflicts -> We created a new copy of the full room state and had to resolve conflicts to do so. + // _load_state_block_nids -> Failed loading the state block nids for a single previous state. + // _load_combined_state -> Failed to load the combined state. + // _resolve_conflicts -> Failed to resolve conflicts. + []string{"algorithm", "outcome"}, +) + +var calculateStatePrevEventLength = prometheus.NewSummaryVec( + prometheus.SummaryOpts{ + Namespace: "dendrite", + Subsystem: "roomserver", + Name: "calculate_state_prev_event_length", + Help: "The length of the list of events to calculate the state after", + }, + []string{"algorithm", "outcome"}, +) + +var calculateStateFullStateLength = prometheus.NewSummaryVec( + prometheus.SummaryOpts{ + Namespace: "dendrite", + Subsystem: "roomserver", + Name: "calculate_state_full_state_length", + Help: "The length of the full room state.", + }, + []string{"algorithm", "outcome"}, +) + +var calculateStateConflictLength = prometheus.NewSummaryVec( + prometheus.SummaryOpts{ + Namespace: "dendrite", + Subsystem: "roomserver", + Name: "calculate_state_conflict_state_length", + Help: "The length of the conflicted room state.", + }, + []string{"algorithm", "outcome"}, +) + +type calculateStateMetrics struct { + algorithm string + startTime time.Time + prevEventLength int + fullStateLength int + conflictLength int +} + +func (c *calculateStateMetrics) stop(stateNID types.StateSnapshotNID, err error) (types.StateSnapshotNID, error) { + var outcome string + if err == nil { + outcome = "success" + } else { + outcome = "failure" + } + endTime := time.Now() + calculateStateDurations.WithLabelValues(c.algorithm, outcome).Observe( + float64(endTime.Sub(c.startTime).Nanoseconds()) / 1000., + ) + calculateStatePrevEventLength.WithLabelValues(c.algorithm, outcome).Observe( + float64(c.prevEventLength), + ) + calculateStateFullStateLength.WithLabelValues(c.algorithm, outcome).Observe( + float64(c.fullStateLength), + ) + calculateStateConflictLength.WithLabelValues(c.algorithm, outcome).Observe( + float64(c.conflictLength), + ) + return stateNID, err +} + +func init() { + prometheus.MustRegister( + calculateStateDurations, calculateStatePrevEventLength, + calculateStateFullStateLength, calculateStateConflictLength, + ) +} + +// CalculateAndStoreStateBeforeEvent calculates a snapshot of the state of a room before an event. +// Stores the snapshot of the state in the database. +// Returns a numeric ID for the snapshot of the state before the event. +func (v StateResolution) CalculateAndStoreStateBeforeEvent( + ctx context.Context, + event gomatrixserverlib.Event, + roomNID types.RoomNID, +) (types.StateSnapshotNID, error) { + // Load the state at the prev events. + prevEventRefs := event.PrevEvents() + prevEventIDs := make([]string, len(prevEventRefs)) + for i := range prevEventRefs { + prevEventIDs[i] = prevEventRefs[i].EventID + } + + prevStates, err := v.db.StateAtEventIDs(ctx, prevEventIDs) + if err != nil { + return 0, err + } + + // The state before this event will be the state after the events that came before it. + return v.CalculateAndStoreStateAfterEvents(ctx, roomNID, prevStates) +} + +// CalculateAndStoreStateAfterEvents finds the room state after the given events. +// Stores the resulting state in the database and returns a numeric ID for that snapshot. +func (v StateResolution) CalculateAndStoreStateAfterEvents( + ctx context.Context, + roomNID types.RoomNID, + prevStates []types.StateAtEvent, +) (types.StateSnapshotNID, error) { + metrics := calculateStateMetrics{startTime: time.Now(), prevEventLength: len(prevStates)} + + if len(prevStates) == 0 { + // 2) There weren't any prev_events for this event so the state is + // empty. + metrics.algorithm = "empty_state" + return metrics.stop(v.db.AddState(ctx, roomNID, nil, nil)) + } + + if len(prevStates) == 1 { + prevState := prevStates[0] + if prevState.EventStateKeyNID == 0 { + // 3) None of the previous events were state events and they all + // have the same state, so this event has exactly the same state + // as the previous events. + // This should be the common case. + metrics.algorithm = "no_change" + return metrics.stop(prevState.BeforeStateSnapshotNID, nil) + } + // The previous event was a state event so we need to store a copy + // of the previous state updated with that event. + stateBlockNIDLists, err := v.db.StateBlockNIDs( + ctx, []types.StateSnapshotNID{prevState.BeforeStateSnapshotNID}, + ) + if err != nil { + metrics.algorithm = "_load_state_blocks" + return metrics.stop(0, err) + } + stateBlockNIDs := stateBlockNIDLists[0].StateBlockNIDs + if len(stateBlockNIDs) < maxStateBlockNIDs { + // 4) The number of state data blocks is small enough that we can just + // add the state event as a block of size one to the end of the blocks. + metrics.algorithm = "single_delta" + return metrics.stop(v.db.AddState( + ctx, roomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry}, + )) + } + // If there are too many deltas then we need to calculate the full state + // So fall through to calculateAndStoreStateAfterManyEvents + } + + return v.calculateAndStoreStateAfterManyEvents(ctx, roomNID, prevStates, metrics) +} + +// maxStateBlockNIDs is the maximum number of state data blocks to use to encode a snapshot of room state. +// Increasing this number means that we can encode more of the state changes as simple deltas which means that +// we need fewer entries in the state data table. However making this number bigger will increase the size of +// the rows in the state table itself and will require more index lookups when retrieving a snapshot. +// TODO: Tune this to get the right balance between size and lookup performance. +const maxStateBlockNIDs = 64 + +// calculateAndStoreStateAfterManyEvents finds the room state after the given events. +// This handles the slow path of calculateAndStoreStateAfterEvents for when there is more than one event. +// Stores the resulting state and returns a numeric ID for the snapshot. +func (v StateResolution) calculateAndStoreStateAfterManyEvents( + ctx context.Context, + roomNID types.RoomNID, + prevStates []types.StateAtEvent, + metrics calculateStateMetrics, +) (types.StateSnapshotNID, error) { + roomVersion, err := v.db.GetRoomVersionForRoomNID(ctx, roomNID) + if err != nil { + return metrics.stop(0, err) + } + + state, algorithm, conflictLength, err := + v.calculateStateAfterManyEvents(ctx, roomVersion, prevStates) + metrics.algorithm = algorithm + if err != nil { + return metrics.stop(0, err) + } + + // TODO: Check if we can encode the new state as a delta against the + // previous state. + metrics.conflictLength = conflictLength + metrics.fullStateLength = len(state) + return metrics.stop(v.db.AddState(ctx, roomNID, nil, state)) +} + +func (v StateResolution) calculateStateAfterManyEvents( + ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, + prevStates []types.StateAtEvent, +) (state []types.StateEntry, algorithm string, conflictLength int, err error) { + var combined []types.StateEntry + // Conflict resolution. + // First stage: load the state after each of the prev events. + combined, err = v.LoadCombinedStateAfterEvents(ctx, prevStates) + if err != nil { + algorithm = "_load_combined_state" + return + } + + // Collect all the entries with the same type and key together. + // We don't care about the order here because the conflict resolution + // algorithm doesn't depend on the order of the prev events. + // Remove duplicate entires. + combined = combined[:util.SortAndUnique(stateEntrySorter(combined))] + + // Find the conflicts + conflicts := findDuplicateStateKeys(combined) + + if len(conflicts) > 0 { + conflictLength = len(conflicts) + + // 5) There are conflicting state events, for each conflict workout + // what the appropriate state event is. + + // Work out which entries aren't conflicted. + var notConflicted []types.StateEntry + for _, entry := range combined { + if _, ok := stateEntryMap(conflicts).lookup(entry.StateKeyTuple); !ok { + notConflicted = append(notConflicted, entry) + } + } + + var resolved []types.StateEntry + resolved, err = v.resolveConflicts(ctx, roomVersion, notConflicted, conflicts) + if err != nil { + algorithm = "_resolve_conflicts" + return + } + algorithm = "full_state_with_conflicts" + state = resolved + } else { + algorithm = "full_state_no_conflicts" + // 6) There weren't any conflicts + state = combined + } + return +} + +func (v StateResolution) resolveConflicts( + ctx context.Context, version gomatrixserverlib.RoomVersion, + notConflicted, conflicted []types.StateEntry, +) ([]types.StateEntry, error) { + stateResAlgo, err := version.StateResAlgorithm() + if err != nil { + return nil, err + } + switch stateResAlgo { + case gomatrixserverlib.StateResV1: + return v.resolveConflictsV1(ctx, notConflicted, conflicted) + case gomatrixserverlib.StateResV2: + return v.resolveConflictsV2(ctx, notConflicted, conflicted) + } + return nil, errors.New("unsupported state resolution algorithm") +} + +// resolveConflicts resolves a list of conflicted state entries. It takes two lists. +// The first is a list of all state entries that are not conflicted. +// The second is a list of all state entries that are conflicted +// A state entry is conflicted when there is more than one numeric event ID for the same state key tuple. +// Returns a list that combines the entries without conflicts with the result of state resolution for the entries with conflicts. +// The returned list is sorted by state key tuple. +// Returns an error if there was a problem talking to the database. +func (v StateResolution) resolveConflictsV1( + ctx context.Context, + notConflicted, conflicted []types.StateEntry, +) ([]types.StateEntry, error) { + + // Load the conflicted events + conflictedEvents, eventIDMap, err := v.loadStateEvents(ctx, conflicted) + if err != nil { + return nil, err + } + + // Work out which auth events we need to load. + needed := gomatrixserverlib.StateNeededForAuth(conflictedEvents) + + // Find the numeric IDs for the necessary state keys. + var neededStateKeys []string + neededStateKeys = append(neededStateKeys, needed.Member...) + neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...) + stateKeyNIDMap, err := v.db.EventStateKeyNIDs(ctx, neededStateKeys) + if err != nil { + return nil, err + } + + // Load the necessary auth events. + tuplesNeeded := v.stateKeyTuplesNeeded(stateKeyNIDMap, needed) + var authEntries []types.StateEntry + for _, tuple := range tuplesNeeded { + if eventNID, ok := stateEntryMap(notConflicted).lookup(tuple); ok { + authEntries = append(authEntries, types.StateEntry{ + StateKeyTuple: tuple, + EventNID: eventNID, + }) + } + } + authEvents, _, err := v.loadStateEvents(ctx, authEntries) + if err != nil { + return nil, err + } + + // Resolve the conflicts. + resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents) + + // Map from the full events back to numeric state entries. + for _, resolvedEvent := range resolvedEvents { + entry, ok := eventIDMap[resolvedEvent.EventID()] + if !ok { + panic(fmt.Errorf("Missing state entry for event ID %q", resolvedEvent.EventID())) + } + notConflicted = append(notConflicted, entry) + } + + // Sort the result so it can be searched. + sort.Sort(stateEntrySorter(notConflicted)) + return notConflicted, nil +} + +// resolveConflicts resolves a list of conflicted state entries. It takes two lists. +// The first is a list of all state entries that are not conflicted. +// The second is a list of all state entries that are conflicted +// A state entry is conflicted when there is more than one numeric event ID for the same state key tuple. +// Returns a list that combines the entries without conflicts with the result of state resolution for the entries with conflicts. +// The returned list is sorted by state key tuple. +// Returns an error if there was a problem talking to the database. +// nolint:gocyclo +func (v StateResolution) resolveConflictsV2( + ctx context.Context, + notConflicted, conflicted []types.StateEntry, +) ([]types.StateEntry, error) { + eventIDMap := make(map[string]types.StateEntry) + + // Load the conflicted events + conflictedEvents, conflictedEventMap, err := v.loadStateEvents(ctx, conflicted) + if err != nil { + return nil, err + } + for k, v := range conflictedEventMap { + eventIDMap[k] = v + } + + // Load the non-conflicted events + nonConflictedEvents, nonConflictedEventMap, err := v.loadStateEvents(ctx, notConflicted) + if err != nil { + return nil, err + } + for k, v := range nonConflictedEventMap { + eventIDMap[k] = v + } + + // For each conflicted event, we will add a new set of auth events. Auth + // events may be duplicated across these sets but that's OK. + authSets := make(map[string][]gomatrixserverlib.Event) + var authEvents []gomatrixserverlib.Event + var authDifference []gomatrixserverlib.Event + + // For each conflicted event, let's try and get the needed auth events. + for _, conflictedEvent := range conflictedEvents { + // Work out which auth events we need to load. + key := conflictedEvent.EventID() + needed := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{conflictedEvent}) + + // Find the numeric IDs for the necessary state keys. + var neededStateKeys []string + neededStateKeys = append(neededStateKeys, needed.Member...) + neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...) + stateKeyNIDMap, err := v.db.EventStateKeyNIDs(ctx, neededStateKeys) + if err != nil { + return nil, err + } + + // Load the necessary auth events. + tuplesNeeded := v.stateKeyTuplesNeeded(stateKeyNIDMap, needed) + var authEntries []types.StateEntry + for _, tuple := range tuplesNeeded { + if eventNID, ok := stateEntryMap(notConflicted).lookup(tuple); ok { + authEntries = append(authEntries, types.StateEntry{ + StateKeyTuple: tuple, + EventNID: eventNID, + }) + } + } + + // Store the newly found auth events in the auth set for this event. + authSets[key], _, err = v.loadStateEvents(ctx, authEntries) + if err != nil { + return nil, err + } + authEvents = append(authEvents, authSets[key]...) + } + + // This function helps us to work out whether an event exists in one of the + // auth sets. + isInAuthList := func(k string, event gomatrixserverlib.Event) bool { + for _, e := range authSets[k] { + if e.EventID() == event.EventID() { + return true + } + } + return false + } + + // This function works out if an event exists in all of the auth sets. + isInAllAuthLists := func(event gomatrixserverlib.Event) bool { + found := true + for k := range authSets { + found = found && isInAuthList(k, event) + } + return found + } + + // Look through all of the auth events that we've been given and work out if + // there are any events which don't appear in all of the auth sets. If they + // don't then we add them to the auth difference. + for _, event := range authEvents { + if !isInAllAuthLists(event) { + authDifference = append(authDifference, event) + } + } + + // Resolve the conflicts. + resolvedEvents := gomatrixserverlib.ResolveStateConflictsV2( + conflictedEvents, + nonConflictedEvents, + authEvents, + authDifference, + ) + + // Map from the full events back to numeric state entries. + for _, resolvedEvent := range resolvedEvents { + entry, ok := eventIDMap[resolvedEvent.EventID()] + if !ok { + panic(fmt.Errorf("Missing state entry for event ID %q", resolvedEvent.EventID())) + } + notConflicted = append(notConflicted, entry) + } + + // Sort the result so it can be searched. + sort.Sort(stateEntrySorter(notConflicted)) + return notConflicted, nil +} + +// stateKeyTuplesNeeded works out which numeric state key tuples we need to authenticate some events. +func (v StateResolution) stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.EventStateKeyNID, stateNeeded gomatrixserverlib.StateNeeded) []types.StateKeyTuple { + var keyTuples []types.StateKeyTuple + if stateNeeded.Create { + keyTuples = append(keyTuples, types.StateKeyTuple{ + EventTypeNID: types.MRoomCreateNID, + EventStateKeyNID: types.EmptyStateKeyNID, + }) + } + if stateNeeded.PowerLevels { + keyTuples = append(keyTuples, types.StateKeyTuple{ + EventTypeNID: types.MRoomPowerLevelsNID, + EventStateKeyNID: types.EmptyStateKeyNID, + }) + } + if stateNeeded.JoinRules { + keyTuples = append(keyTuples, types.StateKeyTuple{ + EventTypeNID: types.MRoomJoinRulesNID, + EventStateKeyNID: types.EmptyStateKeyNID, + }) + } + for _, member := range stateNeeded.Member { + stateKeyNID, ok := stateKeyNIDMap[member] + if ok { + keyTuples = append(keyTuples, types.StateKeyTuple{ + EventTypeNID: types.MRoomMemberNID, + EventStateKeyNID: stateKeyNID, + }) + } + } + for _, token := range stateNeeded.ThirdPartyInvite { + stateKeyNID, ok := stateKeyNIDMap[token] + if ok { + keyTuples = append(keyTuples, types.StateKeyTuple{ + EventTypeNID: types.MRoomThirdPartyInviteNID, + EventStateKeyNID: stateKeyNID, + }) + } + } + return keyTuples +} + +// loadStateEvents loads the matrix events for a list of state entries. +// Returns a list of state events in no particular order and a map from string event ID back to state entry. +// The map can be used to recover which numeric state entry a given event is for. +// Returns an error if there was a problem talking to the database. +func (v StateResolution) loadStateEvents( + ctx context.Context, entries []types.StateEntry, +) ([]gomatrixserverlib.Event, map[string]types.StateEntry, error) { + eventNIDs := make([]types.EventNID, len(entries)) + for i := range entries { + eventNIDs[i] = entries[i].EventNID + } + events, err := v.db.Events(ctx, eventNIDs) + if err != nil { + return nil, nil, err + } + eventIDMap := map[string]types.StateEntry{} + result := make([]gomatrixserverlib.Event, len(entries)) + for i := range entries { + event, ok := eventMap(events).lookup(entries[i].EventNID) + if !ok { + panic(fmt.Errorf("Corrupt DB: Missing event numeric ID %d", entries[i].EventNID)) + } + result[i] = event.Event + eventIDMap[event.Event.EventID()] = entries[i] + } + return result, eventIDMap, nil +} + +// findDuplicateStateKeys finds the state entries where the state key tuple appears more than once in a sorted list. +// Returns a sorted list of those state entries. +func findDuplicateStateKeys(a []types.StateEntry) []types.StateEntry { + var result []types.StateEntry + // j is the starting index of a block of entries with the same state key tuple. + j := 0 + for i := 1; i < len(a); i++ { + // Check if the state key tuple matches the start of the block + if a[j].StateKeyTuple != a[i].StateKeyTuple { + // If the state key tuple is different then we've reached the end of a block of duplicates. + // Check if the size of the block is bigger than one. + // If the size is one then there was only a single entry with that state key tuple so we don't add it to the result + if j+1 != i { + // Add the block to the result. + result = append(result, a[j:i]...) + } + // Start a new block for the next state key tuple. + j = i + } + } + // Check if the last block with the same state key tuple had more than one event in it. + if j+1 != len(a) { + result = append(result, a[j:]...) + } + return result +} + +type stateEntrySorter []types.StateEntry + +func (s stateEntrySorter) Len() int { return len(s) } +func (s stateEntrySorter) Less(i, j int) bool { return s[i].LessThan(s[j]) } +func (s stateEntrySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +type stateBlockNIDListMap []types.StateBlockNIDList + +func (m stateBlockNIDListMap) lookup(stateNID types.StateSnapshotNID) (stateBlockNIDs []types.StateBlockNID, ok bool) { + list := []types.StateBlockNIDList(m) + i := sort.Search(len(list), func(i int) bool { + return list[i].StateSnapshotNID >= stateNID + }) + if i < len(list) && list[i].StateSnapshotNID == stateNID { + ok = true + stateBlockNIDs = list[i].StateBlockNIDs + } + return +} + +type stateEntryListMap []types.StateEntryList + +func (m stateEntryListMap) lookup(stateBlockNID types.StateBlockNID) (stateEntries []types.StateEntry, ok bool) { + list := []types.StateEntryList(m) + i := sort.Search(len(list), func(i int) bool { + return list[i].StateBlockNID >= stateBlockNID + }) + if i < len(list) && list[i].StateBlockNID == stateBlockNID { + ok = true + stateEntries = list[i].StateEntries + } + return +} + +type stateEntryByStateKeySorter []types.StateEntry + +func (s stateEntryByStateKeySorter) Len() int { return len(s) } +func (s stateEntryByStateKeySorter) Less(i, j int) bool { + return s[i].StateKeyTuple.LessThan(s[j].StateKeyTuple) +} +func (s stateEntryByStateKeySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +type stateNIDSorter []types.StateSnapshotNID + +func (s stateNIDSorter) Len() int { return len(s) } +func (s stateNIDSorter) Less(i, j int) bool { return s[i] < s[j] } +func (s stateNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +func uniqueStateSnapshotNIDs(nids []types.StateSnapshotNID) []types.StateSnapshotNID { + return nids[:util.SortAndUnique(stateNIDSorter(nids))] +} + +type stateBlockNIDSorter []types.StateBlockNID + +func (s stateBlockNIDSorter) Len() int { return len(s) } +func (s stateBlockNIDSorter) Less(i, j int) bool { return s[i] < s[j] } +func (s stateBlockNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +func uniqueStateBlockNIDs(nids []types.StateBlockNID) []types.StateBlockNID { + return nids[:util.SortAndUnique(stateBlockNIDSorter(nids))] +} + +// Map from event type, state key tuple to numeric event ID. +// Implemented using binary search on a sorted array. +type stateEntryMap []types.StateEntry + +// lookup an entry in the event map. +func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID types.EventNID, ok bool) { + // Since the list is sorted we can implement this using binary search. + // This is faster than using a hash map. + // We don't have to worry about pathological cases because the keys are fixed + // size and are controlled by us. + list := []types.StateEntry(m) + i := sort.Search(len(list), func(i int) bool { + return !list[i].StateKeyTuple.LessThan(stateKey) + }) + if i < len(list) && list[i].StateKeyTuple == stateKey { + ok = true + eventNID = list[i].EventNID + } + return +} + +// Map from numeric event ID to event. +// Implemented using binary search on a sorted array. +type eventMap []types.Event + +// lookup an entry in the event map. +func (m eventMap) lookup(eventNID types.EventNID) (event *types.Event, ok bool) { + // Since the list is sorted we can implement this using binary search. + // This is faster than using a hash map. + // We don't have to worry about pathological cases because the keys are fixed + // size are controlled by us. + list := []types.Event(m) + i := sort.Search(len(list), func(i int) bool { + return list[i].EventNID >= eventNID + }) + if i < len(list) && list[i].EventNID == eventNID { + ok = true + event = &list[i] + } + return } diff --git a/roomserver/state/v1/state_test.go b/roomserver/state/state_test.go similarity index 99% rename from roomserver/state/v1/state_test.go rename to roomserver/state/state_test.go index 4dc7e52ec..c57056786 100644 --- a/roomserver/state/v1/state_test.go +++ b/roomserver/state/state_test.go @@ -14,7 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package v1 +package state import ( "testing" diff --git a/roomserver/state/v1/state.go b/roomserver/state/v1/state.go deleted file mode 100644 index 3eb601925..000000000 --- a/roomserver/state/v1/state.go +++ /dev/null @@ -1,932 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package state provides functions for reading state from the database. -// The functions for writing state to the database are the input package. -package v1 - -import ( - "context" - "fmt" - "sort" - "time" - - "github.com/matrix-org/dendrite/roomserver/state/database" - "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/prometheus/client_golang/prometheus" -) - -type StateResolutionV1 struct { - db database.RoomStateDatabase -} - -func Prepare(db database.RoomStateDatabase) StateResolutionV1 { - return StateResolutionV1{ - db: db, - } -} - -// LoadStateAtSnapshot loads the full state of a room at a particular snapshot. -// This is typically the state before an event or the current state of a room. -// Returns a sorted list of state entries or an error if there was a problem talking to the database. -func (v StateResolutionV1) LoadStateAtSnapshot( - ctx context.Context, stateNID types.StateSnapshotNID, -) ([]types.StateEntry, error) { - stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID}) - if err != nil { - return nil, err - } - // We've asked for exactly one snapshot from the db so we should have exactly one entry in the result. - stateBlockNIDList := stateBlockNIDLists[0] - - stateEntryLists, err := v.db.StateEntries(ctx, stateBlockNIDList.StateBlockNIDs) - if err != nil { - return nil, err - } - stateEntriesMap := stateEntryListMap(stateEntryLists) - - // Combine all the state entries for this snapshot. - // The order of state block NIDs in the list tells us the order to combine them in. - var fullState []types.StateEntry - for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs { - entries, ok := stateEntriesMap.lookup(stateBlockNID) - if !ok { - // This should only get hit if the database is corrupt. - // It should be impossible for an event to reference a NID that doesn't exist - panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID)) - } - fullState = append(fullState, entries...) - } - - // Stable sort so that the most recent entry for each state key stays - // remains later in the list than the older entries for the same state key. - sort.Stable(stateEntryByStateKeySorter(fullState)) - // Unique returns the last entry and hence the most recent entry for each state key. - fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))] - return fullState, nil -} - -// LoadStateAtEvent loads the full state of a room at a particular event. -func (v StateResolutionV1) LoadStateAtEvent( - ctx context.Context, eventID string, -) ([]types.StateEntry, error) { - snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID) - if err != nil { - return nil, err - } - - stateEntries, err := v.LoadStateAtSnapshot(ctx, snapshotNID) - if err != nil { - return nil, err - } - - return stateEntries, nil -} - -// LoadCombinedStateAfterEvents loads a snapshot of the state after each of the events -// and combines those snapshots together into a single list. -func (v StateResolutionV1) LoadCombinedStateAfterEvents( - ctx context.Context, prevStates []types.StateAtEvent, -) ([]types.StateEntry, error) { - stateNIDs := make([]types.StateSnapshotNID, len(prevStates)) - for i, state := range prevStates { - stateNIDs[i] = state.BeforeStateSnapshotNID - } - // Fetch the state snapshots for the state before the each prev event from the database. - // Deduplicate the IDs before passing them to the database. - // There could be duplicates because the events could be state events where - // the snapshot of the room state before them was the same. - stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, uniqueStateSnapshotNIDs(stateNIDs)) - if err != nil { - return nil, err - } - - var stateBlockNIDs []types.StateBlockNID - for _, list := range stateBlockNIDLists { - stateBlockNIDs = append(stateBlockNIDs, list.StateBlockNIDs...) - } - // Fetch the state entries that will be combined to create the snapshots. - // Deduplicate the IDs before passing them to the database. - // There could be duplicates because a block of state entries could be reused by - // multiple snapshots. - stateEntryLists, err := v.db.StateEntries(ctx, uniqueStateBlockNIDs(stateBlockNIDs)) - if err != nil { - return nil, err - } - stateBlockNIDsMap := stateBlockNIDListMap(stateBlockNIDLists) - stateEntriesMap := stateEntryListMap(stateEntryLists) - - // Combine the entries from all the snapshots of state after each prev event into a single list. - var combined []types.StateEntry - for _, prevState := range prevStates { - // Grab the list of state data NIDs for this snapshot. - stateBlockNIDs, ok := stateBlockNIDsMap.lookup(prevState.BeforeStateSnapshotNID) - if !ok { - // This should only get hit if the database is corrupt. - // It should be impossible for an event to reference a NID that doesn't exist - panic(fmt.Errorf("Corrupt DB: Missing state snapshot numeric ID %d", prevState.BeforeStateSnapshotNID)) - } - - // Combine all the state entries for this snapshot. - // The order of state block NIDs in the list tells us the order to combine them in. - var fullState []types.StateEntry - for _, stateBlockNID := range stateBlockNIDs { - entries, ok := stateEntriesMap.lookup(stateBlockNID) - if !ok { - // This should only get hit if the database is corrupt. - // It should be impossible for an event to reference a NID that doesn't exist - panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID)) - } - fullState = append(fullState, entries...) - } - if prevState.IsStateEvent() { - // If the prev event was a state event then add an entry for the event itself - // so that we get the state after the event rather than the state before. - fullState = append(fullState, prevState.StateEntry) - } - - // Stable sort so that the most recent entry for each state key stays - // remains later in the list than the older entries for the same state key. - sort.Stable(stateEntryByStateKeySorter(fullState)) - // Unique returns the last entry and hence the most recent entry for each state key. - fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))] - // Add the full state for this StateSnapshotNID. - combined = append(combined, fullState...) - } - return combined, nil -} - -// DifferenceBetweeenStateSnapshots works out which state entries have been added and removed between two snapshots. -func (v StateResolutionV1) DifferenceBetweeenStateSnapshots( - ctx context.Context, oldStateNID, newStateNID types.StateSnapshotNID, -) (removed, added []types.StateEntry, err error) { - if oldStateNID == newStateNID { - // If the snapshot NIDs are the same then nothing has changed - return nil, nil, nil - } - - var oldEntries []types.StateEntry - var newEntries []types.StateEntry - if oldStateNID != 0 { - oldEntries, err = v.LoadStateAtSnapshot(ctx, oldStateNID) - if err != nil { - return nil, nil, err - } - } - if newStateNID != 0 { - newEntries, err = v.LoadStateAtSnapshot(ctx, newStateNID) - if err != nil { - return nil, nil, err - } - } - - var oldI int - var newI int - for { - switch { - case oldI == len(oldEntries): - // We've reached the end of the old entries. - // The rest of the new list must have been newly added. - added = append(added, newEntries[newI:]...) - return - case newI == len(newEntries): - // We've reached the end of the new entries. - // The rest of the old list must be have been removed. - removed = append(removed, oldEntries[oldI:]...) - return - case oldEntries[oldI] == newEntries[newI]: - // The entry is in both lists so skip over it. - oldI++ - newI++ - case oldEntries[oldI].LessThan(newEntries[newI]): - // The lists are sorted so the old entry being less than the new entry means that it only appears in the old list. - removed = append(removed, oldEntries[oldI]) - oldI++ - default: - // Reaching the default case implies that the new entry is less than the old entry. - // Since the lists are sorted this means that it only appears in the new list. - added = append(added, newEntries[newI]) - newI++ - } - } -} - -// LoadStateAtSnapshotForStringTuples loads the state for a list of event type and state key pairs at a snapshot. -// This is used when we only want to load a subset of the room state at a snapshot. -// If there is no entry for a given event type and state key pair then it will be discarded. -// This is typically the state before an event or the current state of a room. -// Returns a sorted list of state entries or an error if there was a problem talking to the database. -func (v StateResolutionV1) LoadStateAtSnapshotForStringTuples( - ctx context.Context, - stateNID types.StateSnapshotNID, - stateKeyTuples []gomatrixserverlib.StateKeyTuple, -) ([]types.StateEntry, error) { - numericTuples, err := v.stringTuplesToNumericTuples(ctx, stateKeyTuples) - if err != nil { - return nil, err - } - return v.loadStateAtSnapshotForNumericTuples(ctx, stateNID, numericTuples) -} - -// stringTuplesToNumericTuples converts the string state key tuples into numeric IDs -// If there isn't a numeric ID for either the event type or the event state key then the tuple is discarded. -// Returns an error if there was a problem talking to the database. -func (v StateResolutionV1) stringTuplesToNumericTuples( - ctx context.Context, - stringTuples []gomatrixserverlib.StateKeyTuple, -) ([]types.StateKeyTuple, error) { - eventTypes := make([]string, len(stringTuples)) - stateKeys := make([]string, len(stringTuples)) - for i := range stringTuples { - eventTypes[i] = stringTuples[i].EventType - stateKeys[i] = stringTuples[i].StateKey - } - eventTypes = util.UniqueStrings(eventTypes) - eventTypeMap, err := v.db.EventTypeNIDs(ctx, eventTypes) - if err != nil { - return nil, err - } - stateKeys = util.UniqueStrings(stateKeys) - stateKeyMap, err := v.db.EventStateKeyNIDs(ctx, stateKeys) - if err != nil { - return nil, err - } - - var result []types.StateKeyTuple - for _, stringTuple := range stringTuples { - var numericTuple types.StateKeyTuple - var ok1, ok2 bool - numericTuple.EventTypeNID, ok1 = eventTypeMap[stringTuple.EventType] - numericTuple.EventStateKeyNID, ok2 = stateKeyMap[stringTuple.StateKey] - // Discard the tuple if there wasn't a numeric ID for either the event type or the state key. - if ok1 && ok2 { - result = append(result, numericTuple) - } - } - - return result, nil -} - -// loadStateAtSnapshotForNumericTuples loads the state for a list of event type and state key pairs at a snapshot. -// This is used when we only want to load a subset of the room state at a snapshot. -// If there is no entry for a given event type and state key pair then it will be discarded. -// This is typically the state before an event or the current state of a room. -// Returns a sorted list of state entries or an error if there was a problem talking to the database. -func (v StateResolutionV1) loadStateAtSnapshotForNumericTuples( - ctx context.Context, - stateNID types.StateSnapshotNID, - stateKeyTuples []types.StateKeyTuple, -) ([]types.StateEntry, error) { - stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID}) - if err != nil { - return nil, err - } - // We've asked for exactly one snapshot from the db so we should have exactly one entry in the result. - stateBlockNIDList := stateBlockNIDLists[0] - - stateEntryLists, err := v.db.StateEntriesForTuples( - ctx, stateBlockNIDList.StateBlockNIDs, stateKeyTuples, - ) - if err != nil { - return nil, err - } - stateEntriesMap := stateEntryListMap(stateEntryLists) - - // Combine all the state entries for this snapshot. - // The order of state block NIDs in the list tells us the order to combine them in. - var fullState []types.StateEntry - for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs { - entries, ok := stateEntriesMap.lookup(stateBlockNID) - if !ok { - // If the block is missing from the map it means that none of its entries matched a requested tuple. - // This can happen if the block doesn't contain an update for one of the requested tuples. - // If none of the requested tuples are in the block then it can be safely skipped. - continue - } - fullState = append(fullState, entries...) - } - - // Stable sort so that the most recent entry for each state key stays - // remains later in the list than the older entries for the same state key. - sort.Stable(stateEntryByStateKeySorter(fullState)) - // Unique returns the last entry and hence the most recent entry for each state key. - fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))] - return fullState, nil -} - -// LoadStateAfterEventsForStringTuples loads the state for a list of event type -// and state key pairs after list of events. -// This is used when we only want to load a subset of the room state after a list of events. -// If there is no entry for a given event type and state key pair then it will be discarded. -// This is typically the state before an event. -// Returns a sorted list of state entries or an error if there was a problem talking to the database. -func (v StateResolutionV1) LoadStateAfterEventsForStringTuples( - ctx context.Context, - prevStates []types.StateAtEvent, - stateKeyTuples []gomatrixserverlib.StateKeyTuple, -) ([]types.StateEntry, error) { - numericTuples, err := v.stringTuplesToNumericTuples(ctx, stateKeyTuples) - if err != nil { - return nil, err - } - return v.loadStateAfterEventsForNumericTuples(ctx, prevStates, numericTuples) -} - -func (v StateResolutionV1) loadStateAfterEventsForNumericTuples( - ctx context.Context, - prevStates []types.StateAtEvent, - stateKeyTuples []types.StateKeyTuple, -) ([]types.StateEntry, error) { - if len(prevStates) == 1 { - // Fast path for a single event. - prevState := prevStates[0] - result, err := v.loadStateAtSnapshotForNumericTuples( - ctx, prevState.BeforeStateSnapshotNID, stateKeyTuples, - ) - if err != nil { - return nil, err - } - if prevState.IsStateEvent() { - // The result is current the state before the requested event. - // We want the state after the requested event. - // If the requested event was a state event then we need to - // update that key in the result. - // If the requested event wasn't a state event then the state after - // it is the same as the state before it. - set := false - for i := range result { - if result[i].StateKeyTuple == prevState.StateKeyTuple { - result[i] = prevState.StateEntry - set = true - } - } - if !set { // no previous state exists for this event: add new state - result = append(result, prevState.StateEntry) - } - } - return result, nil - } - - // Slow path for more that one event. - // Load the entire state so that we can do conflict resolution if we need to. - // TODO: The are some optimistations we could do here: - // 1) We only need to do conflict resolution if there is a conflict in the - // requested tuples so we might try loading just those tuples and then - // checking for conflicts. - // 2) When there is a conflict we still only need to load the state - // needed to do conflict resolution which would save us having to load - // the full state. - - // TODO: Add metrics for this as it could take a long time for big rooms - // with large conflicts. - fullState, _, _, err := v.calculateStateAfterManyEvents(ctx, prevStates) - if err != nil { - return nil, err - } - - // Sort the full state so we can use it as a map. - sort.Sort(stateEntrySorter(fullState)) - - // Filter the full state down to the required tuples. - var result []types.StateEntry - for _, tuple := range stateKeyTuples { - eventNID, ok := stateEntryMap(fullState).lookup(tuple) - if ok { - result = append(result, types.StateEntry{ - StateKeyTuple: tuple, - EventNID: eventNID, - }) - } - } - sort.Sort(stateEntrySorter(result)) - return result, nil -} - -var calculateStateDurations = prometheus.NewSummaryVec( - prometheus.SummaryOpts{ - Namespace: "dendrite", - Subsystem: "roomserver", - Name: "calculate_state_duration_microseconds", - Help: "How long it takes to calculate the state after a list of events", - }, - // Takes two labels: - // algorithm: - // The algorithm used to calculate the state or the step it failed on if it failed. - // Labels starting with "_" are used to indicate when the algorithm fails halfway. - // outcome: - // Whether the state was successfully calculated. - // - // The possible values for algorithm are: - // empty_state -> The list of events was empty so the state is empty. - // no_change -> The state hasn't changed. - // single_delta -> There was a single event added to the state in a way that can be encoded as a single delta - // full_state_no_conflicts -> We created a new copy of the full room state, but didn't enounter any conflicts - // while doing so. - // full_state_with_conflicts -> We created a new copy of the full room state and had to resolve conflicts to do so. - // _load_state_block_nids -> Failed loading the state block nids for a single previous state. - // _load_combined_state -> Failed to load the combined state. - // _resolve_conflicts -> Failed to resolve conflicts. - []string{"algorithm", "outcome"}, -) - -var calculateStatePrevEventLength = prometheus.NewSummaryVec( - prometheus.SummaryOpts{ - Namespace: "dendrite", - Subsystem: "roomserver", - Name: "calculate_state_prev_event_length", - Help: "The length of the list of events to calculate the state after", - }, - []string{"algorithm", "outcome"}, -) - -var calculateStateFullStateLength = prometheus.NewSummaryVec( - prometheus.SummaryOpts{ - Namespace: "dendrite", - Subsystem: "roomserver", - Name: "calculate_state_full_state_length", - Help: "The length of the full room state.", - }, - []string{"algorithm", "outcome"}, -) - -var calculateStateConflictLength = prometheus.NewSummaryVec( - prometheus.SummaryOpts{ - Namespace: "dendrite", - Subsystem: "roomserver", - Name: "calculate_state_conflict_state_length", - Help: "The length of the conflicted room state.", - }, - []string{"algorithm", "outcome"}, -) - -type calculateStateMetrics struct { - algorithm string - startTime time.Time - prevEventLength int - fullStateLength int - conflictLength int -} - -func (c *calculateStateMetrics) stop(stateNID types.StateSnapshotNID, err error) (types.StateSnapshotNID, error) { - var outcome string - if err == nil { - outcome = "success" - } else { - outcome = "failure" - } - endTime := time.Now() - calculateStateDurations.WithLabelValues(c.algorithm, outcome).Observe( - float64(endTime.Sub(c.startTime).Nanoseconds()) / 1000., - ) - calculateStatePrevEventLength.WithLabelValues(c.algorithm, outcome).Observe( - float64(c.prevEventLength), - ) - calculateStateFullStateLength.WithLabelValues(c.algorithm, outcome).Observe( - float64(c.fullStateLength), - ) - calculateStateConflictLength.WithLabelValues(c.algorithm, outcome).Observe( - float64(c.conflictLength), - ) - return stateNID, err -} - -func init() { - prometheus.MustRegister( - calculateStateDurations, calculateStatePrevEventLength, - calculateStateFullStateLength, calculateStateConflictLength, - ) -} - -// CalculateAndStoreStateBeforeEvent calculates a snapshot of the state of a room before an event. -// Stores the snapshot of the state in the database. -// Returns a numeric ID for the snapshot of the state before the event. -func (v StateResolutionV1) CalculateAndStoreStateBeforeEvent( - ctx context.Context, - event gomatrixserverlib.Event, - roomNID types.RoomNID, -) (types.StateSnapshotNID, error) { - // Load the state at the prev events. - prevEventRefs := event.PrevEvents() - prevEventIDs := make([]string, len(prevEventRefs)) - for i := range prevEventRefs { - prevEventIDs[i] = prevEventRefs[i].EventID - } - - prevStates, err := v.db.StateAtEventIDs(ctx, prevEventIDs) - if err != nil { - return 0, err - } - - // The state before this event will be the state after the events that came before it. - return v.CalculateAndStoreStateAfterEvents(ctx, roomNID, prevStates) -} - -// CalculateAndStoreStateAfterEvents finds the room state after the given events. -// Stores the resulting state in the database and returns a numeric ID for that snapshot. -func (v StateResolutionV1) CalculateAndStoreStateAfterEvents( - ctx context.Context, - roomNID types.RoomNID, - prevStates []types.StateAtEvent, -) (types.StateSnapshotNID, error) { - metrics := calculateStateMetrics{startTime: time.Now(), prevEventLength: len(prevStates)} - - if len(prevStates) == 0 { - // 2) There weren't any prev_events for this event so the state is - // empty. - metrics.algorithm = "empty_state" - return metrics.stop(v.db.AddState(ctx, roomNID, nil, nil)) - } - - if len(prevStates) == 1 { - prevState := prevStates[0] - if prevState.EventStateKeyNID == 0 { - // 3) None of the previous events were state events and they all - // have the same state, so this event has exactly the same state - // as the previous events. - // This should be the common case. - metrics.algorithm = "no_change" - return metrics.stop(prevState.BeforeStateSnapshotNID, nil) - } - // The previous event was a state event so we need to store a copy - // of the previous state updated with that event. - stateBlockNIDLists, err := v.db.StateBlockNIDs( - ctx, []types.StateSnapshotNID{prevState.BeforeStateSnapshotNID}, - ) - if err != nil { - metrics.algorithm = "_load_state_blocks" - return metrics.stop(0, err) - } - stateBlockNIDs := stateBlockNIDLists[0].StateBlockNIDs - if len(stateBlockNIDs) < maxStateBlockNIDs { - // 4) The number of state data blocks is small enough that we can just - // add the state event as a block of size one to the end of the blocks. - metrics.algorithm = "single_delta" - return metrics.stop(v.db.AddState( - ctx, roomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry}, - )) - } - // If there are too many deltas then we need to calculate the full state - // So fall through to calculateAndStoreStateAfterManyEvents - } - - return v.calculateAndStoreStateAfterManyEvents(ctx, roomNID, prevStates, metrics) -} - -// maxStateBlockNIDs is the maximum number of state data blocks to use to encode a snapshot of room state. -// Increasing this number means that we can encode more of the state changes as simple deltas which means that -// we need fewer entries in the state data table. However making this number bigger will increase the size of -// the rows in the state table itself and will require more index lookups when retrieving a snapshot. -// TODO: Tune this to get the right balance between size and lookup performance. -const maxStateBlockNIDs = 64 - -// calculateAndStoreStateAfterManyEvents finds the room state after the given events. -// This handles the slow path of calculateAndStoreStateAfterEvents for when there is more than one event. -// Stores the resulting state and returns a numeric ID for the snapshot. -func (v StateResolutionV1) calculateAndStoreStateAfterManyEvents( - ctx context.Context, - roomNID types.RoomNID, - prevStates []types.StateAtEvent, - metrics calculateStateMetrics, -) (types.StateSnapshotNID, error) { - - state, algorithm, conflictLength, err := - v.calculateStateAfterManyEvents(ctx, prevStates) - metrics.algorithm = algorithm - if err != nil { - return metrics.stop(0, err) - } - - // TODO: Check if we can encode the new state as a delta against the - // previous state. - metrics.conflictLength = conflictLength - metrics.fullStateLength = len(state) - return metrics.stop(v.db.AddState(ctx, roomNID, nil, state)) -} - -func (v StateResolutionV1) calculateStateAfterManyEvents( - ctx context.Context, prevStates []types.StateAtEvent, -) (state []types.StateEntry, algorithm string, conflictLength int, err error) { - var combined []types.StateEntry - // Conflict resolution. - // First stage: load the state after each of the prev events. - combined, err = v.LoadCombinedStateAfterEvents(ctx, prevStates) - if err != nil { - algorithm = "_load_combined_state" - return - } - - // Collect all the entries with the same type and key together. - // We don't care about the order here because the conflict resolution - // algorithm doesn't depend on the order of the prev events. - // Remove duplicate entires. - combined = combined[:util.SortAndUnique(stateEntrySorter(combined))] - - // Find the conflicts - conflicts := findDuplicateStateKeys(combined) - - if len(conflicts) > 0 { - conflictLength = len(conflicts) - - // 5) There are conflicting state events, for each conflict workout - // what the appropriate state event is. - - // Work out which entries aren't conflicted. - var notConflicted []types.StateEntry - for _, entry := range combined { - if _, ok := stateEntryMap(conflicts).lookup(entry.StateKeyTuple); !ok { - notConflicted = append(notConflicted, entry) - } - } - - var resolved []types.StateEntry - resolved, err = v.resolveConflicts(ctx, notConflicted, conflicts) - if err != nil { - algorithm = "_resolve_conflicts" - return - } - algorithm = "full_state_with_conflicts" - state = resolved - } else { - algorithm = "full_state_no_conflicts" - // 6) There weren't any conflicts - state = combined - } - return -} - -// resolveConflicts resolves a list of conflicted state entries. It takes two lists. -// The first is a list of all state entries that are not conflicted. -// The second is a list of all state entries that are conflicted -// A state entry is conflicted when there is more than one numeric event ID for the same state key tuple. -// Returns a list that combines the entries without conflicts with the result of state resolution for the entries with conflicts. -// The returned list is sorted by state key tuple. -// Returns an error if there was a problem talking to the database. -func (v StateResolutionV1) resolveConflicts( - ctx context.Context, - notConflicted, conflicted []types.StateEntry, -) ([]types.StateEntry, error) { - - // Load the conflicted events - conflictedEvents, eventIDMap, err := v.loadStateEvents(ctx, conflicted) - if err != nil { - return nil, err - } - - // Work out which auth events we need to load. - needed := gomatrixserverlib.StateNeededForAuth(conflictedEvents) - - // Find the numeric IDs for the necessary state keys. - var neededStateKeys []string - neededStateKeys = append(neededStateKeys, needed.Member...) - neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...) - stateKeyNIDMap, err := v.db.EventStateKeyNIDs(ctx, neededStateKeys) - if err != nil { - return nil, err - } - - // Load the necessary auth events. - tuplesNeeded := v.stateKeyTuplesNeeded(stateKeyNIDMap, needed) - var authEntries []types.StateEntry - for _, tuple := range tuplesNeeded { - if eventNID, ok := stateEntryMap(notConflicted).lookup(tuple); ok { - authEntries = append(authEntries, types.StateEntry{ - StateKeyTuple: tuple, - EventNID: eventNID, - }) - } - } - authEvents, _, err := v.loadStateEvents(ctx, authEntries) - if err != nil { - return nil, err - } - - // Resolve the conflicts. - resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents) - - // Map from the full events back to numeric state entries. - for _, resolvedEvent := range resolvedEvents { - entry, ok := eventIDMap[resolvedEvent.EventID()] - if !ok { - panic(fmt.Errorf("Missing state entry for event ID %q", resolvedEvent.EventID())) - } - notConflicted = append(notConflicted, entry) - } - - // Sort the result so it can be searched. - sort.Sort(stateEntrySorter(notConflicted)) - return notConflicted, nil -} - -// stateKeyTuplesNeeded works out which numeric state key tuples we need to authenticate some events. -func (v StateResolutionV1) stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.EventStateKeyNID, stateNeeded gomatrixserverlib.StateNeeded) []types.StateKeyTuple { - var keyTuples []types.StateKeyTuple - if stateNeeded.Create { - keyTuples = append(keyTuples, types.StateKeyTuple{ - EventTypeNID: types.MRoomCreateNID, - EventStateKeyNID: types.EmptyStateKeyNID, - }) - } - if stateNeeded.PowerLevels { - keyTuples = append(keyTuples, types.StateKeyTuple{ - EventTypeNID: types.MRoomPowerLevelsNID, - EventStateKeyNID: types.EmptyStateKeyNID, - }) - } - if stateNeeded.JoinRules { - keyTuples = append(keyTuples, types.StateKeyTuple{ - EventTypeNID: types.MRoomJoinRulesNID, - EventStateKeyNID: types.EmptyStateKeyNID, - }) - } - for _, member := range stateNeeded.Member { - stateKeyNID, ok := stateKeyNIDMap[member] - if ok { - keyTuples = append(keyTuples, types.StateKeyTuple{ - EventTypeNID: types.MRoomMemberNID, - EventStateKeyNID: stateKeyNID, - }) - } - } - for _, token := range stateNeeded.ThirdPartyInvite { - stateKeyNID, ok := stateKeyNIDMap[token] - if ok { - keyTuples = append(keyTuples, types.StateKeyTuple{ - EventTypeNID: types.MRoomThirdPartyInviteNID, - EventStateKeyNID: stateKeyNID, - }) - } - } - return keyTuples -} - -// loadStateEvents loads the matrix events for a list of state entries. -// Returns a list of state events in no particular order and a map from string event ID back to state entry. -// The map can be used to recover which numeric state entry a given event is for. -// Returns an error if there was a problem talking to the database. -func (v StateResolutionV1) loadStateEvents( - ctx context.Context, entries []types.StateEntry, -) ([]gomatrixserverlib.Event, map[string]types.StateEntry, error) { - eventNIDs := make([]types.EventNID, len(entries)) - for i := range entries { - eventNIDs[i] = entries[i].EventNID - } - events, err := v.db.Events(ctx, eventNIDs) - if err != nil { - return nil, nil, err - } - eventIDMap := map[string]types.StateEntry{} - result := make([]gomatrixserverlib.Event, len(entries)) - for i := range entries { - event, ok := eventMap(events).lookup(entries[i].EventNID) - if !ok { - panic(fmt.Errorf("Corrupt DB: Missing event numeric ID %d", entries[i].EventNID)) - } - result[i] = event.Event - eventIDMap[event.Event.EventID()] = entries[i] - } - return result, eventIDMap, nil -} - -// findDuplicateStateKeys finds the state entries where the state key tuple appears more than once in a sorted list. -// Returns a sorted list of those state entries. -func findDuplicateStateKeys(a []types.StateEntry) []types.StateEntry { - var result []types.StateEntry - // j is the starting index of a block of entries with the same state key tuple. - j := 0 - for i := 1; i < len(a); i++ { - // Check if the state key tuple matches the start of the block - if a[j].StateKeyTuple != a[i].StateKeyTuple { - // If the state key tuple is different then we've reached the end of a block of duplicates. - // Check if the size of the block is bigger than one. - // If the size is one then there was only a single entry with that state key tuple so we don't add it to the result - if j+1 != i { - // Add the block to the result. - result = append(result, a[j:i]...) - } - // Start a new block for the next state key tuple. - j = i - } - } - // Check if the last block with the same state key tuple had more than one event in it. - if j+1 != len(a) { - result = append(result, a[j:]...) - } - return result -} - -type stateEntrySorter []types.StateEntry - -func (s stateEntrySorter) Len() int { return len(s) } -func (s stateEntrySorter) Less(i, j int) bool { return s[i].LessThan(s[j]) } -func (s stateEntrySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } - -type stateBlockNIDListMap []types.StateBlockNIDList - -func (m stateBlockNIDListMap) lookup(stateNID types.StateSnapshotNID) (stateBlockNIDs []types.StateBlockNID, ok bool) { - list := []types.StateBlockNIDList(m) - i := sort.Search(len(list), func(i int) bool { - return list[i].StateSnapshotNID >= stateNID - }) - if i < len(list) && list[i].StateSnapshotNID == stateNID { - ok = true - stateBlockNIDs = list[i].StateBlockNIDs - } - return -} - -type stateEntryListMap []types.StateEntryList - -func (m stateEntryListMap) lookup(stateBlockNID types.StateBlockNID) (stateEntries []types.StateEntry, ok bool) { - list := []types.StateEntryList(m) - i := sort.Search(len(list), func(i int) bool { - return list[i].StateBlockNID >= stateBlockNID - }) - if i < len(list) && list[i].StateBlockNID == stateBlockNID { - ok = true - stateEntries = list[i].StateEntries - } - return -} - -type stateEntryByStateKeySorter []types.StateEntry - -func (s stateEntryByStateKeySorter) Len() int { return len(s) } -func (s stateEntryByStateKeySorter) Less(i, j int) bool { - return s[i].StateKeyTuple.LessThan(s[j].StateKeyTuple) -} -func (s stateEntryByStateKeySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } - -type stateNIDSorter []types.StateSnapshotNID - -func (s stateNIDSorter) Len() int { return len(s) } -func (s stateNIDSorter) Less(i, j int) bool { return s[i] < s[j] } -func (s stateNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } - -func uniqueStateSnapshotNIDs(nids []types.StateSnapshotNID) []types.StateSnapshotNID { - return nids[:util.SortAndUnique(stateNIDSorter(nids))] -} - -type stateBlockNIDSorter []types.StateBlockNID - -func (s stateBlockNIDSorter) Len() int { return len(s) } -func (s stateBlockNIDSorter) Less(i, j int) bool { return s[i] < s[j] } -func (s stateBlockNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } - -func uniqueStateBlockNIDs(nids []types.StateBlockNID) []types.StateBlockNID { - return nids[:util.SortAndUnique(stateBlockNIDSorter(nids))] -} - -// Map from event type, state key tuple to numeric event ID. -// Implemented using binary search on a sorted array. -type stateEntryMap []types.StateEntry - -// lookup an entry in the event map. -func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID types.EventNID, ok bool) { - // Since the list is sorted we can implement this using binary search. - // This is faster than using a hash map. - // We don't have to worry about pathological cases because the keys are fixed - // size and are controlled by us. - list := []types.StateEntry(m) - i := sort.Search(len(list), func(i int) bool { - return !list[i].StateKeyTuple.LessThan(stateKey) - }) - if i < len(list) && list[i].StateKeyTuple == stateKey { - ok = true - eventNID = list[i].EventNID - } - return -} - -// Map from numeric event ID to event. -// Implemented using binary search on a sorted array. -type eventMap []types.Event - -// lookup an entry in the event map. -func (m eventMap) lookup(eventNID types.EventNID) (event *types.Event, ok bool) { - // Since the list is sorted we can implement this using binary search. - // This is faster than using a hash map. - // We don't have to worry about pathological cases because the keys are fixed - // size are controlled by us. - list := []types.Event(m) - i := sort.Search(len(list), func(i int) bool { - return list[i].EventNID >= eventNID - }) - if i < len(list) && list[i].EventNID == eventNID { - ok = true - event = &list[i] - } - return -} diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go index 6bb96f1de..0fd9d5b53 100644 --- a/roomserver/storage/postgres/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -68,6 +68,9 @@ const updateLatestEventNIDsSQL = "" + const selectRoomVersionForRoomIDSQL = "" + "SELECT room_version FROM roomserver_rooms WHERE room_id = $1" +const selectRoomVersionForRoomNIDSQL = "" + + "SELECT room_version FROM roomserver_rooms WHERE room_nid = $1" + type roomStatements struct { insertRoomNIDStmt *sql.Stmt selectRoomNIDStmt *sql.Stmt @@ -75,6 +78,7 @@ type roomStatements struct { selectLatestEventNIDsForUpdateStmt *sql.Stmt updateLatestEventNIDsStmt *sql.Stmt selectRoomVersionForRoomIDStmt *sql.Stmt + selectRoomVersionForRoomNIDStmt *sql.Stmt } func (s *roomStatements) prepare(db *sql.DB) (err error) { @@ -89,6 +93,7 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) { {&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL}, {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, {&s.selectRoomVersionForRoomIDStmt, selectRoomVersionForRoomIDSQL}, + {&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL}, }.prepare(db) } @@ -173,3 +178,12 @@ func (s *roomStatements) selectRoomVersionForRoomID( err := stmt.QueryRowContext(ctx, roomID).Scan(&roomVersion) return roomVersion, err } + +func (s *roomStatements) selectRoomVersionForRoomNID( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, +) (gomatrixserverlib.RoomVersion, error) { + var roomVersion gomatrixserverlib.RoomVersion + stmt := common.TxStmt(txn, s.selectRoomVersionForRoomNIDStmt) + err := stmt.QueryRowContext(ctx, roomNID).Scan(&roomVersion) + return roomVersion, err +} diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 9bb6de9d7..084b83ce6 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -747,6 +747,14 @@ func (d *Database) GetRoomVersionForRoom( ) } +func (d *Database) GetRoomVersionForRoomNID( + ctx context.Context, roomNID types.RoomNID, +) (gomatrixserverlib.RoomVersion, error) { + return d.statements.selectRoomVersionForRoomNID( + ctx, nil, roomNID, + ) +} + type transaction struct { ctx context.Context txn *sql.Tx diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index 49fa07ea8..512b98137 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -57,6 +57,9 @@ const updateLatestEventNIDsSQL = "" + const selectRoomVersionForRoomIDSQL = "" + "SELECT room_version FROM roomserver_rooms WHERE room_id = $1" +const selectRoomVersionForRoomNIDSQL = "" + + "SELECT room_version FROM roomserver_rooms WHERE room_nid = $1" + type roomStatements struct { insertRoomNIDStmt *sql.Stmt selectRoomNIDStmt *sql.Stmt @@ -64,6 +67,7 @@ type roomStatements struct { selectLatestEventNIDsForUpdateStmt *sql.Stmt updateLatestEventNIDsStmt *sql.Stmt selectRoomVersionForRoomIDStmt *sql.Stmt + selectRoomVersionForRoomNIDStmt *sql.Stmt } func (s *roomStatements) prepare(db *sql.DB) (err error) { @@ -78,6 +82,7 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) { {&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL}, {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, {&s.selectRoomVersionForRoomIDStmt, selectRoomVersionForRoomIDSQL}, + {&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL}, }.prepare(db) } @@ -165,3 +170,12 @@ func (s *roomStatements) selectRoomVersionForRoomID( err := stmt.QueryRowContext(ctx, roomID).Scan(&roomVersion) return roomVersion, err } + +func (s *roomStatements) selectRoomVersionForRoomNID( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, +) (gomatrixserverlib.RoomVersion, error) { + var roomVersion gomatrixserverlib.RoomVersion + stmt := common.TxStmt(txn, s.selectRoomVersionForRoomNIDStmt) + err := stmt.QueryRowContext(ctx, roomNID).Scan(&roomVersion) + return roomVersion, err +} diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index ae09a88ad..28d608ca5 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -901,6 +901,14 @@ func (d *Database) GetRoomVersionForRoom( ) } +func (d *Database) GetRoomVersionForRoomNID( + ctx context.Context, roomNID types.RoomNID, +) (gomatrixserverlib.RoomVersion, error) { + return d.statements.selectRoomVersionForRoomNID( + ctx, nil, roomNID, + ) +} + type transaction struct { ctx context.Context txn *sql.Tx