diff --git a/federationapi/routing/state.go b/federationapi/routing/state.go index 6c02722e9..548598dd7 100644 --- a/federationapi/routing/state.go +++ b/federationapi/routing/state.go @@ -107,7 +107,6 @@ func getState( return nil, &util.JSONResponse{Code: http.StatusNotFound, JSON: nil} } - prevEventIDs := []string{eventID} authEventIDs := getIDsFromEventRef(event.AuthEvents()) var response api.QueryStateAndAuthChainResponse @@ -115,7 +114,7 @@ func getState( ctx, &api.QueryStateAndAuthChainRequest{ RoomID: roomID, - PrevEventIDs: prevEventIDs, + PrevEventIDs: []string{eventID}, AuthEventIDs: authEventIDs, }, &response, diff --git a/roomserver/api/query.go b/roomserver/api/query.go index e6fbf5483..5f024d266 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -204,6 +204,7 @@ type QueryStateAndAuthChainRequest struct { // The list of auth events for the event. Used to calculate the auth chain AuthEventIDs []string `json:"auth_event_ids"` // Should state resolution be ran on the result events? + // TODO: check call sites and remove if we always want to do state res ResolveState bool `json:"resolve_state"` } diff --git a/roomserver/state/state.go b/roomserver/state/state.go index 6427630c3..3f68e0747 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -18,7 +18,6 @@ package state import ( "context" - "errors" "fmt" "sort" "time" @@ -694,14 +693,25 @@ func ResolveConflictsAdhoc( ) ([]gomatrixserverlib.Event, error) { type stateKeyTuple struct { Type string - StateKey *string + StateKey string } + // Prepare our data structures. eventMap := make(map[stateKeyTuple][]gomatrixserverlib.Event) var conflicted, notConflicted, resolved []gomatrixserverlib.Event + // Run through all of the events that we were given and sort them + // into a map, sorted by (event_type, state_key) tuple. This means + // that we can easily spot events that are "conflicted", e.g. + // there are duplicate values for the same tuple key. for _, event := range events { - tuple := stateKeyTuple{event.Type(), event.StateKey()} + if event.StateKey() == nil { + // Ignore events that are not state events. + continue + } + // Append the events if there is already a conflicted list for + // this tuple key, create it if not. + tuple := stateKeyTuple{event.Type(), *event.StateKey()} if _, ok := eventMap[tuple]; ok { eventMap[tuple] = append(eventMap[tuple], event) } else { @@ -709,6 +719,10 @@ func ResolveConflictsAdhoc( } } + // Split out the events in the map into conflicted and unconflicted + // buckets. The conflicted events will be ran through state res, + // whereas unconfliced events will always going to appear in the + // final resolved state. for _, list := range eventMap { if len(list) > 1 { conflicted = append(conflicted, list...) @@ -717,21 +731,29 @@ func ResolveConflictsAdhoc( } } + // Work out which state resolution algorithm we want to run for + // the room version. stateResAlgo, err := version.StateResAlgorithm() if err != nil { return nil, err } switch stateResAlgo { case gomatrixserverlib.StateResV1: + // Currently state res v1 doesn't handle unconflicted events + // for us, like state res v2 does, so we will need to add the + // unconflicted events into the state ourselves. + // TODO: Fix state res v1 so this is handled for the caller. resolved = gomatrixserverlib.ResolveStateConflicts(conflicted, authEvents) resolved = append(resolved, notConflicted...) case gomatrixserverlib.StateResV2: // TODO: auth difference here? resolved = gomatrixserverlib.ResolveStateConflictsV2(conflicted, notConflicted, authEvents, authEvents) default: - return nil, errors.New("unsupported state resolution algorithm") + return nil, fmt.Errorf("unsupported state resolution algorithm %v", stateResAlgo) } + // Return the final resolved state events, including both the + // resolved set of conflicted events, and the unconflicted events. return resolved, nil } @@ -749,7 +771,7 @@ func (v StateResolution) resolveConflicts( case gomatrixserverlib.StateResV2: return v.resolveConflictsV2(ctx, notConflicted, conflicted) } - return nil, errors.New("unsupported state resolution algorithm") + return nil, fmt.Errorf("unsupported state resolution algorithm %v", stateResAlgo) } // resolveConflicts resolves a list of conflicted state entries. It takes two lists.