mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-26 00:03:09 -06:00
Fix tests; review comments
This commit is contained in:
parent
c388f6cf4b
commit
584e9fb11e
|
|
@ -19,20 +19,20 @@ package hooks
|
||||||
import "sync"
|
import "sync"
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// KindNewEvent is a hook which is called with *gomatrixserverlib.HeaderedEvent
|
// KindNewEventPersisted is a hook which is called with *gomatrixserverlib.HeaderedEvent
|
||||||
// It is run when a new event is persisted in the roomserver.
|
// It is run when a new event is persisted in the roomserver.
|
||||||
// Usage:
|
// Usage:
|
||||||
// hooks.Attach(hooks.KindNewEvent, func(headeredEvent interface{}) { ... })
|
// hooks.Attach(hooks.KindNewEventPersisted, func(headeredEvent interface{}) { ... })
|
||||||
KindNewEvent = "new_event"
|
KindNewEventPersisted = "new_event_persisted"
|
||||||
// KindModifyNewEvent is a hook which is called with *gomatrixserverlib.HeaderedEvent
|
// KindNewEventReceived is a hook which is called with *gomatrixserverlib.HeaderedEvent
|
||||||
// It is run before a new event is processed by the roomserver. This hook can be used
|
// It is run before a new event is processed by the roomserver. This hook can be used
|
||||||
// to modify the event before it is persisted by adding data to `unsigned`.
|
// to modify the event before it is persisted by adding data to `unsigned`.
|
||||||
// Usage:
|
// Usage:
|
||||||
// hooks.Attach(hooks.KindModifyNewEvent, func(headeredEvent interface{}) {
|
// hooks.Attach(hooks.KindNewEventReceived, func(headeredEvent interface{}) {
|
||||||
// ev := headeredEvent.(*gomatrixserverlib.HeaderedEvent)
|
// ev := headeredEvent.(*gomatrixserverlib.HeaderedEvent)
|
||||||
// _ = ev.SetUnsignedField("key", "val")
|
// _ = ev.SetUnsignedField("key", "val")
|
||||||
// })
|
// })
|
||||||
KindModifyNewEvent = "modify_new_event"
|
KindNewEventReceived = "new_event_received"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
|
||||||
|
|
@ -46,10 +46,10 @@ type EventRelationshipRequest struct {
|
||||||
MaxDepth int `json:"max_depth"`
|
MaxDepth int `json:"max_depth"`
|
||||||
MaxBreadth int `json:"max_breadth"`
|
MaxBreadth int `json:"max_breadth"`
|
||||||
Limit int `json:"limit"`
|
Limit int `json:"limit"`
|
||||||
DepthFirst *bool `json:"depth_first"`
|
DepthFirst bool `json:"depth_first"`
|
||||||
RecentFirst *bool `json:"recent_first"`
|
RecentFirst bool `json:"recent_first"`
|
||||||
IncludeParent *bool `json:"include_parent"`
|
IncludeParent bool `json:"include_parent"`
|
||||||
IncludeChildren *bool `json:"include_children"`
|
IncludeChildren bool `json:"include_children"`
|
||||||
Direction string `json:"direction"`
|
Direction string `json:"direction"`
|
||||||
Batch string `json:"batch"`
|
Batch string `json:"batch"`
|
||||||
AutoJoin bool `json:"auto_join"`
|
AutoJoin bool `json:"auto_join"`
|
||||||
|
|
@ -57,42 +57,23 @@ type EventRelationshipRequest struct {
|
||||||
|
|
||||||
func NewEventRelationshipRequest(body io.Reader) (*EventRelationshipRequest, error) {
|
func NewEventRelationshipRequest(body io.Reader) (*EventRelationshipRequest, error) {
|
||||||
var relation EventRelationshipRequest
|
var relation EventRelationshipRequest
|
||||||
|
relation.Defaults()
|
||||||
if err := json.NewDecoder(body).Decode(&relation); err != nil {
|
if err := json.NewDecoder(body).Decode(&relation); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// Sanity check request and set defaults.
|
|
||||||
relation.applyDefaults()
|
|
||||||
return &relation, nil
|
return &relation, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *EventRelationshipRequest) applyDefaults() {
|
func (r *EventRelationshipRequest) Defaults() {
|
||||||
if r.Limit > 100 || r.Limit < 1 {
|
|
||||||
r.Limit = 100
|
r.Limit = 100
|
||||||
}
|
|
||||||
if r.MaxBreadth == 0 {
|
|
||||||
r.MaxBreadth = 10
|
r.MaxBreadth = 10
|
||||||
}
|
|
||||||
if r.MaxDepth == 0 {
|
|
||||||
r.MaxDepth = 3
|
r.MaxDepth = 3
|
||||||
}
|
r.DepthFirst = false
|
||||||
t := true
|
r.RecentFirst = true
|
||||||
f := false
|
r.IncludeParent = false
|
||||||
if r.DepthFirst == nil {
|
r.IncludeChildren = false
|
||||||
r.DepthFirst = &f
|
|
||||||
}
|
|
||||||
if r.RecentFirst == nil {
|
|
||||||
r.RecentFirst = &t
|
|
||||||
}
|
|
||||||
if r.IncludeParent == nil {
|
|
||||||
r.IncludeParent = &f
|
|
||||||
}
|
|
||||||
if r.IncludeChildren == nil {
|
|
||||||
r.IncludeChildren = &f
|
|
||||||
}
|
|
||||||
if r.Direction != "up" {
|
|
||||||
r.Direction = "down"
|
r.Direction = "down"
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
type EventRelationshipResponse struct {
|
type EventRelationshipResponse struct {
|
||||||
Events []gomatrixserverlib.ClientEvent `json:"events"`
|
Events []gomatrixserverlib.ClientEvent `json:"events"`
|
||||||
|
|
@ -111,7 +92,7 @@ func Enable(
|
||||||
return fmt.Errorf("Cannot enable MSC2836: %w", err)
|
return fmt.Errorf("Cannot enable MSC2836: %w", err)
|
||||||
}
|
}
|
||||||
hooks.Enable()
|
hooks.Enable()
|
||||||
hooks.Attach(hooks.KindNewEvent, func(headeredEvent interface{}) {
|
hooks.Attach(hooks.KindNewEventPersisted, func(headeredEvent interface{}) {
|
||||||
he := headeredEvent.(*gomatrixserverlib.HeaderedEvent)
|
he := headeredEvent.(*gomatrixserverlib.HeaderedEvent)
|
||||||
hookErr := db.StoreRelation(context.Background(), he)
|
hookErr := db.StoreRelation(context.Background(), he)
|
||||||
if hookErr != nil {
|
if hookErr != nil {
|
||||||
|
|
@ -120,7 +101,7 @@ func Enable(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
hooks.Attach(hooks.KindModifyNewEvent, func(headeredEvent interface{}) {
|
hooks.Attach(hooks.KindNewEventReceived, func(headeredEvent interface{}) {
|
||||||
he := headeredEvent.(*gomatrixserverlib.HeaderedEvent)
|
he := headeredEvent.(*gomatrixserverlib.HeaderedEvent)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
// we only inject metadata for events our server sends
|
// we only inject metadata for events our server sends
|
||||||
|
|
@ -271,16 +252,16 @@ func (rc *reqCtx) process() (*EventRelationshipResponse, *util.JSONResponse) {
|
||||||
// Retrieve the event. Add it to response array.
|
// Retrieve the event. Add it to response array.
|
||||||
returnEvents = append(returnEvents, event)
|
returnEvents = append(returnEvents, event)
|
||||||
|
|
||||||
if *rc.req.IncludeParent {
|
if rc.req.IncludeParent {
|
||||||
if parentEvent := rc.includeParent(event); parentEvent != nil {
|
if parentEvent := rc.includeParent(event); parentEvent != nil {
|
||||||
returnEvents = append(returnEvents, parentEvent)
|
returnEvents = append(returnEvents, parentEvent)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if *rc.req.IncludeChildren {
|
if rc.req.IncludeChildren {
|
||||||
remaining := rc.req.Limit - len(returnEvents)
|
remaining := rc.req.Limit - len(returnEvents)
|
||||||
if remaining > 0 {
|
if remaining > 0 {
|
||||||
children, resErr := rc.includeChildren(rc.db, event.EventID(), remaining, *rc.req.RecentFirst)
|
children, resErr := rc.includeChildren(rc.db, event.EventID(), remaining, rc.req.RecentFirst)
|
||||||
if resErr != nil {
|
if resErr != nil {
|
||||||
return nil, resErr
|
return nil, resErr
|
||||||
}
|
}
|
||||||
|
|
@ -303,7 +284,7 @@ func (rc *reqCtx) process() (*EventRelationshipResponse, *util.JSONResponse) {
|
||||||
}
|
}
|
||||||
res.Events = make([]gomatrixserverlib.ClientEvent, len(returnEvents))
|
res.Events = make([]gomatrixserverlib.ClientEvent, len(returnEvents))
|
||||||
for i, ev := range returnEvents {
|
for i, ev := range returnEvents {
|
||||||
res.Events[i] = gomatrixserverlib.HeaderedToClientEvent(*ev, gomatrixserverlib.FormatAll)
|
res.Events[i] = gomatrixserverlib.HeaderedToClientEvent(ev, gomatrixserverlib.FormatAll)
|
||||||
}
|
}
|
||||||
res.Limited = remaining == 0 || walkLimited
|
res.Limited = remaining == 0 || walkLimited
|
||||||
return &res, nil
|
return &res, nil
|
||||||
|
|
@ -454,7 +435,7 @@ func getEventIfVisible(ctx context.Context, rsAPI roomserver.RoomserverInternalA
|
||||||
util.GetLogger(ctx).WithError(err).Error("getEventIfVisible: failed to QueryMembershipForUser")
|
util.GetLogger(ctx).WithError(err).Error("getEventIfVisible: failed to QueryMembershipForUser")
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
return &event, queryMembershipRes.IsInRoom
|
return event, queryMembershipRes.IsInRoom
|
||||||
}
|
}
|
||||||
|
|
||||||
type walkInfo struct {
|
type walkInfo struct {
|
||||||
|
|
@ -472,7 +453,7 @@ type walker struct {
|
||||||
|
|
||||||
// WalkFrom the event ID given
|
// WalkFrom the event ID given
|
||||||
func (w *walker) WalkFrom(eventID string) (limited bool, err error) {
|
func (w *walker) WalkFrom(eventID string) (limited bool, err error) {
|
||||||
children, err := w.db.ChildrenForParent(w.ctx, eventID, constRelType, *w.req.RecentFirst)
|
children, err := w.db.ChildrenForParent(w.ctx, eventID, constRelType, w.req.RecentFirst)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() ChildrenForParent failed, cannot walk")
|
util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() ChildrenForParent failed, cannot walk")
|
||||||
return false, err
|
return false, err
|
||||||
|
|
@ -486,7 +467,7 @@ func (w *walker) WalkFrom(eventID string) (limited bool, err error) {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
// find the children's children
|
// find the children's children
|
||||||
children, err = w.db.ChildrenForParent(w.ctx, next.EventID, constRelType, *w.req.RecentFirst)
|
children, err = w.db.ChildrenForParent(w.ctx, next.EventID, constRelType, w.req.RecentFirst)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() ChildrenForParent failed, cannot walk")
|
util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() ChildrenForParent failed, cannot walk")
|
||||||
return false, err
|
return false, err
|
||||||
|
|
@ -509,7 +490,7 @@ func (w *walker) addChildren(toWalk []walkInfo, children []eventInfo, depthOfChi
|
||||||
return toWalk
|
return toWalk
|
||||||
}
|
}
|
||||||
|
|
||||||
if *w.req.DepthFirst {
|
if w.req.DepthFirst {
|
||||||
// the slice is a stack so push them in reverse order so we pop them in the correct order
|
// the slice is a stack so push them in reverse order so we pop them in the correct order
|
||||||
// e.g [3,2,1] => [3,2] , 1 => [3] , 2 => [] , 3
|
// e.g [3,2,1] => [3,2] , 1 => [3] , 2 => [] , 3
|
||||||
for i := len(children) - 1; i >= 0; i-- {
|
for i := len(children) - 1; i >= 0; i-- {
|
||||||
|
|
@ -538,7 +519,7 @@ func (w *walker) nextChild(toWalk []walkInfo) (*walkInfo, []walkInfo) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
var child walkInfo
|
var child walkInfo
|
||||||
if *w.req.DepthFirst {
|
if w.req.DepthFirst {
|
||||||
// toWalk is a stack so pop the child off
|
// toWalk is a stack so pop the child off
|
||||||
child, toWalk = toWalk[len(toWalk)-1], toWalk[:len(toWalk)-1]
|
child, toWalk = toWalk[len(toWalk)-1], toWalk[:len(toWalk)-1]
|
||||||
return &child, toWalk
|
return &child, toWalk
|
||||||
|
|
|
||||||
|
|
@ -26,8 +26,6 @@ var (
|
||||||
client = &http.Client{
|
client = &http.Client{
|
||||||
Timeout: 10 * time.Second,
|
Timeout: 10 * time.Second,
|
||||||
}
|
}
|
||||||
constTrue = true
|
|
||||||
constFalse = false
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Basic sanity check of MSC2836 logic. Injects a thread that looks like:
|
// Basic sanity check of MSC2836 logic. Injects a thread that looks like:
|
||||||
|
|
@ -184,9 +182,9 @@ func TestMSC2836(t *testing.T) {
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
t.Run("returns 403 on invalid event IDs", func(t *testing.T) {
|
t.Run("returns 403 on invalid event IDs", func(t *testing.T) {
|
||||||
_ = postRelationships(t, 403, "alice", &msc2836.EventRelationshipRequest{
|
_ = postRelationships(t, 403, "alice", newReq(t, map[string]interface{}{
|
||||||
EventID: "$invalid",
|
"event_id": "$invalid",
|
||||||
})
|
}))
|
||||||
})
|
})
|
||||||
t.Run("returns 403 if not joined to the room of specified event in request", func(t *testing.T) {
|
t.Run("returns 403 if not joined to the room of specified event in request", func(t *testing.T) {
|
||||||
nopUserAPI.accessTokens["frank"] = userapi.Device{
|
nopUserAPI.accessTokens["frank"] = userapi.Device{
|
||||||
|
|
@ -194,11 +192,11 @@ func TestMSC2836(t *testing.T) {
|
||||||
DisplayName: "Frank Not In Room",
|
DisplayName: "Frank Not In Room",
|
||||||
UserID: "@frank:localhost",
|
UserID: "@frank:localhost",
|
||||||
}
|
}
|
||||||
_ = postRelationships(t, 403, "frank", &msc2836.EventRelationshipRequest{
|
_ = postRelationships(t, 403, "frank", newReq(t, map[string]interface{}{
|
||||||
EventID: eventB.EventID(),
|
"event_id": eventB.EventID(),
|
||||||
Limit: 1,
|
"limit": 1,
|
||||||
IncludeParent: &constTrue,
|
"include_parent": true,
|
||||||
})
|
}))
|
||||||
})
|
})
|
||||||
t.Run("omits parent if not joined to the room of parent of event", func(t *testing.T) {
|
t.Run("omits parent if not joined to the room of parent of event", func(t *testing.T) {
|
||||||
nopUserAPI.accessTokens["frank2"] = userapi.Device{
|
nopUserAPI.accessTokens["frank2"] = userapi.Device{
|
||||||
|
|
@ -208,44 +206,44 @@ func TestMSC2836(t *testing.T) {
|
||||||
}
|
}
|
||||||
// Event B is in roomB, Event A is in roomA, so make frank2 joined to roomB
|
// Event B is in roomB, Event A is in roomA, so make frank2 joined to roomB
|
||||||
nopRsAPI.userToJoinedRooms["@frank2:localhost"] = []string{roomIDB}
|
nopRsAPI.userToJoinedRooms["@frank2:localhost"] = []string{roomIDB}
|
||||||
body := postRelationships(t, 200, "frank2", &msc2836.EventRelationshipRequest{
|
body := postRelationships(t, 200, "frank2", newReq(t, map[string]interface{}{
|
||||||
EventID: eventB.EventID(),
|
"event_id": eventB.EventID(),
|
||||||
Limit: 1,
|
"limit": 1,
|
||||||
IncludeParent: &constTrue,
|
"include_parent": true,
|
||||||
})
|
}))
|
||||||
assertContains(t, body, []string{eventB.EventID()})
|
assertContains(t, body, []string{eventB.EventID()})
|
||||||
})
|
})
|
||||||
t.Run("returns the parent if include_parent is true", func(t *testing.T) {
|
t.Run("returns the parent if include_parent is true", func(t *testing.T) {
|
||||||
body := postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{
|
body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
|
||||||
EventID: eventB.EventID(),
|
"event_id": eventB.EventID(),
|
||||||
IncludeParent: &constTrue,
|
"include_parent": true,
|
||||||
Limit: 2,
|
"limit": 2,
|
||||||
})
|
}))
|
||||||
assertContains(t, body, []string{eventB.EventID(), eventA.EventID()})
|
assertContains(t, body, []string{eventB.EventID(), eventA.EventID()})
|
||||||
})
|
})
|
||||||
t.Run("returns the children in the right order if include_children is true", func(t *testing.T) {
|
t.Run("returns the children in the right order if include_children is true", func(t *testing.T) {
|
||||||
body := postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{
|
body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
|
||||||
EventID: eventD.EventID(),
|
"event_id": eventD.EventID(),
|
||||||
IncludeChildren: &constTrue,
|
"include_children": true,
|
||||||
RecentFirst: &constTrue,
|
"recent_first": true,
|
||||||
Limit: 4,
|
"limit": 4,
|
||||||
})
|
}))
|
||||||
assertContains(t, body, []string{eventD.EventID(), eventG.EventID(), eventF.EventID(), eventE.EventID()})
|
assertContains(t, body, []string{eventD.EventID(), eventG.EventID(), eventF.EventID(), eventE.EventID()})
|
||||||
body = postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{
|
body = postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
|
||||||
EventID: eventD.EventID(),
|
"event_id": eventD.EventID(),
|
||||||
IncludeChildren: &constTrue,
|
"include_children": true,
|
||||||
RecentFirst: &constFalse,
|
"recent_first": false,
|
||||||
Limit: 4,
|
"limit": 4,
|
||||||
})
|
}))
|
||||||
assertContains(t, body, []string{eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID()})
|
assertContains(t, body, []string{eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID()})
|
||||||
})
|
})
|
||||||
t.Run("walks the graph depth first", func(t *testing.T) {
|
t.Run("walks the graph depth first", func(t *testing.T) {
|
||||||
body := postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{
|
body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
|
||||||
EventID: eventB.EventID(),
|
"event_id": eventB.EventID(),
|
||||||
RecentFirst: &constFalse,
|
"recent_first": false,
|
||||||
DepthFirst: &constTrue,
|
"depth_first": true,
|
||||||
Limit: 6,
|
"limit": 6,
|
||||||
})
|
}))
|
||||||
// Oldest first so:
|
// Oldest first so:
|
||||||
// A
|
// A
|
||||||
// |
|
// |
|
||||||
|
|
@ -257,12 +255,12 @@ func TestMSC2836(t *testing.T) {
|
||||||
// |
|
// |
|
||||||
// 5H
|
// 5H
|
||||||
assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventH.EventID(), eventF.EventID()})
|
assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventH.EventID(), eventF.EventID()})
|
||||||
body = postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{
|
body = postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
|
||||||
EventID: eventB.EventID(),
|
"event_id": eventB.EventID(),
|
||||||
RecentFirst: &constTrue,
|
"recent_first": true,
|
||||||
DepthFirst: &constTrue,
|
"depth_first": true,
|
||||||
Limit: 6,
|
"limit": 6,
|
||||||
})
|
}))
|
||||||
// Recent first so:
|
// Recent first so:
|
||||||
// A
|
// A
|
||||||
// |
|
// |
|
||||||
|
|
@ -276,12 +274,12 @@ func TestMSC2836(t *testing.T) {
|
||||||
assertContains(t, body, []string{eventB.EventID(), eventD.EventID(), eventG.EventID(), eventF.EventID(), eventE.EventID(), eventH.EventID()})
|
assertContains(t, body, []string{eventB.EventID(), eventD.EventID(), eventG.EventID(), eventF.EventID(), eventE.EventID(), eventH.EventID()})
|
||||||
})
|
})
|
||||||
t.Run("walks the graph breadth first", func(t *testing.T) {
|
t.Run("walks the graph breadth first", func(t *testing.T) {
|
||||||
body := postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{
|
body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
|
||||||
EventID: eventB.EventID(),
|
"event_id": eventB.EventID(),
|
||||||
RecentFirst: &constFalse,
|
"recent_first": false,
|
||||||
DepthFirst: &constFalse,
|
"depth_first": false,
|
||||||
Limit: 6,
|
"limit": 6,
|
||||||
})
|
}))
|
||||||
// Oldest first so:
|
// Oldest first so:
|
||||||
// A
|
// A
|
||||||
// |
|
// |
|
||||||
|
|
@ -293,12 +291,12 @@ func TestMSC2836(t *testing.T) {
|
||||||
// |
|
// |
|
||||||
// H
|
// H
|
||||||
assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID()})
|
assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID()})
|
||||||
body = postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{
|
body = postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
|
||||||
EventID: eventB.EventID(),
|
"event_id": eventB.EventID(),
|
||||||
RecentFirst: &constTrue,
|
"recent_first": true,
|
||||||
DepthFirst: &constFalse,
|
"depth_first": false,
|
||||||
Limit: 6,
|
"limit": 6,
|
||||||
})
|
}))
|
||||||
// Recent first so:
|
// Recent first so:
|
||||||
// A
|
// A
|
||||||
// |
|
// |
|
||||||
|
|
@ -312,43 +310,43 @@ func TestMSC2836(t *testing.T) {
|
||||||
assertContains(t, body, []string{eventB.EventID(), eventD.EventID(), eventC.EventID(), eventG.EventID(), eventF.EventID(), eventE.EventID()})
|
assertContains(t, body, []string{eventB.EventID(), eventD.EventID(), eventC.EventID(), eventG.EventID(), eventF.EventID(), eventE.EventID()})
|
||||||
})
|
})
|
||||||
t.Run("caps via max_breadth", func(t *testing.T) {
|
t.Run("caps via max_breadth", func(t *testing.T) {
|
||||||
body := postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{
|
body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
|
||||||
EventID: eventB.EventID(),
|
"event_id": eventB.EventID(),
|
||||||
RecentFirst: &constFalse,
|
"recent_first": false,
|
||||||
DepthFirst: &constFalse,
|
"depth_first": false,
|
||||||
MaxBreadth: 2,
|
"max_breadth": 2,
|
||||||
Limit: 10,
|
"limit": 10,
|
||||||
})
|
}))
|
||||||
// Event G gets omitted because of max_breadth
|
// Event G gets omitted because of max_breadth
|
||||||
assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventH.EventID()})
|
assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventH.EventID()})
|
||||||
})
|
})
|
||||||
t.Run("caps via max_depth", func(t *testing.T) {
|
t.Run("caps via max_depth", func(t *testing.T) {
|
||||||
body := postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{
|
body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
|
||||||
EventID: eventB.EventID(),
|
"event_id": eventB.EventID(),
|
||||||
RecentFirst: &constFalse,
|
"recent_first": false,
|
||||||
DepthFirst: &constFalse,
|
"depth_first": false,
|
||||||
MaxDepth: 2,
|
"max_depth": 2,
|
||||||
Limit: 10,
|
"limit": 10,
|
||||||
})
|
}))
|
||||||
// Event H gets omitted because of max_depth
|
// Event H gets omitted because of max_depth
|
||||||
assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID()})
|
assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID()})
|
||||||
})
|
})
|
||||||
t.Run("terminates when reaching the limit", func(t *testing.T) {
|
t.Run("terminates when reaching the limit", func(t *testing.T) {
|
||||||
body := postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{
|
body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
|
||||||
EventID: eventB.EventID(),
|
"event_id": eventB.EventID(),
|
||||||
RecentFirst: &constFalse,
|
"recent_first": false,
|
||||||
DepthFirst: &constFalse,
|
"depth_first": false,
|
||||||
Limit: 4,
|
"limit": 4,
|
||||||
})
|
}))
|
||||||
assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID()})
|
assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID()})
|
||||||
})
|
})
|
||||||
t.Run("returns all events with a high enough limit", func(t *testing.T) {
|
t.Run("returns all events with a high enough limit", func(t *testing.T) {
|
||||||
body := postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{
|
body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
|
||||||
EventID: eventB.EventID(),
|
"event_id": eventB.EventID(),
|
||||||
RecentFirst: &constFalse,
|
"recent_first": false,
|
||||||
DepthFirst: &constFalse,
|
"depth_first": false,
|
||||||
Limit: 400,
|
"limit": 400,
|
||||||
})
|
}))
|
||||||
assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID(), eventH.EventID()})
|
assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID(), eventH.EventID()})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
@ -357,6 +355,19 @@ func TestMSC2836(t *testing.T) {
|
||||||
// TODO: TestMSC2836UnknownEventsSkipped
|
// TODO: TestMSC2836UnknownEventsSkipped
|
||||||
// TODO: TestMSC2836SkipEventIfNotInRoom
|
// TODO: TestMSC2836SkipEventIfNotInRoom
|
||||||
|
|
||||||
|
func newReq(t *testing.T, jsonBody map[string]interface{}) *msc2836.EventRelationshipRequest {
|
||||||
|
t.Helper()
|
||||||
|
b, err := json.Marshal(jsonBody)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to marshal request: %s", err)
|
||||||
|
}
|
||||||
|
r, err := msc2836.NewEventRelationshipRequest(bytes.NewBuffer(b))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to NewEventRelationshipRequest: %s", err)
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
func runServer(t *testing.T, router *mux.Router) func() {
|
func runServer(t *testing.T, router *mux.Router) func() {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
externalServ := &http.Server{
|
externalServ := &http.Server{
|
||||||
|
|
@ -376,6 +387,8 @@ func runServer(t *testing.T, router *mux.Router) func() {
|
||||||
|
|
||||||
func postRelationships(t *testing.T, expectCode int, accessToken string, req *msc2836.EventRelationshipRequest) *msc2836.EventRelationshipResponse {
|
func postRelationships(t *testing.T, expectCode int, accessToken string, req *msc2836.EventRelationshipRequest) *msc2836.EventRelationshipResponse {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
var r msc2836.EventRelationshipRequest
|
||||||
|
r.Defaults()
|
||||||
data, err := json.Marshal(req)
|
data, err := json.Marshal(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to marshal request: %s", err)
|
t.Fatalf("failed to marshal request: %s", err)
|
||||||
|
|
@ -484,7 +497,7 @@ func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver
|
||||||
for _, eventID := range req.EventIDs {
|
for _, eventID := range req.EventIDs {
|
||||||
ev := r.events[eventID]
|
ev := r.events[eventID]
|
||||||
if ev != nil {
|
if ev != nil {
|
||||||
res.Events = append(res.Events, *ev)
|
res.Events = append(res.Events, ev)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
@ -521,7 +534,7 @@ func injectEvents(t *testing.T, userAPI userapi.UserInternalAPI, rsAPI roomserve
|
||||||
t.Fatalf("failed to enable MSC2836: %s", err)
|
t.Fatalf("failed to enable MSC2836: %s", err)
|
||||||
}
|
}
|
||||||
for _, ev := range events {
|
for _, ev := range events {
|
||||||
hooks.Run(hooks.KindNewEvent, ev)
|
hooks.Run(hooks.KindNewEventPersisted, ev)
|
||||||
}
|
}
|
||||||
return base.PublicClientAPIMux
|
return base.PublicClientAPIMux
|
||||||
}
|
}
|
||||||
|
|
@ -557,5 +570,5 @@ func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *gomatrixserverlib
|
||||||
t.Fatalf("mustCreateEvent: failed to sign event: %s", err)
|
t.Fatalf("mustCreateEvent: failed to sign event: %s", err)
|
||||||
}
|
}
|
||||||
h := signedEvent.Headered(roomVer)
|
h := signedEvent.Headered(roomVer)
|
||||||
return &h
|
return h
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -62,10 +62,10 @@ func (w *inputWorker) start() {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case task := <-w.input:
|
case task := <-w.input:
|
||||||
hooks.Run(hooks.KindModifyNewEvent, &task.event.Event)
|
hooks.Run(hooks.KindNewEventReceived, &task.event.Event)
|
||||||
_, task.err = w.r.processRoomEvent(task.ctx, task.event)
|
_, task.err = w.r.processRoomEvent(task.ctx, task.event)
|
||||||
if task.err == nil {
|
if task.err == nil {
|
||||||
hooks.Run(hooks.KindNewEvent, &task.event.Event)
|
hooks.Run(hooks.KindNewEventPersisted, &task.event.Event)
|
||||||
}
|
}
|
||||||
task.wg.Done()
|
task.wg.Done()
|
||||||
case <-time.After(time.Second * 5):
|
case <-time.After(time.Second * 5):
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue