mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-21 05:43:09 -06:00
Merge branch 'master' into neilalexander/config
This commit is contained in:
commit
4f7eb37792
|
|
@ -25,7 +25,6 @@ import (
|
|||
"github.com/matrix-org/dendrite/userapi"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/sirupsen/logrus"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
type DendriteMonolith struct {
|
||||
|
|
@ -34,8 +33,6 @@ type DendriteMonolith struct {
|
|||
StorageDirectory string
|
||||
listener net.Listener
|
||||
httpServer *http.Server
|
||||
httpListening atomic.Bool
|
||||
yggListening atomic.Bool
|
||||
}
|
||||
|
||||
func (m *DendriteMonolith) BaseURL() string {
|
||||
|
|
@ -46,6 +43,10 @@ func (m *DendriteMonolith) PeerCount() int {
|
|||
return m.YggdrasilNode.PeerCount()
|
||||
}
|
||||
|
||||
func (m *DendriteMonolith) SessionCount() int {
|
||||
return m.YggdrasilNode.SessionCount()
|
||||
}
|
||||
|
||||
func (m *DendriteMonolith) SetMulticastEnabled(enabled bool) {
|
||||
m.YggdrasilNode.SetMulticastEnabled(enabled)
|
||||
}
|
||||
|
|
@ -98,13 +99,13 @@ func (m *DendriteMonolith) Start() {
|
|||
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-syncapi.db", m.StorageDirectory))
|
||||
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-roomserver.db", m.StorageDirectory))
|
||||
cfg.ServerKeyAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-serverkey.db", m.StorageDirectory))
|
||||
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-e2ekey.db", m.StorageDirectory))
|
||||
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-keyserver.db", m.StorageDirectory))
|
||||
cfg.FederationSender.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-federationsender.db", m.StorageDirectory))
|
||||
cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-appservice.db", m.StorageDirectory))
|
||||
cfg.CurrentStateServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-currentstate.db", m.StorageDirectory))
|
||||
cfg.MediaAPI.BasePath = config.Path(fmt.Sprintf("%s/tmp", m.StorageDirectory))
|
||||
cfg.MediaAPI.AbsBasePath = config.Path(fmt.Sprintf("%s/tmp", m.StorageDirectory))
|
||||
cfg.FederationSender.FederationMaxRetries = 6
|
||||
cfg.FederationSender.FederationMaxRetries = 8
|
||||
if err = cfg.Derive(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
|
@ -136,6 +137,18 @@ func (m *DendriteMonolith) Start() {
|
|||
base, federation, rsAPI, stateAPI, keyRing,
|
||||
)
|
||||
|
||||
ygg.SetSessionFunc(func(address string) {
|
||||
req := &api.PerformServersAliveRequest{
|
||||
Servers: []gomatrixserverlib.ServerName{
|
||||
gomatrixserverlib.ServerName(address),
|
||||
},
|
||||
}
|
||||
res := &api.PerformServersAliveResponse{}
|
||||
if err := fsAPI.PerformServersAlive(context.TODO(), req, res); err != nil {
|
||||
logrus.WithError(err).Error("Failed to send wake-up message to newly connected node")
|
||||
}
|
||||
})
|
||||
|
||||
// The underlying roomserver implementation needs to be able to call the fedsender.
|
||||
// This is different to rsAPI which can be the http client which doesn't need this dependency
|
||||
rsAPI.SetFederationSenderAPI(fsAPI)
|
||||
|
|
@ -175,9 +188,9 @@ func (m *DendriteMonolith) Start() {
|
|||
m.httpServer = &http.Server{
|
||||
Addr: ":0",
|
||||
TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){},
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
IdleTimeout: 60 * time.Second,
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
IdleTimeout: 30 * time.Second,
|
||||
BaseContext: func(_ net.Listener) context.Context {
|
||||
return context.Background()
|
||||
},
|
||||
|
|
|
|||
|
|
@ -78,19 +78,18 @@ func main() {
|
|||
cfg.Global.Kafka.Topics.OutputTypingEvent = "typingServerOutput"
|
||||
cfg.Global.Kafka.Topics.OutputSendToDeviceEvent = "sendToDeviceOutput"
|
||||
cfg.Global.Kafka.Topics.OutputKeyChangeEvent = "keyChangeOutput"
|
||||
cfg.FederationSender.FederationMaxRetries = 6
|
||||
cfg.FederationSender.FederationMaxRetries = 8
|
||||
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName))
|
||||
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName))
|
||||
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName))
|
||||
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName))
|
||||
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))
|
||||
cfg.ServerKeyAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-serverkey.db", *instanceName))
|
||||
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-e2ekey.db", *instanceName))
|
||||
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", *instanceName))
|
||||
cfg.FederationSender.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationsender.db", *instanceName))
|
||||
cfg.AppServiceAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-appservice.db", *instanceName))
|
||||
cfg.CurrentStateServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-currentstate.db", *instanceName))
|
||||
cfg.Global.Kafka.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-naffka.db", *instanceName))
|
||||
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-e2ekey.db", *instanceName))
|
||||
if err = cfg.Derive(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
|
@ -124,6 +123,18 @@ func main() {
|
|||
base, federation, rsAPI, stateAPI, keyRing,
|
||||
)
|
||||
|
||||
ygg.SetSessionFunc(func(address string) {
|
||||
req := &api.PerformServersAliveRequest{
|
||||
Servers: []gomatrixserverlib.ServerName{
|
||||
gomatrixserverlib.ServerName(address),
|
||||
},
|
||||
}
|
||||
res := &api.PerformServersAliveResponse{}
|
||||
if err := fsAPI.PerformServersAlive(context.TODO(), req, res); err != nil {
|
||||
logrus.WithError(err).Error("Failed to send wake-up message to newly connected node")
|
||||
}
|
||||
})
|
||||
|
||||
rsComponent.SetFederationSenderAPI(fsAPI)
|
||||
|
||||
embed.Embed(base.BaseMux, *instancePort, "Yggdrasil Demo")
|
||||
|
|
@ -163,9 +174,9 @@ func main() {
|
|||
httpServer := &http.Server{
|
||||
Addr: ":0",
|
||||
TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){},
|
||||
ReadTimeout: 15 * time.Second,
|
||||
WriteTimeout: 45 * time.Second,
|
||||
IdleTimeout: 60 * time.Second,
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
IdleTimeout: 30 * time.Second,
|
||||
BaseContext: func(_ net.Listener) context.Context {
|
||||
return context.Background()
|
||||
},
|
||||
|
|
|
|||
|
|
@ -24,9 +24,11 @@ func (n *Node) CreateClient(
|
|||
tr.RegisterProtocol(
|
||||
"matrix", &yggroundtripper{
|
||||
inner: &http.Transport{
|
||||
TLSHandshakeTimeout: 20 * time.Second,
|
||||
MaxIdleConns: -1,
|
||||
MaxIdleConnsPerHost: -1,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ResponseHeaderTimeout: 10 * time.Second,
|
||||
IdleConnTimeout: 60 * time.Second,
|
||||
IdleConnTimeout: 30 * time.Second,
|
||||
DialContext: n.DialerContext,
|
||||
},
|
||||
},
|
||||
|
|
@ -41,9 +43,11 @@ func (n *Node) CreateFederationClient(
|
|||
tr.RegisterProtocol(
|
||||
"matrix", &yggroundtripper{
|
||||
inner: &http.Transport{
|
||||
TLSHandshakeTimeout: 20 * time.Second,
|
||||
MaxIdleConns: -1,
|
||||
MaxIdleConnsPerHost: -1,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ResponseHeaderTimeout: 10 * time.Second,
|
||||
IdleConnTimeout: 60 * time.Second,
|
||||
IdleConnTimeout: 30 * time.Second,
|
||||
DialContext: n.DialerContext,
|
||||
TLSClientConfig: n.tlsConfig,
|
||||
},
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ import (
|
|||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/convert"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"go.uber.org/atomic"
|
||||
|
||||
yggdrasilconfig "github.com/yggdrasil-network/yggdrasil-go/src/config"
|
||||
yggdrasilmulticast "github.com/yggdrasil-network/yggdrasil-go/src/multicast"
|
||||
|
|
@ -41,17 +42,20 @@ import (
|
|||
)
|
||||
|
||||
type Node struct {
|
||||
core *yggdrasil.Core
|
||||
config *yggdrasilconfig.NodeConfig
|
||||
state *yggdrasilconfig.NodeState
|
||||
multicast *yggdrasilmulticast.Multicast
|
||||
log *gologme.Logger
|
||||
listener quic.Listener
|
||||
tlsConfig *tls.Config
|
||||
quicConfig *quic.Config
|
||||
sessions sync.Map // string -> quic.Session
|
||||
incoming chan QUICStream
|
||||
NewSession func(remote gomatrixserverlib.ServerName)
|
||||
core *yggdrasil.Core
|
||||
config *yggdrasilconfig.NodeConfig
|
||||
state *yggdrasilconfig.NodeState
|
||||
multicast *yggdrasilmulticast.Multicast
|
||||
log *gologme.Logger
|
||||
listener quic.Listener
|
||||
tlsConfig *tls.Config
|
||||
quicConfig *quic.Config
|
||||
sessions sync.Map // string -> *session
|
||||
sessionCount atomic.Uint32
|
||||
sessionFunc func(address string)
|
||||
coords sync.Map // string -> yggdrasil.Coords
|
||||
incoming chan QUICStream
|
||||
NewSession func(remote gomatrixserverlib.ServerName)
|
||||
}
|
||||
|
||||
func (n *Node) Dialer(_, address string) (net.Conn, error) {
|
||||
|
|
@ -90,6 +94,19 @@ func Setup(instanceName, storageDirectory string) (*Node, error) {
|
|||
}
|
||||
}
|
||||
|
||||
n.core.SetCoordChangeCallback(func(old, new yggdrasil.Coords) {
|
||||
fmt.Println("COORDINATE CHANGE!")
|
||||
fmt.Println("Old:", old)
|
||||
fmt.Println("New:", new)
|
||||
n.sessions.Range(func(k, v interface{}) bool {
|
||||
if s, ok := v.(*session); ok {
|
||||
fmt.Println("Killing session", k)
|
||||
s.kill()
|
||||
}
|
||||
return true
|
||||
})
|
||||
})
|
||||
|
||||
n.config.Peers = []string{}
|
||||
n.config.AdminListen = "none"
|
||||
n.config.MulticastInterfaces = []string{}
|
||||
|
|
@ -124,8 +141,9 @@ func Setup(instanceName, storageDirectory string) (*Node, error) {
|
|||
MaxIncomingUniStreams: 0,
|
||||
KeepAlive: true,
|
||||
MaxIdleTimeout: time.Minute * 30,
|
||||
HandshakeTimeout: time.Second * 30,
|
||||
HandshakeTimeout: time.Second * 15,
|
||||
}
|
||||
copy(n.quicConfig.StatelessResetKey, n.EncryptionPublicKey())
|
||||
|
||||
n.log.Println("Public curve25519:", n.core.EncryptionPublicKey())
|
||||
n.log.Println("Public ed25519:", n.core.SigningPublicKey())
|
||||
|
|
@ -173,17 +191,25 @@ func (n *Node) SigningPrivateKey() ed25519.PrivateKey {
|
|||
return ed25519.PrivateKey(privBytes)
|
||||
}
|
||||
|
||||
func (n *Node) SetSessionFunc(f func(address string)) {
|
||||
n.sessionFunc = f
|
||||
}
|
||||
|
||||
func (n *Node) PeerCount() int {
|
||||
return len(n.core.GetPeers()) - 1
|
||||
}
|
||||
|
||||
func (n *Node) SessionCount() int {
|
||||
return int(n.sessionCount.Load())
|
||||
}
|
||||
|
||||
func (n *Node) KnownNodes() []gomatrixserverlib.ServerName {
|
||||
nodemap := map[string]struct{}{
|
||||
"b5ae50589e50991dd9dd7d59c5c5f7a4521e8da5b603b7f57076272abc58b374": struct{}{},
|
||||
"b5ae50589e50991dd9dd7d59c5c5f7a4521e8da5b603b7f57076272abc58b374": {},
|
||||
}
|
||||
/*
|
||||
for _, peer := range n.core.GetSwitchPeers() {
|
||||
nodemap[hex.EncodeToString(peer.SigningKey[:])] = struct{}{}
|
||||
nodemap[hex.EncodeToString(peer.PublicKey[:])] = struct{}{}
|
||||
}
|
||||
*/
|
||||
n.sessions.Range(func(_, v interface{}) bool {
|
||||
|
|
|
|||
|
|
@ -31,8 +31,32 @@ import (
|
|||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/yggdrasil-network/yggdrasil-go/src/crypto"
|
||||
"github.com/yggdrasil-network/yggdrasil-go/src/yggdrasil"
|
||||
)
|
||||
|
||||
type session struct {
|
||||
node *Node
|
||||
session quic.Session
|
||||
address string
|
||||
context context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func (n *Node) newSession(sess quic.Session, address string) *session {
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
return &session{
|
||||
node: n,
|
||||
session: sess,
|
||||
address: address,
|
||||
context: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) kill() {
|
||||
s.cancel()
|
||||
}
|
||||
|
||||
func (n *Node) listenFromYgg() {
|
||||
var err error
|
||||
n.listener, err = quic.Listen(
|
||||
|
|
@ -55,22 +79,31 @@ func (n *Node) listenFromYgg() {
|
|||
_ = session.CloseWithError(0, "expected a peer certificate")
|
||||
continue
|
||||
}
|
||||
address := session.ConnectionState().PeerCertificates[0].Subject.CommonName
|
||||
address := session.ConnectionState().PeerCertificates[0].DNSNames[0]
|
||||
n.log.Infoln("Accepted connection from", address)
|
||||
go n.listenFromQUIC(session, address)
|
||||
go n.newSession(session, address).listenFromQUIC()
|
||||
go n.sessionFunc(address)
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Node) listenFromQUIC(session quic.Session, address string) {
|
||||
n.sessions.Store(address, session)
|
||||
defer n.sessions.Delete(address)
|
||||
func (s *session) listenFromQUIC() {
|
||||
if existing, ok := s.node.sessions.Load(s.address); ok {
|
||||
if existingSession, ok := existing.(*session); ok {
|
||||
fmt.Println("Killing existing session to replace", s.address)
|
||||
existingSession.kill()
|
||||
}
|
||||
}
|
||||
s.node.sessionCount.Inc()
|
||||
s.node.sessions.Store(s.address, s)
|
||||
defer s.node.sessions.Delete(s.address)
|
||||
defer s.node.sessionCount.Dec()
|
||||
for {
|
||||
st, err := session.AcceptStream(context.TODO())
|
||||
st, err := s.session.AcceptStream(s.context)
|
||||
if err != nil {
|
||||
n.log.Println("session.AcceptStream:", err)
|
||||
s.node.log.Println("session.AcceptStream:", err)
|
||||
return
|
||||
}
|
||||
n.incoming <- QUICStream{st, session}
|
||||
s.node.incoming <- QUICStream{st, s.session}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -95,53 +128,124 @@ func (n *Node) Dial(network, address string) (net.Conn, error) {
|
|||
}
|
||||
|
||||
// Implements http.Transport.DialContext
|
||||
// nolint:gocyclo
|
||||
func (n *Node) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
s, ok1 := n.sessions.Load(address)
|
||||
session, ok2 := s.(quic.Session)
|
||||
if !ok1 || !ok2 || (ok1 && ok2 && session.ConnectionState().HandshakeComplete) {
|
||||
dest, err := hex.DecodeString(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(dest) != crypto.BoxPubKeyLen {
|
||||
return nil, errors.New("invalid key length supplied")
|
||||
}
|
||||
var pubKey crypto.BoxPubKey
|
||||
copy(pubKey[:], dest)
|
||||
nodeID := crypto.GetNodeID(&pubKey)
|
||||
nodeMask := &crypto.NodeID{}
|
||||
for i := range nodeMask {
|
||||
nodeMask[i] = 0xFF
|
||||
session, ok2 := s.(*session)
|
||||
if !ok1 || !ok2 {
|
||||
// First of all, check if we think we know the coords of this
|
||||
// node. If we do then we'll try to dial to it directly. This
|
||||
// will either succeed or fail.
|
||||
if v, ok := n.coords.Load(address); ok {
|
||||
coords, ok := v.(yggdrasil.Coords)
|
||||
if !ok {
|
||||
n.coords.Delete(address)
|
||||
return nil, errors.New("should have found yggdrasil.Coords but didn't")
|
||||
}
|
||||
n.log.Infof("Coords %s for %q cached, trying to dial", coords.String(), address)
|
||||
var err error
|
||||
// We think we know the coords. Try to dial the node.
|
||||
if session, err = n.tryDial(address, coords); err != nil {
|
||||
// We thought we knew the coords but it didn't result
|
||||
// in a successful dial. Nuke them from the cache.
|
||||
n.coords.Delete(address)
|
||||
n.log.Infof("Cached coords %s for %q failed", coords.String(), address)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("Resolving coords")
|
||||
coords, err := n.core.Resolve(nodeID, nodeMask)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("n.core.Resolve: %w", err)
|
||||
}
|
||||
fmt.Println("Found coords:", coords)
|
||||
fmt.Println("Dialling")
|
||||
// We either don't know the coords for the node, or we failed
|
||||
// to dial it before, in which case try to resolve the coords.
|
||||
if _, ok := n.coords.Load(address); !ok {
|
||||
var coords yggdrasil.Coords
|
||||
var err error
|
||||
|
||||
session, err = quic.Dial(
|
||||
n.core, // yggdrasil.PacketConn
|
||||
coords, // dial address
|
||||
address, // dial SNI
|
||||
n.tlsConfig, // TLS config
|
||||
n.quicConfig, // QUIC config
|
||||
)
|
||||
if err != nil {
|
||||
n.log.Println("n.dialer.DialContext:", err)
|
||||
return nil, err
|
||||
// First look and see if the node is something that we already
|
||||
// know about from our direct switch peers.
|
||||
for _, peer := range n.core.GetSwitchPeers() {
|
||||
if peer.PublicKey.String() == address {
|
||||
coords = peer.Coords
|
||||
n.log.Infof("%q is a direct peer, coords are %s", address, coords.String())
|
||||
n.coords.Store(address, coords)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If it isn' a node that we know directly then try to search
|
||||
// the network.
|
||||
if coords == nil {
|
||||
n.log.Infof("Searching for coords for %q", address)
|
||||
dest, derr := hex.DecodeString(address)
|
||||
if derr != nil {
|
||||
return nil, derr
|
||||
}
|
||||
if len(dest) != crypto.BoxPubKeyLen {
|
||||
return nil, errors.New("invalid key length supplied")
|
||||
}
|
||||
var pubKey crypto.BoxPubKey
|
||||
copy(pubKey[:], dest)
|
||||
nodeID := crypto.GetNodeID(&pubKey)
|
||||
nodeMask := &crypto.NodeID{}
|
||||
for i := range nodeMask {
|
||||
nodeMask[i] = 0xFF
|
||||
}
|
||||
|
||||
fmt.Println("Resolving coords")
|
||||
coords, err = n.core.Resolve(nodeID, nodeMask)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("n.core.Resolve: %w", err)
|
||||
}
|
||||
fmt.Println("Found coords:", coords)
|
||||
n.coords.Store(address, coords)
|
||||
}
|
||||
|
||||
// We now know the coords in theory. Let's try dialling the
|
||||
// node again.
|
||||
if session, err = n.tryDial(address, coords); err != nil {
|
||||
return nil, fmt.Errorf("n.tryDial: %w", err)
|
||||
}
|
||||
}
|
||||
fmt.Println("Dial OK")
|
||||
go n.listenFromQUIC(session, address)
|
||||
}
|
||||
st, err := session.OpenStream()
|
||||
|
||||
if session == nil {
|
||||
return nil, fmt.Errorf("should have found session but didn't")
|
||||
}
|
||||
|
||||
st, err := session.session.OpenStream()
|
||||
if err != nil {
|
||||
n.log.Println("session.OpenStream:", err)
|
||||
_ = session.session.CloseWithError(0, "expected to be able to open session")
|
||||
return nil, err
|
||||
}
|
||||
return QUICStream{st, session}, nil
|
||||
return QUICStream{st, session.session}, nil
|
||||
}
|
||||
|
||||
func (n *Node) tryDial(address string, coords yggdrasil.Coords) (*session, error) {
|
||||
quicSession, err := quic.Dial(
|
||||
n.core, // yggdrasil.PacketConn
|
||||
coords, // dial address
|
||||
address, // dial SNI
|
||||
n.tlsConfig, // TLS config
|
||||
n.quicConfig, // QUIC config
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(quicSession.ConnectionState().PeerCertificates) != 1 {
|
||||
_ = quicSession.CloseWithError(0, "expected a peer certificate")
|
||||
return nil, errors.New("didn't receive a peer certificate")
|
||||
}
|
||||
if len(quicSession.ConnectionState().PeerCertificates[0].DNSNames) != 1 {
|
||||
_ = quicSession.CloseWithError(0, "expected a DNS name")
|
||||
return nil, errors.New("didn't receive a DNS name")
|
||||
}
|
||||
if gotAddress := quicSession.ConnectionState().PeerCertificates[0].DNSNames[0]; address != gotAddress {
|
||||
_ = quicSession.CloseWithError(0, "you aren't the host I was hoping for")
|
||||
return nil, fmt.Errorf("expected %q but dialled %q", address, gotAddress)
|
||||
}
|
||||
session := n.newSession(quicSession, address)
|
||||
go session.listenFromQUIC()
|
||||
go n.sessionFunc(address)
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (n *Node) generateTLSConfig() *tls.Config {
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ func Setup(
|
|||
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
|
||||
return Send(
|
||||
httpReq, request, gomatrixserverlib.TransactionID(vars["txnID"]),
|
||||
cfg, rsAPI, eduAPI, keys, federation,
|
||||
cfg, rsAPI, eduAPI, keyAPI, keys, federation,
|
||||
)
|
||||
},
|
||||
)).Methods(http.MethodPut, http.MethodOptions)
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
eduserverAPI "github.com/matrix-org/dendrite/eduserver/api"
|
||||
"github.com/matrix-org/dendrite/internal/config"
|
||||
keyapi "github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
|
|
@ -37,6 +38,7 @@ func Send(
|
|||
cfg *config.FederationAPI,
|
||||
rsAPI api.RoomserverInternalAPI,
|
||||
eduAPI eduserverAPI.EDUServerInputAPI,
|
||||
keyAPI keyapi.KeyInternalAPI,
|
||||
keys gomatrixserverlib.JSONVerifier,
|
||||
federation *gomatrixserverlib.FederationClient,
|
||||
) util.JSONResponse {
|
||||
|
|
@ -48,6 +50,7 @@ func Send(
|
|||
federation: federation,
|
||||
haveEvents: make(map[string]*gomatrixserverlib.HeaderedEvent),
|
||||
newEvents: make(map[string]bool),
|
||||
keyAPI: keyAPI,
|
||||
}
|
||||
|
||||
var txnEvents struct {
|
||||
|
|
@ -100,6 +103,7 @@ type txnReq struct {
|
|||
context context.Context
|
||||
rsAPI api.RoomserverInternalAPI
|
||||
eduAPI eduserverAPI.EDUServerInputAPI
|
||||
keyAPI keyapi.KeyInternalAPI
|
||||
keys gomatrixserverlib.JSONVerifier
|
||||
federation txnFederationClient
|
||||
// local cache of events for auth checks, etc - this may include events
|
||||
|
|
@ -308,12 +312,29 @@ func (t *txnReq) processEDUs(edus []gomatrixserverlib.EDU) {
|
|||
}
|
||||
}
|
||||
}
|
||||
case gomatrixserverlib.MDeviceListUpdate:
|
||||
t.processDeviceListUpdate(e)
|
||||
default:
|
||||
util.GetLogger(t.context).WithField("type", e.Type).Warn("unhandled edu")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *txnReq) processDeviceListUpdate(e gomatrixserverlib.EDU) {
|
||||
var payload gomatrixserverlib.DeviceListUpdateEvent
|
||||
if err := json.Unmarshal(e.Content, &payload); err != nil {
|
||||
util.GetLogger(t.context).WithError(err).Error("Failed to unmarshal device list update event")
|
||||
return
|
||||
}
|
||||
var inputRes keyapi.InputDeviceListUpdateResponse
|
||||
t.keyAPI.InputDeviceListUpdate(context.Background(), &keyapi.InputDeviceListUpdateRequest{
|
||||
Event: payload,
|
||||
}, &inputRes)
|
||||
if inputRes.Error != nil {
|
||||
util.GetLogger(t.context).WithError(inputRes.Error).WithField("user_id", payload.UserID).Error("failed to InputDeviceListUpdate")
|
||||
}
|
||||
}
|
||||
|
||||
func (t *txnReq) processEvent(e gomatrixserverlib.Event, isInboundTxn bool) error {
|
||||
prevEventIDs := e.PrevEventIDs()
|
||||
|
||||
|
|
|
|||
|
|
@ -319,6 +319,11 @@ func (r *FederationSenderInternalAPI) PerformBroadcastEDU(
|
|||
if err != nil {
|
||||
return fmt.Errorf("r.db.GetAllJoinedHosts: %w", err)
|
||||
}
|
||||
if len(destinations) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
logrus.WithContext(ctx).Infof("Sending wake-up EDU to %d destination(s)", len(destinations))
|
||||
|
||||
edu := &gomatrixserverlib.EDU{
|
||||
Type: "org.matrix.dendrite.wakeup",
|
||||
|
|
@ -328,5 +333,13 @@ func (r *FederationSenderInternalAPI) PerformBroadcastEDU(
|
|||
return fmt.Errorf("r.queues.SendEDU: %w", err)
|
||||
}
|
||||
|
||||
wakeReq := &api.PerformServersAliveRequest{
|
||||
Servers: destinations,
|
||||
}
|
||||
wakeRes := &api.PerformServersAliveResponse{}
|
||||
if err := r.PerformServersAlive(ctx, wakeReq, wakeRes); err != nil {
|
||||
return fmt.Errorf("r.PerformServersAlive: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
6
go.mod
6
go.mod
|
|
@ -18,12 +18,12 @@ require (
|
|||
github.com/libp2p/go-libp2p-pubsub v0.2.5
|
||||
github.com/libp2p/go-libp2p-record v0.1.2
|
||||
github.com/libp2p/go-yamux v1.3.7 // indirect
|
||||
github.com/lucas-clemente/quic-go v0.17.2
|
||||
github.com/lucas-clemente/quic-go v0.17.3
|
||||
github.com/matrix-org/dugong v0.0.0-20171220115018-ea0a4690a0d5
|
||||
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
|
||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3
|
||||
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200804124807-5012a626de1d
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200807122736-eb1a0b991914
|
||||
github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f
|
||||
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7
|
||||
github.com/mattn/go-sqlite3 v2.0.2+incompatible
|
||||
|
|
@ -38,7 +38,7 @@ require (
|
|||
github.com/uber-go/atomic v1.3.0 // indirect
|
||||
github.com/uber/jaeger-client-go v2.15.0+incompatible
|
||||
github.com/uber/jaeger-lib v1.5.0
|
||||
github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20200715104113-1046b00c3be3
|
||||
github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20200806125501-cd4685a3b4de
|
||||
go.uber.org/atomic v1.4.0
|
||||
golang.org/x/crypto v0.0.0-20200423211502-4bdfaf469ed5
|
||||
gopkg.in/h2non/bimg.v1 v1.0.18
|
||||
|
|
|
|||
22
go.sum
22
go.sum
|
|
@ -403,8 +403,8 @@ github.com/libp2p/go-yamux v1.3.0 h1:FsYzT16Wq2XqUGJsBbOxoz9g+dFklvNi7jN6YFPfl7U
|
|||
github.com/libp2p/go-yamux v1.3.0/go.mod h1:FGTiPvoV/3DVdgWpX+tM0OW3tsM+W5bSE3gZwqQTcow=
|
||||
github.com/libp2p/go-yamux v1.3.7 h1:v40A1eSPJDIZwz2AvrV3cxpTZEGDP11QJbukmEhYyQI=
|
||||
github.com/libp2p/go-yamux v1.3.7/go.mod h1:fr7aVgmdNGJK+N1g+b6DW6VxzbRCjCOejR/hkmpooHE=
|
||||
github.com/lucas-clemente/quic-go v0.17.2 h1:4iQInIuNQkPNZmsy9rCnwuOzpH0qGnDo4jn0QfI/qE4=
|
||||
github.com/lucas-clemente/quic-go v0.17.2/go.mod h1:I0+fcNTdb9eS1ZcjQZbDVPGchJ86chcIxPALn9lEJqE=
|
||||
github.com/lucas-clemente/quic-go v0.17.3 h1:jMX/MmDNCljfisgMmPGUcBJ+zUh9w3d3ia4YJjYS3TM=
|
||||
github.com/lucas-clemente/quic-go v0.17.3/go.mod h1:I0+fcNTdb9eS1ZcjQZbDVPGchJ86chcIxPALn9lEJqE=
|
||||
github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI=
|
||||
github.com/lxn/walk v0.0.0-20191128110447-55ccb3a9f5c1/go.mod h1:E23UucZGqpuUANJooIbHWCufXvOcT6E7Stq81gU+CSQ=
|
||||
github.com/lxn/win v0.0.0-20191128105842-2da648fda5b4/go.mod h1:ouWl4wViUNh8tPSIwxTVMuS014WakR1hqvBc2I0bMoA=
|
||||
|
|
@ -425,18 +425,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3 h1:Yb+Wlf
|
|||
github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
|
||||
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bhrnp3Ky1qgx/fzCtCALOoGYylh2tpS9K4=
|
||||
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200721145051-cea6eafced2b h1:ul/Jc5q5+QBHNvhd9idfglOwyGf/Tc3ittINEbKJPsQ=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200721145051-cea6eafced2b/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200722124340-16fba816840d h1:WZXyd8YI+PQIDYjN8HxtqNRJ1DCckt9wPTi2P8cdnKM=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200722124340-16fba816840d/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200803165250-352235625587 h1:n2IZkm5LI4lACulOa5WU6QwWUhHUtBZez7YIFr1fCOs=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200803165250-352235625587/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200803165739-3bd1ef0f0852 h1:OBvHjLWaT2KS9kGarX2ES0yKBL/wMxAeQB39tRrAAls=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200803165739-3bd1ef0f0852/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200804110046-7abbc2918807 h1:ufr+e2FBDuxcO5t/7PMfoiQoma4uyYzS/sLuJSR6tng=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200804110046-7abbc2918807/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200804124807-5012a626de1d h1:zYk/bQ5bmHDsRqHBl57aBxo5bizsknWU3sunZf9WnWI=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200804124807-5012a626de1d/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200807122736-eb1a0b991914 h1:VSGCvSUB1/Y32F/JSjmTaIW9jr1BmBHEd0ok4AaT/lo=
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200807122736-eb1a0b991914/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
|
||||
github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f h1:pRz4VTiRCO4zPlEMc3ESdUOcW4PXHH4Kj+YDz1XyE+Y=
|
||||
github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f/go.mod h1:y0oDTjZDv5SM9a2rp3bl+CU+bvTRINQsdb7YlDql5Go=
|
||||
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo=
|
||||
|
|
@ -666,8 +656,8 @@ github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c/go.mod h1:lB8K/P019DLNhe
|
|||
github.com/xdg/stringprep v1.0.0/go.mod h1:Jhud4/sHMO4oL310DaZAKk9ZaJ08SJfe+sJh0HrGL1Y=
|
||||
github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q=
|
||||
github.com/yggdrasil-network/yggdrasil-extras v0.0.0-20200525205615-6c8a4a2e8855/go.mod h1:xQdsh08Io6nV4WRnOVTe6gI8/2iTvfLDQ0CYa5aMt+I=
|
||||
github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20200715104113-1046b00c3be3 h1:teLoIJgPHysREs8P6GlcS/PgaU9W9+GQndikFCQ1lY0=
|
||||
github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20200715104113-1046b00c3be3/go.mod h1:d+Nz6SPeG6kmeSPFL0cvfWfgwEql75fUnZiAONgvyBE=
|
||||
github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20200806125501-cd4685a3b4de h1:p91aw0Mvol825U+5bvV9BBPl+HQxIczj7wxIOxZs70M=
|
||||
github.com/yggdrasil-network/yggdrasil-go v0.3.15-0.20200806125501-cd4685a3b4de/go.mod h1:d+Nz6SPeG6kmeSPFL0cvfWfgwEql75fUnZiAONgvyBE=
|
||||
go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA=
|
||||
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
|
||||
go.opencensus.io v0.22.1/go.mod h1:Ap50jQcDJrx6rB6VgeeFPtuPIf3wMRvRfrfYDO6+BmA=
|
||||
|
|
|
|||
|
|
@ -21,11 +21,14 @@ import (
|
|||
"time"
|
||||
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
type KeyInternalAPI interface {
|
||||
// SetUserAPI assigns a user API to query when extracting device names.
|
||||
SetUserAPI(i userapi.UserInternalAPI)
|
||||
// InputDeviceListUpdate from a federated server EDU
|
||||
InputDeviceListUpdate(ctx context.Context, req *InputDeviceListUpdateRequest, res *InputDeviceListUpdateResponse)
|
||||
PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse)
|
||||
// PerformClaimKeys claims one-time keys for use in pre-key messages
|
||||
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse)
|
||||
|
|
@ -200,3 +203,11 @@ type QueryDeviceMessagesResponse struct {
|
|||
Devices []DeviceMessage
|
||||
Error *KeyError
|
||||
}
|
||||
|
||||
type InputDeviceListUpdateRequest struct {
|
||||
Event gomatrixserverlib.DeviceListUpdateEvent
|
||||
}
|
||||
|
||||
type InputDeviceListUpdateResponse struct {
|
||||
Error *KeyError
|
||||
}
|
||||
|
|
|
|||
298
keyserver/internal/device_list_update.go
Normal file
298
keyserver/internal/device_list_update.go
Normal file
|
|
@ -0,0 +1,298 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/keyserver/producers"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// DeviceListUpdater handles device list updates from remote servers.
|
||||
//
|
||||
// In the case where we have the prev_id for an update, the updater just stores the update (after acquiring a per-user lock).
|
||||
// In the case where we do not have the prev_id for an update, the updater marks the user_id as stale and notifies
|
||||
// a worker to get the latest device list for this user. Note: stream IDs are scoped per user so missing a prev_id
|
||||
// for a (user, device) does not mean that DEVICE is outdated as the previous ID could be for a different device:
|
||||
// we have to invalidate all devices for that user. Once the list has been fetched, the per-user lock is acquired and the
|
||||
// updater stores the latest list along with the latest stream ID.
|
||||
//
|
||||
// On startup, the updater spins up N workers which are responsible for querying device keys from remote servers.
|
||||
// Workers are scoped by homeserver domain, with one worker responsible for many domains, determined by hashing
|
||||
// mod N the server name. Work is sent via a channel which just serves to "poke" the worker as the data is retrieved
|
||||
// from the database (which allows us to batch requests to the same server). This has a number of desirable properties:
|
||||
// - We guarantee only 1 in-flight /keys/query request per server at any time as there is exactly 1 worker responsible
|
||||
// for that domain.
|
||||
// - We don't have unbounded growth in proportion to the number of servers (this is more important in a P2P world where
|
||||
// we have many many servers)
|
||||
// - We can adjust concurrency (at the cost of memory usage) by tuning N, to accommodate mobile devices vs servers.
|
||||
// The downsides are that:
|
||||
// - Query requests can get queued behind other servers if they hash to the same worker, even if there are other free
|
||||
// workers elsewhere. Whilst suboptimal, provided we cap how long a single request can last (e.g using context timeouts)
|
||||
// we guarantee we will get around to it. Also, more users on a given server does not increase the number of requests
|
||||
// (as /keys/query allows multiple users to be specified) so being stuck behind matrix.org won't materially be any worse
|
||||
// than being stuck behind foo.bar
|
||||
// In the event that the query fails, the worker spins up a short-lived goroutine whose sole purpose is to inject the server
|
||||
// name back into the channel after a certain amount of time. If in the interim the device lists have been updated, then
|
||||
// the database query will return no stale lists. Reinjection into the channel continues until success or the server terminates,
|
||||
// when it will be reloaded on startup.
|
||||
type DeviceListUpdater struct {
|
||||
// A map from user_id to a mutex. Used when we are missing prev IDs so we don't make more than 1
|
||||
// request to the remote server and race.
|
||||
// TODO: Put in an LRU cache to bound growth
|
||||
userIDToMutex map[string]*sync.Mutex
|
||||
mu *sync.Mutex // protects UserIDToMutex
|
||||
|
||||
db DeviceListUpdaterDatabase
|
||||
producer *producers.KeyChange
|
||||
fedClient *gomatrixserverlib.FederationClient
|
||||
workerChans []chan gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
// DeviceListUpdaterDatabase is the subset of functionality from storage.Database required for the updater.
|
||||
// Useful for testing.
|
||||
type DeviceListUpdaterDatabase interface {
|
||||
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
|
||||
// If no domains are given, all user IDs with stale device lists are returned.
|
||||
StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
|
||||
|
||||
// MarkDeviceListStale sets the stale bit for this user to isStale.
|
||||
MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error
|
||||
|
||||
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
||||
// for this (user, device). Does not modify the stream ID for keys.
|
||||
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
|
||||
|
||||
// PrevIDsExists returns true if all prev IDs exist for this user.
|
||||
PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error)
|
||||
}
|
||||
|
||||
// NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale.
|
||||
func NewDeviceListUpdater(
|
||||
db DeviceListUpdaterDatabase, producer *producers.KeyChange, fedClient *gomatrixserverlib.FederationClient,
|
||||
numWorkers int,
|
||||
) *DeviceListUpdater {
|
||||
return &DeviceListUpdater{
|
||||
userIDToMutex: make(map[string]*sync.Mutex),
|
||||
mu: &sync.Mutex{},
|
||||
db: db,
|
||||
producer: producer,
|
||||
fedClient: fedClient,
|
||||
workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers),
|
||||
}
|
||||
}
|
||||
|
||||
// Start the device list updater, which will try to refresh any stale device lists.
|
||||
func (u *DeviceListUpdater) Start() error {
|
||||
for i := 0; i < len(u.workerChans); i++ {
|
||||
// Allocate a small buffer per channel.
|
||||
// If the buffer limit is reached, backpressure will cause the processing of EDUs
|
||||
// to stop (in this transaction) until key requests can be made.
|
||||
ch := make(chan gomatrixserverlib.ServerName, 10)
|
||||
u.workerChans[i] = ch
|
||||
go u.worker(ch)
|
||||
}
|
||||
|
||||
staleLists, err := u.db.StaleDeviceLists(context.Background(), []gomatrixserverlib.ServerName{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, userID := range staleLists {
|
||||
u.notifyWorkers(userID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
if u.userIDToMutex[userID] == nil {
|
||||
u.userIDToMutex[userID] = &sync.Mutex{}
|
||||
}
|
||||
return u.userIDToMutex[userID]
|
||||
}
|
||||
|
||||
func (u *DeviceListUpdater) Update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) error {
|
||||
isDeviceListStale, err := u.update(ctx, event)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if isDeviceListStale {
|
||||
// poke workers to handle stale device lists
|
||||
u.notifyWorkers(event.UserID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) (bool, error) {
|
||||
mu := u.mutex(event.UserID)
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
// check if we have the prev IDs
|
||||
exists, err := u.db.PrevIDsExists(ctx, event.UserID, event.PrevID)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check prev IDs exist for %s (%s): %w", event.UserID, event.DeviceID, err)
|
||||
}
|
||||
util.GetLogger(ctx).WithFields(logrus.Fields{
|
||||
"prev_ids_exist": exists,
|
||||
"user_id": event.UserID,
|
||||
"device_id": event.DeviceID,
|
||||
"stream_id": event.StreamID,
|
||||
"prev_ids": event.PrevID,
|
||||
}).Info("DeviceListUpdater.Update")
|
||||
|
||||
// if we haven't missed anything update the database and notify users
|
||||
if exists {
|
||||
keys := []api.DeviceMessage{
|
||||
{
|
||||
DeviceKeys: api.DeviceKeys{
|
||||
DeviceID: event.DeviceID,
|
||||
DisplayName: event.DeviceDisplayName,
|
||||
KeyJSON: event.Keys,
|
||||
UserID: event.UserID,
|
||||
},
|
||||
StreamID: event.StreamID,
|
||||
},
|
||||
}
|
||||
err = u.db.StoreRemoteDeviceKeys(ctx, keys)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to store remote device keys for %s (%s): %w", event.UserID, event.DeviceID, err)
|
||||
}
|
||||
// ALWAYS emit key changes when we've been poked over federation even if there's no change
|
||||
// just in case this poke is important for something.
|
||||
err = u.producer.ProduceKeyChanges(keys)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to produce device key changes for %s (%s): %w", event.UserID, event.DeviceID, err)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
err = u.db.MarkDeviceListStale(ctx, event.UserID, true)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to mark device list for %s as stale: %w", event.UserID, err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (u *DeviceListUpdater) notifyWorkers(userID string) {
|
||||
_, remoteServer, err := gomatrixserverlib.SplitID('@', userID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
hash := fnv.New32a()
|
||||
_, _ = hash.Write([]byte(remoteServer))
|
||||
index := int(hash.Sum32()) % len(u.workerChans)
|
||||
u.workerChans[index] <- remoteServer
|
||||
}
|
||||
|
||||
func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) {
|
||||
// It's possible to get many of the same server name in the channel, so in order
|
||||
// to prevent processing the same server over and over we keep track of when we
|
||||
// last made a request to the server. If we get the server name during the cooloff
|
||||
// period, we'll ignore the poke.
|
||||
lastProcessed := make(map[gomatrixserverlib.ServerName]time.Time)
|
||||
cooloffPeriod := time.Minute
|
||||
shouldProcess := func(srv gomatrixserverlib.ServerName) bool {
|
||||
// we should process requests when now is after the last process time + cooloff
|
||||
return time.Now().After(lastProcessed[srv].Add(cooloffPeriod))
|
||||
}
|
||||
|
||||
// on failure, spin up a short-lived goroutine to inject the server name again.
|
||||
inject := func(srv gomatrixserverlib.ServerName, duration time.Duration) {
|
||||
time.Sleep(duration)
|
||||
ch <- srv
|
||||
}
|
||||
|
||||
for serverName := range ch {
|
||||
if !shouldProcess(serverName) {
|
||||
// do not inject into the channel as we know there will be a sleeping goroutine
|
||||
// which will do it after the cooloff period expires
|
||||
continue
|
||||
}
|
||||
lastProcessed[serverName] = time.Now()
|
||||
shouldRetry := u.processServer(serverName)
|
||||
if shouldRetry {
|
||||
go inject(serverName, cooloffPeriod) // TODO: Backoff?
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerName) bool {
|
||||
requestTimeout := time.Minute // max amount of time we want to spend on each request
|
||||
ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
|
||||
defer cancel()
|
||||
logger := util.GetLogger(ctx).WithField("server_name", serverName)
|
||||
// fetch stale device lists
|
||||
userIDs, err := u.db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{serverName})
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("failed to load stale device lists")
|
||||
return true
|
||||
}
|
||||
hasFailures := false
|
||||
for _, userID := range userIDs {
|
||||
if ctx.Err() != nil {
|
||||
// we've timed out, give up and go to the back of the queue to let another server be processed.
|
||||
hasFailures = true
|
||||
break
|
||||
}
|
||||
res, err := u.fedClient.GetUserDevices(ctx, serverName, userID)
|
||||
if err != nil {
|
||||
logger.WithError(err).WithField("user_id", userID).Error("failed to query device keys for user")
|
||||
hasFailures = true
|
||||
continue
|
||||
}
|
||||
err = u.updateDeviceList(ctx, &res)
|
||||
if err != nil {
|
||||
logger.WithError(err).WithField("user_id", userID).Error("fetched device list but failed to store it")
|
||||
hasFailures = true
|
||||
}
|
||||
}
|
||||
return hasFailures
|
||||
}
|
||||
|
||||
func (u *DeviceListUpdater) updateDeviceList(ctx context.Context, res *gomatrixserverlib.RespUserDevices) error {
|
||||
keys := make([]api.DeviceMessage, len(res.Devices))
|
||||
for i, device := range res.Devices {
|
||||
keyJSON, err := json.Marshal(device.Keys)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithField("keys", device.Keys).Error("failed to marshal keys, skipping device")
|
||||
continue
|
||||
}
|
||||
keys[i] = api.DeviceMessage{
|
||||
StreamID: res.StreamID,
|
||||
DeviceKeys: api.DeviceKeys{
|
||||
DeviceID: device.DeviceID,
|
||||
DisplayName: device.DisplayName,
|
||||
UserID: res.UserID,
|
||||
KeyJSON: keyJSON,
|
||||
},
|
||||
}
|
||||
}
|
||||
err := u.db.StoreRemoteDeviceKeys(ctx, keys)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return u.db.MarkDeviceListStale(ctx, res.UserID, false)
|
||||
}
|
||||
|
|
@ -38,12 +38,24 @@ type KeyInternalAPI struct {
|
|||
FedClient *gomatrixserverlib.FederationClient
|
||||
UserAPI userapi.UserInternalAPI
|
||||
Producer *producers.KeyChange
|
||||
Updater *DeviceListUpdater
|
||||
}
|
||||
|
||||
func (a *KeyInternalAPI) SetUserAPI(i userapi.UserInternalAPI) {
|
||||
a.UserAPI = i
|
||||
}
|
||||
|
||||
func (a *KeyInternalAPI) InputDeviceListUpdate(
|
||||
ctx context.Context, req *api.InputDeviceListUpdateRequest, res *api.InputDeviceListUpdateResponse,
|
||||
) {
|
||||
err := a.Updater.Update(ctx, req.Event)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("failed to update device list: %s", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) {
|
||||
if req.Partition < 0 {
|
||||
req.Partition = a.Producer.DefaultPartition()
|
||||
|
|
@ -351,7 +363,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
|
|||
return
|
||||
}
|
||||
// store the device keys and emit changes
|
||||
err := a.DB.StoreDeviceKeys(ctx, keysToStore)
|
||||
err := a.DB.StoreLocalDeviceKeys(ctx, keysToStore)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("failed to store device keys: %s", err.Error()),
|
||||
|
|
|
|||
|
|
@ -27,12 +27,13 @@ import (
|
|||
|
||||
// HTTP paths for the internal HTTP APIs
|
||||
const (
|
||||
PerformUploadKeysPath = "/keyserver/performUploadKeys"
|
||||
PerformClaimKeysPath = "/keyserver/performClaimKeys"
|
||||
QueryKeysPath = "/keyserver/queryKeys"
|
||||
QueryKeyChangesPath = "/keyserver/queryKeyChanges"
|
||||
QueryOneTimeKeysPath = "/keyserver/queryOneTimeKeys"
|
||||
QueryDeviceMessagesPath = "/keyserver/queryDeviceMessages"
|
||||
InputDeviceListUpdatePath = "/keyserver/inputDeviceListUpdate"
|
||||
PerformUploadKeysPath = "/keyserver/performUploadKeys"
|
||||
PerformClaimKeysPath = "/keyserver/performClaimKeys"
|
||||
QueryKeysPath = "/keyserver/queryKeys"
|
||||
QueryKeyChangesPath = "/keyserver/queryKeyChanges"
|
||||
QueryOneTimeKeysPath = "/keyserver/queryOneTimeKeys"
|
||||
QueryDeviceMessagesPath = "/keyserver/queryDeviceMessages"
|
||||
)
|
||||
|
||||
// NewKeyServerClient creates a KeyInternalAPI implemented by talking to a HTTP POST API.
|
||||
|
|
@ -58,6 +59,20 @@ type httpKeyInternalAPI struct {
|
|||
func (h *httpKeyInternalAPI) SetUserAPI(i userapi.UserInternalAPI) {
|
||||
// no-op: doesn't need it
|
||||
}
|
||||
func (h *httpKeyInternalAPI) InputDeviceListUpdate(
|
||||
ctx context.Context, req *api.InputDeviceListUpdateRequest, res *api.InputDeviceListUpdateResponse,
|
||||
) {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "InputDeviceListUpdate")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.apiURL + InputDeviceListUpdatePath
|
||||
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: err.Error(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *httpKeyInternalAPI) PerformClaimKeys(
|
||||
ctx context.Context,
|
||||
|
|
|
|||
|
|
@ -25,6 +25,17 @@ import (
|
|||
)
|
||||
|
||||
func AddRoutes(internalAPIMux *mux.Router, s api.KeyInternalAPI) {
|
||||
internalAPIMux.Handle(InputDeviceListUpdatePath,
|
||||
httputil.MakeInternalAPI("inputDeviceListUpdate", func(req *http.Request) util.JSONResponse {
|
||||
request := api.InputDeviceListUpdateRequest{}
|
||||
response := api.InputDeviceListUpdateResponse{}
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
s.InputDeviceListUpdate(req.Context(), &request, &response)
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(PerformClaimKeysPath,
|
||||
httputil.MakeInternalAPI("performClaimKeys", func(req *http.Request) util.JSONResponse {
|
||||
request := api.PerformClaimKeysRequest{}
|
||||
|
|
|
|||
|
|
@ -47,10 +47,16 @@ func NewInternalAPI(
|
|||
Producer: producer,
|
||||
DB: db,
|
||||
}
|
||||
updater := internal.NewDeviceListUpdater(db, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable
|
||||
err = updater.Start()
|
||||
if err != nil {
|
||||
logrus.WithError(err).Panicf("failed to start device list updater")
|
||||
}
|
||||
return &internal.KeyInternalAPI{
|
||||
DB: db,
|
||||
ThisServer: cfg.Matrix.ServerName,
|
||||
FedClient: fedClient,
|
||||
Producer: keyChangeProducer,
|
||||
Updater: updater,
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ import (
|
|||
"encoding/json"
|
||||
|
||||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
type Database interface {
|
||||
|
|
@ -35,11 +36,18 @@ type Database interface {
|
|||
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
|
||||
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
||||
|
||||
// StoreDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
||||
// StoreLocalDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
||||
// for this (user, device).
|
||||
// The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set.
|
||||
// Returns an error if there was a problem storing the keys.
|
||||
StoreDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
|
||||
StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
|
||||
|
||||
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
||||
// for this (user, device). Does not modify the stream ID for keys.
|
||||
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
|
||||
|
||||
// PrevIDsExists returns true if all prev IDs exist for this user.
|
||||
PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error)
|
||||
|
||||
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
|
||||
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
|
||||
|
|
@ -57,4 +65,11 @@ type Database interface {
|
|||
// A to offset of sarama.OffsetNewest means no upper limit.
|
||||
// Returns the offset of the latest key change.
|
||||
KeyChanges(ctx context.Context, partition int32, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
|
||||
|
||||
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
|
||||
// If no domains are given, all user IDs with stale device lists are returned.
|
||||
StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
|
||||
|
||||
// MarkDeviceListStale sets the stale bit for this user to isStale.
|
||||
MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ import (
|
|||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||
|
|
@ -56,12 +57,16 @@ const selectBatchDeviceKeysSQL = "" +
|
|||
const selectMaxStreamForUserSQL = "" +
|
||||
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
const countStreamIDsForUserSQL = "" +
|
||||
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id = ANY($2)"
|
||||
|
||||
type deviceKeysStatements struct {
|
||||
db *sql.DB
|
||||
upsertDeviceKeysStmt *sql.Stmt
|
||||
selectDeviceKeysStmt *sql.Stmt
|
||||
selectBatchDeviceKeysStmt *sql.Stmt
|
||||
selectMaxStreamForUserStmt *sql.Stmt
|
||||
countStreamIDsForUserStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||
|
|
@ -84,6 +89,9 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
|||
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.countStreamIDsForUserStmt, err = db.Prepare(countStreamIDsForUserSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
|
|
@ -115,6 +123,19 @@ func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn
|
|||
return
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) {
|
||||
// nullable if there are no results
|
||||
var count sql.NullInt32
|
||||
err := s.countStreamIDsForUserStmt.QueryRowContext(ctx, userID, pq.Int64Array(streamIDs)).Scan(&count)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if count.Valid {
|
||||
return int(count.Int32), nil
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
|
||||
for _, key := range keys {
|
||||
now := time.Now().Unix()
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
type Database struct {
|
||||
|
|
@ -47,7 +48,25 @@ func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage)
|
|||
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
|
||||
}
|
||||
|
||||
func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
|
||||
func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) {
|
||||
sids := make([]int64, len(prevIDs))
|
||||
for i := range prevIDs {
|
||||
sids[i] = int64(prevIDs[i])
|
||||
}
|
||||
count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, sids)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count == len(prevIDs), nil
|
||||
}
|
||||
|
||||
func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
|
||||
return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
|
||||
return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
|
||||
// work out the latest stream IDs for each user
|
||||
userIDToStreamID := make(map[string]int)
|
||||
for _, k := range keys {
|
||||
|
|
@ -106,3 +125,14 @@ func (d *Database) StoreKeyChange(ctx context.Context, partition int32, offset i
|
|||
func (d *Database) KeyChanges(ctx context.Context, partition int32, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) {
|
||||
return d.KeyChangesTable.SelectKeyChanges(ctx, partition, fromOffset, toOffset)
|
||||
}
|
||||
|
||||
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
|
||||
// If no domains are given, all user IDs with stale device lists are returned.
|
||||
func (d *Database) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
|
||||
return nil, nil // TODO
|
||||
}
|
||||
|
||||
// MarkDeviceListStale sets the stale bit for this user to isStale.
|
||||
func (d *Database) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error {
|
||||
return nil // TODO
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ package sqlite3
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
|
|
@ -53,6 +54,9 @@ const selectBatchDeviceKeysSQL = "" +
|
|||
const selectMaxStreamForUserSQL = "" +
|
||||
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
const countStreamIDsForUserSQL = "" +
|
||||
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)"
|
||||
|
||||
type deviceKeysStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
|
|
@ -143,6 +147,25 @@ func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn
|
|||
return
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) {
|
||||
iStreamIDs := make([]interface{}, len(streamIDs)+1)
|
||||
iStreamIDs[0] = userID
|
||||
for i := range streamIDs {
|
||||
iStreamIDs[i+1] = streamIDs[i]
|
||||
}
|
||||
query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1)
|
||||
// nullable if there are no results
|
||||
var count sql.NullInt32
|
||||
err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if count.Valid {
|
||||
return int(count.Int32), nil
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
|
||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
for _, key := range keys {
|
||||
|
|
|
|||
|
|
@ -129,15 +129,15 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
|
|||
// StreamID: 2 as this is a 2nd device key
|
||||
},
|
||||
}
|
||||
MustNotError(t, db.StoreDeviceKeys(ctx, msgs))
|
||||
MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
|
||||
if msgs[0].StreamID != 1 {
|
||||
t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID)
|
||||
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID)
|
||||
}
|
||||
if msgs[1].StreamID != 1 {
|
||||
t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID)
|
||||
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID)
|
||||
}
|
||||
if msgs[2].StreamID != 2 {
|
||||
t.Fatalf("Expected StoreDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID)
|
||||
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID)
|
||||
}
|
||||
|
||||
// updating a device sets the next stream ID for that user
|
||||
|
|
@ -151,9 +151,9 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
|
|||
// StreamID: 3
|
||||
},
|
||||
}
|
||||
MustNotError(t, db.StoreDeviceKeys(ctx, msgs))
|
||||
MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
|
||||
if msgs[0].StreamID != 3 {
|
||||
t.Fatalf("Expected StoreDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID)
|
||||
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID)
|
||||
}
|
||||
|
||||
// Querying for device keys returns the latest stream IDs
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ type DeviceKeys interface {
|
|||
SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
||||
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
|
||||
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
|
||||
CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
|
||||
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -44,6 +44,9 @@ func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneT
|
|||
}
|
||||
func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *keyapi.QueryDeviceMessagesRequest, res *keyapi.QueryDeviceMessagesResponse) {
|
||||
|
||||
}
|
||||
func (k *mockKeyAPI) InputDeviceListUpdate(ctx context.Context, req *keyapi.InputDeviceListUpdateRequest, res *keyapi.InputDeviceListUpdateResponse) {
|
||||
|
||||
}
|
||||
|
||||
type mockCurrentStateAPI struct {
|
||||
|
|
|
|||
Loading…
Reference in a new issue