diff --git a/cmd/dendrite-demo-yggdrasil/yggconn/session.go b/cmd/dendrite-demo-yggdrasil/yggconn/session.go index 854d525cb..13bb1daed 100644 --- a/cmd/dendrite-demo-yggdrasil/yggconn/session.go +++ b/cmd/dendrite-demo-yggdrasil/yggconn/session.go @@ -3,6 +3,7 @@ package yggconn import ( "context" "net" + "strings" "time" "github.com/libp2p/go-yamux" @@ -24,9 +25,13 @@ func (n *Node) listenFromYgg() { n.log.Println("n.listener.Accept:", err) return } - session, err := yamux.Server(conn, n.yamuxConfig()) + var session *yamux.Session + if strings.Compare(n.EncryptionPublicKey(), conn.RemoteAddr().String()) < 0 { + session, err = yamux.Client(conn, n.yamuxConfig()) + } else { + session, err = yamux.Server(conn, n.yamuxConfig()) + } if err != nil { - n.log.Println("yamux.Server:", err) return } go n.listenFromYggConn(session) @@ -77,15 +82,19 @@ func (n *Node) DialContext(ctx context.Context, network, address string) (net.Co n.log.Println("n.dialer.DialContext:", err) return nil, err } - session, err = yamux.Client(conn, n.yamuxConfig()) + if strings.Compare(n.EncryptionPublicKey(), address) < 0 { + session, err = yamux.Client(conn, n.yamuxConfig()) + } else { + session, err = yamux.Server(conn, n.yamuxConfig()) + } if err != nil { - n.log.Println("yamux.Client.AcceptStream:", err) return nil, err } go n.listenFromYggConn(session) } st, err := session.OpenStream() if err != nil { + n.log.Println("session.OpenStream:", err) return nil, err } return st, nil