diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index 4baf2f42f..c7bf8da53 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -192,26 +192,18 @@ func (u *DeviceListUpdater) Start() error { return nil } -// cleanUp removes stale device entries for users we don't share a room with anymore +// CleanUp removes stale device entries for users we don't share a room with anymore func (u *DeviceListUpdater) CleanUp() error { staleUsers, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{}) if err != nil { return err } - maxRetries := 3 - - // In polylith mode, the roomserver api might not be up yet, so we try again res := rsapi.QueryLeftUsersResponse{} - for i := 0; i <= maxRetries; i++ { - if err = u.rsAPI.QueryLeftUsers(u.process.Context(), &rsapi.QueryLeftUsersRequest{StaleDeviceListUsers: staleUsers}, &res); err != nil { - if i == maxRetries { - return err - } - logrus.WithError(err).Warnf("unable to query left users (try %d/%d)", i+1, maxRetries) - time.Sleep(time.Second * 3) - } + if err = u.rsAPI.QueryLeftUsers(u.process.Context(), &rsapi.QueryLeftUsersRequest{StaleDeviceListUsers: staleUsers}, &res); err != nil { + return err } + if len(res.LeftUsers) == 0 { return nil } diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go index d82552450..82f2371a3 100644 --- a/keyserver/internal/device_list_update_test.go +++ b/keyserver/internal/device_list_update_test.go @@ -30,7 +30,12 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/dendrite/keyserver/storage" + roomserver "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" ) var ( @@ -243,6 +248,7 @@ func TestUpdateNoPrevID(t *testing.T) { UserID: remoteUserID, } err := updater.Update(ctx, event) + updater.CleanUp() if err != nil { t.Fatalf("Update returned an error: %s", err) } @@ -353,3 +359,73 @@ func TestDebounce(t *testing.T) { t.Errorf("user %s is marked as stale", userID) } } + +func mustCreateKeyserverDB(t *testing.T, dbType test.DBType) (storage.Database, func()) { + t.Helper() + + base, _, _ := testrig.Base(nil) + connStr, clearDB := test.PrepareDBConnectionString(t, dbType) + db, err := storage.NewDatabase(base, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)}) + if err != nil { + t.Fatal(err) + } + + return db, clearDB +} + +type mockKeyserverRoomserverAPI struct { + leftUsers []string +} + +func (m *mockKeyserverRoomserverAPI) QueryLeftUsers(ctx context.Context, req *roomserver.QueryLeftUsersRequest, res *roomserver.QueryLeftUsersResponse) error { + res.LeftUsers = m.leftUsers + return nil +} + +func TestDeviceListUpdater_CleanUp(t *testing.T) { + processCtx := process.NewProcessContext() + + alice := test.NewUser(t) + bob := test.NewUser(t) + + // Bob is not joined to any of our rooms + rsAPI := &mockKeyserverRoomserverAPI{leftUsers: []string{bob.ID}} + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clearDB := mustCreateKeyserverDB(t, dbType) + defer clearDB() + + // This should not get deleted + if err := db.MarkDeviceListStale(processCtx.Context(), alice.ID, true); err != nil { + t.Error(err) + } + + // this one should get deleted + if err := db.MarkDeviceListStale(processCtx.Context(), bob.ID, true); err != nil { + t.Error(err) + } + + updater := NewDeviceListUpdater(processCtx, db, nil, + nil, nil, + 0, rsAPI, "test") + if err := updater.CleanUp(); err != nil { + t.Error(err) + } + + // check that we still have Alice in our stale list + staleUsers, err := db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"}) + if err != nil { + t.Error(err) + } + + // There should only be Alice + wantCount := 1 + if count := len(staleUsers); count != wantCount { + t.Fatalf("expected there to be %d stale device lists, got %d", wantCount, count) + } + + if staleUsers[0] != alice.ID { + t.Fatalf("unexpected stale device list user: %s, want %s", staleUsers[0], alice.ID) + } + }) +}