diff --git a/userapi/internal/device_list_update_test.go b/userapi/internal/device_list_update_test.go index 151655c60..b9f1f1885 100644 --- a/userapi/internal/device_list_update_test.go +++ b/userapi/internal/device_list_update_test.go @@ -27,6 +27,7 @@ import ( "testing" "time" + api2 "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -480,3 +481,65 @@ func Test_dedupeStateList(t *testing.T) { }) } } + +func TestDeviceListUpdaterIgnoreBlacklisted(t *testing.T) { + unreachableServer := spec.ServerName("notlocalhost") + + updater := DeviceListUpdater{ + workerChans: make([]chan spec.ServerName, 1), + isBlacklistedOrBackingOffFn: func(s spec.ServerName) (*statistics.ServerStatistics, error) { + switch s { + case unreachableServer: + return nil, &api2.FederationClientError{Blacklisted: true} + } + return nil, nil + }, + mu: &sync.Mutex{}, + userIDToChanMu: &sync.Mutex{}, + userIDToChan: make(map[string]chan bool), + userIDToMutex: make(map[string]*sync.Mutex), + } + workerCh := make(chan spec.ServerName) + updater.workerChans[0] = workerCh + + // happy case + alice := "@alice:localhost" + aliceCh := updater.assignChannel(alice) + + // failing case + bob := "@bob:" + unreachableServer + bobCh := updater.assignChannel(string(bob)) + + expectedServers := map[spec.ServerName]struct{}{ + "localhost": {}, + } + unexpectedServers := make(map[spec.ServerName]struct{}) + + go func() { + for serverName := range workerCh { + switch serverName { + case "localhost": + delete(expectedServers, serverName) + aliceCh <- true // unblock notifyWorkers + case "notlocalhost": // this should not happen as it is "filtered" away by the blacklist + unexpectedServers[serverName] = struct{}{} + bobCh <- true + default: + unexpectedServers[serverName] = struct{}{} + } + } + }() + + // alice is not blacklisted + updater.notifyWorkers(alice) + // bob is blacklisted + updater.notifyWorkers(string(bob)) + + for server := range expectedServers { + t.Errorf("Server still in expectedServers map: %s", server) + } + + for server := range unexpectedServers { + t.Errorf("unexpected server in result: %s", server) + } +}