diff --git a/test/xpu/distributed/__init__.py b/test/xpu/distributed/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/xpu/distributed/test_c10d_nccl_xpu.py b/test/xpu/distributed/test_c10d_nccl_xpu.py new file mode 100644 index 000000000..164000c97 --- /dev/null +++ b/test/xpu/distributed/test_c10d_nccl_xpu.py @@ -0,0 +1,337 @@ +# Owner(s): ["module: intel"] + +import os +import datetime +import pickle +import math +from unittest import mock + +import torch +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_distributed import ( + requires_nccl, + MultiProcessTestCase, +) +import torch.distributed as c10d +import torch.testing._internal.common_utils as common +from torch.testing._internal.common_utils import ( + retry_on_connect_failures, + run_tests, + skip_but_pass_in_sandcastle_if, + parametrize, + instantiate_parametrized_tests, +) +from torch.testing._internal.common_cuda import TEST_MULTIGPU +import torch.distributed as dist + +try: + from .xpu_test_utils import XPUPatchForImport, requires_xccl +except Exception as e: + from ..xpu_test_utils import XPUPatchForImport, requires_xccl +with XPUPatchForImport(False): + TEST_CUDA = torch.testing._internal.common_utils.TEST_CUDA + from test_c10d_nccl import ( + RendezvousEnvTest, + TimeoutTest, + ProcessGroupNCCLNoGPUTest, + ProcessGroupNCCLInitTest + ) + + if torch.xpu.is_available: + ccl_backend = "xccl" + else: + ccl_backend = "nccl" + + @retry_on_connect_failures + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_CUDA, "No GPUs available, skipping test") + def _test_common_errors(self): + vars = { + "WORLD_SIZE": "1", + "RANK": "0", + "MASTER_ADDR": "127.0.0.1", + "MASTER_PORT": str(common.find_free_port()), + } + + class Env: + def __init__(self, vars): + self.env_patcher = mock.patch.dict(os.environ, vars, clear=True) + + def __enter__(self): + self.env_patcher.start() + + def __exit__(self, type, value, traceback): + self.env_patcher.stop() + + def without(d, key): + d = d.copy() + d.pop(key) + return d + + def withouts(d, keys): + d = d.copy() + for key in keys: + d.pop(key) + return d + + with Env(without(vars, "WORLD_SIZE")): + self.assertEqual(None, os.environ.get("WORLD_SIZE")) + with self.assertRaisesRegex(ValueError, "WORLD_SIZE expected"): + gen = c10d.rendezvous("env://") + next(gen) + + c10d.init_process_group(backend=ccl_backend, world_size=1) + self.assertEqual(c10d.get_rank(), 0) + self.assertEqual(c10d.get_world_size(), 1) + c10d.destroy_process_group() + + with Env(without(vars, "RANK")): + self.assertEqual(None, os.environ.get("RANK")) + with self.assertRaisesRegex(ValueError, "RANK expected"): + gen = c10d.rendezvous("env://") + next(gen) + c10d.init_process_group(backend=ccl_backend, rank=0) + self.assertEqual(c10d.get_rank(), 0) + self.assertEqual(c10d.get_world_size(), 1) + c10d.destroy_process_group() + + with Env(withouts(vars, ["RANK", "WORLD_SIZE"])): + self.assertEqual(None, os.environ.get("RANK")) + self.assertEqual(None, os.environ.get("WORLD_SIZE")) + c10d.init_process_group(backend=ccl_backend, rank=0, world_size=1) + self.assertEqual(c10d.get_rank(), 0) + self.assertEqual(c10d.get_world_size(), 1) + c10d.destroy_process_group() + + with Env(vars): + c10d.init_process_group(backend=ccl_backend) + self.assertEqual(c10d.get_rank(), 0) + self.assertEqual(c10d.get_world_size(), 1) + c10d.destroy_process_group() + + with Env(without(vars, "MASTER_ADDR")): + self.assertEqual(None, os.environ.get("MASTER_ADDR")) + with self.assertRaisesRegex(ValueError, "MASTER_ADDR expected"): + gen = c10d.rendezvous("env://") + next(gen) + + with Env(without(vars, "MASTER_PORT")): + self.assertEqual(None, os.environ.get("MASTER_PORT")) + with self.assertRaisesRegex(ValueError, "MASTER_PORT expected"): + gen = c10d.rendezvous("env://") + next(gen) + + with Env(without(vars, "WORLD_SIZE")): + self.assertEqual(None, os.environ.get("WORLD_SIZE")) + gen = c10d.rendezvous(f"env://?world_size={1}") + _, _, size = next(gen) + self.assertEqual(size, 1) + + with Env(without(vars, "RANK")): + self.assertEqual(None, os.environ.get("RANK")) + gen = c10d.rendezvous(f"env://?rank={0}") + _, rank, _ = next(gen) + self.assertEqual(rank, 0) + + with Env(withouts(vars, ["RANK", "WORLD_SIZE"])): + self.assertEqual(None, os.environ.get("RANK")) + self.assertEqual(None, os.environ.get("WORLD_SIZE")) + gen = c10d.rendezvous(f"env://?rank={0}&world_size={1}") + _, rank, size = next(gen) + self.assertEqual(rank, 0) + self.assertEqual(size, 1) + RendezvousEnvTest.test_common_errors = _test_common_errors + + def __test_default_store_timeout_nccl(self): + self._test_default_store_timeout(ccl_backend) + TimeoutTest.test_default_store_timeout_nccl = __test_default_store_timeout_nccl + + @requires_nccl() + @skip_but_pass_in_sandcastle_if(TEST_CUDA, "GPUs are available, skipping test") + def _test_init_no_gpus(self): + store = c10d.FileStore(self.file.name, self.world_size) + with self.assertRaisesRegex( + ValueError, "ProcessGroupXCCL is only supported with GPUs, no GPUs found!" + ): + c10d.ProcessGroupXCCL(store, self.rank, self.world_size) + + ProcessGroupNCCLNoGPUTest.test_init_no_gpus = _test_init_no_gpus + + def _ProcessGroupNCCLInitTest_setUp(self): + super(ProcessGroupNCCLInitTest, self).setUp() + self._spawn_processes() + + ProcessGroupNCCLInitTest.device_type = 'xpu' + ProcessGroupNCCLInitTest.setUp = _ProcessGroupNCCLInitTest_setUp + + +def simple_reduce_tests(rank, world_size): + tests = [ + ( + c10d.ReduceOp.SUM, + torch.tensor([rank + 1.0]), + torch.tensor([float(world_size * (world_size + 1) / 2)]), + ), + ( + c10d.ReduceOp.PRODUCT, + torch.tensor([rank + 1.0]), + torch.tensor([float(math.factorial(world_size))]), + ), + ( + c10d.ReduceOp.MIN, + torch.tensor([rank + 1.0]), + torch.tensor([1.0]), + ), + ( + c10d.ReduceOp.MAX, + torch.tensor([rank + 1.0]), + torch.tensor([world_size]), + ), + ] + + return tests + +from datetime import timedelta +import time +class ProcessGroupXCCLTest(MultiProcessTestCase): + def _create_process_group_xccl( + self, timeout=timedelta(seconds=600), device_id=None + ): + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + "xccl", + world_size=self.world_size, + rank=self.rank, + store=store, + timeout=timeout, + device_id=device_id, + ) + pg = c10d.distributed_c10d._get_default_group() + return pg + + def setUp(self): + super().setUp() + self._spawn_processes() + + def tearDown(self): + super().tearDown() + try: + os.remove(self.file_name) + except OSError: + pass + + @property + def world_size(self): + return 2 + + @property + def rank_to_GPU(self): + # return rank to GPU map + nGPUs = torch.xpu.device_count() + visible_devices = range(nGPUs) + nGPUs_per_process = 1 + if self.world_size > nGPUs: + nGPUs_per_process = nGPUs // self.world_size + GPUs = { + i: list(visible_devices[i * nGPUs_per_process : (i + 1) * nGPUs_per_process]) + for i in range(self.world_size) + } + return GPUs + + @requires_xccl() + @skip_but_pass_in_sandcastle_if( + torch.xpu.device_count() < 2, "XCCL test requires 2+ GPUs" + ) + def test_close_multi_pg_unordered(self): + pg = self._create_process_group_xccl() + device = self.rank_to_GPU[self.rank][0] + t = torch.rand(10, 10, device=device) + # First allreduce to initialize default PG's communicator. + pg.allreduce(t).wait() + new_pg1 = c10d.new_group([0, 1]) + new_pg2 = c10d.new_group([0, 1]) + if self.rank == 0 or self.rank == 1: + t1 = torch.rand(10, 10, device=device) + t2 = torch.rand(10, 10, device=device) + new_pg1.allreduce(t1).wait() + new_pg2.allreduce(t2).wait() + if self.rank == 0: + dist.destroy_process_group(new_pg2) + # force destruction of pg2 first + del new_pg2 + dist.destroy_process_group(new_pg1) + del new_pg1 + if self.rank == 1: + c10d.destroy_process_group(new_pg1) + # force destruction of pg1 first + del new_pg1 + dist.destroy_process_group(new_pg2) + del new_pg2 + dist.destroy_process_group() + + @requires_xccl() + @skip_but_pass_in_sandcastle_if( + torch.xpu.device_count() < 2, "XCCL test requires 2+ GPUs" + ) + def test_file_store_check(self): + # self.file_name is created using "delete=False" + # e.g., self.file_name = tempfile.NamedTemporaryFile(delete=False).name + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + backend="xccl", rank=self.rank, world_size=self.world_size, store=store + ) + pg = dist.distributed_c10d._get_default_group() + self.assertEqual(pg.rank(), self.rank) + self.assertEqual(pg.size(), self.world_size) + # give enough time for check() to be executed multiple times + time.sleep(2) + dist.destroy_process_group() + + # todo: https://github.com/pytorch/pytorch/blob/c06b5048ba866e2dd39e5da5399fe8261322c7ca/torch/distributed/distributed_c10d.py#L1862 device agnostic + # @requires_xccl() + # @skip_but_pass_in_sandcastle_if(not TEST_MULTIXPU, "XCCL test requires 2+ GPUs") + # def test_set_process_group_desc(self): + # device = torch.device(f"xpu:{self.rank}") + # pg_default = self._create_process_group_xccl(device_id=device) + # self.assertEqual(pg_default.group_desc, "default_pg") + # pg_1 = c10d.new_group([0, 1], group_desc="test_purpose") + # self.assertEqual(pg_1.group_desc, "test_purpose") + # pg_2 = c10d.new_group([0, 1]) + # self.assertEqual(pg_2.group_desc, "undefined") + + def _test_allreduce_basics(self, fn): + pg = self._create_process_group_xccl() + device = torch.device("xpu:" + str(self.rank)) + # Single input tests + tests = simple_reduce_tests(self.rank, self.world_size) + for op, input, expected in tests: + opts = c10d.AllreduceOptions() + opts.reduceOp = op + tensor = fn(input.to(device)) + fut = pg.allreduce([tensor], opts).get_future() + fut.wait() + result = fut.value() + self.assertEqual(expected, result[0], exact_dtype=False) + + x = fn(torch.tensor([self.rank + 1.0], device=device)) + fut = pg.allreduce(x).get_future() + fut.wait() + result = fut.value() + self.assertEqual( + torch.tensor([float(self.world_size * (self.world_size + 1) / 2)]), + result[0], + ) + + @requires_xccl() + def test_allreduce_basics(self): + self._test_allreduce_basics(lambda t: t.clone()) + + +#instantiate_parametrized_tests(ProcessGroupNCCLGroupTest) + +if __name__ == "__main__": + assert ( + not torch.xpu._initialized + ), "test_distributed must not have initialized XPU context on main process" + + run_tests() diff --git a/test/xpu/skip_list_common.py b/test/xpu/skip_list_common.py index 780f2efd7..aec4a549a 100644 --- a/test/xpu/skip_list_common.py +++ b/test/xpu/skip_list_common.py @@ -3344,4 +3344,8 @@ "test_sparse_mm_xpu_float64", # - NotImplementedError: Could not run 'aten::addmm' with arguments from the 'SparseXPU' backend. This could be because the operator doesn't exist for this backend, or wa... "test_sparse_sum_xpu_float64", # - NotImplementedError: Could not run 'aten::_sparse_sum_backward' with arguments from the 'SparseXPU' backend. This could be because the operator doesn't exist for this... ), + + "distributed/test_c10d_nccl_xpu.py": ( + "test_init_wo_backend_str", + ) } diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index a7e583331..af7f7eebe 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -17,7 +17,9 @@ from torch.testing._internal.common_device_type import tol, toleranceOverride from torch.testing._internal.common_modules import module_db from torch.testing._internal.common_nn import CriterionTest, ModuleTest -from torch.testing._internal.common_utils import set_default_dtype +from torch.testing._internal.common_utils import (set_default_dtype, + skip_but_pass_in_sandcastle, + skip_but_pass_in_sandcastle_if, ) from torch.testing._internal.opinfo.core import ( BinaryUfuncInfo, DecorateInfo, @@ -27,6 +29,7 @@ SpectralFuncInfo, UnaryUfuncInfo, ) +import torch.distributed as c10d _xpu_computation_op_list = [ "empty", @@ -751,6 +754,57 @@ def sample_inputs_like_fns_nofp64(self, device, dtype, requires_grad, **kwargs): requires_grad=requires_grad) yield SampleInput(t, **kwargs) +def _xccl_version(): + """ + Returns the version of the XCCL. + + + This function returns a tuple containing the major, minor, and patch version numbers of the NCCL. + The suffix is also included in the tuple if a version suffix exists. + Returns: + tuple: The version information of the NCCL. + """ + major = 1 + minor = 0 + patch = 0 + suffix = "" + if suffix == "": + return (major, minor, patch) + else: + return (major, minor, patch, suffix) + +def requires_xccl(): + return skip_but_pass_in_sandcastle_if( + not c10d.is_xccl_available(), + "c10d was not compiled with the XCCL backend", + ) + +def requires_xccl_version(version, msg): + if not c10d.is_xccl_available(): + return skip_but_pass_in_sandcastle( + "c10d was not compiled with the XCCL backend", + ) + else: + return skip_but_pass_in_sandcastle_if( + True, + f"Test workaround to return True for xccl version check", + ) + +def _skip_if_lt_x_gpu(x): + from functools import wraps + from torch.testing._internal.common_distributed import TEST_SKIPS + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + if torch.xpu.is_available() and torch.xpu.device_count() >= x: + return func(*args, **kwargs) + sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code) + + return wrapper + + return decorator + + class XPUPatchForImport: def __init__(self, patch_test_case=True) -> None: test_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../../test") @@ -758,6 +812,7 @@ def __init__(self, patch_test_case=True) -> None: test_dir, os.path.join(test_dir, "nn"), os.path.join(test_dir, "distributions"), + os.path.join(test_dir, "distributed"), os.path.join(test_dir, "quantization/core"), ) self.patch_test_case = patch_test_case @@ -786,6 +841,10 @@ def __init__(self, patch_test_case=True) -> None: self.TEST_CUDNN = common_cuda.TEST_CUDNN self.cuda_is_available = cuda.is_available self.cuda_is_bf16_supported = cuda.is_bf16_supported + self.c10d_is_nccl_available = c10d.is_nccl_available + self.requires_nccl_version = torch.testing._internal.common_distributed.requires_nccl_version + self.TEST_MULTIGPU = torch.testing._internal.common_cuda.TEST_MULTIGPU + self.skip_if_lt_x_gpu = torch.testing._internal.common_distributed.skip_if_lt_x_gpu def align_db_decorators(self, db): @@ -868,6 +927,10 @@ def filter_fp64_sample_input(self, db): if opinfo.reference_inputs_func != None and opinfo.reference_inputs_func.__name__ == common_methods_invocations.reference_inputs_cat.__name__: opinfo.reference_inputs_func = reference_inputs_cat_nofp64 + + + + def __enter__(self): # Monkey patch until we have a fancy way @@ -970,11 +1033,21 @@ def __init__(self, *args): common_cuda.TEST_CUDA = True common_cuda.TEST_CUDNN = True common_cuda.TEST_CUDNN_VERSION = 0 + torch.testing._internal.common_utils.TEST_CUDA = common_cuda.TEST_CUDA cuda.is_available = lambda: True cuda.is_bf16_supported = lambda: True + c10d.is_nccl_available = c10d.is_xccl_available + torch.testing._internal.common_distributed.requires_nccl_version = requires_xccl_version + + torch.testing._internal.common_cuda.TEST_MULTIGPU = common_utils.TEST_XPU and torch.xpu.device_count() >= 2 + sys.path.extend(self.test_package) + torch.testing._internal.common_distributed.skip_if_lt_x_gpu = _skip_if_lt_x_gpu + + torch.cuda.nccl.version = _xccl_version + return self def __exit__(self, exc_type, exc_value, traceback): @@ -997,6 +1070,11 @@ def __exit__(self, exc_type, exc_value, traceback): cuda.is_available = self.cuda_is_available cuda.is_bf16_supported = self.cuda_is_bf16_supported + c10d.is_nccl_available = self.c10d_is_nccl_available + torch.testing._internal.common_distributed.requires_nccl_version = self.requires_nccl_version + torch.testing._internal.common_cuda.TEST_MULTIGPU = self.TEST_MULTIGPU + torch.testing._internal.common_distributed.skip_if_lt_x_gpu = self.skip_if_lt_x_gpu + # Copy the test cases from generic_base_class to generic_test_class. # It serves to reuse test cases. Regarding some newly added hardware,