Fix tests; review comments

This commit is contained in:
Kegan Dougal 2020-11-19 11:03:36 +00:00
parent c388f6cf4b
commit 584e9fb11e
4 changed files with 131 additions and 137 deletions

View file

@ -19,20 +19,20 @@ package hooks
import "sync"
const (
// KindNewEvent is a hook which is called with *gomatrixserverlib.HeaderedEvent
// KindNewEventPersisted is a hook which is called with *gomatrixserverlib.HeaderedEvent
// It is run when a new event is persisted in the roomserver.
// Usage:
// hooks.Attach(hooks.KindNewEvent, func(headeredEvent interface{}) { ... })
KindNewEvent = "new_event"
// KindModifyNewEvent is a hook which is called with *gomatrixserverlib.HeaderedEvent
// hooks.Attach(hooks.KindNewEventPersisted, func(headeredEvent interface{}) { ... })
KindNewEventPersisted = "new_event_persisted"
// KindNewEventReceived is a hook which is called with *gomatrixserverlib.HeaderedEvent
// It is run before a new event is processed by the roomserver. This hook can be used
// to modify the event before it is persisted by adding data to `unsigned`.
// Usage:
// hooks.Attach(hooks.KindModifyNewEvent, func(headeredEvent interface{}) {
// hooks.Attach(hooks.KindNewEventReceived, func(headeredEvent interface{}) {
// ev := headeredEvent.(*gomatrixserverlib.HeaderedEvent)
// _ = ev.SetUnsignedField("key", "val")
// })
KindModifyNewEvent = "modify_new_event"
KindNewEventReceived = "new_event_received"
)
var (

View file

@ -46,10 +46,10 @@ type EventRelationshipRequest struct {
MaxDepth int `json:"max_depth"`
MaxBreadth int `json:"max_breadth"`
Limit int `json:"limit"`
DepthFirst *bool `json:"depth_first"`
RecentFirst *bool `json:"recent_first"`
IncludeParent *bool `json:"include_parent"`
IncludeChildren *bool `json:"include_children"`
DepthFirst bool `json:"depth_first"`
RecentFirst bool `json:"recent_first"`
IncludeParent bool `json:"include_parent"`
IncludeChildren bool `json:"include_children"`
Direction string `json:"direction"`
Batch string `json:"batch"`
AutoJoin bool `json:"auto_join"`
@ -57,42 +57,23 @@ type EventRelationshipRequest struct {
func NewEventRelationshipRequest(body io.Reader) (*EventRelationshipRequest, error) {
var relation EventRelationshipRequest
relation.Defaults()
if err := json.NewDecoder(body).Decode(&relation); err != nil {
return nil, err
}
// Sanity check request and set defaults.
relation.applyDefaults()
return &relation, nil
}
func (r *EventRelationshipRequest) applyDefaults() {
if r.Limit > 100 || r.Limit < 1 {
func (r *EventRelationshipRequest) Defaults() {
r.Limit = 100
}
if r.MaxBreadth == 0 {
r.MaxBreadth = 10
}
if r.MaxDepth == 0 {
r.MaxDepth = 3
}
t := true
f := false
if r.DepthFirst == nil {
r.DepthFirst = &f
}
if r.RecentFirst == nil {
r.RecentFirst = &t
}
if r.IncludeParent == nil {
r.IncludeParent = &f
}
if r.IncludeChildren == nil {
r.IncludeChildren = &f
}
if r.Direction != "up" {
r.DepthFirst = false
r.RecentFirst = true
r.IncludeParent = false
r.IncludeChildren = false
r.Direction = "down"
}
}
type EventRelationshipResponse struct {
Events []gomatrixserverlib.ClientEvent `json:"events"`
@ -111,7 +92,7 @@ func Enable(
return fmt.Errorf("Cannot enable MSC2836: %w", err)
}
hooks.Enable()
hooks.Attach(hooks.KindNewEvent, func(headeredEvent interface{}) {
hooks.Attach(hooks.KindNewEventPersisted, func(headeredEvent interface{}) {
he := headeredEvent.(*gomatrixserverlib.HeaderedEvent)
hookErr := db.StoreRelation(context.Background(), he)
if hookErr != nil {
@ -120,7 +101,7 @@ func Enable(
)
}
})
hooks.Attach(hooks.KindModifyNewEvent, func(headeredEvent interface{}) {
hooks.Attach(hooks.KindNewEventReceived, func(headeredEvent interface{}) {
he := headeredEvent.(*gomatrixserverlib.HeaderedEvent)
ctx := context.Background()
// we only inject metadata for events our server sends
@ -271,16 +252,16 @@ func (rc *reqCtx) process() (*EventRelationshipResponse, *util.JSONResponse) {
// Retrieve the event. Add it to response array.
returnEvents = append(returnEvents, event)
if *rc.req.IncludeParent {
if rc.req.IncludeParent {
if parentEvent := rc.includeParent(event); parentEvent != nil {
returnEvents = append(returnEvents, parentEvent)
}
}
if *rc.req.IncludeChildren {
if rc.req.IncludeChildren {
remaining := rc.req.Limit - len(returnEvents)
if remaining > 0 {
children, resErr := rc.includeChildren(rc.db, event.EventID(), remaining, *rc.req.RecentFirst)
children, resErr := rc.includeChildren(rc.db, event.EventID(), remaining, rc.req.RecentFirst)
if resErr != nil {
return nil, resErr
}
@ -303,7 +284,7 @@ func (rc *reqCtx) process() (*EventRelationshipResponse, *util.JSONResponse) {
}
res.Events = make([]gomatrixserverlib.ClientEvent, len(returnEvents))
for i, ev := range returnEvents {
res.Events[i] = gomatrixserverlib.HeaderedToClientEvent(*ev, gomatrixserverlib.FormatAll)
res.Events[i] = gomatrixserverlib.HeaderedToClientEvent(ev, gomatrixserverlib.FormatAll)
}
res.Limited = remaining == 0 || walkLimited
return &res, nil
@ -454,7 +435,7 @@ func getEventIfVisible(ctx context.Context, rsAPI roomserver.RoomserverInternalA
util.GetLogger(ctx).WithError(err).Error("getEventIfVisible: failed to QueryMembershipForUser")
return nil, false
}
return &event, queryMembershipRes.IsInRoom
return event, queryMembershipRes.IsInRoom
}
type walkInfo struct {
@ -472,7 +453,7 @@ type walker struct {
// WalkFrom the event ID given
func (w *walker) WalkFrom(eventID string) (limited bool, err error) {
children, err := w.db.ChildrenForParent(w.ctx, eventID, constRelType, *w.req.RecentFirst)
children, err := w.db.ChildrenForParent(w.ctx, eventID, constRelType, w.req.RecentFirst)
if err != nil {
util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() ChildrenForParent failed, cannot walk")
return false, err
@ -486,7 +467,7 @@ func (w *walker) WalkFrom(eventID string) (limited bool, err error) {
return true, nil
}
// find the children's children
children, err = w.db.ChildrenForParent(w.ctx, next.EventID, constRelType, *w.req.RecentFirst)
children, err = w.db.ChildrenForParent(w.ctx, next.EventID, constRelType, w.req.RecentFirst)
if err != nil {
util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() ChildrenForParent failed, cannot walk")
return false, err
@ -509,7 +490,7 @@ func (w *walker) addChildren(toWalk []walkInfo, children []eventInfo, depthOfChi
return toWalk
}
if *w.req.DepthFirst {
if w.req.DepthFirst {
// the slice is a stack so push them in reverse order so we pop them in the correct order
// e.g [3,2,1] => [3,2] , 1 => [3] , 2 => [] , 3
for i := len(children) - 1; i >= 0; i-- {
@ -538,7 +519,7 @@ func (w *walker) nextChild(toWalk []walkInfo) (*walkInfo, []walkInfo) {
return nil, nil
}
var child walkInfo
if *w.req.DepthFirst {
if w.req.DepthFirst {
// toWalk is a stack so pop the child off
child, toWalk = toWalk[len(toWalk)-1], toWalk[:len(toWalk)-1]
return &child, toWalk

View file

@ -26,8 +26,6 @@ var (
client = &http.Client{
Timeout: 10 * time.Second,
}
constTrue = true
constFalse = false
)
// Basic sanity check of MSC2836 logic. Injects a thread that looks like:
@ -184,9 +182,9 @@ func TestMSC2836(t *testing.T) {
defer cancel()
t.Run("returns 403 on invalid event IDs", func(t *testing.T) {
_ = postRelationships(t, 403, "alice", &msc2836.EventRelationshipRequest{
EventID: "$invalid",
})
_ = postRelationships(t, 403, "alice", newReq(t, map[string]interface{}{
"event_id": "$invalid",
}))
})
t.Run("returns 403 if not joined to the room of specified event in request", func(t *testing.T) {
nopUserAPI.accessTokens["frank"] = userapi.Device{
@ -194,11 +192,11 @@ func TestMSC2836(t *testing.T) {
DisplayName: "Frank Not In Room",
UserID: "@frank:localhost",
}
_ = postRelationships(t, 403, "frank", &msc2836.EventRelationshipRequest{
EventID: eventB.EventID(),
Limit: 1,
IncludeParent: &constTrue,
})
_ = postRelationships(t, 403, "frank", newReq(t, map[string]interface{}{
"event_id": eventB.EventID(),
"limit": 1,
"include_parent": true,
}))
})
t.Run("omits parent if not joined to the room of parent of event", func(t *testing.T) {
nopUserAPI.accessTokens["frank2"] = userapi.Device{
@ -208,44 +206,44 @@ func TestMSC2836(t *testing.T) {
}
// Event B is in roomB, Event A is in roomA, so make frank2 joined to roomB
nopRsAPI.userToJoinedRooms["@frank2:localhost"] = []string{roomIDB}
body := postRelationships(t, 200, "frank2", &msc2836.EventRelationshipRequest{
EventID: eventB.EventID(),
Limit: 1,
IncludeParent: &constTrue,
})
body := postRelationships(t, 200, "frank2", newReq(t, map[string]interface{}{
"event_id": eventB.EventID(),
"limit": 1,
"include_parent": true,
}))
assertContains(t, body, []string{eventB.EventID()})
})
t.Run("returns the parent if include_parent is true", func(t *testing.T) {
body := postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{
EventID: eventB.EventID(),
IncludeParent: &constTrue,
Limit: 2,
})
body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
"event_id": eventB.EventID(),
"include_parent": true,
"limit": 2,
}))
assertContains(t, body, []string{eventB.EventID(), eventA.EventID()})
})
t.Run("returns the children in the right order if include_children is true", func(t *testing.T) {
body := postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{
EventID: eventD.EventID(),
IncludeChildren: &constTrue,
RecentFirst: &constTrue,
Limit: 4,
})
body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
"event_id": eventD.EventID(),
"include_children": true,
"recent_first": true,
"limit": 4,
}))
assertContains(t, body, []string{eventD.EventID(), eventG.EventID(), eventF.EventID(), eventE.EventID()})
body = postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{
EventID: eventD.EventID(),
IncludeChildren: &constTrue,
RecentFirst: &constFalse,
Limit: 4,
})
body = postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
"event_id": eventD.EventID(),
"include_children": true,
"recent_first": false,
"limit": 4,
}))
assertContains(t, body, []string{eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID()})
})
t.Run("walks the graph depth first", func(t *testing.T) {
body := postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{
EventID: eventB.EventID(),
RecentFirst: &constFalse,
DepthFirst: &constTrue,
Limit: 6,
})
body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
"event_id": eventB.EventID(),
"recent_first": false,
"depth_first": true,
"limit": 6,
}))
// Oldest first so:
// A
// |
@ -257,12 +255,12 @@ func TestMSC2836(t *testing.T) {
// |
// 5H
assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventH.EventID(), eventF.EventID()})
body = postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{
EventID: eventB.EventID(),
RecentFirst: &constTrue,
DepthFirst: &constTrue,
Limit: 6,
})
body = postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
"event_id": eventB.EventID(),
"recent_first": true,
"depth_first": true,
"limit": 6,
}))
// Recent first so:
// A
// |
@ -276,12 +274,12 @@ func TestMSC2836(t *testing.T) {
assertContains(t, body, []string{eventB.EventID(), eventD.EventID(), eventG.EventID(), eventF.EventID(), eventE.EventID(), eventH.EventID()})
})
t.Run("walks the graph breadth first", func(t *testing.T) {
body := postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{
EventID: eventB.EventID(),
RecentFirst: &constFalse,
DepthFirst: &constFalse,
Limit: 6,
})
body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
"event_id": eventB.EventID(),
"recent_first": false,
"depth_first": false,
"limit": 6,
}))
// Oldest first so:
// A
// |
@ -293,12 +291,12 @@ func TestMSC2836(t *testing.T) {
// |
// H
assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID()})
body = postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{
EventID: eventB.EventID(),
RecentFirst: &constTrue,
DepthFirst: &constFalse,
Limit: 6,
})
body = postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
"event_id": eventB.EventID(),
"recent_first": true,
"depth_first": false,
"limit": 6,
}))
// Recent first so:
// A
// |
@ -312,43 +310,43 @@ func TestMSC2836(t *testing.T) {
assertContains(t, body, []string{eventB.EventID(), eventD.EventID(), eventC.EventID(), eventG.EventID(), eventF.EventID(), eventE.EventID()})
})
t.Run("caps via max_breadth", func(t *testing.T) {
body := postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{
EventID: eventB.EventID(),
RecentFirst: &constFalse,
DepthFirst: &constFalse,
MaxBreadth: 2,
Limit: 10,
})
body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
"event_id": eventB.EventID(),
"recent_first": false,
"depth_first": false,
"max_breadth": 2,
"limit": 10,
}))
// Event G gets omitted because of max_breadth
assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventH.EventID()})
})
t.Run("caps via max_depth", func(t *testing.T) {
body := postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{
EventID: eventB.EventID(),
RecentFirst: &constFalse,
DepthFirst: &constFalse,
MaxDepth: 2,
Limit: 10,
})
body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
"event_id": eventB.EventID(),
"recent_first": false,
"depth_first": false,
"max_depth": 2,
"limit": 10,
}))
// Event H gets omitted because of max_depth
assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID()})
})
t.Run("terminates when reaching the limit", func(t *testing.T) {
body := postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{
EventID: eventB.EventID(),
RecentFirst: &constFalse,
DepthFirst: &constFalse,
Limit: 4,
})
body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
"event_id": eventB.EventID(),
"recent_first": false,
"depth_first": false,
"limit": 4,
}))
assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID()})
})
t.Run("returns all events with a high enough limit", func(t *testing.T) {
body := postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{
EventID: eventB.EventID(),
RecentFirst: &constFalse,
DepthFirst: &constFalse,
Limit: 400,
})
body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
"event_id": eventB.EventID(),
"recent_first": false,
"depth_first": false,
"limit": 400,
}))
assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID(), eventH.EventID()})
})
}
@ -357,6 +355,19 @@ func TestMSC2836(t *testing.T) {
// TODO: TestMSC2836UnknownEventsSkipped
// TODO: TestMSC2836SkipEventIfNotInRoom
func newReq(t *testing.T, jsonBody map[string]interface{}) *msc2836.EventRelationshipRequest {
t.Helper()
b, err := json.Marshal(jsonBody)
if err != nil {
t.Fatalf("Failed to marshal request: %s", err)
}
r, err := msc2836.NewEventRelationshipRequest(bytes.NewBuffer(b))
if err != nil {
t.Fatalf("Failed to NewEventRelationshipRequest: %s", err)
}
return r
}
func runServer(t *testing.T, router *mux.Router) func() {
t.Helper()
externalServ := &http.Server{
@ -376,6 +387,8 @@ func runServer(t *testing.T, router *mux.Router) func() {
func postRelationships(t *testing.T, expectCode int, accessToken string, req *msc2836.EventRelationshipRequest) *msc2836.EventRelationshipResponse {
t.Helper()
var r msc2836.EventRelationshipRequest
r.Defaults()
data, err := json.Marshal(req)
if err != nil {
t.Fatalf("failed to marshal request: %s", err)
@ -484,7 +497,7 @@ func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver
for _, eventID := range req.EventIDs {
ev := r.events[eventID]
if ev != nil {
res.Events = append(res.Events, *ev)
res.Events = append(res.Events, ev)
}
}
return nil
@ -521,7 +534,7 @@ func injectEvents(t *testing.T, userAPI userapi.UserInternalAPI, rsAPI roomserve
t.Fatalf("failed to enable MSC2836: %s", err)
}
for _, ev := range events {
hooks.Run(hooks.KindNewEvent, ev)
hooks.Run(hooks.KindNewEventPersisted, ev)
}
return base.PublicClientAPIMux
}
@ -557,5 +570,5 @@ func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *gomatrixserverlib
t.Fatalf("mustCreateEvent: failed to sign event: %s", err)
}
h := signedEvent.Headered(roomVer)
return &h
return h
}

View file

@ -62,10 +62,10 @@ func (w *inputWorker) start() {
for {
select {
case task := <-w.input:
hooks.Run(hooks.KindModifyNewEvent, &task.event.Event)
hooks.Run(hooks.KindNewEventReceived, &task.event.Event)
_, task.err = w.r.processRoomEvent(task.ctx, task.event)
if task.err == nil {
hooks.Run(hooks.KindNewEvent, &task.event.Event)
hooks.Run(hooks.KindNewEventPersisted, &task.event.Event)
}
task.wg.Done()
case <-time.After(time.Second * 5):