Allow more time for device list updates

This updates the device list updater so that it has a context per-request, rather
than a global 30 seconds for the entire server. This could mean that talking to a
slow remote server or requesting a lot of user IDs was pretty much guaranteed to
fail.

It also uses the process context to allow correct cancellation when Dendrite wants
to shut down cleanly.
This commit is contained in:
Neil Alexander 2022-09-29 15:16:02 +01:00
parent 68d6eb0a6f
commit 3f3e9de1f2
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
3 changed files with 109 additions and 96 deletions

View file

@ -31,6 +31,7 @@ import (
fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/process"
) )
var ( var (
@ -45,6 +46,9 @@ var (
) )
) )
const defaultWaitTime = time.Second * 2
const requestTimeout = time.Second * 30
func init() { func init() {
prometheus.MustRegister( prometheus.MustRegister(
deviceListUpdateCount, deviceListUpdateCount,
@ -80,6 +84,7 @@ func init() {
// In the event that the query fails, a lock is acquired and the server name along with the time to wait before retrying is // In the event that the query fails, a lock is acquired and the server name along with the time to wait before retrying is
// set in a map. A restarter goroutine periodically probes this map and injects servers which are ready to be retried. // set in a map. A restarter goroutine periodically probes this map and injects servers which are ready to be retried.
type DeviceListUpdater struct { type DeviceListUpdater struct {
process *process.ProcessContext
// A map from user_id to a mutex. Used when we are missing prev IDs so we don't make more than 1 // A map from user_id to a mutex. Used when we are missing prev IDs so we don't make more than 1
// request to the remote server and race. // request to the remote server and race.
// TODO: Put in an LRU cache to bound growth // TODO: Put in an LRU cache to bound growth
@ -131,10 +136,12 @@ type KeyChangeProducer interface {
// NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale. // NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale.
func NewDeviceListUpdater( func NewDeviceListUpdater(
db DeviceListUpdaterDatabase, api DeviceListUpdaterAPI, producer KeyChangeProducer, process *process.ProcessContext, db DeviceListUpdaterDatabase,
api DeviceListUpdaterAPI, producer KeyChangeProducer,
fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int, fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int,
) *DeviceListUpdater { ) *DeviceListUpdater {
return &DeviceListUpdater{ return &DeviceListUpdater{
process: process,
userIDToMutex: make(map[string]*sync.Mutex), userIDToMutex: make(map[string]*sync.Mutex),
mu: &sync.Mutex{}, mu: &sync.Mutex{},
db: db, db: db,
@ -378,74 +385,92 @@ func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) {
} }
func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerName) (time.Duration, bool) { func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerName) (time.Duration, bool) {
deviceListUpdateCount.WithLabelValues(string(serverName)).Inc() ctx := u.process.Context()
requestTimeout := time.Second * 30 // max amount of time we want to spend on each request
ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
defer cancel()
logger := util.GetLogger(ctx).WithField("server_name", serverName) logger := util.GetLogger(ctx).WithField("server_name", serverName)
waitTime := 2 * time.Second deviceListUpdateCount.WithLabelValues(string(serverName)).Inc()
// fetch stale device lists
waitTime := defaultWaitTime // How long should we wait to try again?
successCount := 0 // How many user requests failed?
userIDs, err := u.db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{serverName}) userIDs, err := u.db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{serverName})
if err != nil { if err != nil {
logger.WithError(err).Error("Failed to load stale device lists") logger.WithError(err).Error("Failed to load stale device lists")
return waitTime, true return waitTime, true
} }
failCount := 0
userLoop: defer func() {
for _, userID := range userIDs { for _, userID := range userIDs {
if ctx.Err() != nil { // always clear the channel to unblock Update calls regardless of success/failure
// we've timed out, give up and go to the back of the queue to let another server be processed. u.clearChannel(userID)
failCount += 1 }
waitTime = time.Minute * 10 }()
for _, userID := range userIDs {
userWait, err := u.processServerUser(ctx, serverName, userID)
if err != nil {
if userWait > waitTime {
waitTime = userWait
}
break break
} }
successCount++
}
allUsersSucceeded := successCount == len(userIDs)
if !allUsersSucceeded {
logger.WithFields(logrus.Fields{
"total": len(userIDs),
"succeeded": successCount,
"failed": len(userIDs) - successCount,
"wait_time": waitTime,
}).Debug("Failed to query device keys for some users")
}
return waitTime, !allUsersSucceeded
}
func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName gomatrixserverlib.ServerName, userID string) (time.Duration, error) {
ctx, cancel := context.WithTimeout(ctx, requestTimeout)
defer cancel()
logger := util.GetLogger(ctx).WithFields(logrus.Fields{
"server_name": serverName,
"user_id": userID,
})
res, err := u.fedClient.GetUserDevices(ctx, serverName, userID) res, err := u.fedClient.GetUserDevices(ctx, serverName, userID)
if err != nil { if err != nil {
failCount += 1 if err == context.DeadlineExceeded {
select { return time.Minute * 10, err
case <-ctx.Done():
// we've timed out, give up and go to the back of the queue to let another server be processed.
waitTime = time.Minute * 10
break userLoop
default:
} }
switch e := err.(type) { switch e := err.(type) {
case *fedsenderapi.FederationClientError: case *fedsenderapi.FederationClientError:
if e.RetryAfter > 0 { if e.RetryAfter > 0 {
waitTime = e.RetryAfter return e.RetryAfter, err
} else if e.Blacklisted { } else if e.Blacklisted {
waitTime = time.Hour * 8 return time.Hour * 8, err
break userLoop
} else if e.Code >= 300 { } else if e.Code >= 300 {
// We didn't get a real FederationClientError (e.g. in polylith mode, where gomatrix.HTTPError // We didn't get a real FederationClientError (e.g. in polylith mode, where gomatrix.HTTPError
// are "converted" to FederationClientError), but we probably shouldn't hit them every $waitTime seconds. // are "converted" to FederationClientError), but we probably shouldn't hit them every $waitTime seconds.
waitTime = time.Hour return time.Hour, err
break userLoop
} }
case net.Error: case net.Error:
// Use the default waitTime, if it's a timeout. // Use the default waitTime, if it's a timeout.
// It probably doesn't make sense to try further users. // It probably doesn't make sense to try further users.
if !e.Timeout() { if !e.Timeout() {
waitTime = time.Minute * 10
logger.WithError(e).Error("GetUserDevices returned net.Error") logger.WithError(e).Error("GetUserDevices returned net.Error")
break userLoop return time.Minute * 10, err
} }
case gomatrix.HTTPError: case gomatrix.HTTPError:
// The remote server returned an error, give it some time to recover. // The remote server returned an error, give it some time to recover.
// This is to avoid spamming remote servers, which may not be Matrix servers anymore. // This is to avoid spamming remote servers, which may not be Matrix servers anymore.
if e.Code >= 300 { if e.Code >= 300 {
waitTime = time.Hour
logger.WithError(e).Error("GetUserDevices returned gomatrix.HTTPError") logger.WithError(e).Error("GetUserDevices returned gomatrix.HTTPError")
break userLoop return time.Hour, err
} }
default: default:
// Something else failed // Something else failed
waitTime = time.Minute * 10
logger.WithError(err).WithField("user_id", userID).Debugf("GetUserDevices returned unknown error type: %T", err) logger.WithError(err).WithField("user_id", userID).Debugf("GetUserDevices returned unknown error type: %T", err)
break userLoop return time.Minute * 10, err
} }
continue
} }
if res.MasterKey != nil || res.SelfSigningKey != nil { if res.MasterKey != nil || res.SelfSigningKey != nil {
uploadReq := &api.PerformUploadDeviceKeysRequest{ uploadReq := &api.PerformUploadDeviceKeysRequest{
@ -467,22 +492,9 @@ userLoop:
err = u.updateDeviceList(&res) err = u.updateDeviceList(&res)
if err != nil { if err != nil {
logger.WithError(err).WithField("user_id", userID).Error("Fetched device list but failed to store/emit it") logger.WithError(err).WithField("user_id", userID).Error("Fetched device list but failed to store/emit it")
failCount += 1 return defaultWaitTime, err
} }
} return defaultWaitTime, nil
if failCount > 0 {
logger.WithFields(logrus.Fields{
"total": len(userIDs),
"failed": failCount,
"skipped": len(userIDs) - failCount,
"waittime": waitTime,
}).Warn("Failed to query device keys for some users")
}
for _, userID := range userIDs {
// always clear the channel to unblock Update calls regardless of success/failure
u.clearChannel(userID)
}
return waitTime, failCount > 0
} }
func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevices) error { func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevices) error {

View file

@ -30,6 +30,7 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/process"
) )
var ( var (
@ -146,7 +147,7 @@ func TestUpdateHavePrevID(t *testing.T) {
} }
ap := &mockDeviceListUpdaterAPI{} ap := &mockDeviceListUpdaterAPI{}
producer := &mockKeyChangeProducer{} producer := &mockKeyChangeProducer{}
updater := NewDeviceListUpdater(db, ap, producer, nil, 1) updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1)
event := gomatrixserverlib.DeviceListUpdateEvent{ event := gomatrixserverlib.DeviceListUpdateEvent{
DeviceDisplayName: "Foo Bar", DeviceDisplayName: "Foo Bar",
Deleted: false, Deleted: false,
@ -218,7 +219,7 @@ func TestUpdateNoPrevID(t *testing.T) {
`)), `)),
}, nil }, nil
}) })
updater := NewDeviceListUpdater(db, ap, producer, fedClient, 2) updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2)
if err := updater.Start(); err != nil { if err := updater.Start(); err != nil {
t.Fatalf("failed to start updater: %s", err) t.Fatalf("failed to start updater: %s", err)
} }
@ -287,7 +288,7 @@ func TestDebounce(t *testing.T) {
close(incomingFedReq) close(incomingFedReq)
return <-fedCh, nil return <-fedCh, nil
}) })
updater := NewDeviceListUpdater(db, ap, producer, fedClient, 1) updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1)
if err := updater.Start(); err != nil { if err := updater.Start(); err != nil {
t.Fatalf("failed to start updater: %s", err) t.Fatalf("failed to start updater: %s", err)
} }

View file

@ -58,7 +58,7 @@ func NewInternalAPI(
FedClient: fedClient, FedClient: fedClient,
Producer: keyChangeProducer, Producer: keyChangeProducer,
} }
updater := internal.NewDeviceListUpdater(db, ap, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable
ap.Updater = updater ap.Updater = updater
go func() { go func() {
if err := updater.Start(); err != nil { if err := updater.Start(); err != nil {