mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-21 05:43:09 -06:00
Merge branch 'master' into neilalexander/config
This commit is contained in:
commit
253175edc2
|
|
@ -75,7 +75,8 @@ func createFederationClient(
|
||||||
p2phttp.NewTransport(base.LibP2P, p2phttp.ProtocolOption("/matrix")),
|
p2phttp.NewTransport(base.LibP2P, p2phttp.ProtocolOption("/matrix")),
|
||||||
)
|
)
|
||||||
return gomatrixserverlib.NewFederationClientWithTransport(
|
return gomatrixserverlib.NewFederationClientWithTransport(
|
||||||
base.Base.Cfg.Global.ServerName, base.Base.Cfg.Global.KeyID, base.Base.Cfg.Global.PrivateKey, tr,
|
base.Base.Cfg.Global.ServerName, base.Base.Cfg.Global.KeyID,
|
||||||
|
base.Base.Cfg.Global.PrivateKey, true, tr,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -87,7 +88,7 @@ func createClient(
|
||||||
"matrix",
|
"matrix",
|
||||||
p2phttp.NewTransport(base.LibP2P, p2phttp.ProtocolOption("/matrix")),
|
p2phttp.NewTransport(base.LibP2P, p2phttp.ProtocolOption("/matrix")),
|
||||||
)
|
)
|
||||||
return gomatrixserverlib.NewClientWithTransport(tr)
|
return gomatrixserverlib.NewClientWithTransport(true, tr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ func (n *Node) CreateClient(
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return gomatrixserverlib.NewClientWithTransport(tr)
|
return gomatrixserverlib.NewClientWithTransport(true, tr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *Node) CreateFederationClient(
|
func (n *Node) CreateFederationClient(
|
||||||
|
|
@ -54,6 +54,7 @@ func (n *Node) CreateFederationClient(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return gomatrixserverlib.NewFederationClientWithTransport(
|
return gomatrixserverlib.NewFederationClientWithTransport(
|
||||||
base.Cfg.Global.ServerName, base.Cfg.Global.KeyID, base.Cfg.Global.PrivateKey, tr,
|
base.Cfg.Global.ServerName, base.Cfg.Global.KeyID,
|
||||||
|
base.Cfg.Global.PrivateKey, true, tr,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ func main() {
|
||||||
defer base.Close() // nolint: errcheck
|
defer base.Close() // nolint: errcheck
|
||||||
|
|
||||||
userAPI := base.UserAPIClient()
|
userAPI := base.UserAPIClient()
|
||||||
client := gomatrixserverlib.NewClient()
|
client := gomatrixserverlib.NewClient(cfg.FederationSender.DisableTLSValidation)
|
||||||
|
|
||||||
mediaapi.AddPublicRoutes(base.PublicAPIMux, &base.Cfg.MediaAPI, userAPI, client)
|
mediaapi.AddPublicRoutes(base.PublicAPIMux, &base.Cfg.MediaAPI, userAPI, client)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -132,7 +132,7 @@ func main() {
|
||||||
Config: base.Cfg,
|
Config: base.Cfg,
|
||||||
AccountDB: accountDB,
|
AccountDB: accountDB,
|
||||||
DeviceDB: deviceDB,
|
DeviceDB: deviceDB,
|
||||||
Client: gomatrixserverlib.NewClient(),
|
Client: gomatrixserverlib.NewClient(cfg.FederationSender.DisableTLSValidation),
|
||||||
FedClient: federation,
|
FedClient: federation,
|
||||||
KeyRing: keyRing,
|
KeyRing: keyRing,
|
||||||
KafkaConsumer: base.KafkaConsumer,
|
KafkaConsumer: base.KafkaConsumer,
|
||||||
|
|
|
||||||
|
|
@ -139,16 +139,16 @@ func createFederationClient(cfg *config.Dendrite, node *go_http_js_libp2p.P2pLoc
|
||||||
tr := go_http_js_libp2p.NewP2pTransport(node)
|
tr := go_http_js_libp2p.NewP2pTransport(node)
|
||||||
|
|
||||||
fed := gomatrixserverlib.NewFederationClient(
|
fed := gomatrixserverlib.NewFederationClient(
|
||||||
cfg.Global.ServerName, cfg.Global.KeyID, cfg.Global.PrivateKey,
|
cfg.Global.ServerName, cfg.Global.KeyID, cfg.Global.PrivateKey, true,
|
||||||
)
|
)
|
||||||
fed.Client = *gomatrixserverlib.NewClientWithTransport(tr)
|
fed.Client = *gomatrixserverlib.NewClientWithTransport(true, tr)
|
||||||
|
|
||||||
return fed
|
return fed
|
||||||
}
|
}
|
||||||
|
|
||||||
func createClient(node *go_http_js_libp2p.P2pLocalNode) *gomatrixserverlib.Client {
|
func createClient(node *go_http_js_libp2p.P2pLocalNode) *gomatrixserverlib.Client {
|
||||||
tr := go_http_js_libp2p.NewP2pTransport(node)
|
tr := go_http_js_libp2p.NewP2pTransport(node)
|
||||||
return gomatrixserverlib.NewClientWithTransport(tr)
|
return gomatrixserverlib.NewClientWithTransport(true, tr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createP2PNode(privKey ed25519.PrivateKey) (serverName string, node *go_http_js_libp2p.P2pLocalNode) {
|
func createP2PNode(privKey ed25519.PrivateKey) (serverName string, node *go_http_js_libp2p.P2pLocalNode) {
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,9 @@ matrix:
|
||||||
# public_key: l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ
|
# public_key: l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ
|
||||||
# Disables new users from registering (except via shared secrets)
|
# Disables new users from registering (except via shared secrets)
|
||||||
registration_disabled: false
|
registration_disabled: false
|
||||||
|
# Whether to disable TLS certificate validation. Warning: this reduces federation
|
||||||
|
# security and should not be enabled in production!
|
||||||
|
federation_disable_tls_validation: false
|
||||||
|
|
||||||
# The media repository config
|
# The media repository config
|
||||||
media:
|
media:
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) {
|
||||||
defer cancel()
|
defer cancel()
|
||||||
serverName := gomatrixserverlib.ServerName(strings.TrimPrefix(baseURL, "https://"))
|
serverName := gomatrixserverlib.ServerName(strings.TrimPrefix(baseURL, "https://"))
|
||||||
|
|
||||||
fedCli := gomatrixserverlib.NewFederationClient(serverName, cfg.Global.KeyID, cfg.Global.PrivateKey)
|
fedCli := gomatrixserverlib.NewFederationClient(serverName, cfg.Global.KeyID, cfg.Global.PrivateKey, true)
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
roomVer gomatrixserverlib.RoomVersion
|
roomVer gomatrixserverlib.RoomVersion
|
||||||
|
|
|
||||||
|
|
@ -218,7 +218,9 @@ func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, *util.JSONRe
|
||||||
}
|
}
|
||||||
|
|
||||||
t.processEDUs(t.EDUs)
|
t.processEDUs(t.EDUs)
|
||||||
util.GetLogger(t.context).Infof("Processed %d PDUs from transaction %q", len(results), t.TransactionID)
|
if c := len(results); c > 0 {
|
||||||
|
util.GetLogger(t.context).Infof("Processed %d PDUs from transaction %q", c, t.TransactionID)
|
||||||
|
}
|
||||||
return &gomatrixserverlib.RespSend{PDUs: results}, nil
|
return &gomatrixserverlib.RespSend{PDUs: results}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -315,7 +317,7 @@ func (t *txnReq) processEDUs(edus []gomatrixserverlib.EDU) {
|
||||||
case gomatrixserverlib.MDeviceListUpdate:
|
case gomatrixserverlib.MDeviceListUpdate:
|
||||||
t.processDeviceListUpdate(e)
|
t.processDeviceListUpdate(e)
|
||||||
default:
|
default:
|
||||||
util.GetLogger(t.context).WithField("type", e.Type).Warn("unhandled edu")
|
util.GetLogger(t.context).WithField("type", e.Type).Debug("Unhandled EDU")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -255,22 +255,20 @@ func (oq *destinationQueue) backgroundSend() {
|
||||||
// The worker is idle so stop the goroutine. It'll get
|
// The worker is idle so stop the goroutine. It'll get
|
||||||
// restarted automatically the next time we have an event to
|
// restarted automatically the next time we have an event to
|
||||||
// send.
|
// send.
|
||||||
log.Infof("Queue %q has been idle for %s, going to sleep", oq.destination, queueIdleTimeout)
|
log.Debugf("Queue %q has been idle for %s, going to sleep", oq.destination, queueIdleTimeout)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we are backing off this server then wait for the
|
// If we are backing off this server then wait for the
|
||||||
// backoff duration to complete first, or until explicitly
|
// backoff duration to complete first, or until explicitly
|
||||||
// told to retry.
|
// told to retry.
|
||||||
if backoff, duration := oq.statistics.BackoffDuration(); backoff {
|
if _, giveUp := oq.statistics.BackoffIfRequired(oq.backingOff, oq.interruptBackoff); giveUp {
|
||||||
log.WithField("duration", duration).Infof("Backing off %s", oq.destination)
|
// It's been suggested that we should give up because the backoff
|
||||||
oq.backingOff.Store(true)
|
// has exceeded a maximum allowable value. Clean up the in-memory
|
||||||
select {
|
// buffers at this point. The PDU clean-up is already on a defer.
|
||||||
case <-time.After(duration):
|
oq.cleanPendingInvites()
|
||||||
case <-oq.interruptBackoff:
|
log.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination)
|
||||||
log.Infof("Interrupting backoff for %q", oq.destination)
|
return
|
||||||
}
|
|
||||||
oq.backingOff.Store(false)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we have pending PDUs or EDUs then construct a transaction.
|
// If we have pending PDUs or EDUs then construct a transaction.
|
||||||
|
|
@ -278,24 +276,8 @@ func (oq *destinationQueue) backgroundSend() {
|
||||||
// Try sending the next transaction and see what happens.
|
// Try sending the next transaction and see what happens.
|
||||||
transaction, terr := oq.nextTransaction()
|
transaction, terr := oq.nextTransaction()
|
||||||
if terr != nil {
|
if terr != nil {
|
||||||
// We failed to send the transaction.
|
// We failed to send the transaction. Mark it as a failure.
|
||||||
if giveUp := oq.statistics.Failure(); giveUp {
|
oq.statistics.Failure()
|
||||||
// It's been suggested that we should give up because the backoff
|
|
||||||
// has exceeded a maximum allowable value. Clean up the in-memory
|
|
||||||
// buffers at this point. The PDU clean-up is already on a defer.
|
|
||||||
oq.cleanPendingInvites()
|
|
||||||
log.Infof("Blacklisting %q due to errors", oq.destination)
|
|
||||||
return
|
|
||||||
} else {
|
|
||||||
// We haven't been told to give up terminally yet but we still have
|
|
||||||
// PDUs waiting to be sent. By sending a message into the wake chan,
|
|
||||||
// the next loop iteration will try processing these PDUs again,
|
|
||||||
// subject to the backoff.
|
|
||||||
select {
|
|
||||||
case oq.notifyPDUs <- true:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if transaction {
|
} else if transaction {
|
||||||
// If we successfully sent the transaction then clear out
|
// If we successfully sent the transaction then clear out
|
||||||
// the pending events and EDUs, and wipe our transaction ID.
|
// the pending events and EDUs, and wipe our transaction ID.
|
||||||
|
|
@ -307,14 +289,8 @@ func (oq *destinationQueue) backgroundSend() {
|
||||||
if len(oq.pendingInvites) > 0 {
|
if len(oq.pendingInvites) > 0 {
|
||||||
sent, ierr := oq.nextInvites(oq.pendingInvites)
|
sent, ierr := oq.nextInvites(oq.pendingInvites)
|
||||||
if ierr != nil {
|
if ierr != nil {
|
||||||
// We failed to send the transaction so increase the
|
// We failed to send the transaction. Mark it as a failure.
|
||||||
// backoff and give it another go shortly.
|
oq.statistics.Failure()
|
||||||
if giveUp := oq.statistics.Failure(); giveUp {
|
|
||||||
// It's been suggested that we should give up because
|
|
||||||
// the backoff has exceeded a maximum allowable value.
|
|
||||||
log.Infof("Blacklisting %q due to errors", oq.destination)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else if sent > 0 {
|
} else if sent > 0 {
|
||||||
// If we successfully sent the invites then clear out
|
// If we successfully sent the invites then clear out
|
||||||
// the pending invites.
|
// the pending invites.
|
||||||
|
|
@ -414,7 +390,7 @@ func (oq *destinationQueue) nextTransaction() (bool, error) {
|
||||||
t.EDUs = append(t.EDUs, *edu)
|
t.EDUs = append(t.EDUs, *edu)
|
||||||
}
|
}
|
||||||
|
|
||||||
logrus.WithField("server_name", oq.destination).Infof("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs))
|
logrus.WithField("server_name", oq.destination).Debugf("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs))
|
||||||
|
|
||||||
// Try to send the transaction to the destination server.
|
// Try to send the transaction to the destination server.
|
||||||
// TODO: we should check for 500-ish fails vs 400-ish here,
|
// TODO: we should check for 500-ish fails vs 400-ish here,
|
||||||
|
|
|
||||||
|
|
@ -136,7 +136,7 @@ func (oqs *OutgoingQueues) SendEvent(
|
||||||
|
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
"destinations": destinations, "event": ev.EventID(),
|
"destinations": destinations, "event": ev.EventID(),
|
||||||
}).Info("Sending event")
|
}).Infof("Sending event")
|
||||||
|
|
||||||
headeredJSON, err := json.Marshal(ev)
|
headeredJSON, err := json.Marshal(ev)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -65,8 +65,8 @@ type ServerStatistics struct {
|
||||||
statistics *Statistics //
|
statistics *Statistics //
|
||||||
serverName gomatrixserverlib.ServerName //
|
serverName gomatrixserverlib.ServerName //
|
||||||
blacklisted atomic.Bool // is the node blacklisted
|
blacklisted atomic.Bool // is the node blacklisted
|
||||||
backoffUntil atomic.Value // time.Time to wait until before sending requests
|
backoffStarted atomic.Bool // is the backoff started
|
||||||
failCounter atomic.Uint32 // how many times have we failed?
|
backoffCount atomic.Uint32 // number of times BackoffDuration has been called
|
||||||
successCounter atomic.Uint32 // how many times have we succeeded?
|
successCounter atomic.Uint32 // how many times have we succeeded?
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -76,55 +76,67 @@ type ServerStatistics struct {
|
||||||
// we will unblacklist it.
|
// we will unblacklist it.
|
||||||
func (s *ServerStatistics) Success() {
|
func (s *ServerStatistics) Success() {
|
||||||
s.successCounter.Add(1)
|
s.successCounter.Add(1)
|
||||||
s.failCounter.Store(0)
|
s.backoffStarted.Store(false)
|
||||||
|
s.backoffCount.Store(0)
|
||||||
s.blacklisted.Store(false)
|
s.blacklisted.Store(false)
|
||||||
|
if s.statistics.DB != nil {
|
||||||
if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil {
|
if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil {
|
||||||
logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName)
|
logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Failure marks a failure and works out when to backoff until. It
|
// Failure marks a failure and starts backing off if needed.
|
||||||
// returns true if the worker should give up altogether because of
|
// The next call to BackoffIfRequired will do the right thing
|
||||||
// too many consecutive failures. At this point the host is marked
|
// after this.
|
||||||
// as blacklisted.
|
func (s *ServerStatistics) Failure() {
|
||||||
func (s *ServerStatistics) Failure() bool {
|
if s.backoffStarted.CAS(false, true) {
|
||||||
// Increase the fail counter.
|
s.backoffCount.Store(0)
|
||||||
failCounter := s.failCounter.Add(1)
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Check that we haven't failed more times than is acceptable.
|
// BackoffIfRequired will block for as long as the current
|
||||||
if failCounter >= s.statistics.FailuresUntilBlacklist {
|
// backoff requires, if needed. Otherwise it will do nothing.
|
||||||
|
func (s *ServerStatistics) BackoffIfRequired(backingOff atomic.Bool, interrupt <-chan bool) (time.Duration, bool) {
|
||||||
|
if started := s.backoffStarted.Load(); !started {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Work out how many times we've backed off so far.
|
||||||
|
count := s.backoffCount.Inc()
|
||||||
|
duration := time.Second * time.Duration(math.Exp2(float64(count)))
|
||||||
|
|
||||||
|
// Work out if we should be blacklisting at this point.
|
||||||
|
if count >= s.statistics.FailuresUntilBlacklist {
|
||||||
// We've exceeded the maximum amount of times we're willing
|
// We've exceeded the maximum amount of times we're willing
|
||||||
// to back off, which is probably in the region of hours by
|
// to back off, which is probably in the region of hours by
|
||||||
// now. Mark the host as blacklisted and tell the caller to
|
// now. Mark the host as blacklisted and tell the caller to
|
||||||
// give up.
|
// give up.
|
||||||
s.blacklisted.Store(true)
|
s.blacklisted.Store(true)
|
||||||
|
if s.statistics.DB != nil {
|
||||||
if err := s.statistics.DB.AddServerToBlacklist(s.serverName); err != nil {
|
if err := s.statistics.DB.AddServerToBlacklist(s.serverName); err != nil {
|
||||||
logrus.WithError(err).Errorf("Failed to add %q to blacklist", s.serverName)
|
logrus.WithError(err).Errorf("Failed to add %q to blacklist", s.serverName)
|
||||||
}
|
}
|
||||||
return true
|
}
|
||||||
|
return duration, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// We're still under the threshold so work out the exponential
|
// Notify the destination queue that we're backing off now.
|
||||||
// backoff based on how many times we have failed already. The
|
backingOff.Store(true)
|
||||||
// worker goroutine will wait until this time before processing
|
defer backingOff.Store(false)
|
||||||
// anything from the queue.
|
|
||||||
backoffSeconds := time.Second * time.Duration(math.Exp2(float64(failCounter)))
|
|
||||||
s.backoffUntil.Store(
|
|
||||||
time.Now().Add(backoffSeconds),
|
|
||||||
)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// BackoffDuration returns both a bool stating whether to wait,
|
// Work out how long we should be backing off for.
|
||||||
// and then if true, a duration to wait for.
|
logrus.Warnf("Backing off %q for %s", s.serverName, duration)
|
||||||
func (s *ServerStatistics) BackoffDuration() (bool, time.Duration) {
|
|
||||||
backoff, until := false, time.Second
|
// Wait for either an interruption or for the backoff to
|
||||||
if b, ok := s.backoffUntil.Load().(time.Time); ok {
|
// complete.
|
||||||
if b.After(time.Now()) {
|
select {
|
||||||
backoff, until = true, time.Until(b)
|
case <-interrupt:
|
||||||
|
logrus.Debugf("Interrupting backoff for %q", s.serverName)
|
||||||
|
case <-time.After(duration):
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return backoff, until
|
return duration, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Blacklisted returns true if the server is blacklisted and false
|
// Blacklisted returns true if the server is blacklisted and false
|
||||||
|
|
|
||||||
60
federationsender/statistics/statistics_test.go
Normal file
60
federationsender/statistics/statistics_test.go
Normal file
|
|
@ -0,0 +1,60 @@
|
||||||
|
package statistics
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go.uber.org/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBackoff(t *testing.T) {
|
||||||
|
stats := Statistics{
|
||||||
|
FailuresUntilBlacklist: 5,
|
||||||
|
}
|
||||||
|
server := ServerStatistics{
|
||||||
|
statistics: &stats,
|
||||||
|
serverName: "test.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start by checking that counting successes works.
|
||||||
|
server.Success()
|
||||||
|
if successes := server.SuccessCount(); successes != 1 {
|
||||||
|
t.Fatalf("Expected success count 1, got %d", successes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register a failure.
|
||||||
|
server.Failure()
|
||||||
|
|
||||||
|
t.Logf("Backoff counter: %d", server.backoffCount.Load())
|
||||||
|
backingOff := atomic.Bool{}
|
||||||
|
|
||||||
|
// Now we're going to simulate backing off a few times to see
|
||||||
|
// what happens.
|
||||||
|
for i := uint32(1); i <= 10; i++ {
|
||||||
|
// Interrupt the backoff - it doesn't really matter if it
|
||||||
|
// completes but we will find out how long the backoff should
|
||||||
|
// have been.
|
||||||
|
interrupt := make(chan bool, 1)
|
||||||
|
close(interrupt)
|
||||||
|
|
||||||
|
// Get the duration.
|
||||||
|
duration, blacklist := server.BackoffIfRequired(backingOff, interrupt)
|
||||||
|
|
||||||
|
// Check if we should be blacklisted by now.
|
||||||
|
if i > stats.FailuresUntilBlacklist {
|
||||||
|
if !blacklist {
|
||||||
|
t.Fatalf("Backoff %d should have resulted in blacklist but didn't", i)
|
||||||
|
} else {
|
||||||
|
t.Logf("Backoff %d is blacklisted as expected", i)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the duration is what we expect.
|
||||||
|
t.Logf("Backoff %d is for %s", i, duration)
|
||||||
|
if wanted := time.Second * time.Duration(math.Exp2(float64(i))); !blacklist && duration != wanted {
|
||||||
|
t.Fatalf("Backoff %d should have been %s but was %s", i, wanted, duration)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
4
go.mod
4
go.mod
|
|
@ -23,9 +23,9 @@ require (
|
||||||
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
|
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
|
||||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3
|
github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26
|
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200807122736-eb1a0b991914
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20200807145008-79c173b65786
|
||||||
github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f
|
github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f
|
||||||
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7
|
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
|
||||||
github.com/mattn/go-sqlite3 v2.0.2+incompatible
|
github.com/mattn/go-sqlite3 v2.0.2+incompatible
|
||||||
github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5
|
github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5
|
||||||
github.com/ngrok/sqlmw v0.0.0-20200129213757-d5c93a81bec6
|
github.com/ngrok/sqlmw v0.0.0-20200129213757-d5c93a81bec6
|
||||||
|
|
|
||||||
6
go.sum
6
go.sum
|
|
@ -425,12 +425,14 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3 h1:Yb+Wlf
|
||||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
|
github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bhrnp3Ky1qgx/fzCtCALOoGYylh2tpS9K4=
|
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bhrnp3Ky1qgx/fzCtCALOoGYylh2tpS9K4=
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
|
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200807122736-eb1a0b991914 h1:VSGCvSUB1/Y32F/JSjmTaIW9jr1BmBHEd0ok4AaT/lo=
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20200807145008-79c173b65786 h1:HQclx5J2CrCBqP88t5Di9IkVDJZn5+h4ZL48viY4FJ4=
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200807122736-eb1a0b991914/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20200807145008-79c173b65786/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
|
||||||
github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f h1:pRz4VTiRCO4zPlEMc3ESdUOcW4PXHH4Kj+YDz1XyE+Y=
|
github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f h1:pRz4VTiRCO4zPlEMc3ESdUOcW4PXHH4Kj+YDz1XyE+Y=
|
||||||
github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f/go.mod h1:y0oDTjZDv5SM9a2rp3bl+CU+bvTRINQsdb7YlDql5Go=
|
github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f/go.mod h1:y0oDTjZDv5SM9a2rp3bl+CU+bvTRINQsdb7YlDql5Go=
|
||||||
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo=
|
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo=
|
||||||
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
|
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
|
||||||
|
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=
|
||||||
|
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
|
||||||
github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU=
|
github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU=
|
||||||
github.com/mattn/go-colorable v0.1.1 h1:G1f5SKeVxmagw/IyvzvtZE4Gybcc4Tr1tf7I8z0XgOg=
|
github.com/mattn/go-colorable v0.1.1 h1:G1f5SKeVxmagw/IyvzvtZE4Gybcc4Tr1tf7I8z0XgOg=
|
||||||
github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ=
|
github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ=
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,11 @@ type FederationSender struct {
|
||||||
// tolerate when sending federation requests to a specific server. The backoff
|
// tolerate when sending federation requests to a specific server. The backoff
|
||||||
// is 2**x seconds, so 1 = 2 seconds, 2 = 4 seconds, 3 = 8 seconds, etc.
|
// is 2**x seconds, so 1 = 2 seconds, 2 = 4 seconds, 3 = 8 seconds, etc.
|
||||||
// The default value is 16 if not specified, which is circa 18 hours.
|
// The default value is 16 if not specified, which is circa 18 hours.
|
||||||
FederationMaxRetries uint32 `yaml:"federation_max_retries"`
|
FederationMaxRetries uint32 `yaml:"send_max_retries"`
|
||||||
|
|
||||||
|
// FederationDisableTLSValidation disables the validation of X.509 TLS certs
|
||||||
|
// on remote federation endpoints. This is not recommended in production!
|
||||||
|
DisableTLSValidation bool `yaml:"disable_tls_validation"`
|
||||||
|
|
||||||
Proxy Proxy `yaml:"proxy_outbound"`
|
Proxy Proxy `yaml:"proxy_outbound"`
|
||||||
}
|
}
|
||||||
|
|
@ -26,6 +30,7 @@ func (c *FederationSender) Defaults() {
|
||||||
c.Database.ConnectionString = "file:federationsender.db"
|
c.Database.ConnectionString = "file:federationsender.db"
|
||||||
|
|
||||||
c.FederationMaxRetries = 16
|
c.FederationMaxRetries = 16
|
||||||
|
c.DisableTLSValidation = false
|
||||||
|
|
||||||
c.Proxy.Defaults()
|
c.Proxy.Defaults()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -251,6 +251,7 @@ func (b *BaseDendrite) CreateAccountsDB() accounts.Database {
|
||||||
func (b *BaseDendrite) CreateFederationClient() *gomatrixserverlib.FederationClient {
|
func (b *BaseDendrite) CreateFederationClient() *gomatrixserverlib.FederationClient {
|
||||||
return gomatrixserverlib.NewFederationClient(
|
return gomatrixserverlib.NewFederationClient(
|
||||||
b.Cfg.Global.ServerName, b.Cfg.Global.KeyID, b.Cfg.Global.PrivateKey,
|
b.Cfg.Global.ServerName, b.Cfg.Global.KeyID, b.Cfg.Global.PrivateKey,
|
||||||
|
b.Cfg.FederationSender.DisableTLSValidation,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/keyserver/api"
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
"github.com/matrix-org/dendrite/keyserver/producers"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
@ -65,7 +64,7 @@ type DeviceListUpdater struct {
|
||||||
mu *sync.Mutex // protects UserIDToMutex
|
mu *sync.Mutex // protects UserIDToMutex
|
||||||
|
|
||||||
db DeviceListUpdaterDatabase
|
db DeviceListUpdaterDatabase
|
||||||
producer *producers.KeyChange
|
producer KeyChangeProducer
|
||||||
fedClient *gomatrixserverlib.FederationClient
|
fedClient *gomatrixserverlib.FederationClient
|
||||||
workerChans []chan gomatrixserverlib.ServerName
|
workerChans []chan gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
@ -88,9 +87,14 @@ type DeviceListUpdaterDatabase interface {
|
||||||
PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error)
|
PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// KeyChangeProducer is the interface for producers.KeyChange useful for testing.
|
||||||
|
type KeyChangeProducer interface {
|
||||||
|
ProduceKeyChanges(keys []api.DeviceMessage) error
|
||||||
|
}
|
||||||
|
|
||||||
// NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale.
|
// NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale.
|
||||||
func NewDeviceListUpdater(
|
func NewDeviceListUpdater(
|
||||||
db DeviceListUpdaterDatabase, producer *producers.KeyChange, fedClient *gomatrixserverlib.FederationClient,
|
db DeviceListUpdaterDatabase, producer KeyChangeProducer, fedClient *gomatrixserverlib.FederationClient,
|
||||||
numWorkers int,
|
numWorkers int,
|
||||||
) *DeviceListUpdater {
|
) *DeviceListUpdater {
|
||||||
return &DeviceListUpdater{
|
return &DeviceListUpdater{
|
||||||
|
|
@ -154,12 +158,17 @@ func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("failed to check prev IDs exist for %s (%s): %w", event.UserID, event.DeviceID, err)
|
return false, fmt.Errorf("failed to check prev IDs exist for %s (%s): %w", event.UserID, event.DeviceID, err)
|
||||||
}
|
}
|
||||||
|
// if this is the first time we're hearing about this user, sync the device list manually.
|
||||||
|
if len(event.PrevID) == 0 {
|
||||||
|
exists = false
|
||||||
|
}
|
||||||
util.GetLogger(ctx).WithFields(logrus.Fields{
|
util.GetLogger(ctx).WithFields(logrus.Fields{
|
||||||
"prev_ids_exist": exists,
|
"prev_ids_exist": exists,
|
||||||
"user_id": event.UserID,
|
"user_id": event.UserID,
|
||||||
"device_id": event.DeviceID,
|
"device_id": event.DeviceID,
|
||||||
"stream_id": event.StreamID,
|
"stream_id": event.StreamID,
|
||||||
"prev_ids": event.PrevID,
|
"prev_ids": event.PrevID,
|
||||||
|
"display_name": event.DeviceDisplayName,
|
||||||
}).Info("DeviceListUpdater.Update")
|
}).Info("DeviceListUpdater.Update")
|
||||||
|
|
||||||
// if we haven't missed anything update the database and notify users
|
// if we haven't missed anything update the database and notify users
|
||||||
|
|
@ -263,16 +272,17 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam
|
||||||
hasFailures = true
|
hasFailures = true
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
err = u.updateDeviceList(ctx, &res)
|
err = u.updateDeviceList(&res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.WithError(err).WithField("user_id", userID).Error("fetched device list but failed to store it")
|
logger.WithError(err).WithField("user_id", userID).Error("fetched device list but failed to store/emit it")
|
||||||
hasFailures = true
|
hasFailures = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return hasFailures
|
return hasFailures
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *DeviceListUpdater) updateDeviceList(ctx context.Context, res *gomatrixserverlib.RespUserDevices) error {
|
func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevices) error {
|
||||||
|
ctx := context.Background() // we've got the keys, don't time out when persisting them to the database.
|
||||||
keys := make([]api.DeviceMessage, len(res.Devices))
|
keys := make([]api.DeviceMessage, len(res.Devices))
|
||||||
for i, device := range res.Devices {
|
for i, device := range res.Devices {
|
||||||
keyJSON, err := json.Marshal(device.Keys)
|
keyJSON, err := json.Marshal(device.Keys)
|
||||||
|
|
@ -292,7 +302,15 @@ func (u *DeviceListUpdater) updateDeviceList(ctx context.Context, res *gomatrixs
|
||||||
}
|
}
|
||||||
err := u.db.StoreRemoteDeviceKeys(ctx, keys)
|
err := u.db.StoreRemoteDeviceKeys(ctx, keys)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to store remote device keys: %w", err)
|
||||||
}
|
}
|
||||||
return u.db.MarkDeviceListStale(ctx, res.UserID, false)
|
err = u.db.MarkDeviceListStale(ctx, res.UserID, false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to mark device list as fresh: %w", err)
|
||||||
|
}
|
||||||
|
err = u.producer.ProduceKeyChanges(keys)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to emit key changes for fresh device list: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
242
keyserver/internal/device_list_update_test.go
Normal file
242
keyserver/internal/device_list_update_test.go
Normal file
|
|
@ -0,0 +1,242 @@
|
||||||
|
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/ed25519"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ctx = context.Background()
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockKeyChangeProducer struct {
|
||||||
|
events []api.DeviceMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mockKeyChangeProducer) ProduceKeyChanges(keys []api.DeviceMessage) error {
|
||||||
|
p.events = append(p.events, keys...)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockDeviceListUpdaterDatabase struct {
|
||||||
|
staleUsers map[string]bool
|
||||||
|
prevIDsExist func(string, []int) bool
|
||||||
|
storedKeys []api.DeviceMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
|
||||||
|
// If no domains are given, all user IDs with stale device lists are returned.
|
||||||
|
func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
|
||||||
|
var result []string
|
||||||
|
for userID := range d.staleUsers {
|
||||||
|
_, remoteServer, err := gomatrixserverlib.SplitID('@', userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(domains) == 0 {
|
||||||
|
result = append(result, userID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, d := range domains {
|
||||||
|
if remoteServer == d {
|
||||||
|
result = append(result, userID)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkDeviceListStale sets the stale bit for this user to isStale.
|
||||||
|
func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error {
|
||||||
|
d.staleUsers[userID] = isStale
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
||||||
|
// for this (user, device). Does not modify the stream ID for keys.
|
||||||
|
func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
|
||||||
|
d.storedKeys = append(d.storedKeys, keys...)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrevIDsExists returns true if all prev IDs exist for this user.
|
||||||
|
func (d *mockDeviceListUpdaterDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) {
|
||||||
|
return d.prevIDsExist(userID, prevIDs), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type roundTripper struct {
|
||||||
|
fn func(*http.Request) (*http.Response, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
return t.fn(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrixserverlib.FederationClient {
|
||||||
|
_, pkey, _ := ed25519.GenerateKey(nil)
|
||||||
|
fedClient := gomatrixserverlib.NewFederationClient(
|
||||||
|
gomatrixserverlib.ServerName("example.test"), gomatrixserverlib.KeyID("ed25519:test"), pkey, true,
|
||||||
|
)
|
||||||
|
fedClient.Client = *gomatrixserverlib.NewClientWithTransport(true, &roundTripper{tripper})
|
||||||
|
return fedClient
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that the device keys get persisted and emitted if we have the previous IDs.
|
||||||
|
func TestUpdateHavePrevID(t *testing.T) {
|
||||||
|
db := &mockDeviceListUpdaterDatabase{
|
||||||
|
staleUsers: make(map[string]bool),
|
||||||
|
prevIDsExist: func(string, []int) bool {
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
}
|
||||||
|
producer := &mockKeyChangeProducer{}
|
||||||
|
updater := NewDeviceListUpdater(db, producer, nil, 1)
|
||||||
|
event := gomatrixserverlib.DeviceListUpdateEvent{
|
||||||
|
DeviceDisplayName: "Foo Bar",
|
||||||
|
Deleted: false,
|
||||||
|
DeviceID: "FOO",
|
||||||
|
Keys: []byte(`{"key":"value"}`),
|
||||||
|
PrevID: []int{0},
|
||||||
|
StreamID: 1,
|
||||||
|
UserID: "@alice:localhost",
|
||||||
|
}
|
||||||
|
err := updater.Update(ctx, event)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Update returned an error: %s", err)
|
||||||
|
}
|
||||||
|
want := api.DeviceMessage{
|
||||||
|
StreamID: event.StreamID,
|
||||||
|
DeviceKeys: api.DeviceKeys{
|
||||||
|
DeviceID: event.DeviceID,
|
||||||
|
DisplayName: event.DeviceDisplayName,
|
||||||
|
KeyJSON: event.Keys,
|
||||||
|
UserID: event.UserID,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) {
|
||||||
|
t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) {
|
||||||
|
t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want)
|
||||||
|
}
|
||||||
|
if db.staleUsers[event.UserID] {
|
||||||
|
t.Errorf("%s incorrectly marked as stale", event.UserID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that device keys are fetched from the remote server if we are missing prev IDs
|
||||||
|
// and that the user's devices are marked as stale until it succeeds.
|
||||||
|
func TestUpdateNoPrevID(t *testing.T) {
|
||||||
|
db := &mockDeviceListUpdaterDatabase{
|
||||||
|
staleUsers: make(map[string]bool),
|
||||||
|
prevIDsExist: func(string, []int) bool {
|
||||||
|
return false
|
||||||
|
},
|
||||||
|
}
|
||||||
|
producer := &mockKeyChangeProducer{}
|
||||||
|
remoteUserID := "@alice:example.somewhere"
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
keyJSON := `{"user_id":"` + remoteUserID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + remoteUserID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}`
|
||||||
|
fedClient := newFedClient(func(req *http.Request) (*http.Response, error) {
|
||||||
|
defer wg.Done()
|
||||||
|
if req.URL.Path != "/_matrix/federation/v1/user/devices/"+url.PathEscape(remoteUserID) {
|
||||||
|
return nil, fmt.Errorf("test: invalid path: %s", req.URL.Path)
|
||||||
|
}
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: 200,
|
||||||
|
Body: ioutil.NopCloser(strings.NewReader(`
|
||||||
|
{
|
||||||
|
"user_id": "` + remoteUserID + `",
|
||||||
|
"stream_id": 5,
|
||||||
|
"devices": [
|
||||||
|
{
|
||||||
|
"device_id": "JLAFKJWSCS",
|
||||||
|
"keys": ` + keyJSON + `,
|
||||||
|
"device_display_name": "Mobile Phone"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
`)),
|
||||||
|
}, nil
|
||||||
|
})
|
||||||
|
updater := NewDeviceListUpdater(db, producer, fedClient, 2)
|
||||||
|
if err := updater.Start(); err != nil {
|
||||||
|
t.Fatalf("failed to start updater: %s", err)
|
||||||
|
}
|
||||||
|
event := gomatrixserverlib.DeviceListUpdateEvent{
|
||||||
|
DeviceDisplayName: "Mobile Phone",
|
||||||
|
Deleted: false,
|
||||||
|
DeviceID: "another_device_id",
|
||||||
|
Keys: []byte(`{"key":"value"}`),
|
||||||
|
PrevID: []int{3},
|
||||||
|
StreamID: 4,
|
||||||
|
UserID: remoteUserID,
|
||||||
|
}
|
||||||
|
err := updater.Update(ctx, event)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Update returned an error: %s", err)
|
||||||
|
}
|
||||||
|
// At this point we show have this device list marked as stale and not store the keys or emitted anything
|
||||||
|
if !db.staleUsers[event.UserID] {
|
||||||
|
t.Errorf("%s not marked as stale", event.UserID)
|
||||||
|
}
|
||||||
|
if len(producer.events) > 0 {
|
||||||
|
t.Errorf("Update incorrect emitted %d device change events", len(producer.events))
|
||||||
|
}
|
||||||
|
if len(db.storedKeys) > 0 {
|
||||||
|
t.Errorf("Update incorrect stored %d device change events", len(db.storedKeys))
|
||||||
|
}
|
||||||
|
t.Log("waiting for /users/devices to be called...")
|
||||||
|
wg.Wait()
|
||||||
|
// wait a bit for db to be updated...
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
want := api.DeviceMessage{
|
||||||
|
StreamID: 5,
|
||||||
|
DeviceKeys: api.DeviceKeys{
|
||||||
|
DeviceID: "JLAFKJWSCS",
|
||||||
|
DisplayName: "Mobile Phone",
|
||||||
|
UserID: remoteUserID,
|
||||||
|
KeyJSON: []byte(keyJSON),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
// Now we should have a fresh list and the keys and emitted something
|
||||||
|
if db.staleUsers[event.UserID] {
|
||||||
|
t.Errorf("%s still marked as stale", event.UserID)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) {
|
||||||
|
t.Logf("len got %d len want %d", len(producer.events[0].KeyJSON), len(want.KeyJSON))
|
||||||
|
t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) {
|
||||||
|
t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
@ -250,10 +250,14 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
|
||||||
if len(dk.KeyJSON) == 0 {
|
if len(dk.KeyJSON) == 0 {
|
||||||
continue // don't include blank keys
|
continue // don't include blank keys
|
||||||
}
|
}
|
||||||
// inject display name if known
|
// inject display name if known (either locally or remotely)
|
||||||
|
displayName := dk.DisplayName
|
||||||
|
if queryRes.DeviceInfo[dk.DeviceID].DisplayName != "" {
|
||||||
|
displayName = queryRes.DeviceInfo[dk.DeviceID].DisplayName
|
||||||
|
}
|
||||||
dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct {
|
dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct {
|
||||||
DisplayName string `json:"device_display_name,omitempty"`
|
DisplayName string `json:"device_display_name,omitempty"`
|
||||||
}{queryRes.DeviceInfo[dk.DeviceID].DisplayName})
|
}{displayName})
|
||||||
res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON
|
res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -261,12 +265,49 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
|
||||||
domainToDeviceKeys[domain][userID] = append(domainToDeviceKeys[domain][userID], deviceIDs...)
|
domainToDeviceKeys[domain][userID] = append(domainToDeviceKeys[domain][userID], deviceIDs...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// TODO: set device display names when they are known
|
|
||||||
|
// attempt to satisfy key queries from the local database first as we should get device updates pushed to us
|
||||||
|
domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, domainToDeviceKeys)
|
||||||
|
if len(domainToDeviceKeys) == 0 {
|
||||||
|
return // nothing to query
|
||||||
|
}
|
||||||
|
|
||||||
// perform key queries for remote devices
|
// perform key queries for remote devices
|
||||||
a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys)
|
a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *KeyInternalAPI) remoteKeysFromDatabase(
|
||||||
|
ctx context.Context, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string,
|
||||||
|
) map[string]map[string][]string {
|
||||||
|
fetchRemote := make(map[string]map[string][]string)
|
||||||
|
for domain, userToDeviceMap := range domainToDeviceKeys {
|
||||||
|
for userID, deviceIDs := range userToDeviceMap {
|
||||||
|
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs)
|
||||||
|
// if we can't query the db or there are fewer keys than requested, fetch from remote.
|
||||||
|
// Likewise, we can't safely return keys from the db when all devices are requested as we don't
|
||||||
|
// know if one has just been added.
|
||||||
|
if len(deviceIDs) == 0 || err != nil || len(keys) < len(deviceIDs) {
|
||||||
|
if _, ok := fetchRemote[domain]; !ok {
|
||||||
|
fetchRemote[domain] = make(map[string][]string)
|
||||||
|
}
|
||||||
|
fetchRemote[domain][userID] = append(fetchRemote[domain][userID], deviceIDs...)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if res.DeviceKeys[userID] == nil {
|
||||||
|
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
||||||
|
}
|
||||||
|
for _, key := range keys {
|
||||||
|
// inject the display name
|
||||||
|
key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
|
||||||
|
DisplayName string `json:"device_display_name,omitempty"`
|
||||||
|
}{key.DisplayName})
|
||||||
|
res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fetchRemote
|
||||||
|
}
|
||||||
|
|
||||||
func (a *KeyInternalAPI) queryRemoteKeys(
|
func (a *KeyInternalAPI) queryRemoteKeys(
|
||||||
ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string,
|
ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string,
|
||||||
) {
|
) {
|
||||||
|
|
|
||||||
|
|
@ -37,22 +37,23 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys (
|
||||||
-- required in the spec because in the event of a missed update the server fetches the entire
|
-- required in the spec because in the event of a missed update the server fetches the entire
|
||||||
-- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs.
|
-- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs.
|
||||||
stream_id BIGINT NOT NULL,
|
stream_id BIGINT NOT NULL,
|
||||||
|
display_name TEXT,
|
||||||
-- Clobber based on tuple of user/device.
|
-- Clobber based on tuple of user/device.
|
||||||
CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id)
|
CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id)
|
||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
const upsertDeviceKeysSQL = "" +
|
const upsertDeviceKeysSQL = "" +
|
||||||
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" +
|
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
|
||||||
" VALUES ($1, $2, $3, $4, $5)" +
|
" VALUES ($1, $2, $3, $4, $5, $6)" +
|
||||||
" ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" +
|
" ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" +
|
||||||
" DO UPDATE SET key_json = $4, stream_id = $5"
|
" DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
|
||||||
|
|
||||||
const selectDeviceKeysSQL = "" +
|
const selectDeviceKeysSQL = "" +
|
||||||
"SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
|
"SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
|
||||||
|
|
||||||
const selectBatchDeviceKeysSQL = "" +
|
const selectBatchDeviceKeysSQL = "" +
|
||||||
"SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1"
|
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
|
||||||
const selectMaxStreamForUserSQL = "" +
|
const selectMaxStreamForUserSQL = "" +
|
||||||
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
|
@ -99,13 +100,17 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
|
||||||
for i, key := range keys {
|
for i, key := range keys {
|
||||||
var keyJSONStr string
|
var keyJSONStr string
|
||||||
var streamID int
|
var streamID int
|
||||||
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID)
|
var displayName sql.NullString
|
||||||
|
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil && err != sql.ErrNoRows {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// this will be '' when there is no device
|
// this will be '' when there is no device
|
||||||
keys[i].KeyJSON = []byte(keyJSONStr)
|
keys[i].KeyJSON = []byte(keyJSONStr)
|
||||||
keys[i].StreamID = streamID
|
keys[i].StreamID = streamID
|
||||||
|
if displayName.Valid {
|
||||||
|
keys[i].DisplayName = displayName.String
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
@ -140,7 +145,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
|
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
|
||||||
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID,
|
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -165,11 +170,15 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
|
||||||
dk.UserID = userID
|
dk.UserID = userID
|
||||||
var keyJSON string
|
var keyJSON string
|
||||||
var streamID int
|
var streamID int
|
||||||
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil {
|
var displayName sql.NullString
|
||||||
|
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
dk.KeyJSON = []byte(keyJSON)
|
dk.KeyJSON = []byte(keyJSON)
|
||||||
dk.StreamID = streamID
|
dk.StreamID = streamID
|
||||||
|
if displayName.Valid {
|
||||||
|
dk.DisplayName = displayName.String
|
||||||
|
}
|
||||||
// include the key if we want all keys (no device) or it was asked
|
// include the key if we want all keys (no device) or it was asked
|
||||||
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
|
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
|
||||||
result = append(result, dk)
|
result = append(result, dk)
|
||||||
|
|
|
||||||
118
keyserver/storage/postgres/stale_device_lists.go
Normal file
118
keyserver/storage/postgres/stale_device_lists.go
Normal file
|
|
@ -0,0 +1,118 @@
|
||||||
|
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
var staleDeviceListsSchema = `
|
||||||
|
-- Stores whether a user's device lists are stale or not.
|
||||||
|
CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists (
|
||||||
|
user_id TEXT PRIMARY KEY NOT NULL,
|
||||||
|
domain TEXT NOT NULL,
|
||||||
|
is_stale BOOLEAN NOT NULL,
|
||||||
|
ts_added_secs BIGINT NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale);
|
||||||
|
`
|
||||||
|
|
||||||
|
const upsertStaleDeviceListSQL = "" +
|
||||||
|
"INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" +
|
||||||
|
" VALUES ($1, $2, $3, $4)" +
|
||||||
|
" ON CONFLICT (user_id)" +
|
||||||
|
" DO UPDATE SET is_stale = $3, ts_added_secs = $4"
|
||||||
|
|
||||||
|
const selectStaleDeviceListsWithDomainsSQL = "" +
|
||||||
|
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2"
|
||||||
|
|
||||||
|
const selectStaleDeviceListsSQL = "" +
|
||||||
|
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1"
|
||||||
|
|
||||||
|
type staleDeviceListsStatements struct {
|
||||||
|
upsertStaleDeviceListStmt *sql.Stmt
|
||||||
|
selectStaleDeviceListsWithDomainsStmt *sql.Stmt
|
||||||
|
selectStaleDeviceListsStmt *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
|
||||||
|
s := &staleDeviceListsStatements{}
|
||||||
|
_, err := db.Exec(staleDeviceListsSchema)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
|
||||||
|
_, domain, err := gomatrixserverlib.SplitID('@', userID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
|
||||||
|
// we only query for 1 domain or all domains so optimise for those use cases
|
||||||
|
if len(domains) == 0 {
|
||||||
|
rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return rowsToUserIDs(ctx, rows)
|
||||||
|
}
|
||||||
|
var result []string
|
||||||
|
for _, domain := range domains {
|
||||||
|
rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
userIDs, err := rowsToUserIDs(ctx, rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result = append(result, userIDs...)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
|
||||||
|
for rows.Next() {
|
||||||
|
var userID string
|
||||||
|
if err := rows.Scan(&userID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result = append(result, userID)
|
||||||
|
}
|
||||||
|
return result, rows.Err()
|
||||||
|
}
|
||||||
|
|
@ -39,10 +39,15 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
sdl, err := NewPostgresStaleDeviceListsTable(db)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return &shared.Database{
|
return &shared.Database{
|
||||||
DB: db,
|
DB: db,
|
||||||
OneTimeKeysTable: otk,
|
OneTimeKeysTable: otk,
|
||||||
DeviceKeysTable: dk,
|
DeviceKeysTable: dk,
|
||||||
KeyChangesTable: kc,
|
KeyChangesTable: kc,
|
||||||
|
StaleDeviceListsTable: sdl,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,7 @@ type Database struct {
|
||||||
OneTimeKeysTable tables.OneTimeKeys
|
OneTimeKeysTable tables.OneTimeKeys
|
||||||
DeviceKeysTable tables.DeviceKeys
|
DeviceKeysTable tables.DeviceKeys
|
||||||
KeyChangesTable tables.KeyChanges
|
KeyChangesTable tables.KeyChanges
|
||||||
|
StaleDeviceListsTable tables.StaleDeviceLists
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
||||||
|
|
@ -129,10 +130,10 @@ func (d *Database) KeyChanges(ctx context.Context, partition int32, fromOffset,
|
||||||
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
|
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
|
||||||
// If no domains are given, all user IDs with stale device lists are returned.
|
// If no domains are given, all user IDs with stale device lists are returned.
|
||||||
func (d *Database) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
|
func (d *Database) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
|
||||||
return nil, nil // TODO
|
return d.StaleDeviceListsTable.SelectUserIDsWithStaleDeviceLists(ctx, domains)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarkDeviceListStale sets the stale bit for this user to isStale.
|
// MarkDeviceListStale sets the stale bit for this user to isStale.
|
||||||
func (d *Database) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error {
|
func (d *Database) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error {
|
||||||
return nil // TODO
|
return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -34,22 +34,23 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys (
|
||||||
ts_added_secs BIGINT NOT NULL,
|
ts_added_secs BIGINT NOT NULL,
|
||||||
key_json TEXT NOT NULL,
|
key_json TEXT NOT NULL,
|
||||||
stream_id BIGINT NOT NULL,
|
stream_id BIGINT NOT NULL,
|
||||||
|
display_name TEXT,
|
||||||
-- Clobber based on tuple of user/device.
|
-- Clobber based on tuple of user/device.
|
||||||
UNIQUE (user_id, device_id)
|
UNIQUE (user_id, device_id)
|
||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
const upsertDeviceKeysSQL = "" +
|
const upsertDeviceKeysSQL = "" +
|
||||||
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" +
|
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
|
||||||
" VALUES ($1, $2, $3, $4, $5)" +
|
" VALUES ($1, $2, $3, $4, $5, $6)" +
|
||||||
" ON CONFLICT (user_id, device_id)" +
|
" ON CONFLICT (user_id, device_id)" +
|
||||||
" DO UPDATE SET key_json = $4, stream_id = $5"
|
" DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
|
||||||
|
|
||||||
const selectDeviceKeysSQL = "" +
|
const selectDeviceKeysSQL = "" +
|
||||||
"SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
|
"SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
|
||||||
|
|
||||||
const selectBatchDeviceKeysSQL = "" +
|
const selectBatchDeviceKeysSQL = "" +
|
||||||
"SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1"
|
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
|
||||||
const selectMaxStreamForUserSQL = "" +
|
const selectMaxStreamForUserSQL = "" +
|
||||||
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
|
@ -106,11 +107,15 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
|
||||||
dk.UserID = userID
|
dk.UserID = userID
|
||||||
var keyJSON string
|
var keyJSON string
|
||||||
var streamID int
|
var streamID int
|
||||||
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil {
|
var displayName sql.NullString
|
||||||
|
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
dk.KeyJSON = []byte(keyJSON)
|
dk.KeyJSON = []byte(keyJSON)
|
||||||
dk.StreamID = streamID
|
dk.StreamID = streamID
|
||||||
|
if displayName.Valid {
|
||||||
|
dk.DisplayName = displayName.String
|
||||||
|
}
|
||||||
// include the key if we want all keys (no device) or it was asked
|
// include the key if we want all keys (no device) or it was asked
|
||||||
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
|
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
|
||||||
result = append(result, dk)
|
result = append(result, dk)
|
||||||
|
|
@ -123,13 +128,17 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
|
||||||
for i, key := range keys {
|
for i, key := range keys {
|
||||||
var keyJSONStr string
|
var keyJSONStr string
|
||||||
var streamID int
|
var streamID int
|
||||||
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID)
|
var displayName sql.NullString
|
||||||
|
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil && err != sql.ErrNoRows {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// this will be '' when there is no device
|
// this will be '' when there is no device
|
||||||
keys[i].KeyJSON = []byte(keyJSONStr)
|
keys[i].KeyJSON = []byte(keyJSONStr)
|
||||||
keys[i].StreamID = streamID
|
keys[i].StreamID = streamID
|
||||||
|
if displayName.Valid {
|
||||||
|
keys[i].DisplayName = displayName.String
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
@ -171,7 +180,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
|
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
|
||||||
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID,
|
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
||||||
|
|
@ -196,6 +196,9 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
|
||||||
_, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
|
_, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
|
if keyJSON == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
return map[string]json.RawMessage{
|
return map[string]json.RawMessage{
|
||||||
algorithm + ":" + keyID: json.RawMessage(keyJSON),
|
algorithm + ":" + keyID: json.RawMessage(keyJSON),
|
||||||
}, err
|
}, err
|
||||||
|
|
|
||||||
118
keyserver/storage/sqlite3/stale_device_lists.go
Normal file
118
keyserver/storage/sqlite3/stale_device_lists.go
Normal file
|
|
@ -0,0 +1,118 @@
|
||||||
|
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
var staleDeviceListsSchema = `
|
||||||
|
-- Stores whether a user's device lists are stale or not.
|
||||||
|
CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists (
|
||||||
|
user_id TEXT PRIMARY KEY NOT NULL,
|
||||||
|
domain TEXT NOT NULL,
|
||||||
|
is_stale BOOLEAN NOT NULL,
|
||||||
|
ts_added_secs BIGINT NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale);
|
||||||
|
`
|
||||||
|
|
||||||
|
const upsertStaleDeviceListSQL = "" +
|
||||||
|
"INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" +
|
||||||
|
" VALUES ($1, $2, $3, $4)" +
|
||||||
|
" ON CONFLICT (user_id)" +
|
||||||
|
" DO UPDATE SET is_stale = $3, ts_added_secs = $4"
|
||||||
|
|
||||||
|
const selectStaleDeviceListsWithDomainsSQL = "" +
|
||||||
|
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2"
|
||||||
|
|
||||||
|
const selectStaleDeviceListsSQL = "" +
|
||||||
|
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1"
|
||||||
|
|
||||||
|
type staleDeviceListsStatements struct {
|
||||||
|
upsertStaleDeviceListStmt *sql.Stmt
|
||||||
|
selectStaleDeviceListsWithDomainsStmt *sql.Stmt
|
||||||
|
selectStaleDeviceListsStmt *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
|
||||||
|
s := &staleDeviceListsStatements{}
|
||||||
|
_, err := db.Exec(staleDeviceListsSchema)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
|
||||||
|
_, domain, err := gomatrixserverlib.SplitID('@', userID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
|
||||||
|
// we only query for 1 domain or all domains so optimise for those use cases
|
||||||
|
if len(domains) == 0 {
|
||||||
|
rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return rowsToUserIDs(ctx, rows)
|
||||||
|
}
|
||||||
|
var result []string
|
||||||
|
for _, domain := range domains {
|
||||||
|
rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
userIDs, err := rowsToUserIDs(ctx, rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result = append(result, userIDs...)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
|
||||||
|
for rows.Next() {
|
||||||
|
var userID string
|
||||||
|
if err := rows.Scan(&userID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result = append(result, userID)
|
||||||
|
}
|
||||||
|
return result, rows.Err()
|
||||||
|
}
|
||||||
|
|
@ -37,10 +37,15 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
sdl, err := NewSqliteStaleDeviceListsTable(db)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return &shared.Database{
|
return &shared.Database{
|
||||||
DB: db,
|
DB: db,
|
||||||
OneTimeKeysTable: otk,
|
OneTimeKeysTable: otk,
|
||||||
DeviceKeysTable: dk,
|
DeviceKeysTable: dk,
|
||||||
KeyChangesTable: kc,
|
KeyChangesTable: kc,
|
||||||
|
StaleDeviceListsTable: sdl,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/keyserver/api"
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
type OneTimeKeys interface {
|
type OneTimeKeys interface {
|
||||||
|
|
@ -45,3 +46,8 @@ type KeyChanges interface {
|
||||||
// Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of sarama.OffsetNewest means no upper offset.
|
// Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of sarama.OffsetNewest means no upper offset.
|
||||||
SelectKeyChanges(ctx context.Context, partition int32, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
|
SelectKeyChanges(ctx context.Context, partition int32, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type StaleDeviceLists interface {
|
||||||
|
InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error
|
||||||
|
SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -88,7 +88,7 @@ func TestMain(m *testing.M) {
|
||||||
|
|
||||||
// Create the federation client.
|
// Create the federation client.
|
||||||
s.fedclient = gomatrixserverlib.NewFederationClientWithTransport(
|
s.fedclient = gomatrixserverlib.NewFederationClientWithTransport(
|
||||||
s.config.Matrix.ServerName, serverKeyID, testPriv, transport,
|
s.config.Matrix.ServerName, serverKeyID, testPriv, true, transport,
|
||||||
)
|
)
|
||||||
|
|
||||||
// Finally, build the server key APIs.
|
// Finally, build the server key APIs.
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,7 @@ func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.KeyInternalAPI, userID,
|
||||||
// DeviceListCatchup fills in the given response for the given user ID to bring it up-to-date with device lists. hasNew=true if the response
|
// DeviceListCatchup fills in the given response for the given user ID to bring it up-to-date with device lists. hasNew=true if the response
|
||||||
// was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST
|
// was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST
|
||||||
// be already filled in with join/leave information.
|
// be already filled in with join/leave information.
|
||||||
|
// nolint:gocyclo
|
||||||
func DeviceListCatchup(
|
func DeviceListCatchup(
|
||||||
ctx context.Context, keyAPI keyapi.KeyInternalAPI, stateAPI currentstateAPI.CurrentStateInternalAPI,
|
ctx context.Context, keyAPI keyapi.KeyInternalAPI, stateAPI currentstateAPI.CurrentStateInternalAPI,
|
||||||
userID string, res *types.Response, from, to types.StreamingToken,
|
userID string, res *types.Response, from, to types.StreamingToken,
|
||||||
|
|
@ -68,22 +69,20 @@ func DeviceListCatchup(
|
||||||
|
|
||||||
var partition int32
|
var partition int32
|
||||||
var offset int64
|
var offset int64
|
||||||
|
partition = -1
|
||||||
|
offset = sarama.OffsetOldest
|
||||||
// Extract partition/offset from sync token
|
// Extract partition/offset from sync token
|
||||||
// TODO: In a world where keyserver is sharded there will be multiple partitions and hence multiple QueryKeyChanges to make.
|
// TODO: In a world where keyserver is sharded there will be multiple partitions and hence multiple QueryKeyChanges to make.
|
||||||
logOffset := from.Log(DeviceListLogName)
|
logOffset := from.Log(DeviceListLogName)
|
||||||
if logOffset != nil {
|
if logOffset != nil {
|
||||||
partition = logOffset.Partition
|
partition = logOffset.Partition
|
||||||
offset = logOffset.Offset
|
offset = logOffset.Offset
|
||||||
} else {
|
|
||||||
partition = -1
|
|
||||||
offset = sarama.OffsetOldest
|
|
||||||
}
|
}
|
||||||
var toOffset int64
|
var toOffset int64
|
||||||
toLog := to.Log(DeviceListLogName)
|
|
||||||
if toLog != nil {
|
|
||||||
toOffset = toLog.Offset
|
|
||||||
} else {
|
|
||||||
toOffset = sarama.OffsetNewest
|
toOffset = sarama.OffsetNewest
|
||||||
|
toLog := to.Log(DeviceListLogName)
|
||||||
|
if toLog != nil && toLog.Offset > 0 {
|
||||||
|
toOffset = toLog.Offset
|
||||||
}
|
}
|
||||||
var queryRes api.QueryKeyChangesResponse
|
var queryRes api.QueryKeyChangesResponse
|
||||||
keyAPI.QueryKeyChanges(ctx, &api.QueryKeyChangesRequest{
|
keyAPI.QueryKeyChanges(ctx, &api.QueryKeyChangesRequest{
|
||||||
|
|
@ -96,6 +95,10 @@ func DeviceListCatchup(
|
||||||
util.GetLogger(ctx).WithError(queryRes.Error).Error("QueryKeyChanges failed")
|
util.GetLogger(ctx).WithError(queryRes.Error).Error("QueryKeyChanges failed")
|
||||||
return hasNew, nil
|
return hasNew, nil
|
||||||
}
|
}
|
||||||
|
util.GetLogger(ctx).Debugf(
|
||||||
|
"QueryKeyChanges request p=%d,off=%d,to=%d response p=%d off=%d uids=%v",
|
||||||
|
partition, offset, toOffset, queryRes.Partition, queryRes.Offset, queryRes.UserIDs,
|
||||||
|
)
|
||||||
userSet := make(map[string]bool)
|
userSet := make(map[string]bool)
|
||||||
for _, userID := range res.DeviceLists.Changed {
|
for _, userID := range res.DeviceLists.Changed {
|
||||||
userSet[userID] = true
|
userSet[userID] = true
|
||||||
|
|
@ -116,6 +119,13 @@ func DeviceListCatchup(
|
||||||
userSet[userID] = true
|
userSet[userID] = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// set the new token
|
||||||
|
to.SetLog(DeviceListLogName, &types.LogPosition{
|
||||||
|
Partition: queryRes.Partition,
|
||||||
|
Offset: queryRes.Offset,
|
||||||
|
})
|
||||||
|
res.NextBatch = to.String()
|
||||||
|
|
||||||
return hasNew, nil
|
return hasNew, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -112,6 +112,9 @@ type StreamingToken struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *StreamingToken) SetLog(name string, lp *LogPosition) {
|
func (t *StreamingToken) SetLog(name string, lp *LogPosition) {
|
||||||
|
if t.logs == nil {
|
||||||
|
t.logs = make(map[string]*LogPosition)
|
||||||
|
}
|
||||||
t.logs[name] = lp
|
t.logs[name] = lp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -173,12 +176,14 @@ func (t *StreamingToken) WithUpdates(other StreamingToken) (ret StreamingToken)
|
||||||
}
|
}
|
||||||
ret.Positions[i] = other.Positions[i]
|
ret.Positions[i] = other.Positions[i]
|
||||||
}
|
}
|
||||||
|
ret.logs = make(map[string]*LogPosition)
|
||||||
for name := range t.logs {
|
for name := range t.logs {
|
||||||
otherLog := other.Log(name)
|
otherLog := other.Log(name)
|
||||||
if otherLog == nil {
|
if otherLog == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
t.logs[name] = otherLog
|
copy := *otherLog
|
||||||
|
ret.logs[name] = ©
|
||||||
}
|
}
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -138,6 +138,7 @@ Users receive device_list updates for their own devices
|
||||||
Get left notifs for other users in sync and /keys/changes when user leaves
|
Get left notifs for other users in sync and /keys/changes when user leaves
|
||||||
Local device key changes get to remote servers
|
Local device key changes get to remote servers
|
||||||
Local device key changes get to remote servers with correct prev_id
|
Local device key changes get to remote servers with correct prev_id
|
||||||
|
#Server correctly handles incoming m.device_list_update
|
||||||
Can add account data
|
Can add account data
|
||||||
Can add account data to room
|
Can add account data to room
|
||||||
Can get account data without syncing
|
Can get account data without syncing
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue