From 584e9fb11e22ff4df2e4432766d91df00ff7cf3a Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Thu, 19 Nov 2020 11:03:36 +0000 Subject: [PATCH] Fix tests; review comments --- internal/hooks/hooks.go | 12 +- internal/mscs/msc2836/msc2836.go | 69 ++++------ internal/mscs/msc2836/msc2836_test.go | 183 ++++++++++++++------------ roomserver/internal/input/input.go | 4 +- 4 files changed, 131 insertions(+), 137 deletions(-) diff --git a/internal/hooks/hooks.go b/internal/hooks/hooks.go index cf7d5c569..223282a25 100644 --- a/internal/hooks/hooks.go +++ b/internal/hooks/hooks.go @@ -19,20 +19,20 @@ package hooks import "sync" const ( - // KindNewEvent is a hook which is called with *gomatrixserverlib.HeaderedEvent + // KindNewEventPersisted is a hook which is called with *gomatrixserverlib.HeaderedEvent // It is run when a new event is persisted in the roomserver. // Usage: - // hooks.Attach(hooks.KindNewEvent, func(headeredEvent interface{}) { ... }) - KindNewEvent = "new_event" - // KindModifyNewEvent is a hook which is called with *gomatrixserverlib.HeaderedEvent + // hooks.Attach(hooks.KindNewEventPersisted, func(headeredEvent interface{}) { ... }) + KindNewEventPersisted = "new_event_persisted" + // KindNewEventReceived is a hook which is called with *gomatrixserverlib.HeaderedEvent // It is run before a new event is processed by the roomserver. This hook can be used // to modify the event before it is persisted by adding data to `unsigned`. // Usage: - // hooks.Attach(hooks.KindModifyNewEvent, func(headeredEvent interface{}) { + // hooks.Attach(hooks.KindNewEventReceived, func(headeredEvent interface{}) { // ev := headeredEvent.(*gomatrixserverlib.HeaderedEvent) // _ = ev.SetUnsignedField("key", "val") // }) - KindModifyNewEvent = "modify_new_event" + KindNewEventReceived = "new_event_received" ) var ( diff --git a/internal/mscs/msc2836/msc2836.go b/internal/mscs/msc2836/msc2836.go index b67200fba..865bc3111 100644 --- a/internal/mscs/msc2836/msc2836.go +++ b/internal/mscs/msc2836/msc2836.go @@ -46,10 +46,10 @@ type EventRelationshipRequest struct { MaxDepth int `json:"max_depth"` MaxBreadth int `json:"max_breadth"` Limit int `json:"limit"` - DepthFirst *bool `json:"depth_first"` - RecentFirst *bool `json:"recent_first"` - IncludeParent *bool `json:"include_parent"` - IncludeChildren *bool `json:"include_children"` + DepthFirst bool `json:"depth_first"` + RecentFirst bool `json:"recent_first"` + IncludeParent bool `json:"include_parent"` + IncludeChildren bool `json:"include_children"` Direction string `json:"direction"` Batch string `json:"batch"` AutoJoin bool `json:"auto_join"` @@ -57,41 +57,22 @@ type EventRelationshipRequest struct { func NewEventRelationshipRequest(body io.Reader) (*EventRelationshipRequest, error) { var relation EventRelationshipRequest + relation.Defaults() if err := json.NewDecoder(body).Decode(&relation); err != nil { return nil, err } - // Sanity check request and set defaults. - relation.applyDefaults() return &relation, nil } -func (r *EventRelationshipRequest) applyDefaults() { - if r.Limit > 100 || r.Limit < 1 { - r.Limit = 100 - } - if r.MaxBreadth == 0 { - r.MaxBreadth = 10 - } - if r.MaxDepth == 0 { - r.MaxDepth = 3 - } - t := true - f := false - if r.DepthFirst == nil { - r.DepthFirst = &f - } - if r.RecentFirst == nil { - r.RecentFirst = &t - } - if r.IncludeParent == nil { - r.IncludeParent = &f - } - if r.IncludeChildren == nil { - r.IncludeChildren = &f - } - if r.Direction != "up" { - r.Direction = "down" - } +func (r *EventRelationshipRequest) Defaults() { + r.Limit = 100 + r.MaxBreadth = 10 + r.MaxDepth = 3 + r.DepthFirst = false + r.RecentFirst = true + r.IncludeParent = false + r.IncludeChildren = false + r.Direction = "down" } type EventRelationshipResponse struct { @@ -111,7 +92,7 @@ func Enable( return fmt.Errorf("Cannot enable MSC2836: %w", err) } hooks.Enable() - hooks.Attach(hooks.KindNewEvent, func(headeredEvent interface{}) { + hooks.Attach(hooks.KindNewEventPersisted, func(headeredEvent interface{}) { he := headeredEvent.(*gomatrixserverlib.HeaderedEvent) hookErr := db.StoreRelation(context.Background(), he) if hookErr != nil { @@ -120,7 +101,7 @@ func Enable( ) } }) - hooks.Attach(hooks.KindModifyNewEvent, func(headeredEvent interface{}) { + hooks.Attach(hooks.KindNewEventReceived, func(headeredEvent interface{}) { he := headeredEvent.(*gomatrixserverlib.HeaderedEvent) ctx := context.Background() // we only inject metadata for events our server sends @@ -271,16 +252,16 @@ func (rc *reqCtx) process() (*EventRelationshipResponse, *util.JSONResponse) { // Retrieve the event. Add it to response array. returnEvents = append(returnEvents, event) - if *rc.req.IncludeParent { + if rc.req.IncludeParent { if parentEvent := rc.includeParent(event); parentEvent != nil { returnEvents = append(returnEvents, parentEvent) } } - if *rc.req.IncludeChildren { + if rc.req.IncludeChildren { remaining := rc.req.Limit - len(returnEvents) if remaining > 0 { - children, resErr := rc.includeChildren(rc.db, event.EventID(), remaining, *rc.req.RecentFirst) + children, resErr := rc.includeChildren(rc.db, event.EventID(), remaining, rc.req.RecentFirst) if resErr != nil { return nil, resErr } @@ -303,7 +284,7 @@ func (rc *reqCtx) process() (*EventRelationshipResponse, *util.JSONResponse) { } res.Events = make([]gomatrixserverlib.ClientEvent, len(returnEvents)) for i, ev := range returnEvents { - res.Events[i] = gomatrixserverlib.HeaderedToClientEvent(*ev, gomatrixserverlib.FormatAll) + res.Events[i] = gomatrixserverlib.HeaderedToClientEvent(ev, gomatrixserverlib.FormatAll) } res.Limited = remaining == 0 || walkLimited return &res, nil @@ -454,7 +435,7 @@ func getEventIfVisible(ctx context.Context, rsAPI roomserver.RoomserverInternalA util.GetLogger(ctx).WithError(err).Error("getEventIfVisible: failed to QueryMembershipForUser") return nil, false } - return &event, queryMembershipRes.IsInRoom + return event, queryMembershipRes.IsInRoom } type walkInfo struct { @@ -472,7 +453,7 @@ type walker struct { // WalkFrom the event ID given func (w *walker) WalkFrom(eventID string) (limited bool, err error) { - children, err := w.db.ChildrenForParent(w.ctx, eventID, constRelType, *w.req.RecentFirst) + children, err := w.db.ChildrenForParent(w.ctx, eventID, constRelType, w.req.RecentFirst) if err != nil { util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() ChildrenForParent failed, cannot walk") return false, err @@ -486,7 +467,7 @@ func (w *walker) WalkFrom(eventID string) (limited bool, err error) { return true, nil } // find the children's children - children, err = w.db.ChildrenForParent(w.ctx, next.EventID, constRelType, *w.req.RecentFirst) + children, err = w.db.ChildrenForParent(w.ctx, next.EventID, constRelType, w.req.RecentFirst) if err != nil { util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() ChildrenForParent failed, cannot walk") return false, err @@ -509,7 +490,7 @@ func (w *walker) addChildren(toWalk []walkInfo, children []eventInfo, depthOfChi return toWalk } - if *w.req.DepthFirst { + if w.req.DepthFirst { // the slice is a stack so push them in reverse order so we pop them in the correct order // e.g [3,2,1] => [3,2] , 1 => [3] , 2 => [] , 3 for i := len(children) - 1; i >= 0; i-- { @@ -538,7 +519,7 @@ func (w *walker) nextChild(toWalk []walkInfo) (*walkInfo, []walkInfo) { return nil, nil } var child walkInfo - if *w.req.DepthFirst { + if w.req.DepthFirst { // toWalk is a stack so pop the child off child, toWalk = toWalk[len(toWalk)-1], toWalk[:len(toWalk)-1] return &child, toWalk diff --git a/internal/mscs/msc2836/msc2836_test.go b/internal/mscs/msc2836/msc2836_test.go index 80ff98ffe..cbf8b726e 100644 --- a/internal/mscs/msc2836/msc2836_test.go +++ b/internal/mscs/msc2836/msc2836_test.go @@ -26,8 +26,6 @@ var ( client = &http.Client{ Timeout: 10 * time.Second, } - constTrue = true - constFalse = false ) // Basic sanity check of MSC2836 logic. Injects a thread that looks like: @@ -184,9 +182,9 @@ func TestMSC2836(t *testing.T) { defer cancel() t.Run("returns 403 on invalid event IDs", func(t *testing.T) { - _ = postRelationships(t, 403, "alice", &msc2836.EventRelationshipRequest{ - EventID: "$invalid", - }) + _ = postRelationships(t, 403, "alice", newReq(t, map[string]interface{}{ + "event_id": "$invalid", + })) }) t.Run("returns 403 if not joined to the room of specified event in request", func(t *testing.T) { nopUserAPI.accessTokens["frank"] = userapi.Device{ @@ -194,11 +192,11 @@ func TestMSC2836(t *testing.T) { DisplayName: "Frank Not In Room", UserID: "@frank:localhost", } - _ = postRelationships(t, 403, "frank", &msc2836.EventRelationshipRequest{ - EventID: eventB.EventID(), - Limit: 1, - IncludeParent: &constTrue, - }) + _ = postRelationships(t, 403, "frank", newReq(t, map[string]interface{}{ + "event_id": eventB.EventID(), + "limit": 1, + "include_parent": true, + })) }) t.Run("omits parent if not joined to the room of parent of event", func(t *testing.T) { nopUserAPI.accessTokens["frank2"] = userapi.Device{ @@ -208,44 +206,44 @@ func TestMSC2836(t *testing.T) { } // Event B is in roomB, Event A is in roomA, so make frank2 joined to roomB nopRsAPI.userToJoinedRooms["@frank2:localhost"] = []string{roomIDB} - body := postRelationships(t, 200, "frank2", &msc2836.EventRelationshipRequest{ - EventID: eventB.EventID(), - Limit: 1, - IncludeParent: &constTrue, - }) + body := postRelationships(t, 200, "frank2", newReq(t, map[string]interface{}{ + "event_id": eventB.EventID(), + "limit": 1, + "include_parent": true, + })) assertContains(t, body, []string{eventB.EventID()}) }) t.Run("returns the parent if include_parent is true", func(t *testing.T) { - body := postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{ - EventID: eventB.EventID(), - IncludeParent: &constTrue, - Limit: 2, - }) + body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventB.EventID(), + "include_parent": true, + "limit": 2, + })) assertContains(t, body, []string{eventB.EventID(), eventA.EventID()}) }) t.Run("returns the children in the right order if include_children is true", func(t *testing.T) { - body := postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{ - EventID: eventD.EventID(), - IncludeChildren: &constTrue, - RecentFirst: &constTrue, - Limit: 4, - }) + body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventD.EventID(), + "include_children": true, + "recent_first": true, + "limit": 4, + })) assertContains(t, body, []string{eventD.EventID(), eventG.EventID(), eventF.EventID(), eventE.EventID()}) - body = postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{ - EventID: eventD.EventID(), - IncludeChildren: &constTrue, - RecentFirst: &constFalse, - Limit: 4, - }) + body = postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventD.EventID(), + "include_children": true, + "recent_first": false, + "limit": 4, + })) assertContains(t, body, []string{eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID()}) }) t.Run("walks the graph depth first", func(t *testing.T) { - body := postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{ - EventID: eventB.EventID(), - RecentFirst: &constFalse, - DepthFirst: &constTrue, - Limit: 6, - }) + body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventB.EventID(), + "recent_first": false, + "depth_first": true, + "limit": 6, + })) // Oldest first so: // A // | @@ -257,12 +255,12 @@ func TestMSC2836(t *testing.T) { // | // 5H assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventH.EventID(), eventF.EventID()}) - body = postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{ - EventID: eventB.EventID(), - RecentFirst: &constTrue, - DepthFirst: &constTrue, - Limit: 6, - }) + body = postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventB.EventID(), + "recent_first": true, + "depth_first": true, + "limit": 6, + })) // Recent first so: // A // | @@ -276,12 +274,12 @@ func TestMSC2836(t *testing.T) { assertContains(t, body, []string{eventB.EventID(), eventD.EventID(), eventG.EventID(), eventF.EventID(), eventE.EventID(), eventH.EventID()}) }) t.Run("walks the graph breadth first", func(t *testing.T) { - body := postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{ - EventID: eventB.EventID(), - RecentFirst: &constFalse, - DepthFirst: &constFalse, - Limit: 6, - }) + body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventB.EventID(), + "recent_first": false, + "depth_first": false, + "limit": 6, + })) // Oldest first so: // A // | @@ -293,12 +291,12 @@ func TestMSC2836(t *testing.T) { // | // H assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID()}) - body = postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{ - EventID: eventB.EventID(), - RecentFirst: &constTrue, - DepthFirst: &constFalse, - Limit: 6, - }) + body = postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventB.EventID(), + "recent_first": true, + "depth_first": false, + "limit": 6, + })) // Recent first so: // A // | @@ -312,43 +310,43 @@ func TestMSC2836(t *testing.T) { assertContains(t, body, []string{eventB.EventID(), eventD.EventID(), eventC.EventID(), eventG.EventID(), eventF.EventID(), eventE.EventID()}) }) t.Run("caps via max_breadth", func(t *testing.T) { - body := postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{ - EventID: eventB.EventID(), - RecentFirst: &constFalse, - DepthFirst: &constFalse, - MaxBreadth: 2, - Limit: 10, - }) + body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventB.EventID(), + "recent_first": false, + "depth_first": false, + "max_breadth": 2, + "limit": 10, + })) // Event G gets omitted because of max_breadth assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventH.EventID()}) }) t.Run("caps via max_depth", func(t *testing.T) { - body := postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{ - EventID: eventB.EventID(), - RecentFirst: &constFalse, - DepthFirst: &constFalse, - MaxDepth: 2, - Limit: 10, - }) + body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventB.EventID(), + "recent_first": false, + "depth_first": false, + "max_depth": 2, + "limit": 10, + })) // Event H gets omitted because of max_depth assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID()}) }) t.Run("terminates when reaching the limit", func(t *testing.T) { - body := postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{ - EventID: eventB.EventID(), - RecentFirst: &constFalse, - DepthFirst: &constFalse, - Limit: 4, - }) + body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventB.EventID(), + "recent_first": false, + "depth_first": false, + "limit": 4, + })) assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID()}) }) t.Run("returns all events with a high enough limit", func(t *testing.T) { - body := postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{ - EventID: eventB.EventID(), - RecentFirst: &constFalse, - DepthFirst: &constFalse, - Limit: 400, - }) + body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{ + "event_id": eventB.EventID(), + "recent_first": false, + "depth_first": false, + "limit": 400, + })) assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID(), eventH.EventID()}) }) } @@ -357,6 +355,19 @@ func TestMSC2836(t *testing.T) { // TODO: TestMSC2836UnknownEventsSkipped // TODO: TestMSC2836SkipEventIfNotInRoom +func newReq(t *testing.T, jsonBody map[string]interface{}) *msc2836.EventRelationshipRequest { + t.Helper() + b, err := json.Marshal(jsonBody) + if err != nil { + t.Fatalf("Failed to marshal request: %s", err) + } + r, err := msc2836.NewEventRelationshipRequest(bytes.NewBuffer(b)) + if err != nil { + t.Fatalf("Failed to NewEventRelationshipRequest: %s", err) + } + return r +} + func runServer(t *testing.T, router *mux.Router) func() { t.Helper() externalServ := &http.Server{ @@ -376,6 +387,8 @@ func runServer(t *testing.T, router *mux.Router) func() { func postRelationships(t *testing.T, expectCode int, accessToken string, req *msc2836.EventRelationshipRequest) *msc2836.EventRelationshipResponse { t.Helper() + var r msc2836.EventRelationshipRequest + r.Defaults() data, err := json.Marshal(req) if err != nil { t.Fatalf("failed to marshal request: %s", err) @@ -484,7 +497,7 @@ func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver for _, eventID := range req.EventIDs { ev := r.events[eventID] if ev != nil { - res.Events = append(res.Events, *ev) + res.Events = append(res.Events, ev) } } return nil @@ -521,7 +534,7 @@ func injectEvents(t *testing.T, userAPI userapi.UserInternalAPI, rsAPI roomserve t.Fatalf("failed to enable MSC2836: %s", err) } for _, ev := range events { - hooks.Run(hooks.KindNewEvent, ev) + hooks.Run(hooks.KindNewEventPersisted, ev) } return base.PublicClientAPIMux } @@ -557,5 +570,5 @@ func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *gomatrixserverlib t.Fatalf("mustCreateEvent: failed to sign event: %s", err) } h := signedEvent.Headered(roomVer) - return &h + return h } diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index 4983a4434..79dc2fe14 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -62,10 +62,10 @@ func (w *inputWorker) start() { for { select { case task := <-w.input: - hooks.Run(hooks.KindModifyNewEvent, &task.event.Event) + hooks.Run(hooks.KindNewEventReceived, &task.event.Event) _, task.err = w.r.processRoomEvent(task.ctx, task.event) if task.err == nil { - hooks.Run(hooks.KindNewEvent, &task.event.Event) + hooks.Run(hooks.KindNewEventPersisted, &task.event.Event) } task.wg.Done() case <-time.After(time.Second * 5):