diff --git a/cmd/dendrite-demo-libp2p/main.go b/cmd/dendrite-demo-libp2p/main.go index 09fd574c1..8c28014ac 100644 --- a/cmd/dendrite-demo-libp2p/main.go +++ b/cmd/dendrite-demo-libp2p/main.go @@ -75,7 +75,8 @@ func createFederationClient( p2phttp.NewTransport(base.LibP2P, p2phttp.ProtocolOption("/matrix")), ) return gomatrixserverlib.NewFederationClientWithTransport( - base.Base.Cfg.Global.ServerName, base.Base.Cfg.Global.KeyID, base.Base.Cfg.Global.PrivateKey, tr, + base.Base.Cfg.Global.ServerName, base.Base.Cfg.Global.KeyID, + base.Base.Cfg.Global.PrivateKey, true, tr, ) } @@ -87,7 +88,7 @@ func createClient( "matrix", p2phttp.NewTransport(base.LibP2P, p2phttp.ProtocolOption("/matrix")), ) - return gomatrixserverlib.NewClientWithTransport(tr) + return gomatrixserverlib.NewClientWithTransport(true, tr) } func main() { diff --git a/cmd/dendrite-demo-yggdrasil/yggconn/client.go b/cmd/dendrite-demo-yggdrasil/yggconn/client.go index c99449d63..1236c5530 100644 --- a/cmd/dendrite-demo-yggdrasil/yggconn/client.go +++ b/cmd/dendrite-demo-yggdrasil/yggconn/client.go @@ -33,7 +33,7 @@ func (n *Node) CreateClient( }, }, ) - return gomatrixserverlib.NewClientWithTransport(tr) + return gomatrixserverlib.NewClientWithTransport(true, tr) } func (n *Node) CreateFederationClient( @@ -54,6 +54,7 @@ func (n *Node) CreateFederationClient( }, ) return gomatrixserverlib.NewFederationClientWithTransport( - base.Cfg.Global.ServerName, base.Cfg.Global.KeyID, base.Cfg.Global.PrivateKey, tr, + base.Cfg.Global.ServerName, base.Cfg.Global.KeyID, + base.Cfg.Global.PrivateKey, true, tr, ) } diff --git a/cmd/dendrite-media-api-server/main.go b/cmd/dendrite-media-api-server/main.go index cada8b742..1bbb62bdc 100644 --- a/cmd/dendrite-media-api-server/main.go +++ b/cmd/dendrite-media-api-server/main.go @@ -26,7 +26,7 @@ func main() { defer base.Close() // nolint: errcheck userAPI := base.UserAPIClient() - client := gomatrixserverlib.NewClient() + client := gomatrixserverlib.NewClient(cfg.FederationSender.DisableTLSValidation) mediaapi.AddPublicRoutes(base.PublicAPIMux, &base.Cfg.MediaAPI, userAPI, client) diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go index 02e5df012..8f98cdd04 100644 --- a/cmd/dendrite-monolith-server/main.go +++ b/cmd/dendrite-monolith-server/main.go @@ -132,7 +132,7 @@ func main() { Config: base.Cfg, AccountDB: accountDB, DeviceDB: deviceDB, - Client: gomatrixserverlib.NewClient(), + Client: gomatrixserverlib.NewClient(cfg.FederationSender.DisableTLSValidation), FedClient: federation, KeyRing: keyRing, KafkaConsumer: base.KafkaConsumer, diff --git a/cmd/dendritejs/main.go b/cmd/dendritejs/main.go index 3c2fc0ab5..ce7812fa9 100644 --- a/cmd/dendritejs/main.go +++ b/cmd/dendritejs/main.go @@ -139,16 +139,16 @@ func createFederationClient(cfg *config.Dendrite, node *go_http_js_libp2p.P2pLoc tr := go_http_js_libp2p.NewP2pTransport(node) fed := gomatrixserverlib.NewFederationClient( - cfg.Global.ServerName, cfg.Global.KeyID, cfg.Global.PrivateKey, + cfg.Global.ServerName, cfg.Global.KeyID, cfg.Global.PrivateKey, true, ) - fed.Client = *gomatrixserverlib.NewClientWithTransport(tr) + fed.Client = *gomatrixserverlib.NewClientWithTransport(true, tr) return fed } func createClient(node *go_http_js_libp2p.P2pLocalNode) *gomatrixserverlib.Client { tr := go_http_js_libp2p.NewP2pTransport(node) - return gomatrixserverlib.NewClientWithTransport(tr) + return gomatrixserverlib.NewClientWithTransport(true, tr) } func createP2PNode(privKey ed25519.PrivateKey) (serverName string, node *go_http_js_libp2p.P2pLocalNode) { diff --git a/dendrite-config.yaml b/dendrite-config.yaml index 8a0ecdaed..8f1448754 100644 --- a/dendrite-config.yaml +++ b/dendrite-config.yaml @@ -27,6 +27,9 @@ matrix: # public_key: l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ # Disables new users from registering (except via shared secrets) registration_disabled: false + # Whether to disable TLS certificate validation. Warning: this reduces federation + # security and should not be enabled in production! + federation_disable_tls_validation: false # The media repository config media: diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index 64faa1a0a..b31326dc4 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -44,7 +44,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) { defer cancel() serverName := gomatrixserverlib.ServerName(strings.TrimPrefix(baseURL, "https://")) - fedCli := gomatrixserverlib.NewFederationClient(serverName, cfg.Global.KeyID, cfg.Global.PrivateKey) + fedCli := gomatrixserverlib.NewFederationClient(serverName, cfg.Global.KeyID, cfg.Global.PrivateKey, true) testCases := []struct { roomVer gomatrixserverlib.RoomVersion diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index cf3fe933f..d1aa728cf 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -218,7 +218,9 @@ func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, *util.JSONRe } t.processEDUs(t.EDUs) - util.GetLogger(t.context).Infof("Processed %d PDUs from transaction %q", len(results), t.TransactionID) + if c := len(results); c > 0 { + util.GetLogger(t.context).Infof("Processed %d PDUs from transaction %q", c, t.TransactionID) + } return &gomatrixserverlib.RespSend{PDUs: results}, nil } @@ -315,7 +317,7 @@ func (t *txnReq) processEDUs(edus []gomatrixserverlib.EDU) { case gomatrixserverlib.MDeviceListUpdate: t.processDeviceListUpdate(e) default: - util.GetLogger(t.context).WithField("type", e.Type).Warn("unhandled edu") + util.GetLogger(t.context).WithField("type", e.Type).Debug("Unhandled EDU") } } } diff --git a/federationsender/queue/destinationqueue.go b/federationsender/queue/destinationqueue.go index dc2d40910..9ccfbacec 100644 --- a/federationsender/queue/destinationqueue.go +++ b/federationsender/queue/destinationqueue.go @@ -255,22 +255,20 @@ func (oq *destinationQueue) backgroundSend() { // The worker is idle so stop the goroutine. It'll get // restarted automatically the next time we have an event to // send. - log.Infof("Queue %q has been idle for %s, going to sleep", oq.destination, queueIdleTimeout) + log.Debugf("Queue %q has been idle for %s, going to sleep", oq.destination, queueIdleTimeout) return } // If we are backing off this server then wait for the // backoff duration to complete first, or until explicitly // told to retry. - if backoff, duration := oq.statistics.BackoffDuration(); backoff { - log.WithField("duration", duration).Infof("Backing off %s", oq.destination) - oq.backingOff.Store(true) - select { - case <-time.After(duration): - case <-oq.interruptBackoff: - log.Infof("Interrupting backoff for %q", oq.destination) - } - oq.backingOff.Store(false) + if _, giveUp := oq.statistics.BackoffIfRequired(oq.backingOff, oq.interruptBackoff); giveUp { + // It's been suggested that we should give up because the backoff + // has exceeded a maximum allowable value. Clean up the in-memory + // buffers at this point. The PDU clean-up is already on a defer. + oq.cleanPendingInvites() + log.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination) + return } // If we have pending PDUs or EDUs then construct a transaction. @@ -278,24 +276,8 @@ func (oq *destinationQueue) backgroundSend() { // Try sending the next transaction and see what happens. transaction, terr := oq.nextTransaction() if terr != nil { - // We failed to send the transaction. - if giveUp := oq.statistics.Failure(); giveUp { - // It's been suggested that we should give up because the backoff - // has exceeded a maximum allowable value. Clean up the in-memory - // buffers at this point. The PDU clean-up is already on a defer. - oq.cleanPendingInvites() - log.Infof("Blacklisting %q due to errors", oq.destination) - return - } else { - // We haven't been told to give up terminally yet but we still have - // PDUs waiting to be sent. By sending a message into the wake chan, - // the next loop iteration will try processing these PDUs again, - // subject to the backoff. - select { - case oq.notifyPDUs <- true: - default: - } - } + // We failed to send the transaction. Mark it as a failure. + oq.statistics.Failure() } else if transaction { // If we successfully sent the transaction then clear out // the pending events and EDUs, and wipe our transaction ID. @@ -307,14 +289,8 @@ func (oq *destinationQueue) backgroundSend() { if len(oq.pendingInvites) > 0 { sent, ierr := oq.nextInvites(oq.pendingInvites) if ierr != nil { - // We failed to send the transaction so increase the - // backoff and give it another go shortly. - if giveUp := oq.statistics.Failure(); giveUp { - // It's been suggested that we should give up because - // the backoff has exceeded a maximum allowable value. - log.Infof("Blacklisting %q due to errors", oq.destination) - return - } + // We failed to send the transaction. Mark it as a failure. + oq.statistics.Failure() } else if sent > 0 { // If we successfully sent the invites then clear out // the pending invites. @@ -414,7 +390,7 @@ func (oq *destinationQueue) nextTransaction() (bool, error) { t.EDUs = append(t.EDUs, *edu) } - logrus.WithField("server_name", oq.destination).Infof("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs)) + logrus.WithField("server_name", oq.destination).Debugf("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs)) // Try to send the transaction to the destination server. // TODO: we should check for 500-ish fails vs 400-ish here, diff --git a/federationsender/queue/queue.go b/federationsender/queue/queue.go index 5651fba26..6d5403e82 100644 --- a/federationsender/queue/queue.go +++ b/federationsender/queue/queue.go @@ -136,7 +136,7 @@ func (oqs *OutgoingQueues) SendEvent( log.WithFields(log.Fields{ "destinations": destinations, "event": ev.EventID(), - }).Info("Sending event") + }).Infof("Sending event") headeredJSON, err := json.Marshal(ev) if err != nil { diff --git a/federationsender/statistics/statistics.go b/federationsender/statistics/statistics.go index 17dd896d5..0dd8da200 100644 --- a/federationsender/statistics/statistics.go +++ b/federationsender/statistics/statistics.go @@ -65,8 +65,8 @@ type ServerStatistics struct { statistics *Statistics // serverName gomatrixserverlib.ServerName // blacklisted atomic.Bool // is the node blacklisted - backoffUntil atomic.Value // time.Time to wait until before sending requests - failCounter atomic.Uint32 // how many times have we failed? + backoffStarted atomic.Bool // is the backoff started + backoffCount atomic.Uint32 // number of times BackoffDuration has been called successCounter atomic.Uint32 // how many times have we succeeded? } @@ -76,55 +76,67 @@ type ServerStatistics struct { // we will unblacklist it. func (s *ServerStatistics) Success() { s.successCounter.Add(1) - s.failCounter.Store(0) + s.backoffStarted.Store(false) + s.backoffCount.Store(0) s.blacklisted.Store(false) - if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil { - logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName) + if s.statistics.DB != nil { + if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil { + logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName) + } } } -// Failure marks a failure and works out when to backoff until. It -// returns true if the worker should give up altogether because of -// too many consecutive failures. At this point the host is marked -// as blacklisted. -func (s *ServerStatistics) Failure() bool { - // Increase the fail counter. - failCounter := s.failCounter.Add(1) +// Failure marks a failure and starts backing off if needed. +// The next call to BackoffIfRequired will do the right thing +// after this. +func (s *ServerStatistics) Failure() { + if s.backoffStarted.CAS(false, true) { + s.backoffCount.Store(0) + } +} - // Check that we haven't failed more times than is acceptable. - if failCounter >= s.statistics.FailuresUntilBlacklist { +// BackoffIfRequired will block for as long as the current +// backoff requires, if needed. Otherwise it will do nothing. +func (s *ServerStatistics) BackoffIfRequired(backingOff atomic.Bool, interrupt <-chan bool) (time.Duration, bool) { + if started := s.backoffStarted.Load(); !started { + return 0, false + } + + // Work out how many times we've backed off so far. + count := s.backoffCount.Inc() + duration := time.Second * time.Duration(math.Exp2(float64(count))) + + // Work out if we should be blacklisting at this point. + if count >= s.statistics.FailuresUntilBlacklist { // We've exceeded the maximum amount of times we're willing // to back off, which is probably in the region of hours by // now. Mark the host as blacklisted and tell the caller to // give up. s.blacklisted.Store(true) - if err := s.statistics.DB.AddServerToBlacklist(s.serverName); err != nil { - logrus.WithError(err).Errorf("Failed to add %q to blacklist", s.serverName) + if s.statistics.DB != nil { + if err := s.statistics.DB.AddServerToBlacklist(s.serverName); err != nil { + logrus.WithError(err).Errorf("Failed to add %q to blacklist", s.serverName) + } } - return true + return duration, true } - // We're still under the threshold so work out the exponential - // backoff based on how many times we have failed already. The - // worker goroutine will wait until this time before processing - // anything from the queue. - backoffSeconds := time.Second * time.Duration(math.Exp2(float64(failCounter))) - s.backoffUntil.Store( - time.Now().Add(backoffSeconds), - ) - return false -} + // Notify the destination queue that we're backing off now. + backingOff.Store(true) + defer backingOff.Store(false) -// BackoffDuration returns both a bool stating whether to wait, -// and then if true, a duration to wait for. -func (s *ServerStatistics) BackoffDuration() (bool, time.Duration) { - backoff, until := false, time.Second - if b, ok := s.backoffUntil.Load().(time.Time); ok { - if b.After(time.Now()) { - backoff, until = true, time.Until(b) - } + // Work out how long we should be backing off for. + logrus.Warnf("Backing off %q for %s", s.serverName, duration) + + // Wait for either an interruption or for the backoff to + // complete. + select { + case <-interrupt: + logrus.Debugf("Interrupting backoff for %q", s.serverName) + case <-time.After(duration): } - return backoff, until + + return duration, false } // Blacklisted returns true if the server is blacklisted and false diff --git a/federationsender/statistics/statistics_test.go b/federationsender/statistics/statistics_test.go new file mode 100644 index 000000000..9050662ec --- /dev/null +++ b/federationsender/statistics/statistics_test.go @@ -0,0 +1,60 @@ +package statistics + +import ( + "math" + "testing" + "time" + + "go.uber.org/atomic" +) + +func TestBackoff(t *testing.T) { + stats := Statistics{ + FailuresUntilBlacklist: 5, + } + server := ServerStatistics{ + statistics: &stats, + serverName: "test.com", + } + + // Start by checking that counting successes works. + server.Success() + if successes := server.SuccessCount(); successes != 1 { + t.Fatalf("Expected success count 1, got %d", successes) + } + + // Register a failure. + server.Failure() + + t.Logf("Backoff counter: %d", server.backoffCount.Load()) + backingOff := atomic.Bool{} + + // Now we're going to simulate backing off a few times to see + // what happens. + for i := uint32(1); i <= 10; i++ { + // Interrupt the backoff - it doesn't really matter if it + // completes but we will find out how long the backoff should + // have been. + interrupt := make(chan bool, 1) + close(interrupt) + + // Get the duration. + duration, blacklist := server.BackoffIfRequired(backingOff, interrupt) + + // Check if we should be blacklisted by now. + if i > stats.FailuresUntilBlacklist { + if !blacklist { + t.Fatalf("Backoff %d should have resulted in blacklist but didn't", i) + } else { + t.Logf("Backoff %d is blacklisted as expected", i) + continue + } + } + + // Check if the duration is what we expect. + t.Logf("Backoff %d is for %s", i, duration) + if wanted := time.Second * time.Duration(math.Exp2(float64(i))); !blacklist && duration != wanted { + t.Fatalf("Backoff %d should have been %s but was %s", i, wanted, duration) + } + } +} diff --git a/go.mod b/go.mod index ce18a95b9..d5cf91713 100644 --- a/go.mod +++ b/go.mod @@ -23,9 +23,9 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 - github.com/matrix-org/gomatrixserverlib v0.0.0-20200807122736-eb1a0b991914 + github.com/matrix-org/gomatrixserverlib v0.0.0-20200807145008-79c173b65786 github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f - github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 + github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v2.0.2+incompatible github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5 github.com/ngrok/sqlmw v0.0.0-20200129213757-d5c93a81bec6 diff --git a/go.sum b/go.sum index 131617e3c..8ed62ffd6 100644 --- a/go.sum +++ b/go.sum @@ -425,12 +425,14 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3 h1:Yb+Wlf github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bhrnp3Ky1qgx/fzCtCALOoGYylh2tpS9K4= github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= -github.com/matrix-org/gomatrixserverlib v0.0.0-20200807122736-eb1a0b991914 h1:VSGCvSUB1/Y32F/JSjmTaIW9jr1BmBHEd0ok4AaT/lo= -github.com/matrix-org/gomatrixserverlib v0.0.0-20200807122736-eb1a0b991914/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20200807145008-79c173b65786 h1:HQclx5J2CrCBqP88t5Di9IkVDJZn5+h4ZL48viY4FJ4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20200807145008-79c173b65786/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f h1:pRz4VTiRCO4zPlEMc3ESdUOcW4PXHH4Kj+YDz1XyE+Y= github.com/matrix-org/naffka v0.0.0-20200422140631-181f1ee7401f/go.mod h1:y0oDTjZDv5SM9a2rp3bl+CU+bvTRINQsdb7YlDql5Go= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= +github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= +github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-colorable v0.1.1 h1:G1f5SKeVxmagw/IyvzvtZE4Gybcc4Tr1tf7I8z0XgOg= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= diff --git a/internal/config/config_federationsender.go b/internal/config/config_federationsender.go index d42c83f64..9b9b83ea2 100644 --- a/internal/config/config_federationsender.go +++ b/internal/config/config_federationsender.go @@ -14,7 +14,11 @@ type FederationSender struct { // tolerate when sending federation requests to a specific server. The backoff // is 2**x seconds, so 1 = 2 seconds, 2 = 4 seconds, 3 = 8 seconds, etc. // The default value is 16 if not specified, which is circa 18 hours. - FederationMaxRetries uint32 `yaml:"federation_max_retries"` + FederationMaxRetries uint32 `yaml:"send_max_retries"` + + // FederationDisableTLSValidation disables the validation of X.509 TLS certs + // on remote federation endpoints. This is not recommended in production! + DisableTLSValidation bool `yaml:"disable_tls_validation"` Proxy Proxy `yaml:"proxy_outbound"` } @@ -26,6 +30,7 @@ func (c *FederationSender) Defaults() { c.Database.ConnectionString = "file:federationsender.db" c.FederationMaxRetries = 16 + c.DisableTLSValidation = false c.Proxy.Defaults() } diff --git a/internal/setup/base.go b/internal/setup/base.go index c5a7a8a85..65f386620 100644 --- a/internal/setup/base.go +++ b/internal/setup/base.go @@ -251,6 +251,7 @@ func (b *BaseDendrite) CreateAccountsDB() accounts.Database { func (b *BaseDendrite) CreateFederationClient() *gomatrixserverlib.FederationClient { return gomatrixserverlib.NewFederationClient( b.Cfg.Global.ServerName, b.Cfg.Global.KeyID, b.Cfg.Global.PrivateKey, + b.Cfg.FederationSender.DisableTLSValidation, ) } diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index 19d8463d8..ec7dff560 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -23,7 +23,6 @@ import ( "time" "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/producers" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" @@ -65,7 +64,7 @@ type DeviceListUpdater struct { mu *sync.Mutex // protects UserIDToMutex db DeviceListUpdaterDatabase - producer *producers.KeyChange + producer KeyChangeProducer fedClient *gomatrixserverlib.FederationClient workerChans []chan gomatrixserverlib.ServerName } @@ -88,9 +87,14 @@ type DeviceListUpdaterDatabase interface { PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) } +// KeyChangeProducer is the interface for producers.KeyChange useful for testing. +type KeyChangeProducer interface { + ProduceKeyChanges(keys []api.DeviceMessage) error +} + // NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale. func NewDeviceListUpdater( - db DeviceListUpdaterDatabase, producer *producers.KeyChange, fedClient *gomatrixserverlib.FederationClient, + db DeviceListUpdaterDatabase, producer KeyChangeProducer, fedClient *gomatrixserverlib.FederationClient, numWorkers int, ) *DeviceListUpdater { return &DeviceListUpdater{ @@ -154,12 +158,17 @@ func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib. if err != nil { return false, fmt.Errorf("failed to check prev IDs exist for %s (%s): %w", event.UserID, event.DeviceID, err) } + // if this is the first time we're hearing about this user, sync the device list manually. + if len(event.PrevID) == 0 { + exists = false + } util.GetLogger(ctx).WithFields(logrus.Fields{ "prev_ids_exist": exists, "user_id": event.UserID, "device_id": event.DeviceID, "stream_id": event.StreamID, "prev_ids": event.PrevID, + "display_name": event.DeviceDisplayName, }).Info("DeviceListUpdater.Update") // if we haven't missed anything update the database and notify users @@ -263,16 +272,17 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam hasFailures = true continue } - err = u.updateDeviceList(ctx, &res) + err = u.updateDeviceList(&res) if err != nil { - logger.WithError(err).WithField("user_id", userID).Error("fetched device list but failed to store it") + logger.WithError(err).WithField("user_id", userID).Error("fetched device list but failed to store/emit it") hasFailures = true } } return hasFailures } -func (u *DeviceListUpdater) updateDeviceList(ctx context.Context, res *gomatrixserverlib.RespUserDevices) error { +func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevices) error { + ctx := context.Background() // we've got the keys, don't time out when persisting them to the database. keys := make([]api.DeviceMessage, len(res.Devices)) for i, device := range res.Devices { keyJSON, err := json.Marshal(device.Keys) @@ -292,7 +302,15 @@ func (u *DeviceListUpdater) updateDeviceList(ctx context.Context, res *gomatrixs } err := u.db.StoreRemoteDeviceKeys(ctx, keys) if err != nil { - return err + return fmt.Errorf("failed to store remote device keys: %w", err) } - return u.db.MarkDeviceListStale(ctx, res.UserID, false) + err = u.db.MarkDeviceListStale(ctx, res.UserID, false) + if err != nil { + return fmt.Errorf("failed to mark device list as fresh: %w", err) + } + err = u.producer.ProduceKeyChanges(keys) + if err != nil { + return fmt.Errorf("failed to emit key changes for fresh device list: %w", err) + } + return nil } diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go new file mode 100644 index 000000000..b07148bbd --- /dev/null +++ b/keyserver/internal/device_list_update_test.go @@ -0,0 +1,242 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "crypto/ed25519" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "reflect" + "strings" + "sync" + "testing" + "time" + + "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/gomatrixserverlib" +) + +var ( + ctx = context.Background() +) + +type mockKeyChangeProducer struct { + events []api.DeviceMessage +} + +func (p *mockKeyChangeProducer) ProduceKeyChanges(keys []api.DeviceMessage) error { + p.events = append(p.events, keys...) + return nil +} + +type mockDeviceListUpdaterDatabase struct { + staleUsers map[string]bool + prevIDsExist func(string, []int) bool + storedKeys []api.DeviceMessage +} + +// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. +// If no domains are given, all user IDs with stale device lists are returned. +func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { + var result []string + for userID := range d.staleUsers { + _, remoteServer, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return nil, err + } + if len(domains) == 0 { + result = append(result, userID) + continue + } + for _, d := range domains { + if remoteServer == d { + result = append(result, userID) + break + } + } + } + return result, nil +} + +// MarkDeviceListStale sets the stale bit for this user to isStale. +func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error { + d.staleUsers[userID] = isStale + return nil +} + +// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key +// for this (user, device). Does not modify the stream ID for keys. +func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { + d.storedKeys = append(d.storedKeys, keys...) + return nil +} + +// PrevIDsExists returns true if all prev IDs exist for this user. +func (d *mockDeviceListUpdaterDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) { + return d.prevIDsExist(userID, prevIDs), nil +} + +type roundTripper struct { + fn func(*http.Request) (*http.Response, error) +} + +func (t *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return t.fn(req) +} + +func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrixserverlib.FederationClient { + _, pkey, _ := ed25519.GenerateKey(nil) + fedClient := gomatrixserverlib.NewFederationClient( + gomatrixserverlib.ServerName("example.test"), gomatrixserverlib.KeyID("ed25519:test"), pkey, true, + ) + fedClient.Client = *gomatrixserverlib.NewClientWithTransport(true, &roundTripper{tripper}) + return fedClient +} + +// Test that the device keys get persisted and emitted if we have the previous IDs. +func TestUpdateHavePrevID(t *testing.T) { + db := &mockDeviceListUpdaterDatabase{ + staleUsers: make(map[string]bool), + prevIDsExist: func(string, []int) bool { + return true + }, + } + producer := &mockKeyChangeProducer{} + updater := NewDeviceListUpdater(db, producer, nil, 1) + event := gomatrixserverlib.DeviceListUpdateEvent{ + DeviceDisplayName: "Foo Bar", + Deleted: false, + DeviceID: "FOO", + Keys: []byte(`{"key":"value"}`), + PrevID: []int{0}, + StreamID: 1, + UserID: "@alice:localhost", + } + err := updater.Update(ctx, event) + if err != nil { + t.Fatalf("Update returned an error: %s", err) + } + want := api.DeviceMessage{ + StreamID: event.StreamID, + DeviceKeys: api.DeviceKeys{ + DeviceID: event.DeviceID, + DisplayName: event.DeviceDisplayName, + KeyJSON: event.Keys, + UserID: event.UserID, + }, + } + if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) { + t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want) + } + if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) { + t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want) + } + if db.staleUsers[event.UserID] { + t.Errorf("%s incorrectly marked as stale", event.UserID) + } +} + +// Test that device keys are fetched from the remote server if we are missing prev IDs +// and that the user's devices are marked as stale until it succeeds. +func TestUpdateNoPrevID(t *testing.T) { + db := &mockDeviceListUpdaterDatabase{ + staleUsers: make(map[string]bool), + prevIDsExist: func(string, []int) bool { + return false + }, + } + producer := &mockKeyChangeProducer{} + remoteUserID := "@alice:example.somewhere" + var wg sync.WaitGroup + wg.Add(1) + keyJSON := `{"user_id":"` + remoteUserID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + remoteUserID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}` + fedClient := newFedClient(func(req *http.Request) (*http.Response, error) { + defer wg.Done() + if req.URL.Path != "/_matrix/federation/v1/user/devices/"+url.PathEscape(remoteUserID) { + return nil, fmt.Errorf("test: invalid path: %s", req.URL.Path) + } + return &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader(` + { + "user_id": "` + remoteUserID + `", + "stream_id": 5, + "devices": [ + { + "device_id": "JLAFKJWSCS", + "keys": ` + keyJSON + `, + "device_display_name": "Mobile Phone" + } + ] + } + `)), + }, nil + }) + updater := NewDeviceListUpdater(db, producer, fedClient, 2) + if err := updater.Start(); err != nil { + t.Fatalf("failed to start updater: %s", err) + } + event := gomatrixserverlib.DeviceListUpdateEvent{ + DeviceDisplayName: "Mobile Phone", + Deleted: false, + DeviceID: "another_device_id", + Keys: []byte(`{"key":"value"}`), + PrevID: []int{3}, + StreamID: 4, + UserID: remoteUserID, + } + err := updater.Update(ctx, event) + if err != nil { + t.Fatalf("Update returned an error: %s", err) + } + // At this point we show have this device list marked as stale and not store the keys or emitted anything + if !db.staleUsers[event.UserID] { + t.Errorf("%s not marked as stale", event.UserID) + } + if len(producer.events) > 0 { + t.Errorf("Update incorrect emitted %d device change events", len(producer.events)) + } + if len(db.storedKeys) > 0 { + t.Errorf("Update incorrect stored %d device change events", len(db.storedKeys)) + } + t.Log("waiting for /users/devices to be called...") + wg.Wait() + // wait a bit for db to be updated... + time.Sleep(100 * time.Millisecond) + want := api.DeviceMessage{ + StreamID: 5, + DeviceKeys: api.DeviceKeys{ + DeviceID: "JLAFKJWSCS", + DisplayName: "Mobile Phone", + UserID: remoteUserID, + KeyJSON: []byte(keyJSON), + }, + } + // Now we should have a fresh list and the keys and emitted something + if db.staleUsers[event.UserID] { + t.Errorf("%s still marked as stale", event.UserID) + } + if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) { + t.Logf("len got %d len want %d", len(producer.events[0].KeyJSON), len(want.KeyJSON)) + t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want) + } + if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) { + t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want) + } + +} diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index ff298c07c..075622b7c 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -250,10 +250,14 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques if len(dk.KeyJSON) == 0 { continue // don't include blank keys } - // inject display name if known + // inject display name if known (either locally or remotely) + displayName := dk.DisplayName + if queryRes.DeviceInfo[dk.DeviceID].DisplayName != "" { + displayName = queryRes.DeviceInfo[dk.DeviceID].DisplayName + } dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct { DisplayName string `json:"device_display_name,omitempty"` - }{queryRes.DeviceInfo[dk.DeviceID].DisplayName}) + }{displayName}) res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON } } else { @@ -261,12 +265,49 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques domainToDeviceKeys[domain][userID] = append(domainToDeviceKeys[domain][userID], deviceIDs...) } } - // TODO: set device display names when they are known + + // attempt to satisfy key queries from the local database first as we should get device updates pushed to us + domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, domainToDeviceKeys) + if len(domainToDeviceKeys) == 0 { + return // nothing to query + } // perform key queries for remote devices a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys) } +func (a *KeyInternalAPI) remoteKeysFromDatabase( + ctx context.Context, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string, +) map[string]map[string][]string { + fetchRemote := make(map[string]map[string][]string) + for domain, userToDeviceMap := range domainToDeviceKeys { + for userID, deviceIDs := range userToDeviceMap { + keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs) + // if we can't query the db or there are fewer keys than requested, fetch from remote. + // Likewise, we can't safely return keys from the db when all devices are requested as we don't + // know if one has just been added. + if len(deviceIDs) == 0 || err != nil || len(keys) < len(deviceIDs) { + if _, ok := fetchRemote[domain]; !ok { + fetchRemote[domain] = make(map[string][]string) + } + fetchRemote[domain][userID] = append(fetchRemote[domain][userID], deviceIDs...) + continue + } + if res.DeviceKeys[userID] == nil { + res.DeviceKeys[userID] = make(map[string]json.RawMessage) + } + for _, key := range keys { + // inject the display name + key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct { + DisplayName string `json:"device_display_name,omitempty"` + }{key.DisplayName}) + res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON + } + } + } + return fetchRemote +} + func (a *KeyInternalAPI) queryRemoteKeys( ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string, ) { diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go index d321860d4..b9d5d4c36 100644 --- a/keyserver/storage/postgres/device_keys_table.go +++ b/keyserver/storage/postgres/device_keys_table.go @@ -37,22 +37,23 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys ( -- required in the spec because in the event of a missed update the server fetches the entire -- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs. stream_id BIGINT NOT NULL, + display_name TEXT, -- Clobber based on tuple of user/device. CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id) ); ` const upsertDeviceKeysSQL = "" + - "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" + - " VALUES ($1, $2, $3, $4, $5)" + + "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" + + " VALUES ($1, $2, $3, $4, $5, $6)" + " ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" + - " DO UPDATE SET key_json = $4, stream_id = $5" + " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6" const selectDeviceKeysSQL = "" + - "SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" + "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" const selectBatchDeviceKeysSQL = "" + - "SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1" + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1" const selectMaxStreamForUserSQL = "" + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" @@ -99,13 +100,17 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys [] for i, key := range keys { var keyJSONStr string var streamID int - err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID) + var displayName sql.NullString + err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName) if err != nil && err != sql.ErrNoRows { return err } // this will be '' when there is no device keys[i].KeyJSON = []byte(keyJSONStr) keys[i].StreamID = streamID + if displayName.Valid { + keys[i].DisplayName = displayName.String + } } return nil } @@ -140,7 +145,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx for _, key := range keys { now := time.Now().Unix() _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( - ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, + ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName, ) if err != nil { return err @@ -165,11 +170,15 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID dk.UserID = userID var keyJSON string var streamID int - if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil { + var displayName sql.NullString + if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil { return nil, err } dk.KeyJSON = []byte(keyJSON) dk.StreamID = streamID + if displayName.Valid { + dk.DisplayName = displayName.String + } // include the key if we want all keys (no device) or it was asked if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { result = append(result, dk) diff --git a/keyserver/storage/postgres/stale_device_lists.go b/keyserver/storage/postgres/stale_device_lists.go new file mode 100644 index 000000000..63281adfb --- /dev/null +++ b/keyserver/storage/postgres/stale_device_lists.go @@ -0,0 +1,118 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/gomatrixserverlib" +) + +var staleDeviceListsSchema = ` +-- Stores whether a user's device lists are stale or not. +CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists ( + user_id TEXT PRIMARY KEY NOT NULL, + domain TEXT NOT NULL, + is_stale BOOLEAN NOT NULL, + ts_added_secs BIGINT NOT NULL +); + +CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale); +` + +const upsertStaleDeviceListSQL = "" + + "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" + + " VALUES ($1, $2, $3, $4)" + + " ON CONFLICT (user_id)" + + " DO UPDATE SET is_stale = $3, ts_added_secs = $4" + +const selectStaleDeviceListsWithDomainsSQL = "" + + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2" + +const selectStaleDeviceListsSQL = "" + + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1" + +type staleDeviceListsStatements struct { + upsertStaleDeviceListStmt *sql.Stmt + selectStaleDeviceListsWithDomainsStmt *sql.Stmt + selectStaleDeviceListsStmt *sql.Stmt +} + +func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { + s := &staleDeviceListsStatements{} + _, err := db.Exec(staleDeviceListsSchema) + if err != nil { + return nil, err + } + if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil { + return nil, err + } + if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil { + return nil, err + } + if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error { + _, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return err + } + _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix()) + return err +} + +func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { + // we only query for 1 domain or all domains so optimise for those use cases + if len(domains) == 0 { + rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true) + if err != nil { + return nil, err + } + return rowsToUserIDs(ctx, rows) + } + var result []string + for _, domain := range domains { + rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain)) + if err != nil { + return nil, err + } + userIDs, err := rowsToUserIDs(ctx, rows) + if err != nil { + return nil, err + } + result = append(result, userIDs...) + } + return result, nil +} + +func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) { + defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed") + for rows.Next() { + var userID string + if err := rows.Scan(&userID); err != nil { + return nil, err + } + result = append(result, userID) + } + return result, rows.Err() +} diff --git a/keyserver/storage/postgres/storage.go b/keyserver/storage/postgres/storage.go index 134ef4657..1c693f5b2 100644 --- a/keyserver/storage/postgres/storage.go +++ b/keyserver/storage/postgres/storage.go @@ -39,10 +39,15 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) if err != nil { return nil, err } + sdl, err := NewPostgresStaleDeviceListsTable(db) + if err != nil { + return nil, err + } return &shared.Database{ - DB: db, - OneTimeKeysTable: otk, - DeviceKeysTable: dk, - KeyChangesTable: kc, + DB: db, + OneTimeKeysTable: otk, + DeviceKeysTable: dk, + KeyChangesTable: kc, + StaleDeviceListsTable: sdl, }, nil } diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index 68964be67..4279eae77 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -26,10 +26,11 @@ import ( ) type Database struct { - DB *sql.DB - OneTimeKeysTable tables.OneTimeKeys - DeviceKeysTable tables.DeviceKeys - KeyChangesTable tables.KeyChanges + DB *sql.DB + OneTimeKeysTable tables.OneTimeKeys + DeviceKeysTable tables.DeviceKeys + KeyChangesTable tables.KeyChanges + StaleDeviceListsTable tables.StaleDeviceLists } func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { @@ -129,10 +130,10 @@ func (d *Database) KeyChanges(ctx context.Context, partition int32, fromOffset, // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. // If no domains are given, all user IDs with stale device lists are returned. func (d *Database) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { - return nil, nil // TODO + return d.StaleDeviceListsTable.SelectUserIDsWithStaleDeviceLists(ctx, domains) } // MarkDeviceListStale sets the stale bit for this user to isStale. func (d *Database) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error { - return nil // TODO + return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale) } diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index 15d9c775f..abe6636af 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -34,22 +34,23 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys ( ts_added_secs BIGINT NOT NULL, key_json TEXT NOT NULL, stream_id BIGINT NOT NULL, + display_name TEXT, -- Clobber based on tuple of user/device. UNIQUE (user_id, device_id) ); ` const upsertDeviceKeysSQL = "" + - "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" + - " VALUES ($1, $2, $3, $4, $5)" + + "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" + + " VALUES ($1, $2, $3, $4, $5, $6)" + " ON CONFLICT (user_id, device_id)" + - " DO UPDATE SET key_json = $4, stream_id = $5" + " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6" const selectDeviceKeysSQL = "" + - "SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" + "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" const selectBatchDeviceKeysSQL = "" + - "SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1" + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1" const selectMaxStreamForUserSQL = "" + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" @@ -106,11 +107,15 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID dk.UserID = userID var keyJSON string var streamID int - if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil { + var displayName sql.NullString + if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil { return nil, err } dk.KeyJSON = []byte(keyJSON) dk.StreamID = streamID + if displayName.Valid { + dk.DisplayName = displayName.String + } // include the key if we want all keys (no device) or it was asked if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { result = append(result, dk) @@ -123,13 +128,17 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys [] for i, key := range keys { var keyJSONStr string var streamID int - err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID) + var displayName sql.NullString + err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName) if err != nil && err != sql.ErrNoRows { return err } // this will be '' when there is no device keys[i].KeyJSON = []byte(keyJSONStr) keys[i].StreamID = streamID + if displayName.Valid { + keys[i].DisplayName = displayName.String + } } return nil } @@ -171,7 +180,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx for _, key := range keys { now := time.Now().Unix() _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( - ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, + ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName, ) if err != nil { return err diff --git a/keyserver/storage/sqlite3/one_time_keys_table.go b/keyserver/storage/sqlite3/one_time_keys_table.go index f910479f5..907966a7a 100644 --- a/keyserver/storage/sqlite3/one_time_keys_table.go +++ b/keyserver/storage/sqlite3/one_time_keys_table.go @@ -196,6 +196,9 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( _, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) return err }) + if keyJSON == "" { + return nil, nil + } return map[string]json.RawMessage{ algorithm + ":" + keyID: json.RawMessage(keyJSON), }, err diff --git a/keyserver/storage/sqlite3/stale_device_lists.go b/keyserver/storage/sqlite3/stale_device_lists.go new file mode 100644 index 000000000..a989476d1 --- /dev/null +++ b/keyserver/storage/sqlite3/stale_device_lists.go @@ -0,0 +1,118 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/gomatrixserverlib" +) + +var staleDeviceListsSchema = ` +-- Stores whether a user's device lists are stale or not. +CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists ( + user_id TEXT PRIMARY KEY NOT NULL, + domain TEXT NOT NULL, + is_stale BOOLEAN NOT NULL, + ts_added_secs BIGINT NOT NULL +); + +CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale); +` + +const upsertStaleDeviceListSQL = "" + + "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" + + " VALUES ($1, $2, $3, $4)" + + " ON CONFLICT (user_id)" + + " DO UPDATE SET is_stale = $3, ts_added_secs = $4" + +const selectStaleDeviceListsWithDomainsSQL = "" + + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2" + +const selectStaleDeviceListsSQL = "" + + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1" + +type staleDeviceListsStatements struct { + upsertStaleDeviceListStmt *sql.Stmt + selectStaleDeviceListsWithDomainsStmt *sql.Stmt + selectStaleDeviceListsStmt *sql.Stmt +} + +func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { + s := &staleDeviceListsStatements{} + _, err := db.Exec(staleDeviceListsSchema) + if err != nil { + return nil, err + } + if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil { + return nil, err + } + if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil { + return nil, err + } + if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error { + _, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return err + } + _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix()) + return err +} + +func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { + // we only query for 1 domain or all domains so optimise for those use cases + if len(domains) == 0 { + rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true) + if err != nil { + return nil, err + } + return rowsToUserIDs(ctx, rows) + } + var result []string + for _, domain := range domains { + rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain)) + if err != nil { + return nil, err + } + userIDs, err := rowsToUserIDs(ctx, rows) + if err != nil { + return nil, err + } + result = append(result, userIDs...) + } + return result, nil +} + +func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) { + defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed") + for rows.Next() { + var userID string + if err := rows.Scan(&userID); err != nil { + return nil, err + } + result = append(result, userID) + } + return result, rows.Err() +} diff --git a/keyserver/storage/sqlite3/storage.go b/keyserver/storage/sqlite3/storage.go index 248db99af..bb2935582 100644 --- a/keyserver/storage/sqlite3/storage.go +++ b/keyserver/storage/sqlite3/storage.go @@ -37,10 +37,15 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) if err != nil { return nil, err } + sdl, err := NewSqliteStaleDeviceListsTable(db) + if err != nil { + return nil, err + } return &shared.Database{ - DB: db, - OneTimeKeysTable: otk, - DeviceKeysTable: dk, - KeyChangesTable: kc, + DB: db, + OneTimeKeysTable: otk, + DeviceKeysTable: dk, + KeyChangesTable: kc, + StaleDeviceListsTable: sdl, }, nil } diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index ac932d56d..a4d5dede2 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -20,6 +20,7 @@ import ( "encoding/json" "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/gomatrixserverlib" ) type OneTimeKeys interface { @@ -45,3 +46,8 @@ type KeyChanges interface { // Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of sarama.OffsetNewest means no upper offset. SelectKeyChanges(ctx context.Context, partition int32, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) } + +type StaleDeviceLists interface { + InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error + SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) +} diff --git a/serverkeyapi/serverkeyapi_test.go b/serverkeyapi/serverkeyapi_test.go index b8c462c7b..152a853e3 100644 --- a/serverkeyapi/serverkeyapi_test.go +++ b/serverkeyapi/serverkeyapi_test.go @@ -88,7 +88,7 @@ func TestMain(m *testing.M) { // Create the federation client. s.fedclient = gomatrixserverlib.NewFederationClientWithTransport( - s.config.Matrix.ServerName, serverKeyID, testPriv, transport, + s.config.Matrix.ServerName, serverKeyID, testPriv, true, transport, ) // Finally, build the server key APIs. diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index 66134d791..e0379aafb 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -46,6 +46,7 @@ func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.KeyInternalAPI, userID, // DeviceListCatchup fills in the given response for the given user ID to bring it up-to-date with device lists. hasNew=true if the response // was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST // be already filled in with join/leave information. +// nolint:gocyclo func DeviceListCatchup( ctx context.Context, keyAPI keyapi.KeyInternalAPI, stateAPI currentstateAPI.CurrentStateInternalAPI, userID string, res *types.Response, from, to types.StreamingToken, @@ -68,22 +69,20 @@ func DeviceListCatchup( var partition int32 var offset int64 + partition = -1 + offset = sarama.OffsetOldest // Extract partition/offset from sync token // TODO: In a world where keyserver is sharded there will be multiple partitions and hence multiple QueryKeyChanges to make. logOffset := from.Log(DeviceListLogName) if logOffset != nil { partition = logOffset.Partition offset = logOffset.Offset - } else { - partition = -1 - offset = sarama.OffsetOldest } var toOffset int64 + toOffset = sarama.OffsetNewest toLog := to.Log(DeviceListLogName) - if toLog != nil { + if toLog != nil && toLog.Offset > 0 { toOffset = toLog.Offset - } else { - toOffset = sarama.OffsetNewest } var queryRes api.QueryKeyChangesResponse keyAPI.QueryKeyChanges(ctx, &api.QueryKeyChangesRequest{ @@ -96,6 +95,10 @@ func DeviceListCatchup( util.GetLogger(ctx).WithError(queryRes.Error).Error("QueryKeyChanges failed") return hasNew, nil } + util.GetLogger(ctx).Debugf( + "QueryKeyChanges request p=%d,off=%d,to=%d response p=%d off=%d uids=%v", + partition, offset, toOffset, queryRes.Partition, queryRes.Offset, queryRes.UserIDs, + ) userSet := make(map[string]bool) for _, userID := range res.DeviceLists.Changed { userSet[userID] = true @@ -116,6 +119,13 @@ func DeviceListCatchup( userSet[userID] = true } } + // set the new token + to.SetLog(DeviceListLogName, &types.LogPosition{ + Partition: queryRes.Partition, + Offset: queryRes.Offset, + }) + res.NextBatch = to.String() + return hasNew, nil } diff --git a/syncapi/types/types.go b/syncapi/types/types.go index f465d9fff..f3324800f 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -112,6 +112,9 @@ type StreamingToken struct { } func (t *StreamingToken) SetLog(name string, lp *LogPosition) { + if t.logs == nil { + t.logs = make(map[string]*LogPosition) + } t.logs[name] = lp } @@ -173,12 +176,14 @@ func (t *StreamingToken) WithUpdates(other StreamingToken) (ret StreamingToken) } ret.Positions[i] = other.Positions[i] } + ret.logs = make(map[string]*LogPosition) for name := range t.logs { otherLog := other.Log(name) if otherLog == nil { continue } - t.logs[name] = otherLog + copy := *otherLog + ret.logs[name] = © } return ret } diff --git a/sytest-whitelist b/sytest-whitelist index 18978bbe6..cc49bf389 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -138,6 +138,7 @@ Users receive device_list updates for their own devices Get left notifs for other users in sync and /keys/changes when user leaves Local device key changes get to remote servers Local device key changes get to remote servers with correct prev_id +#Server correctly handles incoming m.device_list_update Can add account data Can add account data to room Can get account data without syncing