Add children and children_hash to unsigned, with tests

This commit is contained in:
Kegan Dougal 2020-11-26 14:53:24 +00:00
parent 91ee096b26
commit d222289821
2 changed files with 96 additions and 2 deletions

View file

@ -18,10 +18,13 @@ package msc2836
import (
"bytes"
"context"
"crypto/sha256"
"encoding/json"
"fmt"
"io"
"net/http"
"sort"
"strings"
"time"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
@ -277,6 +280,8 @@ func (rc *reqCtx) process() (*gomatrixserverlib.MSC2836EventRelationshipsRespons
}
res.Events = make([]*gomatrixserverlib.Event, len(returnEvents))
for i, ev := range returnEvents {
// for each event, extract the children_count | hash and add it as unsigned data.
rc.addChildMetadata(ev)
res.Events[i] = ev.Unwrap()
}
res.Limited = remaining == 0 || walkLimited
@ -358,8 +363,7 @@ func walkThread(
return result, limited
}
// MSC2836EventRelationships performs an /event_relationships request to a remote server, injecting the resulting events
// into the roomserver as KindOutlier, with auth chains.
// MSC2836EventRelationships performs an /event_relationships request to a remote server
func (rc *reqCtx) MSC2836EventRelationships(eventID string, srv gomatrixserverlib.ServerName, ver gomatrixserverlib.RoomVersion) (*gomatrixserverlib.MSC2836EventRelationshipsResponse, error) {
res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{
EventID: eventID,
@ -568,6 +572,8 @@ func (rc *reqCtx) getLocalEvent(eventID string) *gomatrixserverlib.HeaderedEvent
return queryEventsRes.Events[0]
}
// injectResponseToRoomserver injects the events
// into the roomserver as KindOutlier, with auth chains.
func (rc *reqCtx) injectResponseToRoomserver(res *gomatrixserverlib.MSC2836EventRelationshipsResponse) {
var stateEvents []*gomatrixserverlib.Event
var messageEvents []*gomatrixserverlib.Event
@ -604,6 +610,38 @@ func (rc *reqCtx) injectResponseToRoomserver(res *gomatrixserverlib.MSC2836Event
}
}
func (rc *reqCtx) addChildMetadata(ev *gomatrixserverlib.HeaderedEvent) {
children, err := rc.db.ChildrenForParent(rc.ctx, ev.EventID(), constRelType, false)
if err != nil {
util.GetLogger(rc.ctx).WithError(err).Warn("Failed to get ChildrenForParent for adding child metadata")
return
}
if len(children) == 0 {
return
}
// sort it lexiographically
sort.Slice(children, func(i, j int) bool {
return children[i].EventID < children[j].EventID
})
// hash it
var eventIDs strings.Builder
for _, c := range children {
_, _ = eventIDs.WriteString(c.EventID)
}
hashValBytes := sha256.Sum256([]byte(eventIDs.String()))
err = ev.SetUnsignedField("children_hash", gomatrixserverlib.Base64Bytes(hashValBytes[:]))
if err != nil {
util.GetLogger(rc.ctx).WithError(err).Warn("Failed to set children_hash")
}
err = ev.SetUnsignedField("children", map[string]int{
constRelType: len(children),
})
if err != nil {
util.GetLogger(rc.ctx).WithError(err).Warn("Failed to set children count")
}
}
type walkInfo struct {
eventInfo
SiblingNumber int

View file

@ -4,10 +4,14 @@ import (
"bytes"
"context"
"crypto/ed25519"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"sort"
"strings"
"testing"
"time"
@ -367,6 +371,21 @@ func TestMSC2836(t *testing.T) {
}))
assertContains(t, body, []string{eventF.EventID(), eventD.EventID(), eventB.EventID(), eventA.EventID()})
})
t.Run("includes children and children_hash in unsigned", func(t *testing.T) {
body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
"event_id": eventB.EventID(),
"recent_first": false,
"depth_first": false,
"limit": 3,
}))
// event B has C,D as children
// event C has no children
// event D has 3 children (not included in response)
assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID()})
assertUnsignedChildren(t, body.Events[0], "m.reference", 2, []string{eventC.EventID(), eventD.EventID()})
assertUnsignedChildren(t, body.Events[1], "", 0, nil)
assertUnsignedChildren(t, body.Events[2], "m.reference", 3, []string{eventE.EventID(), eventF.EventID(), eventG.EventID()})
})
}
// TODO: TestMSC2836TerminatesLoops (short and long)
@ -457,6 +476,43 @@ func assertContains(t *testing.T, result *msc2836.EventRelationshipResponse, wan
}
}
func assertUnsignedChildren(t *testing.T, ev gomatrixserverlib.ClientEvent, relType string, wantCount int, childrenEventIDs []string) {
t.Helper()
unsigned := struct {
Children map[string]int `json:"children"`
Hash string `json:"children_hash"`
}{}
if err := json.Unmarshal(ev.Unsigned, &unsigned); err != nil {
if wantCount == 0 {
return // no children so possible there is no unsigned field at all
}
t.Fatalf("Failed to unmarshal unsigned field: %s", err)
}
// zero checks
if wantCount == 0 {
if len(unsigned.Children) != 0 || unsigned.Hash != "" {
t.Fatalf("want 0 children but got unsigned fields %+v", unsigned)
}
return
}
gotCount := unsigned.Children[relType]
if gotCount != wantCount {
t.Errorf("Got %d count, want %d count for rel_type %s", gotCount, wantCount, relType)
}
// work out the hash
sort.Strings(childrenEventIDs)
var b strings.Builder
for _, s := range childrenEventIDs {
b.WriteString(s)
}
t.Logf("hashing %s", b.String())
hashValBytes := sha256.Sum256([]byte(b.String()))
wantHash := base64.RawStdEncoding.EncodeToString(hashValBytes[:])
if wantHash != unsigned.Hash {
t.Errorf("Got unsigned hash %s want hash %s", unsigned.Hash, wantHash)
}
}
type testUserAPI struct {
accessTokens map[string]userapi.Device
}