mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-26 00:03:09 -06:00
Add children and children_hash to unsigned, with tests
This commit is contained in:
parent
91ee096b26
commit
d222289821
|
|
@ -18,10 +18,13 @@ package msc2836
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
|
|
@ -277,6 +280,8 @@ func (rc *reqCtx) process() (*gomatrixserverlib.MSC2836EventRelationshipsRespons
|
||||||
}
|
}
|
||||||
res.Events = make([]*gomatrixserverlib.Event, len(returnEvents))
|
res.Events = make([]*gomatrixserverlib.Event, len(returnEvents))
|
||||||
for i, ev := range returnEvents {
|
for i, ev := range returnEvents {
|
||||||
|
// for each event, extract the children_count | hash and add it as unsigned data.
|
||||||
|
rc.addChildMetadata(ev)
|
||||||
res.Events[i] = ev.Unwrap()
|
res.Events[i] = ev.Unwrap()
|
||||||
}
|
}
|
||||||
res.Limited = remaining == 0 || walkLimited
|
res.Limited = remaining == 0 || walkLimited
|
||||||
|
|
@ -358,8 +363,7 @@ func walkThread(
|
||||||
return result, limited
|
return result, limited
|
||||||
}
|
}
|
||||||
|
|
||||||
// MSC2836EventRelationships performs an /event_relationships request to a remote server, injecting the resulting events
|
// MSC2836EventRelationships performs an /event_relationships request to a remote server
|
||||||
// into the roomserver as KindOutlier, with auth chains.
|
|
||||||
func (rc *reqCtx) MSC2836EventRelationships(eventID string, srv gomatrixserverlib.ServerName, ver gomatrixserverlib.RoomVersion) (*gomatrixserverlib.MSC2836EventRelationshipsResponse, error) {
|
func (rc *reqCtx) MSC2836EventRelationships(eventID string, srv gomatrixserverlib.ServerName, ver gomatrixserverlib.RoomVersion) (*gomatrixserverlib.MSC2836EventRelationshipsResponse, error) {
|
||||||
res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{
|
res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{
|
||||||
EventID: eventID,
|
EventID: eventID,
|
||||||
|
|
@ -568,6 +572,8 @@ func (rc *reqCtx) getLocalEvent(eventID string) *gomatrixserverlib.HeaderedEvent
|
||||||
return queryEventsRes.Events[0]
|
return queryEventsRes.Events[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// injectResponseToRoomserver injects the events
|
||||||
|
// into the roomserver as KindOutlier, with auth chains.
|
||||||
func (rc *reqCtx) injectResponseToRoomserver(res *gomatrixserverlib.MSC2836EventRelationshipsResponse) {
|
func (rc *reqCtx) injectResponseToRoomserver(res *gomatrixserverlib.MSC2836EventRelationshipsResponse) {
|
||||||
var stateEvents []*gomatrixserverlib.Event
|
var stateEvents []*gomatrixserverlib.Event
|
||||||
var messageEvents []*gomatrixserverlib.Event
|
var messageEvents []*gomatrixserverlib.Event
|
||||||
|
|
@ -604,6 +610,38 @@ func (rc *reqCtx) injectResponseToRoomserver(res *gomatrixserverlib.MSC2836Event
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rc *reqCtx) addChildMetadata(ev *gomatrixserverlib.HeaderedEvent) {
|
||||||
|
children, err := rc.db.ChildrenForParent(rc.ctx, ev.EventID(), constRelType, false)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(rc.ctx).WithError(err).Warn("Failed to get ChildrenForParent for adding child metadata")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(children) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// sort it lexiographically
|
||||||
|
sort.Slice(children, func(i, j int) bool {
|
||||||
|
return children[i].EventID < children[j].EventID
|
||||||
|
})
|
||||||
|
// hash it
|
||||||
|
var eventIDs strings.Builder
|
||||||
|
for _, c := range children {
|
||||||
|
_, _ = eventIDs.WriteString(c.EventID)
|
||||||
|
}
|
||||||
|
hashValBytes := sha256.Sum256([]byte(eventIDs.String()))
|
||||||
|
|
||||||
|
err = ev.SetUnsignedField("children_hash", gomatrixserverlib.Base64Bytes(hashValBytes[:]))
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(rc.ctx).WithError(err).Warn("Failed to set children_hash")
|
||||||
|
}
|
||||||
|
err = ev.SetUnsignedField("children", map[string]int{
|
||||||
|
constRelType: len(children),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(rc.ctx).WithError(err).Warn("Failed to set children count")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type walkInfo struct {
|
type walkInfo struct {
|
||||||
eventInfo
|
eventInfo
|
||||||
SiblingNumber int
|
SiblingNumber int
|
||||||
|
|
|
||||||
|
|
@ -4,10 +4,14 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
|
@ -367,6 +371,21 @@ func TestMSC2836(t *testing.T) {
|
||||||
}))
|
}))
|
||||||
assertContains(t, body, []string{eventF.EventID(), eventD.EventID(), eventB.EventID(), eventA.EventID()})
|
assertContains(t, body, []string{eventF.EventID(), eventD.EventID(), eventB.EventID(), eventA.EventID()})
|
||||||
})
|
})
|
||||||
|
t.Run("includes children and children_hash in unsigned", func(t *testing.T) {
|
||||||
|
body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
|
||||||
|
"event_id": eventB.EventID(),
|
||||||
|
"recent_first": false,
|
||||||
|
"depth_first": false,
|
||||||
|
"limit": 3,
|
||||||
|
}))
|
||||||
|
// event B has C,D as children
|
||||||
|
// event C has no children
|
||||||
|
// event D has 3 children (not included in response)
|
||||||
|
assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID()})
|
||||||
|
assertUnsignedChildren(t, body.Events[0], "m.reference", 2, []string{eventC.EventID(), eventD.EventID()})
|
||||||
|
assertUnsignedChildren(t, body.Events[1], "", 0, nil)
|
||||||
|
assertUnsignedChildren(t, body.Events[2], "m.reference", 3, []string{eventE.EventID(), eventF.EventID(), eventG.EventID()})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: TestMSC2836TerminatesLoops (short and long)
|
// TODO: TestMSC2836TerminatesLoops (short and long)
|
||||||
|
|
@ -457,6 +476,43 @@ func assertContains(t *testing.T, result *msc2836.EventRelationshipResponse, wan
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func assertUnsignedChildren(t *testing.T, ev gomatrixserverlib.ClientEvent, relType string, wantCount int, childrenEventIDs []string) {
|
||||||
|
t.Helper()
|
||||||
|
unsigned := struct {
|
||||||
|
Children map[string]int `json:"children"`
|
||||||
|
Hash string `json:"children_hash"`
|
||||||
|
}{}
|
||||||
|
if err := json.Unmarshal(ev.Unsigned, &unsigned); err != nil {
|
||||||
|
if wantCount == 0 {
|
||||||
|
return // no children so possible there is no unsigned field at all
|
||||||
|
}
|
||||||
|
t.Fatalf("Failed to unmarshal unsigned field: %s", err)
|
||||||
|
}
|
||||||
|
// zero checks
|
||||||
|
if wantCount == 0 {
|
||||||
|
if len(unsigned.Children) != 0 || unsigned.Hash != "" {
|
||||||
|
t.Fatalf("want 0 children but got unsigned fields %+v", unsigned)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
gotCount := unsigned.Children[relType]
|
||||||
|
if gotCount != wantCount {
|
||||||
|
t.Errorf("Got %d count, want %d count for rel_type %s", gotCount, wantCount, relType)
|
||||||
|
}
|
||||||
|
// work out the hash
|
||||||
|
sort.Strings(childrenEventIDs)
|
||||||
|
var b strings.Builder
|
||||||
|
for _, s := range childrenEventIDs {
|
||||||
|
b.WriteString(s)
|
||||||
|
}
|
||||||
|
t.Logf("hashing %s", b.String())
|
||||||
|
hashValBytes := sha256.Sum256([]byte(b.String()))
|
||||||
|
wantHash := base64.RawStdEncoding.EncodeToString(hashValBytes[:])
|
||||||
|
if wantHash != unsigned.Hash {
|
||||||
|
t.Errorf("Got unsigned hash %s want hash %s", unsigned.Hash, wantHash)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type testUserAPI struct {
|
type testUserAPI struct {
|
||||||
accessTokens map[string]userapi.Device
|
accessTokens map[string]userapi.Device
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue