diff --git a/internal/mscs/msc2836/msc2836.go b/internal/mscs/msc2836/msc2836.go index 13bb81c7d..e17fb1802 100644 --- a/internal/mscs/msc2836/msc2836.go +++ b/internal/mscs/msc2836/msc2836.go @@ -18,10 +18,13 @@ package msc2836 import ( "bytes" "context" + "crypto/sha256" "encoding/json" "fmt" "io" "net/http" + "sort" + "strings" "time" "github.com/matrix-org/dendrite/clientapi/jsonerror" @@ -277,6 +280,8 @@ func (rc *reqCtx) process() (*gomatrixserverlib.MSC2836EventRelationshipsRespons } res.Events = make([]*gomatrixserverlib.Event, len(returnEvents)) for i, ev := range returnEvents { + // for each event, extract the children_count | hash and add it as unsigned data. + rc.addChildMetadata(ev) res.Events[i] = ev.Unwrap() } res.Limited = remaining == 0 || walkLimited @@ -358,8 +363,7 @@ func walkThread( return result, limited } -// MSC2836EventRelationships performs an /event_relationships request to a remote server, injecting the resulting events -// into the roomserver as KindOutlier, with auth chains. +// MSC2836EventRelationships performs an /event_relationships request to a remote server func (rc *reqCtx) MSC2836EventRelationships(eventID string, srv gomatrixserverlib.ServerName, ver gomatrixserverlib.RoomVersion) (*gomatrixserverlib.MSC2836EventRelationshipsResponse, error) { res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{ EventID: eventID, @@ -568,6 +572,8 @@ func (rc *reqCtx) getLocalEvent(eventID string) *gomatrixserverlib.HeaderedEvent return queryEventsRes.Events[0] } +// injectResponseToRoomserver injects the events +// into the roomserver as KindOutlier, with auth chains. func (rc *reqCtx) injectResponseToRoomserver(res *gomatrixserverlib.MSC2836EventRelationshipsResponse) { var stateEvents []*gomatrixserverlib.Event var messageEvents []*gomatrixserverlib.Event @@ -604,6 +610,38 @@ func (rc *reqCtx) injectResponseToRoomserver(res *gomatrixserverlib.MSC2836Event } } +func (rc *reqCtx) addChildMetadata(ev *gomatrixserverlib.HeaderedEvent) { + children, err := rc.db.ChildrenForParent(rc.ctx, ev.EventID(), constRelType, false) + if err != nil { + util.GetLogger(rc.ctx).WithError(err).Warn("Failed to get ChildrenForParent for adding child metadata") + return + } + if len(children) == 0 { + return + } + // sort it lexiographically + sort.Slice(children, func(i, j int) bool { + return children[i].EventID < children[j].EventID + }) + // hash it + var eventIDs strings.Builder + for _, c := range children { + _, _ = eventIDs.WriteString(c.EventID) + } + hashValBytes := sha256.Sum256([]byte(eventIDs.String())) + + err = ev.SetUnsignedField("children_hash", gomatrixserverlib.Base64Bytes(hashValBytes[:])) + if err != nil { + util.GetLogger(rc.ctx).WithError(err).Warn("Failed to set children_hash") + } + err = ev.SetUnsignedField("children", map[string]int{ + constRelType: len(children), + }) + if err != nil { + util.GetLogger(rc.ctx).WithError(err).Warn("Failed to set children count") + } +} + type walkInfo struct { eventInfo SiblingNumber int diff --git a/internal/mscs/msc2836/msc2836_test.go b/internal/mscs/msc2836/msc2836_test.go index d91f368ca..b56fd234d 100644 --- a/internal/mscs/msc2836/msc2836_test.go +++ b/internal/mscs/msc2836/msc2836_test.go @@ -4,10 +4,14 @@ import ( "bytes" "context" "crypto/ed25519" + "crypto/sha256" + "encoding/base64" "encoding/json" "fmt" "io/ioutil" "net/http" + "sort" + "strings" "testing" "time" @@ -367,6 +371,21 @@ func TestMSC2836(t *testing.T) { })) assertContains(t, body, []string{eventF.EventID(), eventD.EventID(), eventB.EventID(), eventA.EventID()}) }) + t.Run("includes children and children_hash in unsigned", func(t *testing.T) { + body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventB.EventID(), + "recent_first": false, + "depth_first": false, + "limit": 3, + })) + // event B has C,D as children + // event C has no children + // event D has 3 children (not included in response) + assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID()}) + assertUnsignedChildren(t, body.Events[0], "m.reference", 2, []string{eventC.EventID(), eventD.EventID()}) + assertUnsignedChildren(t, body.Events[1], "", 0, nil) + assertUnsignedChildren(t, body.Events[2], "m.reference", 3, []string{eventE.EventID(), eventF.EventID(), eventG.EventID()}) + }) } // TODO: TestMSC2836TerminatesLoops (short and long) @@ -457,6 +476,43 @@ func assertContains(t *testing.T, result *msc2836.EventRelationshipResponse, wan } } +func assertUnsignedChildren(t *testing.T, ev gomatrixserverlib.ClientEvent, relType string, wantCount int, childrenEventIDs []string) { + t.Helper() + unsigned := struct { + Children map[string]int `json:"children"` + Hash string `json:"children_hash"` + }{} + if err := json.Unmarshal(ev.Unsigned, &unsigned); err != nil { + if wantCount == 0 { + return // no children so possible there is no unsigned field at all + } + t.Fatalf("Failed to unmarshal unsigned field: %s", err) + } + // zero checks + if wantCount == 0 { + if len(unsigned.Children) != 0 || unsigned.Hash != "" { + t.Fatalf("want 0 children but got unsigned fields %+v", unsigned) + } + return + } + gotCount := unsigned.Children[relType] + if gotCount != wantCount { + t.Errorf("Got %d count, want %d count for rel_type %s", gotCount, wantCount, relType) + } + // work out the hash + sort.Strings(childrenEventIDs) + var b strings.Builder + for _, s := range childrenEventIDs { + b.WriteString(s) + } + t.Logf("hashing %s", b.String()) + hashValBytes := sha256.Sum256([]byte(b.String())) + wantHash := base64.RawStdEncoding.EncodeToString(hashValBytes[:]) + if wantHash != unsigned.Hash { + t.Errorf("Got unsigned hash %s want hash %s", unsigned.Hash, wantHash) + } +} + type testUserAPI struct { accessTokens map[string]userapi.Device }