Allow more time for device list updates (#2749)

This updates the device list updater so that it has a context
per-request, rather than a global 30 seconds for the entire server. This
could mean that talking to a slow remote server or requesting a lot of
user IDs was pretty much guaranteed to fail.

It also uses the process context to allow correct cancellation when
Dendrite wants to shut down cleanly.
This commit is contained in:
Neil Alexander 2022-09-30 09:41:16 +01:00 committed by GitHub
parent 9005e5b4a8
commit 8a82f10046
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 117 additions and 96 deletions

View file

@ -17,6 +17,7 @@ package internal
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"hash/fnv" "hash/fnv"
"net" "net"
@ -31,6 +32,7 @@ import (
fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/process"
) )
var ( var (
@ -45,6 +47,9 @@ var (
) )
) )
const defaultWaitTime = time.Minute
const requestTimeout = time.Second * 30
func init() { func init() {
prometheus.MustRegister( prometheus.MustRegister(
deviceListUpdateCount, deviceListUpdateCount,
@ -80,6 +85,7 @@ func init() {
// In the event that the query fails, a lock is acquired and the server name along with the time to wait before retrying is // In the event that the query fails, a lock is acquired and the server name along with the time to wait before retrying is
// set in a map. A restarter goroutine periodically probes this map and injects servers which are ready to be retried. // set in a map. A restarter goroutine periodically probes this map and injects servers which are ready to be retried.
type DeviceListUpdater struct { type DeviceListUpdater struct {
process *process.ProcessContext
// A map from user_id to a mutex. Used when we are missing prev IDs so we don't make more than 1 // A map from user_id to a mutex. Used when we are missing prev IDs so we don't make more than 1
// request to the remote server and race. // request to the remote server and race.
// TODO: Put in an LRU cache to bound growth // TODO: Put in an LRU cache to bound growth
@ -131,10 +137,12 @@ type KeyChangeProducer interface {
// NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale. // NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale.
func NewDeviceListUpdater( func NewDeviceListUpdater(
db DeviceListUpdaterDatabase, api DeviceListUpdaterAPI, producer KeyChangeProducer, process *process.ProcessContext, db DeviceListUpdaterDatabase,
api DeviceListUpdaterAPI, producer KeyChangeProducer,
fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int, fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int,
) *DeviceListUpdater { ) *DeviceListUpdater {
return &DeviceListUpdater{ return &DeviceListUpdater{
process: process,
userIDToMutex: make(map[string]*sync.Mutex), userIDToMutex: make(map[string]*sync.Mutex),
mu: &sync.Mutex{}, mu: &sync.Mutex{},
db: db, db: db,
@ -234,7 +242,7 @@ func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib.
"prev_ids": event.PrevID, "prev_ids": event.PrevID,
"display_name": event.DeviceDisplayName, "display_name": event.DeviceDisplayName,
"deleted": event.Deleted, "deleted": event.Deleted,
}).Info("DeviceListUpdater.Update") }).Trace("DeviceListUpdater.Update")
// if we haven't missed anything update the database and notify users // if we haven't missed anything update the database and notify users
if exists || event.Deleted { if exists || event.Deleted {
@ -378,74 +386,99 @@ func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) {
} }
func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerName) (time.Duration, bool) { func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerName) (time.Duration, bool) {
deviceListUpdateCount.WithLabelValues(string(serverName)).Inc() ctx := u.process.Context()
requestTimeout := time.Second * 30 // max amount of time we want to spend on each request
ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
defer cancel()
logger := util.GetLogger(ctx).WithField("server_name", serverName) logger := util.GetLogger(ctx).WithField("server_name", serverName)
waitTime := 2 * time.Second deviceListUpdateCount.WithLabelValues(string(serverName)).Inc()
// fetch stale device lists
waitTime := defaultWaitTime // How long should we wait to try again?
successCount := 0 // How many user requests failed?
userIDs, err := u.db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{serverName}) userIDs, err := u.db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{serverName})
if err != nil { if err != nil {
logger.WithError(err).Error("Failed to load stale device lists") logger.WithError(err).Error("Failed to load stale device lists")
return waitTime, true return waitTime, true
} }
failCount := 0
userLoop: defer func() {
for _, userID := range userIDs { for _, userID := range userIDs {
if ctx.Err() != nil { // always clear the channel to unblock Update calls regardless of success/failure
// we've timed out, give up and go to the back of the queue to let another server be processed. u.clearChannel(userID)
failCount += 1 }
waitTime = time.Minute * 10 }()
for _, userID := range userIDs {
userWait, err := u.processServerUser(ctx, serverName, userID)
if err != nil {
if userWait > waitTime {
waitTime = userWait
}
break break
} }
successCount++
}
allUsersSucceeded := successCount == len(userIDs)
if !allUsersSucceeded {
logger.WithFields(logrus.Fields{
"total": len(userIDs),
"succeeded": successCount,
"failed": len(userIDs) - successCount,
"wait_time": waitTime,
}).Warn("Failed to query device keys for some users")
}
return waitTime, !allUsersSucceeded
}
func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName gomatrixserverlib.ServerName, userID string) (time.Duration, error) {
ctx, cancel := context.WithTimeout(ctx, requestTimeout)
defer cancel()
logger := util.GetLogger(ctx).WithFields(logrus.Fields{
"server_name": serverName,
"user_id": userID,
})
res, err := u.fedClient.GetUserDevices(ctx, serverName, userID) res, err := u.fedClient.GetUserDevices(ctx, serverName, userID)
if err != nil { if err != nil {
failCount += 1 if errors.Is(err, context.DeadlineExceeded) {
select { return time.Minute * 10, err
case <-ctx.Done():
// we've timed out, give up and go to the back of the queue to let another server be processed.
waitTime = time.Minute * 10
break userLoop
default:
} }
switch e := err.(type) { switch e := err.(type) {
case *json.UnmarshalTypeError, *json.SyntaxError:
logger.WithError(err).Debugf("Device list update for %q contained invalid JSON", userID)
return defaultWaitTime, nil
case *fedsenderapi.FederationClientError: case *fedsenderapi.FederationClientError:
if e.RetryAfter > 0 { if e.RetryAfter > 0 {
waitTime = e.RetryAfter return e.RetryAfter, err
} else if e.Blacklisted { } else if e.Blacklisted {
waitTime = time.Hour * 8 return time.Hour * 8, err
break userLoop
} else if e.Code >= 300 { } else if e.Code >= 300 {
// We didn't get a real FederationClientError (e.g. in polylith mode, where gomatrix.HTTPError // We didn't get a real FederationClientError (e.g. in polylith mode, where gomatrix.HTTPError
// are "converted" to FederationClientError), but we probably shouldn't hit them every $waitTime seconds. // are "converted" to FederationClientError), but we probably shouldn't hit them every $waitTime seconds.
waitTime = time.Hour return time.Hour, err
break userLoop
} }
case net.Error: case net.Error:
// Use the default waitTime, if it's a timeout. // Use the default waitTime, if it's a timeout.
// It probably doesn't make sense to try further users. // It probably doesn't make sense to try further users.
if !e.Timeout() { if !e.Timeout() {
waitTime = time.Minute * 10 logger.WithError(e).Debug("GetUserDevices returned net.Error")
logger.WithError(e).Error("GetUserDevices returned net.Error") return time.Minute * 10, err
break userLoop
} }
case gomatrix.HTTPError: case gomatrix.HTTPError:
// The remote server returned an error, give it some time to recover. // The remote server returned an error, give it some time to recover.
// This is to avoid spamming remote servers, which may not be Matrix servers anymore. // This is to avoid spamming remote servers, which may not be Matrix servers anymore.
if e.Code >= 300 { if e.Code >= 300 {
waitTime = time.Hour logger.WithError(e).Debug("GetUserDevices returned gomatrix.HTTPError")
logger.WithError(e).Error("GetUserDevices returned gomatrix.HTTPError") return time.Hour, err
break userLoop
} }
default: default:
// Something else failed // Something else failed
waitTime = time.Minute * 10 logger.WithError(err).Debugf("GetUserDevices returned unknown error type: %T", err)
logger.WithError(err).WithField("user_id", userID).Debugf("GetUserDevices returned unknown error type: %T", err) return time.Minute * 10, err
break userLoop
} }
continue }
if res.UserID != userID {
logger.WithError(err).Debugf("User ID %q in device list update response doesn't match expected %q", res.UserID, userID)
return defaultWaitTime, nil
} }
if res.MasterKey != nil || res.SelfSigningKey != nil { if res.MasterKey != nil || res.SelfSigningKey != nil {
uploadReq := &api.PerformUploadDeviceKeysRequest{ uploadReq := &api.PerformUploadDeviceKeysRequest{
@ -466,23 +499,10 @@ userLoop:
} }
err = u.updateDeviceList(&res) err = u.updateDeviceList(&res)
if err != nil { if err != nil {
logger.WithError(err).WithField("user_id", userID).Error("Fetched device list but failed to store/emit it") logger.WithError(err).Error("Fetched device list but failed to store/emit it")
failCount += 1 return defaultWaitTime, err
} }
} return defaultWaitTime, nil
if failCount > 0 {
logger.WithFields(logrus.Fields{
"total": len(userIDs),
"failed": failCount,
"skipped": len(userIDs) - failCount,
"waittime": waitTime,
}).Warn("Failed to query device keys for some users")
}
for _, userID := range userIDs {
// always clear the channel to unblock Update calls regardless of success/failure
u.clearChannel(userID)
}
return waitTime, failCount > 0
} }
func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevices) error { func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevices) error {

View file

@ -30,6 +30,7 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/process"
) )
var ( var (
@ -146,7 +147,7 @@ func TestUpdateHavePrevID(t *testing.T) {
} }
ap := &mockDeviceListUpdaterAPI{} ap := &mockDeviceListUpdaterAPI{}
producer := &mockKeyChangeProducer{} producer := &mockKeyChangeProducer{}
updater := NewDeviceListUpdater(db, ap, producer, nil, 1) updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1)
event := gomatrixserverlib.DeviceListUpdateEvent{ event := gomatrixserverlib.DeviceListUpdateEvent{
DeviceDisplayName: "Foo Bar", DeviceDisplayName: "Foo Bar",
Deleted: false, Deleted: false,
@ -218,7 +219,7 @@ func TestUpdateNoPrevID(t *testing.T) {
`)), `)),
}, nil }, nil
}) })
updater := NewDeviceListUpdater(db, ap, producer, fedClient, 2) updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2)
if err := updater.Start(); err != nil { if err := updater.Start(); err != nil {
t.Fatalf("failed to start updater: %s", err) t.Fatalf("failed to start updater: %s", err)
} }
@ -287,7 +288,7 @@ func TestDebounce(t *testing.T) {
close(incomingFedReq) close(incomingFedReq)
return <-fedCh, nil return <-fedCh, nil
}) })
updater := NewDeviceListUpdater(db, ap, producer, fedClient, 1) updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1)
if err := updater.Start(); err != nil { if err := updater.Start(); err != nil {
t.Fatalf("failed to start updater: %s", err) t.Fatalf("failed to start updater: %s", err)
} }

View file

@ -58,7 +58,7 @@ func NewInternalAPI(
FedClient: fedClient, FedClient: fedClient,
Producer: keyChangeProducer, Producer: keyChangeProducer,
} }
updater := internal.NewDeviceListUpdater(db, ap, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable
ap.Updater = updater ap.Updater = updater
go func() { go func() {
if err := updater.Start(); err != nil { if err := updater.Start(); err != nil {