More reliable QUIC session handling

This commit is contained in:
Neil Alexander 2020-08-05 14:47:13 +01:00
parent 29708638d5
commit a662defa8c
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
4 changed files with 83 additions and 35 deletions

View file

@ -98,7 +98,7 @@ func (m *DendriteMonolith) Start() {
cfg.Database.SyncAPI = config.DataSource(fmt.Sprintf("file:%s/dendrite-syncapi.db", m.StorageDirectory)) cfg.Database.SyncAPI = config.DataSource(fmt.Sprintf("file:%s/dendrite-syncapi.db", m.StorageDirectory))
cfg.Database.RoomServer = config.DataSource(fmt.Sprintf("file:%s/dendrite-roomserver.db", m.StorageDirectory)) cfg.Database.RoomServer = config.DataSource(fmt.Sprintf("file:%s/dendrite-roomserver.db", m.StorageDirectory))
cfg.Database.ServerKey = config.DataSource(fmt.Sprintf("file:%s/dendrite-serverkey.db", m.StorageDirectory)) cfg.Database.ServerKey = config.DataSource(fmt.Sprintf("file:%s/dendrite-serverkey.db", m.StorageDirectory))
cfg.Database.E2EKey = config.DataSource(fmt.Sprintf("file:%s/dendrite-e2ekey.db", m.StorageDirectory)) cfg.Database.E2EKey = config.DataSource(fmt.Sprintf("file:%s/dendrite-keyserver.db", m.StorageDirectory))
cfg.Database.FederationSender = config.DataSource(fmt.Sprintf("file:%s/dendrite-federationsender.db", m.StorageDirectory)) cfg.Database.FederationSender = config.DataSource(fmt.Sprintf("file:%s/dendrite-federationsender.db", m.StorageDirectory))
cfg.Database.AppService = config.DataSource(fmt.Sprintf("file:%s/dendrite-appservice.db", m.StorageDirectory)) cfg.Database.AppService = config.DataSource(fmt.Sprintf("file:%s/dendrite-appservice.db", m.StorageDirectory))
cfg.Database.CurrentState = config.DataSource(fmt.Sprintf("file:%s/dendrite-currentstate.db", m.StorageDirectory)) cfg.Database.CurrentState = config.DataSource(fmt.Sprintf("file:%s/dendrite-currentstate.db", m.StorageDirectory))

View file

@ -83,7 +83,7 @@ func main() {
cfg.Database.SyncAPI = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName)) cfg.Database.SyncAPI = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName))
cfg.Database.RoomServer = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName)) cfg.Database.RoomServer = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))
cfg.Database.ServerKey = config.DataSource(fmt.Sprintf("file:%s-serverkey.db", *instanceName)) cfg.Database.ServerKey = config.DataSource(fmt.Sprintf("file:%s-serverkey.db", *instanceName))
cfg.Database.E2EKey = config.DataSource(fmt.Sprintf("file:%s-e2ekey.db", *instanceName)) cfg.Database.E2EKey = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", *instanceName))
cfg.Database.FederationSender = config.DataSource(fmt.Sprintf("file:%s-federationsender.db", *instanceName)) cfg.Database.FederationSender = config.DataSource(fmt.Sprintf("file:%s-federationsender.db", *instanceName))
cfg.Database.AppService = config.DataSource(fmt.Sprintf("file:%s-appservice.db", *instanceName)) cfg.Database.AppService = config.DataSource(fmt.Sprintf("file:%s-appservice.db", *instanceName))
cfg.Database.CurrentState = config.DataSource(fmt.Sprintf("file:%s-currentstate.db", *instanceName)) cfg.Database.CurrentState = config.DataSource(fmt.Sprintf("file:%s-currentstate.db", *instanceName))

View file

@ -50,6 +50,7 @@ type Node struct {
tlsConfig *tls.Config tlsConfig *tls.Config
quicConfig *quic.Config quicConfig *quic.Config
sessions sync.Map // string -> quic.Session sessions sync.Map // string -> quic.Session
coords sync.Map // string -> yggdrasil.Coords
incoming chan QUICStream incoming chan QUICStream
NewSession func(remote gomatrixserverlib.ServerName) NewSession func(remote gomatrixserverlib.ServerName)
} }

View file

@ -31,6 +31,7 @@ import (
"github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go"
"github.com/yggdrasil-network/yggdrasil-go/src/crypto" "github.com/yggdrasil-network/yggdrasil-go/src/crypto"
"github.com/yggdrasil-network/yggdrasil-go/src/yggdrasil"
) )
func (n *Node) listenFromYgg() { func (n *Node) listenFromYgg() {
@ -95,55 +96,101 @@ func (n *Node) Dial(network, address string) (net.Conn, error) {
} }
// Implements http.Transport.DialContext // Implements http.Transport.DialContext
// nolint:gocyclo
func (n *Node) DialContext(ctx context.Context, network, address string) (net.Conn, error) { func (n *Node) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
s, ok1 := n.sessions.Load(address) s, ok1 := n.sessions.Load(address)
session, ok2 := s.(quic.Session) session, ok2 := s.(quic.Session)
if !ok1 || !ok2 || (ok1 && ok2 && session.ConnectionState().HandshakeComplete) { if !ok1 || !ok2 {
dest, err := hex.DecodeString(address) // First of all, check if we think we know the coords of this
if err != nil { // node. If we do then we'll try to dial to it directly. This
return nil, err // will either succeed or fail.
} if v, ok := n.coords.Load(address); ok {
if len(dest) != crypto.BoxPubKeyLen { coords, ok := v.(yggdrasil.Coords)
return nil, errors.New("invalid key length supplied") if !ok {
} return nil, errors.New("should have found yggdrasil.Coords but didn't")
var pubKey crypto.BoxPubKey }
copy(pubKey[:], dest) n.log.Infof("Coords %s for %q cached, trying to dial", coords.String(), address)
nodeID := crypto.GetNodeID(&pubKey) var err error
nodeMask := &crypto.NodeID{} // We think we know the coords. Try to dial the node.
for i := range nodeMask { if session, err = n.tryDial(address, coords); err != nil {
nodeMask[i] = 0xFF // We thought we knew the coords but it didn't result
// in a successful dial. Nuke them from the cache.
n.coords.Delete(address)
n.log.Infof("Cached coords %s for %q failed", coords.String(), address)
}
} }
fmt.Println("Resolving coords") // We either don't know the coords for the node, or we failed
coords, err := n.core.Resolve(nodeID, nodeMask) // to dial it before, in which case try to resolve the coords.
if err != nil { if _, ok := n.coords.Load(address); !ok {
return nil, fmt.Errorf("n.core.Resolve: %w", err) n.log.Infof("Searching for coords for %q", address)
} dest, err := hex.DecodeString(address)
fmt.Println("Found coords:", coords) if err != nil {
fmt.Println("Dialling") return nil, err
}
if len(dest) != crypto.BoxPubKeyLen {
return nil, errors.New("invalid key length supplied")
}
var pubKey crypto.BoxPubKey
copy(pubKey[:], dest)
nodeID := crypto.GetNodeID(&pubKey)
nodeMask := &crypto.NodeID{}
for i := range nodeMask {
nodeMask[i] = 0xFF
}
session, err = quic.Dial( fmt.Println("Resolving coords")
n.core, // yggdrasil.PacketConn coords, err := n.core.Resolve(nodeID, nodeMask)
coords, // dial address if err != nil {
address, // dial SNI return nil, fmt.Errorf("n.core.Resolve: %w", err)
n.tlsConfig, // TLS config }
n.quicConfig, // QUIC config fmt.Println("Found coords:", coords)
) n.coords.Store(address, coords)
if err != nil {
n.log.Println("n.dialer.DialContext:", err) // We now know the coords in theory. Let's try dialling the
return nil, err // node again.
if session, err = n.tryDial(address, coords); err != nil {
return nil, fmt.Errorf("n.tryDial: %w", err)
}
} }
fmt.Println("Dial OK")
go n.listenFromQUIC(session, address)
} }
if session == nil {
return nil, fmt.Errorf("should have found session but didn't")
}
st, err := session.OpenStream() st, err := session.OpenStream()
if err != nil { if err != nil {
n.log.Println("session.OpenStream:", err) n.log.Println("session.OpenStream:", err)
_ = session.CloseWithError(0, "expected to be able to open session")
return nil, err return nil, err
} }
return QUICStream{st, session}, nil return QUICStream{st, session}, nil
} }
func (n *Node) tryDial(address string, coords yggdrasil.Coords) (quic.Session, error) {
session, err := quic.Dial(
n.core, // yggdrasil.PacketConn
coords, // dial address
address, // dial SNI
n.tlsConfig, // TLS config
n.quicConfig, // QUIC config
)
if err != nil {
return nil, err
}
if len(session.ConnectionState().PeerCertificates) != 1 {
_ = session.CloseWithError(0, "expected a peer certificate")
return nil, errors.New("didn't receive a peer certificate")
}
if gotAddress := session.ConnectionState().PeerCertificates[0].DNSNames[0]; address != gotAddress {
_ = session.CloseWithError(0, "you aren't the host I was hoping for")
return nil, fmt.Errorf("expected %q but dialled %q", address, gotAddress)
}
go n.listenFromQUIC(session, address)
return session, nil
}
func (n *Node) generateTLSConfig() *tls.Config { func (n *Node) generateTLSConfig() *tls.Config {
key, err := rsa.GenerateKey(rand.Reader, 1024) key, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil { if err != nil {