Fix tests; review comments

This commit is contained in:
Kegan Dougal 2020-11-19 11:03:36 +00:00
parent c388f6cf4b
commit 584e9fb11e
4 changed files with 131 additions and 137 deletions

View file

@ -19,20 +19,20 @@ package hooks
import "sync" import "sync"
const ( const (
// KindNewEvent is a hook which is called with *gomatrixserverlib.HeaderedEvent // KindNewEventPersisted is a hook which is called with *gomatrixserverlib.HeaderedEvent
// It is run when a new event is persisted in the roomserver. // It is run when a new event is persisted in the roomserver.
// Usage: // Usage:
// hooks.Attach(hooks.KindNewEvent, func(headeredEvent interface{}) { ... }) // hooks.Attach(hooks.KindNewEventPersisted, func(headeredEvent interface{}) { ... })
KindNewEvent = "new_event" KindNewEventPersisted = "new_event_persisted"
// KindModifyNewEvent is a hook which is called with *gomatrixserverlib.HeaderedEvent // KindNewEventReceived is a hook which is called with *gomatrixserverlib.HeaderedEvent
// It is run before a new event is processed by the roomserver. This hook can be used // It is run before a new event is processed by the roomserver. This hook can be used
// to modify the event before it is persisted by adding data to `unsigned`. // to modify the event before it is persisted by adding data to `unsigned`.
// Usage: // Usage:
// hooks.Attach(hooks.KindModifyNewEvent, func(headeredEvent interface{}) { // hooks.Attach(hooks.KindNewEventReceived, func(headeredEvent interface{}) {
// ev := headeredEvent.(*gomatrixserverlib.HeaderedEvent) // ev := headeredEvent.(*gomatrixserverlib.HeaderedEvent)
// _ = ev.SetUnsignedField("key", "val") // _ = ev.SetUnsignedField("key", "val")
// }) // })
KindModifyNewEvent = "modify_new_event" KindNewEventReceived = "new_event_received"
) )
var ( var (

View file

@ -46,10 +46,10 @@ type EventRelationshipRequest struct {
MaxDepth int `json:"max_depth"` MaxDepth int `json:"max_depth"`
MaxBreadth int `json:"max_breadth"` MaxBreadth int `json:"max_breadth"`
Limit int `json:"limit"` Limit int `json:"limit"`
DepthFirst *bool `json:"depth_first"` DepthFirst bool `json:"depth_first"`
RecentFirst *bool `json:"recent_first"` RecentFirst bool `json:"recent_first"`
IncludeParent *bool `json:"include_parent"` IncludeParent bool `json:"include_parent"`
IncludeChildren *bool `json:"include_children"` IncludeChildren bool `json:"include_children"`
Direction string `json:"direction"` Direction string `json:"direction"`
Batch string `json:"batch"` Batch string `json:"batch"`
AutoJoin bool `json:"auto_join"` AutoJoin bool `json:"auto_join"`
@ -57,42 +57,23 @@ type EventRelationshipRequest struct {
func NewEventRelationshipRequest(body io.Reader) (*EventRelationshipRequest, error) { func NewEventRelationshipRequest(body io.Reader) (*EventRelationshipRequest, error) {
var relation EventRelationshipRequest var relation EventRelationshipRequest
relation.Defaults()
if err := json.NewDecoder(body).Decode(&relation); err != nil { if err := json.NewDecoder(body).Decode(&relation); err != nil {
return nil, err return nil, err
} }
// Sanity check request and set defaults.
relation.applyDefaults()
return &relation, nil return &relation, nil
} }
func (r *EventRelationshipRequest) applyDefaults() { func (r *EventRelationshipRequest) Defaults() {
if r.Limit > 100 || r.Limit < 1 {
r.Limit = 100 r.Limit = 100
}
if r.MaxBreadth == 0 {
r.MaxBreadth = 10 r.MaxBreadth = 10
}
if r.MaxDepth == 0 {
r.MaxDepth = 3 r.MaxDepth = 3
} r.DepthFirst = false
t := true r.RecentFirst = true
f := false r.IncludeParent = false
if r.DepthFirst == nil { r.IncludeChildren = false
r.DepthFirst = &f
}
if r.RecentFirst == nil {
r.RecentFirst = &t
}
if r.IncludeParent == nil {
r.IncludeParent = &f
}
if r.IncludeChildren == nil {
r.IncludeChildren = &f
}
if r.Direction != "up" {
r.Direction = "down" r.Direction = "down"
} }
}
type EventRelationshipResponse struct { type EventRelationshipResponse struct {
Events []gomatrixserverlib.ClientEvent `json:"events"` Events []gomatrixserverlib.ClientEvent `json:"events"`
@ -111,7 +92,7 @@ func Enable(
return fmt.Errorf("Cannot enable MSC2836: %w", err) return fmt.Errorf("Cannot enable MSC2836: %w", err)
} }
hooks.Enable() hooks.Enable()
hooks.Attach(hooks.KindNewEvent, func(headeredEvent interface{}) { hooks.Attach(hooks.KindNewEventPersisted, func(headeredEvent interface{}) {
he := headeredEvent.(*gomatrixserverlib.HeaderedEvent) he := headeredEvent.(*gomatrixserverlib.HeaderedEvent)
hookErr := db.StoreRelation(context.Background(), he) hookErr := db.StoreRelation(context.Background(), he)
if hookErr != nil { if hookErr != nil {
@ -120,7 +101,7 @@ func Enable(
) )
} }
}) })
hooks.Attach(hooks.KindModifyNewEvent, func(headeredEvent interface{}) { hooks.Attach(hooks.KindNewEventReceived, func(headeredEvent interface{}) {
he := headeredEvent.(*gomatrixserverlib.HeaderedEvent) he := headeredEvent.(*gomatrixserverlib.HeaderedEvent)
ctx := context.Background() ctx := context.Background()
// we only inject metadata for events our server sends // we only inject metadata for events our server sends
@ -271,16 +252,16 @@ func (rc *reqCtx) process() (*EventRelationshipResponse, *util.JSONResponse) {
// Retrieve the event. Add it to response array. // Retrieve the event. Add it to response array.
returnEvents = append(returnEvents, event) returnEvents = append(returnEvents, event)
if *rc.req.IncludeParent { if rc.req.IncludeParent {
if parentEvent := rc.includeParent(event); parentEvent != nil { if parentEvent := rc.includeParent(event); parentEvent != nil {
returnEvents = append(returnEvents, parentEvent) returnEvents = append(returnEvents, parentEvent)
} }
} }
if *rc.req.IncludeChildren { if rc.req.IncludeChildren {
remaining := rc.req.Limit - len(returnEvents) remaining := rc.req.Limit - len(returnEvents)
if remaining > 0 { if remaining > 0 {
children, resErr := rc.includeChildren(rc.db, event.EventID(), remaining, *rc.req.RecentFirst) children, resErr := rc.includeChildren(rc.db, event.EventID(), remaining, rc.req.RecentFirst)
if resErr != nil { if resErr != nil {
return nil, resErr return nil, resErr
} }
@ -303,7 +284,7 @@ func (rc *reqCtx) process() (*EventRelationshipResponse, *util.JSONResponse) {
} }
res.Events = make([]gomatrixserverlib.ClientEvent, len(returnEvents)) res.Events = make([]gomatrixserverlib.ClientEvent, len(returnEvents))
for i, ev := range returnEvents { for i, ev := range returnEvents {
res.Events[i] = gomatrixserverlib.HeaderedToClientEvent(*ev, gomatrixserverlib.FormatAll) res.Events[i] = gomatrixserverlib.HeaderedToClientEvent(ev, gomatrixserverlib.FormatAll)
} }
res.Limited = remaining == 0 || walkLimited res.Limited = remaining == 0 || walkLimited
return &res, nil return &res, nil
@ -454,7 +435,7 @@ func getEventIfVisible(ctx context.Context, rsAPI roomserver.RoomserverInternalA
util.GetLogger(ctx).WithError(err).Error("getEventIfVisible: failed to QueryMembershipForUser") util.GetLogger(ctx).WithError(err).Error("getEventIfVisible: failed to QueryMembershipForUser")
return nil, false return nil, false
} }
return &event, queryMembershipRes.IsInRoom return event, queryMembershipRes.IsInRoom
} }
type walkInfo struct { type walkInfo struct {
@ -472,7 +453,7 @@ type walker struct {
// WalkFrom the event ID given // WalkFrom the event ID given
func (w *walker) WalkFrom(eventID string) (limited bool, err error) { func (w *walker) WalkFrom(eventID string) (limited bool, err error) {
children, err := w.db.ChildrenForParent(w.ctx, eventID, constRelType, *w.req.RecentFirst) children, err := w.db.ChildrenForParent(w.ctx, eventID, constRelType, w.req.RecentFirst)
if err != nil { if err != nil {
util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() ChildrenForParent failed, cannot walk") util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() ChildrenForParent failed, cannot walk")
return false, err return false, err
@ -486,7 +467,7 @@ func (w *walker) WalkFrom(eventID string) (limited bool, err error) {
return true, nil return true, nil
} }
// find the children's children // find the children's children
children, err = w.db.ChildrenForParent(w.ctx, next.EventID, constRelType, *w.req.RecentFirst) children, err = w.db.ChildrenForParent(w.ctx, next.EventID, constRelType, w.req.RecentFirst)
if err != nil { if err != nil {
util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() ChildrenForParent failed, cannot walk") util.GetLogger(w.ctx).WithError(err).Error("WalkFrom() ChildrenForParent failed, cannot walk")
return false, err return false, err
@ -509,7 +490,7 @@ func (w *walker) addChildren(toWalk []walkInfo, children []eventInfo, depthOfChi
return toWalk return toWalk
} }
if *w.req.DepthFirst { if w.req.DepthFirst {
// the slice is a stack so push them in reverse order so we pop them in the correct order // the slice is a stack so push them in reverse order so we pop them in the correct order
// e.g [3,2,1] => [3,2] , 1 => [3] , 2 => [] , 3 // e.g [3,2,1] => [3,2] , 1 => [3] , 2 => [] , 3
for i := len(children) - 1; i >= 0; i-- { for i := len(children) - 1; i >= 0; i-- {
@ -538,7 +519,7 @@ func (w *walker) nextChild(toWalk []walkInfo) (*walkInfo, []walkInfo) {
return nil, nil return nil, nil
} }
var child walkInfo var child walkInfo
if *w.req.DepthFirst { if w.req.DepthFirst {
// toWalk is a stack so pop the child off // toWalk is a stack so pop the child off
child, toWalk = toWalk[len(toWalk)-1], toWalk[:len(toWalk)-1] child, toWalk = toWalk[len(toWalk)-1], toWalk[:len(toWalk)-1]
return &child, toWalk return &child, toWalk

View file

@ -26,8 +26,6 @@ var (
client = &http.Client{ client = &http.Client{
Timeout: 10 * time.Second, Timeout: 10 * time.Second,
} }
constTrue = true
constFalse = false
) )
// Basic sanity check of MSC2836 logic. Injects a thread that looks like: // Basic sanity check of MSC2836 logic. Injects a thread that looks like:
@ -184,9 +182,9 @@ func TestMSC2836(t *testing.T) {
defer cancel() defer cancel()
t.Run("returns 403 on invalid event IDs", func(t *testing.T) { t.Run("returns 403 on invalid event IDs", func(t *testing.T) {
_ = postRelationships(t, 403, "alice", &msc2836.EventRelationshipRequest{ _ = postRelationships(t, 403, "alice", newReq(t, map[string]interface{}{
EventID: "$invalid", "event_id": "$invalid",
}) }))
}) })
t.Run("returns 403 if not joined to the room of specified event in request", func(t *testing.T) { t.Run("returns 403 if not joined to the room of specified event in request", func(t *testing.T) {
nopUserAPI.accessTokens["frank"] = userapi.Device{ nopUserAPI.accessTokens["frank"] = userapi.Device{
@ -194,11 +192,11 @@ func TestMSC2836(t *testing.T) {
DisplayName: "Frank Not In Room", DisplayName: "Frank Not In Room",
UserID: "@frank:localhost", UserID: "@frank:localhost",
} }
_ = postRelationships(t, 403, "frank", &msc2836.EventRelationshipRequest{ _ = postRelationships(t, 403, "frank", newReq(t, map[string]interface{}{
EventID: eventB.EventID(), "event_id": eventB.EventID(),
Limit: 1, "limit": 1,
IncludeParent: &constTrue, "include_parent": true,
}) }))
}) })
t.Run("omits parent if not joined to the room of parent of event", func(t *testing.T) { t.Run("omits parent if not joined to the room of parent of event", func(t *testing.T) {
nopUserAPI.accessTokens["frank2"] = userapi.Device{ nopUserAPI.accessTokens["frank2"] = userapi.Device{
@ -208,44 +206,44 @@ func TestMSC2836(t *testing.T) {
} }
// Event B is in roomB, Event A is in roomA, so make frank2 joined to roomB // Event B is in roomB, Event A is in roomA, so make frank2 joined to roomB
nopRsAPI.userToJoinedRooms["@frank2:localhost"] = []string{roomIDB} nopRsAPI.userToJoinedRooms["@frank2:localhost"] = []string{roomIDB}
body := postRelationships(t, 200, "frank2", &msc2836.EventRelationshipRequest{ body := postRelationships(t, 200, "frank2", newReq(t, map[string]interface{}{
EventID: eventB.EventID(), "event_id": eventB.EventID(),
Limit: 1, "limit": 1,
IncludeParent: &constTrue, "include_parent": true,
}) }))
assertContains(t, body, []string{eventB.EventID()}) assertContains(t, body, []string{eventB.EventID()})
}) })
t.Run("returns the parent if include_parent is true", func(t *testing.T) { t.Run("returns the parent if include_parent is true", func(t *testing.T) {
body := postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{ body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
EventID: eventB.EventID(), "event_id": eventB.EventID(),
IncludeParent: &constTrue, "include_parent": true,
Limit: 2, "limit": 2,
}) }))
assertContains(t, body, []string{eventB.EventID(), eventA.EventID()}) assertContains(t, body, []string{eventB.EventID(), eventA.EventID()})
}) })
t.Run("returns the children in the right order if include_children is true", func(t *testing.T) { t.Run("returns the children in the right order if include_children is true", func(t *testing.T) {
body := postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{ body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
EventID: eventD.EventID(), "event_id": eventD.EventID(),
IncludeChildren: &constTrue, "include_children": true,
RecentFirst: &constTrue, "recent_first": true,
Limit: 4, "limit": 4,
}) }))
assertContains(t, body, []string{eventD.EventID(), eventG.EventID(), eventF.EventID(), eventE.EventID()}) assertContains(t, body, []string{eventD.EventID(), eventG.EventID(), eventF.EventID(), eventE.EventID()})
body = postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{ body = postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
EventID: eventD.EventID(), "event_id": eventD.EventID(),
IncludeChildren: &constTrue, "include_children": true,
RecentFirst: &constFalse, "recent_first": false,
Limit: 4, "limit": 4,
}) }))
assertContains(t, body, []string{eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID()}) assertContains(t, body, []string{eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID()})
}) })
t.Run("walks the graph depth first", func(t *testing.T) { t.Run("walks the graph depth first", func(t *testing.T) {
body := postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{ body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
EventID: eventB.EventID(), "event_id": eventB.EventID(),
RecentFirst: &constFalse, "recent_first": false,
DepthFirst: &constTrue, "depth_first": true,
Limit: 6, "limit": 6,
}) }))
// Oldest first so: // Oldest first so:
// A // A
// | // |
@ -257,12 +255,12 @@ func TestMSC2836(t *testing.T) {
// | // |
// 5H // 5H
assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventH.EventID(), eventF.EventID()}) assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventH.EventID(), eventF.EventID()})
body = postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{ body = postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
EventID: eventB.EventID(), "event_id": eventB.EventID(),
RecentFirst: &constTrue, "recent_first": true,
DepthFirst: &constTrue, "depth_first": true,
Limit: 6, "limit": 6,
}) }))
// Recent first so: // Recent first so:
// A // A
// | // |
@ -276,12 +274,12 @@ func TestMSC2836(t *testing.T) {
assertContains(t, body, []string{eventB.EventID(), eventD.EventID(), eventG.EventID(), eventF.EventID(), eventE.EventID(), eventH.EventID()}) assertContains(t, body, []string{eventB.EventID(), eventD.EventID(), eventG.EventID(), eventF.EventID(), eventE.EventID(), eventH.EventID()})
}) })
t.Run("walks the graph breadth first", func(t *testing.T) { t.Run("walks the graph breadth first", func(t *testing.T) {
body := postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{ body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
EventID: eventB.EventID(), "event_id": eventB.EventID(),
RecentFirst: &constFalse, "recent_first": false,
DepthFirst: &constFalse, "depth_first": false,
Limit: 6, "limit": 6,
}) }))
// Oldest first so: // Oldest first so:
// A // A
// | // |
@ -293,12 +291,12 @@ func TestMSC2836(t *testing.T) {
// | // |
// H // H
assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID()}) assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID()})
body = postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{ body = postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
EventID: eventB.EventID(), "event_id": eventB.EventID(),
RecentFirst: &constTrue, "recent_first": true,
DepthFirst: &constFalse, "depth_first": false,
Limit: 6, "limit": 6,
}) }))
// Recent first so: // Recent first so:
// A // A
// | // |
@ -312,43 +310,43 @@ func TestMSC2836(t *testing.T) {
assertContains(t, body, []string{eventB.EventID(), eventD.EventID(), eventC.EventID(), eventG.EventID(), eventF.EventID(), eventE.EventID()}) assertContains(t, body, []string{eventB.EventID(), eventD.EventID(), eventC.EventID(), eventG.EventID(), eventF.EventID(), eventE.EventID()})
}) })
t.Run("caps via max_breadth", func(t *testing.T) { t.Run("caps via max_breadth", func(t *testing.T) {
body := postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{ body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
EventID: eventB.EventID(), "event_id": eventB.EventID(),
RecentFirst: &constFalse, "recent_first": false,
DepthFirst: &constFalse, "depth_first": false,
MaxBreadth: 2, "max_breadth": 2,
Limit: 10, "limit": 10,
}) }))
// Event G gets omitted because of max_breadth // Event G gets omitted because of max_breadth
assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventH.EventID()}) assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventH.EventID()})
}) })
t.Run("caps via max_depth", func(t *testing.T) { t.Run("caps via max_depth", func(t *testing.T) {
body := postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{ body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
EventID: eventB.EventID(), "event_id": eventB.EventID(),
RecentFirst: &constFalse, "recent_first": false,
DepthFirst: &constFalse, "depth_first": false,
MaxDepth: 2, "max_depth": 2,
Limit: 10, "limit": 10,
}) }))
// Event H gets omitted because of max_depth // Event H gets omitted because of max_depth
assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID()}) assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID()})
}) })
t.Run("terminates when reaching the limit", func(t *testing.T) { t.Run("terminates when reaching the limit", func(t *testing.T) {
body := postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{ body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
EventID: eventB.EventID(), "event_id": eventB.EventID(),
RecentFirst: &constFalse, "recent_first": false,
DepthFirst: &constFalse, "depth_first": false,
Limit: 4, "limit": 4,
}) }))
assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID()}) assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID()})
}) })
t.Run("returns all events with a high enough limit", func(t *testing.T) { t.Run("returns all events with a high enough limit", func(t *testing.T) {
body := postRelationships(t, 200, "alice", &msc2836.EventRelationshipRequest{ body := postRelationships(t, 200, "alice", newReq(t, map[string]interface{}{
EventID: eventB.EventID(), "event_id": eventB.EventID(),
RecentFirst: &constFalse, "recent_first": false,
DepthFirst: &constFalse, "depth_first": false,
Limit: 400, "limit": 400,
}) }))
assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID(), eventH.EventID()}) assertContains(t, body, []string{eventB.EventID(), eventC.EventID(), eventD.EventID(), eventE.EventID(), eventF.EventID(), eventG.EventID(), eventH.EventID()})
}) })
} }
@ -357,6 +355,19 @@ func TestMSC2836(t *testing.T) {
// TODO: TestMSC2836UnknownEventsSkipped // TODO: TestMSC2836UnknownEventsSkipped
// TODO: TestMSC2836SkipEventIfNotInRoom // TODO: TestMSC2836SkipEventIfNotInRoom
func newReq(t *testing.T, jsonBody map[string]interface{}) *msc2836.EventRelationshipRequest {
t.Helper()
b, err := json.Marshal(jsonBody)
if err != nil {
t.Fatalf("Failed to marshal request: %s", err)
}
r, err := msc2836.NewEventRelationshipRequest(bytes.NewBuffer(b))
if err != nil {
t.Fatalf("Failed to NewEventRelationshipRequest: %s", err)
}
return r
}
func runServer(t *testing.T, router *mux.Router) func() { func runServer(t *testing.T, router *mux.Router) func() {
t.Helper() t.Helper()
externalServ := &http.Server{ externalServ := &http.Server{
@ -376,6 +387,8 @@ func runServer(t *testing.T, router *mux.Router) func() {
func postRelationships(t *testing.T, expectCode int, accessToken string, req *msc2836.EventRelationshipRequest) *msc2836.EventRelationshipResponse { func postRelationships(t *testing.T, expectCode int, accessToken string, req *msc2836.EventRelationshipRequest) *msc2836.EventRelationshipResponse {
t.Helper() t.Helper()
var r msc2836.EventRelationshipRequest
r.Defaults()
data, err := json.Marshal(req) data, err := json.Marshal(req)
if err != nil { if err != nil {
t.Fatalf("failed to marshal request: %s", err) t.Fatalf("failed to marshal request: %s", err)
@ -484,7 +497,7 @@ func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver
for _, eventID := range req.EventIDs { for _, eventID := range req.EventIDs {
ev := r.events[eventID] ev := r.events[eventID]
if ev != nil { if ev != nil {
res.Events = append(res.Events, *ev) res.Events = append(res.Events, ev)
} }
} }
return nil return nil
@ -521,7 +534,7 @@ func injectEvents(t *testing.T, userAPI userapi.UserInternalAPI, rsAPI roomserve
t.Fatalf("failed to enable MSC2836: %s", err) t.Fatalf("failed to enable MSC2836: %s", err)
} }
for _, ev := range events { for _, ev := range events {
hooks.Run(hooks.KindNewEvent, ev) hooks.Run(hooks.KindNewEventPersisted, ev)
} }
return base.PublicClientAPIMux return base.PublicClientAPIMux
} }
@ -557,5 +570,5 @@ func mustCreateEvent(t *testing.T, ev fledglingEvent) (result *gomatrixserverlib
t.Fatalf("mustCreateEvent: failed to sign event: %s", err) t.Fatalf("mustCreateEvent: failed to sign event: %s", err)
} }
h := signedEvent.Headered(roomVer) h := signedEvent.Headered(roomVer)
return &h return h
} }

View file

@ -62,10 +62,10 @@ func (w *inputWorker) start() {
for { for {
select { select {
case task := <-w.input: case task := <-w.input:
hooks.Run(hooks.KindModifyNewEvent, &task.event.Event) hooks.Run(hooks.KindNewEventReceived, &task.event.Event)
_, task.err = w.r.processRoomEvent(task.ctx, task.event) _, task.err = w.r.processRoomEvent(task.ctx, task.event)
if task.err == nil { if task.err == nil {
hooks.Run(hooks.KindNewEvent, &task.event.Event) hooks.Run(hooks.KindNewEventPersisted, &task.event.Event)
} }
task.wg.Done() task.wg.Done()
case <-time.After(time.Second * 5): case <-time.After(time.Second * 5):