Add sync mechanism to block when updating device lists

With a timeout, mainly for sytest to fix the test
"Server correctly handles incoming m.device_list_update"
which is flakey because it assumes that when `/send` 200 OKs
that the server has updated the device lists in prep for
`/keys/query` which is not always true when using workers.
This commit is contained in:
Kegan Dougal 2020-08-12 11:51:23 +01:00
parent b8b854d642
commit e87740fd7b

View file

@ -67,6 +67,11 @@ type DeviceListUpdater struct {
producer KeyChangeProducer producer KeyChangeProducer
fedClient *gomatrixserverlib.FederationClient fedClient *gomatrixserverlib.FederationClient
workerChans []chan gomatrixserverlib.ServerName workerChans []chan gomatrixserverlib.ServerName
// When device lists are stale for a user, they get inserted into this map with a channel which `Update` will
// block on or timeout via a select.
userIDToChan map[string]chan bool
userIDToChanMu *sync.Mutex
} }
// DeviceListUpdaterDatabase is the subset of functionality from storage.Database required for the updater. // DeviceListUpdaterDatabase is the subset of functionality from storage.Database required for the updater.
@ -98,12 +103,14 @@ func NewDeviceListUpdater(
numWorkers int, numWorkers int,
) *DeviceListUpdater { ) *DeviceListUpdater {
return &DeviceListUpdater{ return &DeviceListUpdater{
userIDToMutex: make(map[string]*sync.Mutex), userIDToMutex: make(map[string]*sync.Mutex),
mu: &sync.Mutex{}, mu: &sync.Mutex{},
db: db, db: db,
producer: producer, producer: producer,
fedClient: fedClient, fedClient: fedClient,
workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers), workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers),
userIDToChan: make(map[string]chan bool),
userIDToChanMu: &sync.Mutex{},
} }
} }
@ -137,6 +144,8 @@ func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex {
return u.userIDToMutex[userID] return u.userIDToMutex[userID]
} }
// Update blocks until the update has been stored in the database. It blocks primarily for satisfying sytest,
// which assumes when /send 200 OKs that the device lists have been updated.
func (u *DeviceListUpdater) Update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) error { func (u *DeviceListUpdater) Update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) error {
isDeviceListStale, err := u.update(ctx, event) isDeviceListStale, err := u.update(ctx, event)
if err != nil { if err != nil {
@ -213,7 +222,35 @@ func (u *DeviceListUpdater) notifyWorkers(userID string) {
hash := fnv.New32a() hash := fnv.New32a()
_, _ = hash.Write([]byte(remoteServer)) _, _ = hash.Write([]byte(remoteServer))
index := int(hash.Sum32()) % len(u.workerChans) index := int(hash.Sum32()) % len(u.workerChans)
ch := u.assignChannel(userID)
u.workerChans[index] <- remoteServer u.workerChans[index] <- remoteServer
select {
case <-ch:
case <-time.After(10 * time.Second):
// we don't return an error in this case as it's not a failure condition.
// we mainly block for the benefit of sytest anyway
}
}
func (u *DeviceListUpdater) assignChannel(userID string) chan bool {
u.userIDToChanMu.Lock()
defer u.userIDToChanMu.Unlock()
if ch, ok := u.userIDToChan[userID]; ok {
return ch
}
ch := make(chan bool)
u.userIDToChan[userID] = ch
return ch
}
func (u *DeviceListUpdater) clearChannel(userID string) {
u.userIDToChanMu.Lock()
defer u.userIDToChanMu.Unlock()
if ch, ok := u.userIDToChan[userID]; ok {
close(ch)
delete(u.userIDToChan, userID)
}
} }
func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) { func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) {
@ -285,6 +322,8 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam
if err != nil { if err != nil {
logger.WithError(err).WithField("user_id", userID).Error("fetched device list but failed to store/emit it") logger.WithError(err).WithField("user_id", userID).Error("fetched device list but failed to store/emit it")
hasFailures = true hasFailures = true
} else {
u.clearChannel(userID)
} }
} }
return hasFailures return hasFailures