From d9c3df87384c50aaa20047b703ec77f1e18797ae Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 12 Oct 2020 15:34:37 +0100 Subject: [PATCH] Fix test --- federationapi/routing/send.go | 33 ++++++++---------------------- federationapi/routing/send_test.go | 17 +++++++++++++++ 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 2142ead38..d10c9fed8 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -468,27 +468,19 @@ func (t *txnReq) processEventWithMissingState(ctx context.Context, e gomatrixser // - fill in the gap completely then process event `e` returning no backwards extremity // - fail to fill in the gap and tell us to terminate the transaction err=not nil // - fail to fill in the gap and tell us to fetch state at the new backwards extremity, and to not terminate the transaction - newEvents, backwardsExtremity, err := t.getMissingEvents(gmectx, e, roomVersion, true) + newEvents, err := t.getMissingEvents(gmectx, e, roomVersion) if err != nil { return err } - if backwardsExtremity == nil { - // we filled in the gap! - fmt.Println("No backwards extremity") - //return nil - } if len(newEvents) == 0 { - fmt.Println("No new events") return nil } - backwardsExtremity = &newEvents[0] + backwardsExtremity := &newEvents[0] newEvents = newEvents[1:] fmt.Println(len(newEvents), "new events") - fmt.Println("GO!") - // at this point we know we're going to have a gap: we need to work out the room state at the new backwards extremity. // Therefore, we cannot just query /state_ids with this event to get the state before. Instead, we need to query // the state AFTER all the prev_events for this event, then apply state resolution to that to get the state before the event. @@ -509,8 +501,6 @@ func (t *txnReq) processEventWithMissingState(ctx context.Context, e gomatrixser states = append(states, prevState) } - fmt.Println("CHECKPOINT 1") - // Now that we have collected all of the state from the prev_events, we'll // run the state through the appropriate state resolution algorithm for the // room. This does a couple of things: @@ -523,8 +513,6 @@ func (t *txnReq) processEventWithMissingState(ctx context.Context, e gomatrixser return err } - fmt.Println("CHECKPOINT 2") - // First of all, send the backward extremity into the roomserver with the // newly resolved state. This marks the "oldest" point in the backfill and // sets the baseline state for any new events after this. @@ -552,11 +540,10 @@ func (t *txnReq) processEventWithMissingState(ctx context.Context, e gomatrixser if err = api.SendEvents( context.Background(), t.rsAPI, - headeredNewEvents, + append(headeredNewEvents, e.Headered(roomVersion)), api.DoNotSendToOtherServers, nil, ); err != nil { - fmt.Println("ERROR!", err) return fmt.Errorf("api.SendEvents: %w", err) } fmt.Println("SUCCESS!") @@ -717,11 +704,7 @@ retryAllowedState: // This function recursively calls txnReq.processEvent with the missing events, which will be processed before this function returns. // This means that we may recursively call this function, as we spider back up prev_events. // nolint:gocyclo -func (t *txnReq) getMissingEvents(ctx context.Context, e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, isInboundTxn bool) (newEvents []gomatrixserverlib.Event, backwardsExtremity *gomatrixserverlib.Event, err error) { - if !isInboundTxn { - // we've recursed here, so just take a state snapshot please! - return nil, &e, nil - } +func (t *txnReq) getMissingEvents(ctx context.Context, e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (newEvents []gomatrixserverlib.Event, err error) { logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) needed := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{e}) // query latest events (our trusted forward extremities) @@ -732,7 +715,7 @@ func (t *txnReq) getMissingEvents(ctx context.Context, e gomatrixserverlib.Event var res api.QueryLatestEventsAndStateResponse if err = t.rsAPI.QueryLatestEventsAndState(ctx, &req, &res); err != nil { logger.WithError(err).Warn("Failed to query latest events") - return nil, &e, nil + return nil, err } latestEvents := make([]string, len(res.LatestEvents)) for i := range res.LatestEvents { @@ -771,7 +754,7 @@ func (t *txnReq) getMissingEvents(ctx context.Context, e gomatrixserverlib.Event "%s pushed us an event but %d server(s) couldn't give us details about prev_events via /get_missing_events - dropping this event until it can", t.Origin, len(servers), ) - return nil, nil, missingPrevEventsError{ + return nil, missingPrevEventsError{ eventID: e.EventID(), err: err, } @@ -809,14 +792,14 @@ Event: "%s pushed us an event but couldn't give us details about prev_events via /get_missing_events - dropping this event until it can", t.Origin, ) - return nil, nil, missingPrevEventsError{ + return nil, missingPrevEventsError{ eventID: e.EventID(), err: err, } } // we processed everything! - return newEvents, nil, nil + return newEvents, nil } func (t *txnReq) lookupMissingStateViaState(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index ba653c1e8..d7e422479 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -516,6 +516,23 @@ func TestTransactionFetchMissingPrevEvents(t *testing.T) { var rsAPI *testRoomserverAPI // ref here so we can refer to inputRoomEvents inside these functions rsAPI = &testRoomserverAPI{ + queryEventsByID: func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse { + res := api.QueryEventsByIDResponse{} + for _, ev := range testEvents { + for _, id := range req.EventIDs { + if ev.EventID() == id { + res.Events = append(res.Events, ev) + } + } + } + return res + }, + queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse { + return api.QueryStateAfterEventsResponse{ + PrevEventsExist: true, + StateEvents: testEvents[:5], + } + }, queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { missingPrevEvent := []string{"missing_prev_event"} if len(req.PrevEventIDs) == 1 {