diff --git a/test/test_collector.py b/test/test_collector.py index fb97d070f95..d2f1c102416 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -519,7 +519,9 @@ def _set_seed(self, seed: Optional[int]): def test_no_synchronize(self, env_device, storing_device, no_cuda_sync): """Tests that no_cuda_sync avoids any call to torch.cuda.synchronize() and that the data is not corrupted.""" should_raise = not no_cuda_sync - should_raise = should_raise & ((env_device == "cpu") or (storing_device == "cpu")) + should_raise = should_raise & ( + (env_device == "cpu") or (storing_device == "cpu") + ) with patch("torch.cuda.synchronize") as mock_synchronize, pytest.raises( AssertionError, match="Expected 'synchronize' to not have been called." ) if should_raise else contextlib.nullcontext():