mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-21 05:43:09 -06:00
More reliable QUIC session handling
This commit is contained in:
parent
29708638d5
commit
a662defa8c
|
|
@ -98,7 +98,7 @@ func (m *DendriteMonolith) Start() {
|
||||||
cfg.Database.SyncAPI = config.DataSource(fmt.Sprintf("file:%s/dendrite-syncapi.db", m.StorageDirectory))
|
cfg.Database.SyncAPI = config.DataSource(fmt.Sprintf("file:%s/dendrite-syncapi.db", m.StorageDirectory))
|
||||||
cfg.Database.RoomServer = config.DataSource(fmt.Sprintf("file:%s/dendrite-roomserver.db", m.StorageDirectory))
|
cfg.Database.RoomServer = config.DataSource(fmt.Sprintf("file:%s/dendrite-roomserver.db", m.StorageDirectory))
|
||||||
cfg.Database.ServerKey = config.DataSource(fmt.Sprintf("file:%s/dendrite-serverkey.db", m.StorageDirectory))
|
cfg.Database.ServerKey = config.DataSource(fmt.Sprintf("file:%s/dendrite-serverkey.db", m.StorageDirectory))
|
||||||
cfg.Database.E2EKey = config.DataSource(fmt.Sprintf("file:%s/dendrite-e2ekey.db", m.StorageDirectory))
|
cfg.Database.E2EKey = config.DataSource(fmt.Sprintf("file:%s/dendrite-keyserver.db", m.StorageDirectory))
|
||||||
cfg.Database.FederationSender = config.DataSource(fmt.Sprintf("file:%s/dendrite-federationsender.db", m.StorageDirectory))
|
cfg.Database.FederationSender = config.DataSource(fmt.Sprintf("file:%s/dendrite-federationsender.db", m.StorageDirectory))
|
||||||
cfg.Database.AppService = config.DataSource(fmt.Sprintf("file:%s/dendrite-appservice.db", m.StorageDirectory))
|
cfg.Database.AppService = config.DataSource(fmt.Sprintf("file:%s/dendrite-appservice.db", m.StorageDirectory))
|
||||||
cfg.Database.CurrentState = config.DataSource(fmt.Sprintf("file:%s/dendrite-currentstate.db", m.StorageDirectory))
|
cfg.Database.CurrentState = config.DataSource(fmt.Sprintf("file:%s/dendrite-currentstate.db", m.StorageDirectory))
|
||||||
|
|
|
||||||
|
|
@ -83,7 +83,7 @@ func main() {
|
||||||
cfg.Database.SyncAPI = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName))
|
cfg.Database.SyncAPI = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName))
|
||||||
cfg.Database.RoomServer = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))
|
cfg.Database.RoomServer = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))
|
||||||
cfg.Database.ServerKey = config.DataSource(fmt.Sprintf("file:%s-serverkey.db", *instanceName))
|
cfg.Database.ServerKey = config.DataSource(fmt.Sprintf("file:%s-serverkey.db", *instanceName))
|
||||||
cfg.Database.E2EKey = config.DataSource(fmt.Sprintf("file:%s-e2ekey.db", *instanceName))
|
cfg.Database.E2EKey = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", *instanceName))
|
||||||
cfg.Database.FederationSender = config.DataSource(fmt.Sprintf("file:%s-federationsender.db", *instanceName))
|
cfg.Database.FederationSender = config.DataSource(fmt.Sprintf("file:%s-federationsender.db", *instanceName))
|
||||||
cfg.Database.AppService = config.DataSource(fmt.Sprintf("file:%s-appservice.db", *instanceName))
|
cfg.Database.AppService = config.DataSource(fmt.Sprintf("file:%s-appservice.db", *instanceName))
|
||||||
cfg.Database.CurrentState = config.DataSource(fmt.Sprintf("file:%s-currentstate.db", *instanceName))
|
cfg.Database.CurrentState = config.DataSource(fmt.Sprintf("file:%s-currentstate.db", *instanceName))
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,7 @@ type Node struct {
|
||||||
tlsConfig *tls.Config
|
tlsConfig *tls.Config
|
||||||
quicConfig *quic.Config
|
quicConfig *quic.Config
|
||||||
sessions sync.Map // string -> quic.Session
|
sessions sync.Map // string -> quic.Session
|
||||||
|
coords sync.Map // string -> yggdrasil.Coords
|
||||||
incoming chan QUICStream
|
incoming chan QUICStream
|
||||||
NewSession func(remote gomatrixserverlib.ServerName)
|
NewSession func(remote gomatrixserverlib.ServerName)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,7 @@ import (
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go"
|
"github.com/lucas-clemente/quic-go"
|
||||||
"github.com/yggdrasil-network/yggdrasil-go/src/crypto"
|
"github.com/yggdrasil-network/yggdrasil-go/src/crypto"
|
||||||
|
"github.com/yggdrasil-network/yggdrasil-go/src/yggdrasil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (n *Node) listenFromYgg() {
|
func (n *Node) listenFromYgg() {
|
||||||
|
|
@ -95,55 +96,101 @@ func (n *Node) Dial(network, address string) (net.Conn, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Implements http.Transport.DialContext
|
// Implements http.Transport.DialContext
|
||||||
|
// nolint:gocyclo
|
||||||
func (n *Node) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
func (n *Node) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
s, ok1 := n.sessions.Load(address)
|
s, ok1 := n.sessions.Load(address)
|
||||||
session, ok2 := s.(quic.Session)
|
session, ok2 := s.(quic.Session)
|
||||||
if !ok1 || !ok2 || (ok1 && ok2 && session.ConnectionState().HandshakeComplete) {
|
if !ok1 || !ok2 {
|
||||||
dest, err := hex.DecodeString(address)
|
// First of all, check if we think we know the coords of this
|
||||||
if err != nil {
|
// node. If we do then we'll try to dial to it directly. This
|
||||||
return nil, err
|
// will either succeed or fail.
|
||||||
}
|
if v, ok := n.coords.Load(address); ok {
|
||||||
if len(dest) != crypto.BoxPubKeyLen {
|
coords, ok := v.(yggdrasil.Coords)
|
||||||
return nil, errors.New("invalid key length supplied")
|
if !ok {
|
||||||
}
|
return nil, errors.New("should have found yggdrasil.Coords but didn't")
|
||||||
var pubKey crypto.BoxPubKey
|
}
|
||||||
copy(pubKey[:], dest)
|
n.log.Infof("Coords %s for %q cached, trying to dial", coords.String(), address)
|
||||||
nodeID := crypto.GetNodeID(&pubKey)
|
var err error
|
||||||
nodeMask := &crypto.NodeID{}
|
// We think we know the coords. Try to dial the node.
|
||||||
for i := range nodeMask {
|
if session, err = n.tryDial(address, coords); err != nil {
|
||||||
nodeMask[i] = 0xFF
|
// We thought we knew the coords but it didn't result
|
||||||
|
// in a successful dial. Nuke them from the cache.
|
||||||
|
n.coords.Delete(address)
|
||||||
|
n.log.Infof("Cached coords %s for %q failed", coords.String(), address)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println("Resolving coords")
|
// We either don't know the coords for the node, or we failed
|
||||||
coords, err := n.core.Resolve(nodeID, nodeMask)
|
// to dial it before, in which case try to resolve the coords.
|
||||||
if err != nil {
|
if _, ok := n.coords.Load(address); !ok {
|
||||||
return nil, fmt.Errorf("n.core.Resolve: %w", err)
|
n.log.Infof("Searching for coords for %q", address)
|
||||||
}
|
dest, err := hex.DecodeString(address)
|
||||||
fmt.Println("Found coords:", coords)
|
if err != nil {
|
||||||
fmt.Println("Dialling")
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(dest) != crypto.BoxPubKeyLen {
|
||||||
|
return nil, errors.New("invalid key length supplied")
|
||||||
|
}
|
||||||
|
var pubKey crypto.BoxPubKey
|
||||||
|
copy(pubKey[:], dest)
|
||||||
|
nodeID := crypto.GetNodeID(&pubKey)
|
||||||
|
nodeMask := &crypto.NodeID{}
|
||||||
|
for i := range nodeMask {
|
||||||
|
nodeMask[i] = 0xFF
|
||||||
|
}
|
||||||
|
|
||||||
session, err = quic.Dial(
|
fmt.Println("Resolving coords")
|
||||||
n.core, // yggdrasil.PacketConn
|
coords, err := n.core.Resolve(nodeID, nodeMask)
|
||||||
coords, // dial address
|
if err != nil {
|
||||||
address, // dial SNI
|
return nil, fmt.Errorf("n.core.Resolve: %w", err)
|
||||||
n.tlsConfig, // TLS config
|
}
|
||||||
n.quicConfig, // QUIC config
|
fmt.Println("Found coords:", coords)
|
||||||
)
|
n.coords.Store(address, coords)
|
||||||
if err != nil {
|
|
||||||
n.log.Println("n.dialer.DialContext:", err)
|
// We now know the coords in theory. Let's try dialling the
|
||||||
return nil, err
|
// node again.
|
||||||
|
if session, err = n.tryDial(address, coords); err != nil {
|
||||||
|
return nil, fmt.Errorf("n.tryDial: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
fmt.Println("Dial OK")
|
|
||||||
go n.listenFromQUIC(session, address)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if session == nil {
|
||||||
|
return nil, fmt.Errorf("should have found session but didn't")
|
||||||
|
}
|
||||||
|
|
||||||
st, err := session.OpenStream()
|
st, err := session.OpenStream()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
n.log.Println("session.OpenStream:", err)
|
n.log.Println("session.OpenStream:", err)
|
||||||
|
_ = session.CloseWithError(0, "expected to be able to open session")
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return QUICStream{st, session}, nil
|
return QUICStream{st, session}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (n *Node) tryDial(address string, coords yggdrasil.Coords) (quic.Session, error) {
|
||||||
|
session, err := quic.Dial(
|
||||||
|
n.core, // yggdrasil.PacketConn
|
||||||
|
coords, // dial address
|
||||||
|
address, // dial SNI
|
||||||
|
n.tlsConfig, // TLS config
|
||||||
|
n.quicConfig, // QUIC config
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(session.ConnectionState().PeerCertificates) != 1 {
|
||||||
|
_ = session.CloseWithError(0, "expected a peer certificate")
|
||||||
|
return nil, errors.New("didn't receive a peer certificate")
|
||||||
|
}
|
||||||
|
if gotAddress := session.ConnectionState().PeerCertificates[0].DNSNames[0]; address != gotAddress {
|
||||||
|
_ = session.CloseWithError(0, "you aren't the host I was hoping for")
|
||||||
|
return nil, fmt.Errorf("expected %q but dialled %q", address, gotAddress)
|
||||||
|
}
|
||||||
|
go n.listenFromQUIC(session, address)
|
||||||
|
return session, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (n *Node) generateTLSConfig() *tls.Config {
|
func (n *Node) generateTLSConfig() *tls.Config {
|
||||||
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue