From f15a96bec0f9735c56846b355fbaf2edb1c14d4f Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 6 Aug 2020 14:30:56 +0100 Subject: [PATCH] Reset sessions when coordinates change --- cmd/dendrite-demo-yggdrasil/yggconn/node.go | 22 ++++++- .../yggconn/session.go | 57 +++++++++---------- go.mod | 2 +- go.sum | 4 ++ 4 files changed, 52 insertions(+), 33 deletions(-) diff --git a/cmd/dendrite-demo-yggdrasil/yggconn/node.go b/cmd/dendrite-demo-yggdrasil/yggconn/node.go index 120d41f1e..b0fb4d14f 100644 --- a/cmd/dendrite-demo-yggdrasil/yggconn/node.go +++ b/cmd/dendrite-demo-yggdrasil/yggconn/node.go @@ -50,7 +50,7 @@ type Node struct { listener quic.Listener tlsConfig *tls.Config quicConfig *quic.Config - sessions sync.Map // string -> quic.Session + sessions sync.Map // string -> *session sessionCount atomic.Uint32 sessionFunc func(address string) coords sync.Map // string -> yggdrasil.Coords @@ -94,6 +94,24 @@ func Setup(instanceName, storageDirectory string) (*Node, error) { } } + n.core.SetCoordChangeCallback(func(old, new yggdrasil.Coords) { + fmt.Println("COORDINATE CHANGE!") + fmt.Println("Old:", old) + fmt.Println("New:", new) + n.coords.Range(func(k, _ interface{}) bool { + fmt.Println("Deleting cached coords for", k) + n.coords.Delete(k) + return true + }) + n.sessions.Range(func(k, v interface{}) bool { + if s, ok := v.(*session); ok { + fmt.Println("Killing session", k) + s.kill() + } + return true + }) + }) + n.config.Peers = []string{} n.config.AdminListen = "none" n.config.MulticastInterfaces = []string{} @@ -127,7 +145,7 @@ func Setup(instanceName, storageDirectory string) (*Node, error) { MaxIncomingStreams: 0, MaxIncomingUniStreams: 0, KeepAlive: true, - MaxIdleTimeout: time.Minute * 5, + MaxIdleTimeout: time.Minute * 30, HandshakeTimeout: time.Second * 15, } copy(n.quicConfig.StatelessResetKey, n.EncryptionPublicKey()) diff --git a/cmd/dendrite-demo-yggdrasil/yggconn/session.go b/cmd/dendrite-demo-yggdrasil/yggconn/session.go index 087b1a2bb..0cf524d9b 100644 --- a/cmd/dendrite-demo-yggdrasil/yggconn/session.go +++ b/cmd/dendrite-demo-yggdrasil/yggconn/session.go @@ -38,20 +38,23 @@ type session struct { node *Node session quic.Session address string - cancel chan struct{} + context context.Context + cancel context.CancelFunc } func (n *Node) newSession(sess quic.Session, address string) *session { + ctx, cancel := context.WithCancel(context.TODO()) return &session{ node: n, session: sess, address: address, - cancel: make(chan struct{}), + context: ctx, + cancel: cancel, } } func (s *session) kill() { - close(s.cancel) + s.cancel() } func (n *Node) listenFromYgg() { @@ -85,29 +88,22 @@ func (n *Node) listenFromYgg() { func (s *session) listenFromQUIC() { if existing, ok := s.node.sessions.Load(s.address); ok { - if existingSession, ok := existing.(session); ok { + if existingSession, ok := existing.(*session); ok { + fmt.Println("Killing existing session to replace", s.address) existingSession.kill() } } s.node.sessionCount.Inc() - s.node.sessions.Store(s.address, s.session) + s.node.sessions.Store(s.address, s) defer s.node.sessions.Delete(s.address) defer s.node.sessionCount.Dec() for { - select { - case <-s.cancel: - _ = s.session.CloseWithError(0, "killed") + st, err := s.session.AcceptStream(s.context) + if err != nil { + s.node.log.Println("session.AcceptStream:", err) return - default: - ctx, cancel := context.WithTimeout(context.TODO(), s.node.quicConfig.MaxIdleTimeout) - defer cancel() - st, err := s.session.AcceptStream(ctx) - if err != nil { - s.node.log.Println("session.AcceptStream:", err) - return - } - s.node.incoming <- QUICStream{st, s.session} } + s.node.incoming <- QUICStream{st, s.session} } } @@ -135,7 +131,7 @@ func (n *Node) Dial(network, address string) (net.Conn, error) { // nolint:gocyclo func (n *Node) DialContext(ctx context.Context, network, address string) (net.Conn, error) { s, ok1 := n.sessions.Load(address) - session, ok2 := s.(quic.Session) + session, ok2 := s.(*session) if !ok1 || !ok2 { // First of all, check if we think we know the coords of this // node. If we do then we'll try to dial to it directly. This @@ -214,17 +210,17 @@ func (n *Node) DialContext(ctx context.Context, network, address string) (net.Co return nil, fmt.Errorf("should have found session but didn't") } - st, err := session.OpenStream() + st, err := session.session.OpenStream() if err != nil { n.log.Println("session.OpenStream:", err) - _ = session.CloseWithError(0, "expected to be able to open session") + _ = session.session.CloseWithError(0, "expected to be able to open session") return nil, err } - return QUICStream{st, session}, nil + return QUICStream{st, session.session}, nil } -func (n *Node) tryDial(address string, coords yggdrasil.Coords) (quic.Session, error) { - session, err := quic.Dial( +func (n *Node) tryDial(address string, coords yggdrasil.Coords) (*session, error) { + quicSession, err := quic.Dial( n.core, // yggdrasil.PacketConn coords, // dial address address, // dial SNI @@ -234,19 +230,20 @@ func (n *Node) tryDial(address string, coords yggdrasil.Coords) (quic.Session, e if err != nil { return nil, err } - if len(session.ConnectionState().PeerCertificates) != 1 { - _ = session.CloseWithError(0, "expected a peer certificate") + if len(quicSession.ConnectionState().PeerCertificates) != 1 { + _ = quicSession.CloseWithError(0, "expected a peer certificate") return nil, errors.New("didn't receive a peer certificate") } - if len(session.ConnectionState().PeerCertificates[0].DNSNames) != 1 { - _ = session.CloseWithError(0, "expected a DNS name") + if len(quicSession.ConnectionState().PeerCertificates[0].DNSNames) != 1 { + _ = quicSession.CloseWithError(0, "expected a DNS name") return nil, errors.New("didn't receive a DNS name") } - if gotAddress := session.ConnectionState().PeerCertificates[0].DNSNames[0]; address != gotAddress { - _ = session.CloseWithError(0, "you aren't the host I was hoping for") + if gotAddress := quicSession.ConnectionState().PeerCertificates[0].DNSNames[0]; address != gotAddress { + _ = quicSession.CloseWithError(0, "you aren't the host I was hoping for") return nil, fmt.Errorf("expected %q but dialled %q", address, gotAddress) } - go n.newSession(session, address).listenFromQUIC() + session := n.newSession(quicSession, address) + go session.listenFromQUIC() go n.sessionFunc(address) return session, nil } diff --git a/go.mod b/go.mod index 232328733..e367b87ac 100644 --- a/go.mod +++ b/go.mod @@ -36,7 +36,7 @@ require ( github.com/uber-go/atomic v1.3.0 // indirect github.com/uber/jaeger-client-go v2.15.0+incompatible github.com/uber/jaeger-lib v1.5.0 - github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20200715104113-1046b00c3be3 + github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20200806125501-cd4685a3b4de go.uber.org/atomic v1.4.0 golang.org/x/crypto v0.0.0-20200423211502-4bdfaf469ed5 golang.org/x/mobile v0.0.0-20200801112145-973feb4309de // indirect diff --git a/go.sum b/go.sum index 926822063..9aa039091 100644 --- a/go.sum +++ b/go.sum @@ -655,6 +655,10 @@ github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1: github.com/yggdrasil-network/yggdrasil-extras v0.0.0-20200525205615-6c8a4a2e8855/go.mod h1:xQdsh08Io6nV4WRnOVTe6gI8/2iTvfLDQ0CYa5aMt+I= github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20200715104113-1046b00c3be3 h1:teLoIJgPHysREs8P6GlcS/PgaU9W9+GQndikFCQ1lY0= github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20200715104113-1046b00c3be3/go.mod h1:d+Nz6SPeG6kmeSPFL0cvfWfgwEql75fUnZiAONgvyBE= +github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20200806124633-bd1bdd6be073 h1:Fg4Bszd2qp6eyz/yDMYfB8g2PC1FfNQphGRgZAyD0VU= +github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20200806124633-bd1bdd6be073/go.mod h1:d+Nz6SPeG6kmeSPFL0cvfWfgwEql75fUnZiAONgvyBE= +github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20200806125501-cd4685a3b4de h1:p91aw0Mvol825U+5bvV9BBPl+HQxIczj7wxIOxZs70M= +github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20200806125501-cd4685a3b4de/go.mod h1:d+Nz6SPeG6kmeSPFL0cvfWfgwEql75fUnZiAONgvyBE= go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.1/go.mod h1:Ap50jQcDJrx6rB6VgeeFPtuPIf3wMRvRfrfYDO6+BmA=