diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 6290ece0a..02683aeaf 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -102,12 +102,13 @@ func Send( type txnReq struct { gomatrixserverlib.Transaction - rsAPI api.RoomserverInternalAPI - eduAPI eduserverAPI.EDUServerInputAPI - keyAPI keyapi.KeyInternalAPI - keys gomatrixserverlib.JSONVerifier - federation txnFederationClient - servers []gomatrixserverlib.ServerName + rsAPI api.RoomserverInternalAPI + eduAPI eduserverAPI.EDUServerInputAPI + keyAPI keyapi.KeyInternalAPI + keys gomatrixserverlib.JSONVerifier + federation txnFederationClient + servers []gomatrixserverlib.ServerName + serversMutex sync.RWMutex // local cache of events for auth checks, etc - this may include events // which the roomserver is unaware of. haveEvents map[string]*gomatrixserverlib.HeaderedEvent @@ -405,16 +406,21 @@ func (t *txnReq) processDeviceListUpdate(ctx context.Context, e gomatrixserverli } func (t *txnReq) getServers(ctx context.Context, roomID string) []gomatrixserverlib.ServerName { - servers := []gomatrixserverlib.ServerName{t.Origin} + t.serversMutex.Lock() + defer t.serversMutex.Unlock() + if t.servers != nil { + return t.servers + } + t.servers = []gomatrixserverlib.ServerName{t.Origin} serverReq := &api.QueryServerJoinedToRoomRequest{ RoomID: roomID, } serverRes := &api.QueryServerJoinedToRoomResponse{} if err := t.rsAPI.QueryServerJoinedToRoom(ctx, serverReq, serverRes); err == nil { - servers = append(servers, serverRes.ServerNames...) - util.GetLogger(ctx).Infof("Found %d server(s) to query for missing events in %q", len(servers), roomID) + t.servers = append(t.servers, serverRes.ServerNames...) + util.GetLogger(ctx).Infof("Found %d server(s) to query for missing events in %q", len(t.servers), roomID) } - return servers + return t.servers } func (t *txnReq) processEvent(ctx context.Context, e *gomatrixserverlib.Event) error { @@ -486,7 +492,7 @@ func (t *txnReq) retrieveMissingAuthEvents( withNextEvent: for missingAuthEventID := range missingAuthEvents { withNextServer: - for _, server := range t.servers { + for _, server := range t.getServers(ctx, e.RoomID()) { logger.Infof("Retrieving missing auth event %q from %q", missingAuthEventID, server) tx, err := t.federation.GetEvent(ctx, server, missingAuthEventID) if err != nil { @@ -555,7 +561,7 @@ func (t *txnReq) processEventWithMissingState(ctx context.Context, e *gomatrixse // event ids and then use /event to fetch the individual events. // However not all version of synapse support /state_ids so you may // need to fallback to /state. - t.servers = t.getServers(ctx, e.RoomID()) + // Attempt to fill in the gap using /get_missing_events // This will either: // - fill in the gap completely then process event `e` returning no backwards extremity @@ -690,7 +696,7 @@ func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrix } // fetch the event we're missing and add it to the pile - h, err := t.lookupEvent(ctx, roomVersion, eventID, false) + h, err := t.lookupEvent(ctx, roomVersion, roomID, eventID, false) switch err.(type) { case verifySigError: return respState, false, nil @@ -797,7 +803,7 @@ retryAllowedState: if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents); err != nil { switch missing := err.(type) { case gomatrixserverlib.MissingAuthEventError: - h, err2 := t.lookupEvent(ctx, roomVersion, missing.AuthEventID, true) + h, err2 := t.lookupEvent(ctx, roomVersion, backwardsExtremity.RoomID(), missing.AuthEventID, true) switch err2.(type) { case verifySigError: return &gomatrixserverlib.RespState{ @@ -846,7 +852,8 @@ func (t *txnReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Even } var missingResp *gomatrixserverlib.RespMissingEvents - for _, server := range t.servers { + servers := t.getServers(ctx, e.RoomID()) + for _, server := range servers { var m gomatrixserverlib.RespMissingEvents if m, err = t.federation.LookupMissingEvents(ctx, server, e.RoomID(), gomatrixserverlib.MissingEvents{ Limit: 20, @@ -865,7 +872,7 @@ func (t *txnReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Even if missingResp == nil { logger.WithError(err).Errorf( "%s pushed us an event but %d server(s) couldn't give us details about prev_events via /get_missing_events - dropping this event until it can", - t.Origin, len(t.servers), + t.Origin, len(servers), ) return nil, missingPrevEventsError{ eventID: e.EventID(), @@ -1018,7 +1025,7 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even // Define what we'll do in order to fetch the missing event ID. fetch := func(missingEventID string) { var h *gomatrixserverlib.HeaderedEvent - h, err = t.lookupEvent(ctx, roomVersion, missingEventID, false) + h, err = t.lookupEvent(ctx, roomVersion, roomID, missingEventID, false) switch err.(type) { case verifySigError: return @@ -1084,7 +1091,7 @@ func (t *txnReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib.RespStat return &respState, nil } -func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, missingEventID string, localFirst bool) (*gomatrixserverlib.HeaderedEvent, error) { +func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, missingEventID string, localFirst bool) (*gomatrixserverlib.HeaderedEvent, error) { if localFirst { // fetch from the roomserver queryReq := api.QueryEventsByIDRequest{ @@ -1099,7 +1106,8 @@ func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib. } var event *gomatrixserverlib.Event found := false - for _, serverName := range t.servers { + servers := t.getServers(ctx, roomID) + for _, serverName := range servers { txn, err := t.federation.GetEvent(ctx, serverName, missingEventID) if err != nil || len(txn.PDUs) == 0 { util.GetLogger(ctx).WithError(err).WithField("event_id", missingEventID).Warn("Failed to get missing /event for event ID") @@ -1114,8 +1122,8 @@ func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib. break } if !found { - util.GetLogger(ctx).WithField("event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers)) - return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers)) + util.GetLogger(ctx).WithField("event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(servers)) + return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(servers)) } if err := gomatrixserverlib.VerifyAllEventSignatures(ctx, []*gomatrixserverlib.Event{event}, t.keys); err != nil { util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID())